How to split Vector into columns - using PySpark

后端 未结 5 1393
夕颜
夕颜 2020-11-22 16:23

Context: I have a DataFrame with 2 columns: word and vector. Where the column type of \"vector\" is VectorUDT.

An Example:

5条回答
  •  天涯浪人
    2020-11-22 17:09

    It is much faster to use the i_th udf from how-to-access-element-of-a-vectorudt-column-in-a-spark-dataframe

    The extract function given in the solution by zero323 above uses toList, which creates a Python list object, populates it with Python float objects, finds the desired element by traversing the list, which then needs to be converted back to java double; repeated for each row. Using the rdd is much slower than the to_array udf, which also calls toList, but both are much slower than a udf that lets SparkSQL handle most of the work.

    Timing code comparing rdd extract and to_array udf proposed here to i_th udf from 3955864:

    from pyspark.context import SparkContext
    from pyspark.sql import Row, SQLContext, SparkSession
    from pyspark.sql.functions import lit, udf, col
    from pyspark.sql.types import ArrayType, DoubleType
    import pyspark.sql.dataframe
    from pyspark.sql.functions import pandas_udf, PandasUDFType
    
    sc = SparkContext('local[4]', 'FlatTestTime')
    
    spark = SparkSession(sc)
    spark.conf.set("spark.sql.execution.arrow.enabled", True)
    
    from pyspark.ml.linalg import Vectors
    
    # copy the two rows in the test dataframe a bunch of times,
    # make this small enough for testing, or go for "big data" and be prepared to wait
    REPS = 20000
    
    df = sc.parallelize([
        ("assert", Vectors.dense([1, 2, 3]), 1, Vectors.dense([4.1, 5.1])),
        ("require", Vectors.sparse(3, {1: 2}), 2, Vectors.dense([6.2, 7.2])),
    ] * REPS).toDF(["word", "vector", "more", "vorpal"])
    
    def extract(row):
        return (row.word, ) + tuple(row.vector.toArray().tolist(),) + (row.more,) + tuple(row.vorpal.toArray().tolist(),)
    
    def test_extract():
        return df.rdd.map(extract).toDF(['word', 'vector__0', 'vector__1', 'vector__2', 'more', 'vorpal__0', 'vorpal__1'])
    
    def to_array(col):
        def to_array_(v):
            return v.toArray().tolist()
        return udf(to_array_, ArrayType(DoubleType()))(col)
    
    def test_to_array():
        df_to_array = df.withColumn("xs", to_array(col("vector"))) \
            .select(["word"] + [col("xs")[i] for i in range(3)] + ["more", "vorpal"]) \
            .withColumn("xx", to_array(col("vorpal"))) \
            .select(["word"] + ["xs[{}]".format(i) for i in range(3)] + ["more"] + [col("xx")[i] for i in range(2)])
        return df_to_array
    
    # pack up to_array into a tidy function
    def flatten(df, vector, vlen):
        fieldNames = df.schema.fieldNames()
        if vector in fieldNames:
            names = []
            for fieldname in fieldNames:
                if fieldname == vector:
                    names.extend([col(vector)[i] for i in range(vlen)])
                else:
                    names.append(col(fieldname))
            return df.withColumn(vector, to_array(col(vector)))\
                     .select(names)
        else:
            return df
    
    def test_flatten():
        dflat = flatten(df, "vector", 3)
        dflat2 = flatten(dflat, "vorpal", 2)
        return dflat2
    
    def ith_(v, i):
        try:
            return float(v[i])
        except ValueError:
            return None
    
    ith = udf(ith_, DoubleType())
    
    select = ["word"]
    select.extend([ith("vector", lit(i)) for i in range(3)])
    select.append("more")
    select.extend([ith("vorpal", lit(i)) for i in range(2)])
    
    # %% timeit ...
    def test_ith():
        return df.select(select)
    
    if __name__ == '__main__':
        import timeit
    
        # make sure these work as intended
        test_ith().show(4)
        test_flatten().show(4)
        test_to_array().show(4)
        test_extract().show(4)
    
        print("i_th\t\t",
              timeit.timeit("test_ith()",
                           setup="from __main__ import test_ith",
                           number=7)
             )
        print("flatten\t\t",
              timeit.timeit("test_flatten()",
                           setup="from __main__ import test_flatten",
                           number=7)
             )
        print("to_array\t",
              timeit.timeit("test_to_array()",
                           setup="from __main__ import test_to_array",
                           number=7)
             )
        print("extract\t\t",
              timeit.timeit("test_extract()",
                           setup="from __main__ import test_extract",
                           number=7)
             )
    

    Results:

    i_th         0.05964796099999958
    flatten      0.4842299350000001
    to_array     0.42978780299999997
    extract      2.9254476840000017
    

提交回复
热议问题