Spark dataframe transform multiple rows to column

前端 未结 2 844
孤街浪徒
孤街浪徒 2020-12-28 16:32

I am a novice to spark, and I want to transform below source dataframe (load from JSON file):

+--+-----+-----+
|A |count|ma         


        
相关标签:
2条回答
  • 2020-12-28 16:58

    Using zero323's dataframe,

    df = sqlContext.createDataFrame([
    ("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
    ("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
    ("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
    ("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
    ("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
    ("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
    ("e", 1, "m4"), ("e", 1, "m5")], 
    ("a", "cnt", "major"))
    

    you could also use

    reshaped_df = df.groupby('a').pivot('major').max('cnt').fillna(0)
    
    0 讨论(0)
  • 2020-12-28 17:22

    Lets start with example data:

    df = sqlContext.createDataFrame([
        ("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
        ("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
        ("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
        ("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
        ("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
        ("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
        ("e", 1, "m4"), ("e", 1, "m5")], 
        ("a", "cnt", "major"))
    

    Please note that I've changed count to cnt. Count is a reserved keyword in most of the SQL dialects and it is not a good choice for a column name.

    There are at least two ways to reshape this data:

    • aggregating over DataFrame

      from pyspark.sql.functions import col, when, max
      
      majors = sorted(df.select("major")
          .distinct()
          .map(lambda row: row[0])
          .collect())
      
      cols = [when(col("major") == m, col("cnt")).otherwise(None).alias(m) 
          for m in  majors]
      maxs = [max(col(m)).alias(m) for m in majors]
      
      reshaped1 = (df
          .select(col("a"), *cols)
          .groupBy("a")
          .agg(*maxs)
          .na.fill(0))
      
      reshaped1.show()
      
      ## +---+---+---+---+---+---+
      ## |  a| m1| m2| m3| m4| m5|
      ## +---+---+---+---+---+---+
      ## |  a|  1|  1|  2|  3|  0|
      ## |  b|  4|  1|  2|  0|  0|
      ## |  c|  3|  0|  4|  5|  0|
      ## |  d|  6|  1|  2|  3|  4|
      ## |  e|  4|  5|  1|  1|  1|
      ## +---+---+---+---+---+---+
      
    • groupBy over RDD

      from pyspark.sql import Row
      
      grouped = (df
          .map(lambda row: (row.a, (row.major, row.cnt)))
          .groupByKey())
      
      def make_row(kv):
          k, vs = kv
          tmp = dict(list(vs) + [("a", k)])
          return Row(**{k: tmp.get(k, 0) for k in ["a"] + majors})
      
      reshaped2 = sqlContext.createDataFrame(grouped.map(make_row))
      
      reshaped2.show()
      
      ## +---+---+---+---+---+---+
      ## |  a| m1| m2| m3| m4| m5|
      ## +---+---+---+---+---+---+
      ## |  a|  1|  1|  2|  3|  0|
      ## |  e|  4|  5|  1|  1|  1|
      ## |  c|  3|  0|  4|  5|  0|
      ## |  b|  4|  1|  2|  0|  0|
      ## |  d|  6|  1|  2|  3|  4|
      ## +---+---+---+---+---+---+
      
    0 讨论(0)
提交回复
热议问题