PySpark - Adding a Column from a list of values using a UDF

前端 未结 5 1555
臣服心动
臣服心动 2021-01-05 00:32

I have to add column to a PySpark dataframe based on a list of values.

a= spark.createDataFrame([(\"Dog\", \"Cat\"), (\"Cat\", \"Dog\"), (\"Mouse\", \"Cat\"         


        
5条回答
  •  被撕碎了的回忆
    2021-01-05 01:17

    As mentioned by @Tw UxTLi51Nus, if you can order the DataFrame, let's say, by Animal, without this changing your results, you can then do the following:

    def add_labels(indx):
        return rating[indx-1] # since row num begins from 1
    labels_udf = udf(add_labels, IntegerType())
    
    a = spark.createDataFrame([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")],["Animal", "Enemy"])
    a.createOrReplaceTempView('a')
    a = spark.sql('select row_number() over (order by "Animal") as num, * from a')
    
    a.show()
    
    
    +---+------+-----+
    |num|Animal|Enemy|
    +---+------+-----+
    |  1|   Dog|  Cat|
    |  2|   Cat|  Dog|
    |  3| Mouse|  Cat|
    +---+------+-----+
    
    new_df = a.withColumn('Rating', labels_udf('num'))
    new_df.show()
    +---+------+-----+------+
    |num|Animal|Enemy|Rating|
    +---+------+-----+------+
    |  1|   Dog|  Cat|     5|
    |  2|   Cat|  Dog|     4|
    |  3| Mouse|  Cat|     1|
    +---+------+-----+------+
    

    And then drop the num column:

    new_df.drop('num').show()
    +------+-----+------+
    |Animal|Enemy|Rating|
    +------+-----+------+
    |   Dog|  Cat|     5|
    |   Cat|  Dog|     4|
    | Mouse|  Cat|     1|
    +------+-----+------+
    

    Edit:

    Another - but perhaps ugly and a bit inefficient - way, if you cannot sort by a column, is to go back to rdd and do the following:

    a = spark.createDataFrame([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")],["Animal", "Enemy"])
    
    # or create the rdd from the start:
    # a = spark.sparkContext.parallelize([("Dog", "Cat"), ("Cat", "Dog"), ("Mouse", "Cat")])
    
    a = a.rdd.zipWithIndex()
    a = a.toDF()
    a.show()
    
    +-----------+---+
    |         _1| _2|
    +-----------+---+
    |  [Dog,Cat]|  0|
    |  [Cat,Dog]|  1|
    |[Mouse,Cat]|  2|
    +-----------+---+
    
    a = a.select(bb._1.getItem('Animal').alias('Animal'), bb._1.getItem('Enemy').alias('Enemy'), bb._2.alias('num'))
    
    def add_labels(indx):
        return rating[indx] # indx here will start from zero
    
    labels_udf = udf(add_labels, IntegerType())
    
    new_df = a.withColumn('Rating', labels_udf('num'))
    
    new_df.show()
    
    +---------+--------+---+------+
    |Animal   |   Enemy|num|Rating|
    +---------+--------+---+------+
    |      Dog|     Cat|  0|     5|
    |      Cat|     Dog|  1|     4|
    |    Mouse|     Cat|  2|     1|
    +---------+--------+---+------+
    

    (I would not recommend this if you have much data)

    Hope this helps, good luck!

提交回复
热议问题