I am joining two big datasets using Spark RDD. One dataset is very much skewed so few of the executor tasks taking a long time to finish the job. How can I solve this scenar
Taking reference from https://datarus.wordpress.com/2015/05/04/fighting-the-skew-in-spark/ below is the code for fighting the skew in spark using Pyspark dataframe API
Creating the 2 dataframes:
from math import exp
from random import randint
from datetime import datetime
def count_elements(splitIndex, iterator):
n = sum(1 for _ in iterator)
yield (splitIndex, n)
def get_part_index(splitIndex, iterator):
for it in iterator:
yield (splitIndex, it)
num_parts = 18
# create the large skewed rdd
skewed_large_rdd = sc.parallelize(range(0,num_parts), num_parts).flatMap(lambda x: range(0, int(exp(x))))
skewed_large_rdd = skewed_large_rdd.mapPartitionsWithIndex(lambda ind, x: get_part_index(ind, x))
skewed_large_df = spark.createDataFrame(skewed_large_rdd,['x','y'])
small_rdd = sc.parallelize(range(0,num_parts), num_parts).map(lambda x: (x, x))
small_df = spark.createDataFrame(small_rdd,['a','b'])
Dividing the data into 100 bins for large df and replicating the small df 100 times
salt_bins = 100
from pyspark.sql import functions as F
skewed_transformed_df = skewed_large_df.withColumn('salt', (F.rand()*salt_bins).cast('int')).cache()
small_transformed_df = small_df.withColumn('replicate', F.array([F.lit(i) for i in range(salt_bins)]))
small_transformed_df = small_transformed_df.select('*', F.explode('replicate').alias('salt')).drop('replicate').cache()
Finally the join avoiding the skew
t0 = datetime.now()
result2 = skewed_transformed_df.join(small_transformed_df, (skewed_transformed_df['x'] == small_transformed_df['a']) & (skewed_transformed_df['salt'] == small_transformed_df['salt']) )
result2.count()
print "The direct join takes %s"%(str(datetime.now() - t0))