spark sql查询hive表实现脱敏

情到浓时终转凉″ 提交于 2020-01-15 02:09:47

关于hive表查询脱敏,原理是select的时候在后台对sql进行处理,对每个要查询的字段都加一个自定义的mask脱敏函数。

一开始的实现思路是类似 select col1,col2,col3 from mask_table,后台处理后sql改造成select m.col1,m.col2,m.col3 from (select mask(col1),mask(col2),mask(col3) from mask_table) m , 优点是实现比较简单,根据spark sql执行的物理计划,获取到sql涉及到的根表,对根表的每个字段进行脱敏。

但以上实现存在问题,假设原sql为select mobile from mask_table where mobile=‘13111111111’,通过上述方法改造后,sql变成select mobile from (select mask(mobile) from mask_table) a where mobile=‘13111111111’,当前这种情况已经无法查询结果

最后实现的方式其实一开始就想到了,但是最开始的时候觉得要追溯字段的依赖关系有点不太好弄,所以用了上述方式,不过最后还是改造成了下面的方法:

实现原理一样,也是根据spark sql的物理计划对sql重新进行拼接,不同的是,原来是对根表的每个字段进行脱敏处理,现在对查询结果的每个字段进行脱敏处理。

优点:
1、能根据具体条件查询
2、查询性能也提高很多

大概实现是:
1、将原始sql改成,sql = “select * from (” + sql + “) m”,原因是用户执行的sql可能会比较复杂,所以在外面又嵌套了一个查询,这样不管原sql多复杂,经过一层嵌套后查询的返回结果字段获取就会比较简单

2、根据spark sql的物理计划获得字段的依赖关系

3、重新拼接sql

最复杂的就是第二步了,代码如下:

import com.alibaba.fastjson.JSON
import com.alibaba.fastjson.serializer.SerializeFilter
import com.google.common.collect.{Lists, Maps}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Project, SubqueryAlias}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.{FileSourceScanExec, ProjectExec}

import scala.collection.JavaConversions._

/**
  * Created by zheng  
  */
object MaskUtils {
    def getMaskSql(sparkSession: SparkSession, oldSql: String, userId: String): String = {
        val sql = "select * from (" + oldSql + ") m"
        val df = sparkSession.sql(sql)
        val query = df.queryExecution


        val colMap = Maps.newHashMap[String, java.util.ArrayList[String]]();
        val sparkPlan = query.sparkPlan

        val newSql = new StringBuilder(sql)
        var colStr = ""

        val children = sparkPlan.children;
        if (children == null || children.size == 0) {
            val columns = Lists.newArrayList[String]();
            var exce = query.sparkPlan.asInstanceOf[FileSourceScanExec];
            exce.requiredSchema.foreach(field => columns.add(field.name))
            exce.relation.partitionSchema.foreach(field => columns.add(field.name))
            val table = query.sparkPlan.asInstanceOf[FileSourceScanExec].tableIdentifier.get

            var tableName = table.table + ".";
            if (table.database.isDefined) {
                tableName = table.database.get + "." + tableName;
            }

            for (col <- columns) {
                val json = JSON.toJSONString(Array(tableName + "." + col), new Array[SerializeFilter](0))
                colStr += "mask(m.`" + col + "`,'" + json + "','" + userId + "') `" + col + "`,"
            }
        } else {
            val anylized = query.analyzed

            val tableMap = Maps.newHashMap[String, String]();
            getFullTabs(anylized.children, tableMap) //平铺获得所有字段对应的表
            getColRelation(sparkPlan.children, colMap) //获取字段血缘关系

            anylized.asInstanceOf[Project].projectList.map(item => {
                val tabs = getTabs(item.name + "#" + item.exprId.id, colMap, tableMap)
                if (tabs.size() > 0) {
                    val json = JSON.toJSONString(tabs, new Array[SerializeFilter](0))
                    colStr += "mask(m.`" + item.name + "`,'" + json + "','" + userId + "') `" + item.name + "`,"
                } else {
                    colStr += "`" + item.name + "`,"
                }
            })
        }
        newSql.replace(sql.indexOf("select ") + 7, sql.indexOf(" from "), colStr.substring(0, colStr.length - 1))
        print(newSql.toString())

        newSql.toString()
    }

    def getTabs(colName: String, colMap: java.util.HashMap[String, java.util.ArrayList[String]], tabMap: java.util.HashMap[String, String]): java.util.concurrent.CopyOnWriteArrayList[String] = {
        val colList = new java.util.concurrent.ConcurrentHashMap[String, String]
        getList(colName, colMap, colList)
        val tabList = new java.util.concurrent.CopyOnWriteArrayList[String]()

        for (col <- colList) {
            val tab = tabMap.get(col._1);
            if (tab != null) {
                tabList.add(tab + "." + col._1.split("#")(0))
            }
        }

        tabList
    }

    /**
      * allMap 这里应该用多线程的hashset,用了ConcurrentHashMap来去重
      * @param colName
      * @param colMap
      * @param allMap
      */
    def getList(colName: String, colMap: java.util.HashMap[String, java.util.ArrayList[String]], allMap: java.util.concurrent.ConcurrentHashMap[String, String]): Unit = {
        allMap.put(colName, colName)
        val colList = colMap.get(colName)
        if (colList != null) {
            for (col: String <- colList) {
                allMap.put(col, col)
                getList(col, colMap, allMap)

            }

        }
    }

    def getColRelation(childs: Seq[TreeNode[_]], colMap: java.util.HashMap[String, java.util.ArrayList[String]]): Unit = {
        childs.foreach(child => {
            if (child.isInstanceOf[ProjectExec]) {
                val source = child.asInstanceOf[ProjectExec]
                getColRelationDetail(source.projectList, "-1", colMap)
            } else {
                getColRelation(child.children.asInstanceOf[Seq[TreeNode[_]]], colMap)
            }
        })
    }

    def getColRelationDetail(childs: Seq[Expression], parent: String, colMap: java.util.HashMap[String, java.util.ArrayList[String]]): Unit = {
        childs.foreach(ss => {
            if (ss.isInstanceOf[NamedExpression]) {
                val sub = ss.asInstanceOf[NamedExpression]
                var list = colMap.get(parent)
                if (list == null) {
                    list = Lists.newArrayList[String]()
                }
                val key = sub.name + "#" + sub.exprId.id
                list.add(key)
                colMap.put(parent, list)

                getColRelationDetail(sub.children, key, colMap)
            } else {
                var list = colMap.get(parent)
                if (list == null) {
                    list = Lists.newArrayList[String]()
                }
                list.add(ss.toString())
                colMap.put(parent, list)

                getColRelationDetail(ss.children, ss.toString(), colMap)

            }
        })
    }

    /**
      * 获取所有表的所有字段
      *
      * @param childs
      * @param tableMap
      */
    def getFullTabs(childs: Seq[TreeNode[_]], tableMap: java.util.HashMap[String, String]): Unit = {
        childs.foreach(sub => {
            if (sub.isInstanceOf[SubqueryAlias] && sub.asInstanceOf[SubqueryAlias].child.isInstanceOf[LogicalRelation]) {
                val subquery = sub.asInstanceOf[SubqueryAlias];
                val table = subquery.name
                var tableName = table.identifier
                if (table.database != null) {
                    tableName = table.database.get + "." + tableName
                }

                subquery.child.asInstanceOf[LogicalRelation].output.foreach(item => {
                    tableMap.put(item.name + "#" + item.exprId.id, tableName)
                })
            } else {
                getFullTabs(sub.children.asInstanceOf[Seq[TreeNode[_]]], tableMap)
            }
        })
    }
}


易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!