Filtering rows based on column values in spark dataframe scala

前端 未结 4 714
时光说笑
时光说笑 2020-12-09 21:16

I have a dataframe(spark):

id  value 
3     0
3     1
3     0
4     1
4     0
4     0

I want to create a new dataframe:

3 0         


        
相关标签:
4条回答
  • 2020-12-09 21:43

    You can simply use groupBy like this

    val df2 = df1.groupBy("id","value").count().select("id","value")
    

    Here your df1 is

    id  value 
    3     0
    3     1
    3     0
    4     1
    4     0
    4     0
    

    And resultant dataframe is df2 which is your expected output like this

    id  value 
    3     0
    3     1
    4     1
    4     0
    
    0 讨论(0)
  • 2020-12-09 21:44
    use isin method and filter as below:
    
    val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id", "value","sorted")
    val idFilter = List(1, 2)
     data.filter($"id".isin(idFilter:_*)).show
    +---+-----+------+
    | id|value|sorted|
    +---+-----+------+
    |  1|    0|     7|
    |  1|    1|     8|
    |  1|    0|     9|
    |  2|    1|    10|
    |  2|    0|    11|
    |  2|    0|    12|
    +---+-----+------+
    
    Ex: filter based on val
    val valFilter = List(0)
    data.filter($"value".isin(valFilter:_*)).show
    +---+-----+------+
    | id|value|sorted|
    +---+-----+------+
    |  3|    0|     2|
    |  3|    0|     1|
    |  4|    0|     5|
    |  4|    0|     4|
    |  1|    0|     7|
    |  1|    0|     9|
    |  2|    0|    11|
    |  2|    0|    12|
    +---+-----+------+
    
    0 讨论(0)
  • 2020-12-09 21:54

    One way is to use monotonically_increasing_id() and a self-join:

    val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
    data.show
    +---+-----+
    | id|value|
    +---+-----+
    |  3|    0|
    |  3|    1|
    |  3|    0|
    |  4|    1|
    |  4|    0|
    |  4|    0|
    +---+-----+
    

    Now we generate a column named idx with an increasing Long:

    val dataWithIndex = data.withColumn("idx", monotonically_increasing_id())
    // dataWithIndex.cache()
    

    Now we get the min(idx) for each id where value = 1:

    val minIdx = dataWithIndex
                   .filter($"value" === 1)
                   .groupBy($"id")
                   .agg(min($"idx"))
                   .toDF("r_id", "min_idx")
    

    Now we join the min(idx) back to the original DataFrame:

    dataWithIndex.join(
      minIdx,
      ($"r_id" === $"id") && ($"idx" <= $"min_idx")
    ).select($"id", $"value").show
    +---+-----+
    | id|value|
    +---+-----+
    |  3|    0|
    |  3|    1|
    |  4|    1|
    +---+-----+
    

    Note: monotonically_increasing_id() generates its value based on the partition of the row. This value may change each time dataWithIndex is re-evaluated. In my code above, because of lazy evaluation, it's only when I call the final show that monotonically_increasing_id() is evaluated.

    If you want to force the value to stay the same, for example so you can use show to evaluate the above step-by-step, uncomment this line above:

    //  dataWithIndex.cache()
    
    0 讨论(0)
  • 2020-12-09 21:56

    Hi I found the solution using Window and self join.

    val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id", "value","sorted")
    
    data.show
    
    scala> data.show
    +---+-----+------+
    | id|value|sorted|
    +---+-----+------+
    |  3|    0|     2|
    |  3|    1|     3|
    |  3|    0|     1|
    |  4|    1|     6|
    |  4|    0|     5|
    |  4|    0|     4|
    |  1|    0|     7|
    |  1|    1|     8|
    |  1|    0|     9|
    |  2|    1|    10|
    |  2|    0|    11|
    |  2|    0|    12|
    +---+-----+------+
    
    
    
    
    val sort_df=data.sort($"sorted")
    
    scala> sort_df.show
    +---+-----+------+
    | id|value|sorted|
    +---+-----+------+
    |  3|    0|     1|
    |  3|    0|     2|
    |  3|    1|     3|
    |  4|    0|     4|
    |  4|    0|     5|
    |  4|    1|     6|
    |  1|    0|     7|
    |  1|    1|     8|
    |  1|    0|     9|
    |  2|    1|    10|
    |  2|    0|    11|
    |  2|    0|    12|
    +---+-----+------+
    
    
    
    var window=Window.partitionBy("id").orderBy("$sorted")
    
     val sort_idx=sort_df.select($"*",rowNumber.over(window).as("count_index"))
    
    val minIdx=sort_idx.filter($"value"===1).groupBy("id").agg(min("count_index")).toDF("idx","min_idx")
    
    val result_id=sort_idx.join(minIdx,($"id"===$"idx") &&($"count_index" <= $"min_idx"))
    
    result_id.show
    
    +---+-----+------+-----------+---+-------+
    | id|value|sorted|count_index|idx|min_idx|
    +---+-----+------+-----------+---+-------+
    |  1|    0|     7|          1|  1|      2|
    |  1|    1|     8|          2|  1|      2|
    |  2|    1|    10|          1|  2|      1|
    |  3|    0|     1|          1|  3|      3|
    |  3|    0|     2|          2|  3|      3|
    |  3|    1|     3|          3|  3|      3|
    |  4|    0|     4|          1|  4|      3|
    |  4|    0|     5|          2|  4|      3|
    |  4|    1|     6|          3|  4|      3|
    +---+-----+------+-----------+---+-------+
    

    Still looking for a more optimized solutions.Thanks

    0 讨论(0)
提交回复
热议问题