PySpark- How to use a row value from one column to access another column which has the same name as of the row value

后端 未结 2 1179
猫巷女王i
猫巷女王i 2020-12-17 05:12

I have a PySpark df:

+---+---+---+---+---+---+---+---+
| id| a1| b1| c1| d1| e1| f1|ref|
+---+---+---+---+---+---+---+---+
|  0|  1| 23|  4|  8|  9|  5| b1|
         


        
相关标签:
2条回答
  • 2020-12-17 05:46

    Independent of version you can convert to RDD, map, and convert back to DataFrame:

    df = spark.createDataFrame(
        [(0, 1, 23, 4, 8, 9, 5, "b1"), (1, 2, 43, 8, 10, 20, 43, "e1")], 
        ("id", "a1", "b1", "c1", "d1", "e1", "f1", "ref")
    )
    
    df.rdd.map(lambda row: row + (row[row.ref], )).toDF(df.columns + ["out"])
    
    +---+---+---+---+---+---+---+---+---+
    | id| a1| b1| c1| d1| e1| f1|ref|out|
    +---+---+---+---+---+---+---+---+---+
    |  0|  1| 23|  4|  8|  9|  5| b1| 23|
    |  1|  2| 43|  8| 10| 20| 43| e1| 20|
    +---+---+---+---+---+---+---+---+---+
    

    You could also preserve schema

    from pyspark.sql.types import LongType, StructField
    
    spark.createDataFrame(
        df.rdd.map(lambda row: row + (row[row.ref], )), 
        df.schema.add(StructField("out", LongType())))
    

    With DataFrames you can compose complex Columns. In 1.6:

    from pyspark.sql.functions import array, col, udf
    from pyspark.sql.types import  LongType, MapType, StringType
    
    data_cols = [x for x in df.columns if x not in {"id", "ref"}]
    
    # Literal map from column name to index
    name_to_index = udf(
        lambda: {x: i for i, x in enumerate(data_cols)},
        MapType(StringType(), LongType())
    )()
    
    # Array of data
    data_array = array(*[col(c) for c in data_cols])
    df.withColumn("out", data_array[name_to_index[col("ref")]])
    
    +---+---+---+---+---+---+---+---+---+
    | id| a1| b1| c1| d1| e1| f1|ref|out|
    +---+---+---+---+---+---+---+---+---+
    |  0|  1| 23|  4|  8|  9|  5| b1| 23|
    |  1|  2| 43|  8| 10| 20| 43| e1| 20|
    +---+---+---+---+---+---+---+---+---+
    

    In 2.x you can skip intermediate objects:

    from pyspark.sql.functions import create_map, lit, col
    from itertools import chain
    
    # Map from column name to column value
    name_to_value = create_map(*chain.from_iterable(
        (lit(c), col(c)) for c in data_cols
    ))
    
    df.withColumn("out", name_to_value[col("ref")])
    
    +---+---+---+---+---+---+---+---+---+
    | id| a1| b1| c1| d1| e1| f1|ref|out|
    +---+---+---+---+---+---+---+---+---+
    |  0|  1| 23|  4|  8|  9|  5| b1| 23|
    |  1|  2| 43|  8| 10| 20| 43| e1| 20|
    +---+---+---+---+---+---+---+---+---+
    

    Finally you can use when:

    from pyspark.sql.functions import col, lit, when
    from functools import reduce
    
    out = reduce(
        lambda acc, x: when(col("ref") == x, col(x)).otherwise(acc), 
        data_cols,
        lit(None)
    )
    
    +---+---+---+---+---+---+---+---+---+
    | id| a1| b1| c1| d1| e1| f1|ref|out|
    +---+---+---+---+---+---+---+---+---+
    |  0|  1| 23|  4|  8|  9|  5| b1| 23|
    |  1|  2| 43|  8| 10| 20| 43| e1| 20|
    +---+---+---+---+---+---+---+---+---+
    
    0 讨论(0)
  • 2020-12-17 06:10

    The OP has asked python solution. I'm just answering the same in spark-scala 2.X for reference. Hope it helps somebody

    scala> val df = Seq((0, 1, 23, 4, 8, 9, 5, "b1"), (1, 2, 43, 8, 10, 20, 43, "e1"), (2,  3, 15,  0,  1, 23,  7, "b1"),(3,  4,  2,  6, 11,  5,  8, "d1"),(4,  5,  6,  7,  2,  8,  1, "f1")).toDF("id", "a1", "b1", "c1", "d1", "e1", "f1", "ref")
    df: org.apache.spark.sql.DataFrame = [id: int, a1: int ... 6 more fields]
    
    scala> df.show(false)
    +---+---+---+---+---+---+---+---+
    |id |a1 |b1 |c1 |d1 |e1 |f1 |ref|
    +---+---+---+---+---+---+---+---+
    |0  |1  |23 |4  |8  |9  |5  |b1 |
    |1  |2  |43 |8  |10 |20 |43 |e1 |
    |2  |3  |15 |0  |1  |23 |7  |b1 |
    |3  |4  |2  |6  |11 |5  |8  |d1 |
    |4  |5  |6  |7  |2  |8  |1  |f1 |
    +---+---+---+---+---+---+---+---+
    
    
    scala> val colx = df.columns.filter(x=>x!="ref").filter(x=>x!="id")
    colx: Array[String] = Array(a1, b1, c1, d1, e1, f1)
    
    scala> val colm = colx.map( x=> when(col("ref")===lit(x),col(x)) )
    colm: Array[org.apache.spark.sql.Column] = Array(CASE WHEN (ref = a1) THEN a1 END, CASE WHEN (ref = b1) THEN b1 END, CASE WHEN (ref = c1) THEN c1 END, CASE WHEN (ref = d1) THEN d1 END, CASE WHEN (ref = e1) THEN e1 END, CASE WHEN (ref = f1) THEN f1 END)
    
    scala> df.select(col("*"),concat_ws("",array(colm:_*)).as("res1")).show(false)
    +---+---+---+---+---+---+---+---+----+
    |id |a1 |b1 |c1 |d1 |e1 |f1 |ref|res1|
    +---+---+---+---+---+---+---+---+----+
    |0  |1  |23 |4  |8  |9  |5  |b1 |23  |
    |1  |2  |43 |8  |10 |20 |43 |e1 |20  |
    |2  |3  |15 |0  |1  |23 |7  |b1 |15  |
    |3  |4  |2  |6  |11 |5  |8  |d1 |11  |
    |4  |5  |6  |7  |2  |8  |1  |f1 |1   |
    +---+---+---+---+---+---+---+---+----+
    
    
    scala>
    
    0 讨论(0)
提交回复
热议问题