How to retrieve all columns using pyspark collect_list functions

前端 未结 3 1141
梦如初夏
梦如初夏 2021-01-14 05:42

I have a pyspark 2.0.1. I\'m trying to groupby my data frame & retrieve the value for all the fields from my data frame. I found that

z=data1.groupby(\'         


        
3条回答
  •  失恋的感觉
    2021-01-14 05:50

    Use struct to combine the columns before calling groupBy

    suppose you have a dataframe

    df = spark.createDataFrame(sc.parallelize([(0,1,2),(0,4,5),(1,7,8),(1,8,7)])).toDF("a","b","c")
    
    df = df.select("a", f.struct(["b","c"]).alias("newcol"))
    df.show()
    +---+------+
    |  a|newcol|
    +---+------+
    |  0| [1,2]|
    |  0| [4,5]|
    |  1| [7,8]|
    |  1| [8,7]|
    +---+------+
    df = df.groupBy("a").agg(f.collect_list("newcol").alias("collected_col"))
    df.show()
    +---+--------------+
    |  a| collected_col|
    +---+--------------+
    |  0|[[1,2], [4,5]]|
    |  1|[[7,8], [8,7]]|
    +---+--------------+
    

    Aggregation operation can be done only on single columns.

    After aggregation, You can collect the result and iterate over it to separate the combined columns generate the index dict. or you can write a udf to separate the combined columns.

    from pyspark.sql.types import *
    def foo(x):
        x1 = [y[0] for y in x]
        x2 = [y[1] for y in x]
        return(x1,x2)
    
    st = StructType([StructField("b", ArrayType(LongType())), StructField("c", ArrayType(LongType()))])
    udf_foo = udf(foo, st)
    df = df.withColumn("ncol", 
                      udf_foo("collected_col")).select("a",
                      col("ncol").getItem("b").alias("b"), 
                      col("ncol").getItem("c").alias("c"))
    df.show()
    
    +---+------+------+
    |  a|     b|     c|
    +---+------+------+
    |  0|[1, 4]|[2, 5]|
    |  1|[7, 8]|[8, 7]|
    +---+------+------+
    

提交回复
热议问题