Skewed dataset join in Spark?

后端 未结 5 1886

I am joining two big datasets using Spark RDD. One dataset is very much skewed so few of the executor tasks taking a long time to finish the job. How can I solve this scenar

5条回答
  •  北荒
    北荒 (楼主)
    2020-12-01 02:06

    Taking reference from https://datarus.wordpress.com/2015/05/04/fighting-the-skew-in-spark/ below is the code for fighting the skew in spark using Pyspark dataframe API

    Creating the 2 dataframes:

    from math import exp
    from random import randint
    from datetime import datetime
    
    def count_elements(splitIndex, iterator):
        n = sum(1 for _ in iterator)
        yield (splitIndex, n)
    
    def get_part_index(splitIndex, iterator):
        for it in iterator:
            yield (splitIndex, it)
    
    num_parts = 18
    # create the large skewed rdd
    skewed_large_rdd = sc.parallelize(range(0,num_parts), num_parts).flatMap(lambda x: range(0, int(exp(x))))
    skewed_large_rdd = skewed_large_rdd.mapPartitionsWithIndex(lambda ind, x: get_part_index(ind, x))
    
    skewed_large_df = spark.createDataFrame(skewed_large_rdd,['x','y'])
    
    small_rdd = sc.parallelize(range(0,num_parts), num_parts).map(lambda x: (x, x))
    
    small_df = spark.createDataFrame(small_rdd,['a','b'])
    

    Dividing the data into 100 bins for large df and replicating the small df 100 times

    salt_bins = 100
    from pyspark.sql import functions as F
    
    skewed_transformed_df = skewed_large_df.withColumn('salt', (F.rand()*salt_bins).cast('int')).cache()
    
    small_transformed_df = small_df.withColumn('replicate', F.array([F.lit(i) for i in range(salt_bins)]))
    
    small_transformed_df = small_transformed_df.select('*', F.explode('replicate').alias('salt')).drop('replicate').cache()
    

    Finally the join avoiding the skew

    t0 = datetime.now()
    result2 = skewed_transformed_df.join(small_transformed_df, (skewed_transformed_df['x'] == small_transformed_df['a']) & (skewed_transformed_df['salt'] == small_transformed_df['salt']) )
    result2.count() 
    print "The direct join takes %s"%(str(datetime.now() - t0))
    

提交回复
热议问题