How to get the table name from Spark SQL Query [PySpark]?

早过忘川 提交于 2019-12-22 04:49:08

问题


To get the table name from a SQL Query,

select *
from table1 as t1
full outer join table2 as t2
  on t1.id = t2.id

I found a solution in Scala How to get table names from SQL query?

def getTables(query: String): Seq[String] = {
  val logicalPlan = spark.sessionState.sqlParser.parsePlan(query)
  import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
  logicalPlan.collect { case r: UnresolvedRelation => r.tableName }
}

which gives me the correct table names when I iterate over the return sequence getTables(query).foreach(println)

table1
table2

What would be the equivalent syntax for PySpark? The closest I came across was How to extract column name and column type from SQL in pyspark

plan = spark_session._jsparkSession.sessionState().sqlParser().parsePlan(query)
print(f"table: {plan.tableDesc().identifier().table()}")

which fails with the traceback

Py4JError: An error occurred while calling o78.tableDesc. Trace:
py4j.Py4JException: Method tableDesc([]) does not exist
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
    at py4j.Gateway.invoke(Gateway.java:274)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.base/java.lang.Thread.run(Thread.java:835)

I understand, the problem stems up from the fact that I need to filter all plan items which are of type UnresolvedRelation but I cannot find an equivalent notation in python/pyspark


回答1:


I have an approach, but rather convoluted. It dumps the Java Object and JSON (a poor man's serialization process), deserializes it to python object, filter and parse the table names

import json
def get_tables(query: str):
    plan = spark._jsparkSession.sessionState().sqlParser().parsePlan(query)
    plan_items = json.loads(plan.toJSON())
    for plan_item in plan_items:
        if plan_item['class'] == 'org.apache.spark.sql.catalyst.analysis.UnresolvedRelation':
            yield plan_item['tableIdentifier']['table']

which yields ['fast_track_gv_nexus', 'buybox_gv_nexus'] when I iterate over the function list(get_tables(query))

Note Unfortunately, this breaks for CTE

Example

with delta as (
   select *
    group by id
    cluster by id
 )
select   *
  from ( select  *
         FROM
          (select   *
            from dmm
            inner join delta on dmm.id = delta.id
           )
  )

And to resolve it, I have to hack around through regular expression

import json
import re
def get_tables(query: str):
    plan = spark._jsparkSession.sessionState().sqlParser().parsePlan(query)
    plan_items = json.loads(plan.toJSON())
    plan_string = plan.toString()
    cte = re.findall(r"CTE \[(.*?)\]", plan_string)
    for plan_item in plan_items:
        if plan_item['class'] == 'org.apache.spark.sql.catalyst.analysis.UnresolvedRelation':
            tableIdentifier = plan_item['tableIdentifier']
            table =  plan_item['tableIdentifier']['table']   
            database =  tableIdentifier.get('database', '')
            table_name = "{}.{}".format(database, table) if database else table
            if table_name not in cte:
                yield table_name


来源:https://stackoverflow.com/questions/58556062/how-to-get-the-table-name-from-spark-sql-query-pyspark

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!