Pyspark : forward fill with last observation for a DataFrame

前端 未结 5 1801
不思量自难忘°
不思量自难忘° 2020-12-03 09:12

Using Spark 1.5.1,

I\'ve been trying to forward fill null values with the last known observation for one column of my DataFrame.

5条回答
  •  日久生厌
    2020-12-03 09:45

    The partitioned example code from Spark / Scala: forward fill with last observation in pyspark is shown. This only works for data that can be partitioned.

    Load the data

    values = [
        (1, "2015-12-01", None),
        (1, "2015-12-02", "U1"),
        (1, "2015-12-02", "U1"),
        (1, "2015-12-03", "U2"),
        (1, "2015-12-04", None),
        (1, "2015-12-05", None),
        (2, "2015-12-04", None),
        (2, "2015-12-03", None),
        (2, "2015-12-02", "U3"),
        (2, "2015-12-05", None),
    ]
    rdd = sc.parallelize(values)
    df = rdd.toDF(["cookie_id", "c_date", "user_id"])
    df = df.withColumn("c_date", df.c_date.cast("date"))
    df.show()
    

    The DataFrame is

    +---------+----------+-------+
    |cookie_id|    c_date|user_id|
    +---------+----------+-------+
    |        1|2015-12-01|   null|
    |        1|2015-12-02|     U1|
    |        1|2015-12-02|     U1|
    |        1|2015-12-03|     U2|
    |        1|2015-12-04|   null|
    |        1|2015-12-05|   null|
    |        2|2015-12-04|   null|
    |        2|2015-12-03|   null|
    |        2|2015-12-02|     U3|
    |        2|2015-12-05|   null|
    +---------+----------+-------+
    

    Column used to sort the partitions

    # get the sort key
    def getKey(item):
        return item.c_date
    

    The fill function. Can be used to fill in multiple columns if necessary.

    # fill function
    def fill(x):
        out = []
        last_val = None
        for v in x:
            if v["user_id"] is None:
                data = [v["cookie_id"], v["c_date"], last_val]
            else:
                data = [v["cookie_id"], v["c_date"], v["user_id"]]
                last_val = v["user_id"]
            out.append(data)
        return out
    

    Convert to rdd, partition, sort and fill the missing values

    # Partition the data
    rdd = df.rdd.groupBy(lambda x: x.cookie_id).mapValues(list)
    # Sort the data by date
    rdd = rdd.mapValues(lambda x: sorted(x, key=getKey))
    # fill missing value and flatten
    rdd = rdd.mapValues(fill).flatMapValues(lambda x: x)
    # discard the key
    rdd = rdd.map(lambda v: v[1])
    

    Convert back to DataFrame

    df_out = sqlContext.createDataFrame(rdd)
    df_out.show()
    

    The output is

    +---+----------+----+
    | _1|        _2|  _3|
    +---+----------+----+
    |  1|2015-12-01|null|
    |  1|2015-12-02|  U1|
    |  1|2015-12-02|  U1|
    |  1|2015-12-03|  U2|
    |  1|2015-12-04|  U2|
    |  1|2015-12-05|  U2|
    |  2|2015-12-02|  U3|
    |  2|2015-12-03|  U3|
    |  2|2015-12-04|  U3|
    |  2|2015-12-05|  U3|
    +---+----------+----+
    

提交回复
热议问题