DataFrame filtering based on second Dataframe

僤鯓⒐⒋嵵緔 提交于 2019-12-05 23:59:34
zero323

There at least few problems with your code:

  • you cannot execute action or transformation inside another action or transformation. It means that filtering broadcasted DataFrame simply cannot work and you should get an exception.
  • join you use is executed as a Cartesian product followed by filter. Since Spark is using Hashing for joins only equality based joins can be efficiently executed without Cartesian. It is slightly related to Why using a UDF in a SQL query leads to cartesian product?
  • if both DataFrames are relatively large and have similar size then broadcasting is unlikely to be useful. See Why my BroadcastHashJoin is slower than ShuffledHashJoin in Spark
  • not important when it comes to performance but isPrefix seems to wrong. In particular it looks like it can match both prefix and suffix
  • col("first.path").lt(col("second.path")) condition looks wrong. I assume you want a/a/a/b/c from df1 match a/a/a from df2. If so it should be gt not lt.

Probably the best thing you can do is something similar to this:

import org.apache.spark.sql.functions.{col, regexp_extract}

val df = sc.parallelize(Seq(
    ("a/a/a/b/c", "abc"), ("a/a/a","qwe"),
    ("a/b/c/d/e", "abc"), ("a/b/c", "qwe"),
    ("a/b/b/k", "foo"), ("a/b/f/a", "bar")
)).toDF("path", "value")

val df1 = df
    .where(col("value") === "abc")    
    .withColumn("path_short", regexp_extract(col("path"), "^(.*)(/.){2}$", 1))
    .as("df1")

val df2 = df.where(col("value") === "qwe").as("df2")
val joined = df1.join(df2, col("df1.path_short") === col("df2.path"))

You can try to broadcast one of the tables like this (Spark >= 1.5.0 only):

import org.apache.spark.sql.functions.broadcast

df1.join(broadcast(df2), col("df1.path_short") === col("df2.path"))

and increase auto broadcast limits, but as I've mentioned above it most likely will be less efficient than plain HashJoin.

As a possible way of implementing IN with subquery, the LEFT SEMI JOIN can be used:

    JavaSparkContext javaSparkContext = new JavaSparkContext("local", "testApp");
    SQLContext sqlContext = new SQLContext(javaSparkContext);
    StructType schema = DataTypes.createStructType(new StructField[]{
            DataTypes.createStructField("path", DataTypes.StringType, false),
            DataTypes.createStructField("value", DataTypes.StringType, false)
    });
    // Prepare First DataFrame
    List<Row> dataForFirstDF = new ArrayList<>();
    dataForFirstDF.add(RowFactory.create("a/a/a/b/c", "abc"));
    dataForFirstDF.add(RowFactory.create("a/b/c/d/e", "abc"));
    dataForFirstDF.add(RowFactory.create("x/y/z", "xyz"));
    DataFrame df1 = sqlContext.createDataFrame(javaSparkContext.parallelize(dataForFirstDF), schema);
    // 
    df1.show();
    //
    // +---------+-----+
    // |     path|value|
    // +---------+-----+
    // |a/a/a/b/c|  abc|
    // |a/b/c/d/e|  abc|
    // |    x/y/z|  xyz|
    // +---------+-----+

    // Prepare Second DataFrame
    List<Row> dataForSecondDF = new ArrayList<>();
    dataForSecondDF.add(RowFactory.create("a/a/a", "qwe"));
    dataForSecondDF.add(RowFactory.create("a/b/c", "qwe"));
    DataFrame df2 = sqlContext.createDataFrame(javaSparkContext.parallelize(dataForSecondDF), schema);

    // Use left semi join to filter out df1 based on path in df2
    Column pathContains = functions.column("firstDF.path").contains(functions.column("secondDF.path"));
    DataFrame result = df1.as("firstDF").join(df2.as("secondDF"), pathContains, "leftsemi");

    //
    result.show();
    //
    // +---------+-----+
    // |     path|value|
    // +---------+-----+
    // |a/a/a/b/c|  abc|
    // |a/b/c/d/e|  abc|
    // +---------+-----+

The Physical Plan of such query will look like this:

== Physical Plan ==
Limit 21
 ConvertToSafe
  LeftSemiJoinBNL Some(Contains(path#0, path#2))
   ConvertToUnsafe
    Scan PhysicalRDD[path#0,value#1]
   TungstenProject [path#2]
    Scan PhysicalRDD[path#2,value#3]

It will use the LeftSemiJoinBNL for the actual join operation, which should broadcast values internally. From more details refer to the actual implementation in Spark - LeftSemiJoinBNL.scala

P.S. I didn't quite understand the need for removing the last two characters, but if that's needed - it can be done, like @zero323 proposed (using regexp_extract).

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!