I have script in pyspark
like below. I want to unit test a function
in this script.
def rename_chars(column_name):
chars = ((\'
Assuming you have pyspark
installed (e.g. pip install pyspark
on a venv), you can use the class below for unit testing it in unittest
:
import unittest
import pyspark
class PySparkTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
conf = pyspark.SparkConf().setMaster("local[*]").setAppName("testing")
cls.sc = pyspark.SparkContext(conf=conf)
cls.spark = pyspark.SQLContext(cls.sc)
@classmethod
def tearDownClass(cls):
cls.sc.stop()
Example:
class SimpleTestCase(PySparkTestCase):
def test_with_rdd(self):
test_input = [
' hello spark ',
' hello again spark spark'
]
input_rdd = self.sc.parallelize(test_input, 1)
from operator import add
results = input_rdd.flatMap(lambda x: x.split()).map(lambda x: (x, 1)).reduceByKey(add).collect()
self.assertEqual(results, [('hello', 2), ('spark', 3), ('again', 1)])
def test_with_df(self):
df = self.spark.createDataFrame(data=[[1, 'a'], [2, 'b']],
schema=['c1', 'c2'])
self.assertEqual(df.count(), 2)
Note that this creates a context per class. Use setUp
instead of setUpClass
to get a context per test. This typically adds a lot of overhead time on the execution of the tests, as creating a new spark context is currently expensive.