My current Java/Spark Unit Test approach works (detailed here) by instantiating a SparkContext using \"local\" and running unit tests using JUnit.
The code has to be
You can test PySpark code by running your code on DataFrames in the test suite and comparing DataFrame column equality or equality of two entire DataFrames.
The quinn project has several examples.
Create SparkSession for test suite
Create a tests/conftest.py file with this fixture, so you can easily access the SparkSession in your tests.
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope='session')
def spark():
return SparkSession.builder \
.master("local") \
.appName("chispa") \
.getOrCreate()
Column equality example
Suppose you'd like to test the following function that removes all non-word characters from a string.
def remove_non_word_characters(col):
return F.regexp_replace(col, "[^\\w\\s]+", "")
You can test this function with the assert_column_equality
function that's defined in the chispa library.
def test_remove_non_word_characters(spark):
data = [
("jo&&se", "jose"),
("**li**", "li"),
("#::luisa", "luisa"),
(None, None)
]
df = spark.createDataFrame(data, ["name", "expected_name"])\
.withColumn("clean_name", remove_non_word_characters(F.col("name")))
assert_column_equality(df, "clean_name", "expected_name")
DataFrame equality example
Some functions need to be tested by comparing entire DataFrames. Here's a function that sorts the columns in a DataFrame.
def sort_columns(df, sort_order):
sorted_col_names = None
if sort_order == "asc":
sorted_col_names = sorted(df.columns)
elif sort_order == "desc":
sorted_col_names = sorted(df.columns, reverse=True)
else:
raise ValueError("['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format(
sort_order=sort_order
))
return df.select(*sorted_col_names)
Here's one test you'd write for this function.
def test_sort_columns_asc(spark):
source_data = [
("jose", "oak", "switch"),
("li", "redwood", "xbox"),
("luisa", "maple", "ps4"),
]
source_df = spark.createDataFrame(source_data, ["name", "tree", "gaming_system"])
actual_df = T.sort_columns(source_df, "asc")
expected_data = [
("switch", "jose", "oak"),
("xbox", "li", "redwood"),
("ps4", "luisa", "maple"),
]
expected_df = spark.createDataFrame(expected_data, ["gaming_system", "name", "tree"])
assert_df_equality(actual_df, expected_df)
Testing I/O
It's generally best to abstract code logic from I/O functions, so they're easier to test.
Suppose you have a function like this.
def your_big_function:
df = spark.read.parquet("some_directory")
df2 = df.withColumn(...).transform(function1).transform(function2)
df2.write.parquet("other directory")
It's better to refactor the code like this:
def all_logic(df):
return df.withColumn(...).transform(function1).transform(function2)
def your_formerly_big_function:
df = spark.read.parquet("some_directory")
df2 = df.transform(all_logic)
df2.write.parquet("other directory")
Designing your code like this lets you easily test the all_logic
function with the column equality or DataFrame equality functions mentioned above. You can use mocking to test your_formerly_big_function
. It's generally best to avoid I/O in test suites (but sometimes unavoidable).
This blog post has more details on how to test PySpark code.
Assuming you have pyspark
installed, you can use the class below for unitTest it in unittest
:
import unittest
import pyspark
class PySparkTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
conf = pyspark.SparkConf().setMaster("local[2]").setAppName("testing")
cls.sc = pyspark.SparkContext(conf=conf)
cls.spark = pyspark.SQLContext(cls.sc)
@classmethod
def tearDownClass(cls):
cls.sc.stop()
Example:
class SimpleTestCase(PySparkTestCase):
def test_with_rdd(self):
test_input = [
' hello spark ',
' hello again spark spark'
]
input_rdd = self.sc.parallelize(test_input, 1)
from operator import add
results = input_rdd.flatMap(lambda x: x.split()).map(lambda x: (x, 1)).reduceByKey(add).collect()
self.assertEqual(results, [('hello', 2), ('spark', 3), ('again', 1)])
def test_with_df(self):
df = self.spark.createDataFrame(data=[[1, 'a'], [2, 'b']],
schema=['c1', 'c2'])
self.assertEqual(df.count(), 2)
Note that this creates a context per class. Use setUp
instead of setUpClass
to get a context per test. This typically adds a lot of overhead time on the execution of the tests, as creating a new spark context is currently expensive.
Sometime ago I've also faced the same issue and after reading through several articles, forums and some StackOverflow answers I've ended with writing a small plugin for pytest: pytest-spark
I am already using it for few months and the general workflow looks good on Linux:
I use pytest
, which allows test fixtures so you can instantiate a pyspark context and inject it into all of your tests that require it. Something along the lines of
@pytest.fixture(scope="session",
params=[pytest.mark.spark_local('local'),
pytest.mark.spark_yarn('yarn')])
def spark_context(request):
if request.param == 'local':
conf = (SparkConf()
.setMaster("local[2]")
.setAppName("pytest-pyspark-local-testing")
)
elif request.param == 'yarn':
conf = (SparkConf()
.setMaster("yarn-client")
.setAppName("pytest-pyspark-yarn-testing")
.set("spark.executor.memory", "1g")
.set("spark.executor.instances", 2)
)
request.addfinalizer(lambda: sc.stop())
sc = SparkContext(conf=conf)
return sc
def my_test_that_requires_sc(spark_context):
assert spark_context.textFile('/path/to/a/file').count() == 10
Then you can run the tests in local mode by calling py.test -m spark_local
or in YARN with py.test -m spark_yarn
. This has worked pretty well for me.
I'd recommend using py.test as well. py.test makes it easy to create re-usable SparkContext test fixtures and use it to write concise test functions. You can also specialize fixtures (to create a StreamingContext for example) and use one or more of them in your tests.
I wrote a blog post on Medium on this topic:
https://engblog.nextdoor.com/unit-testing-apache-spark-with-py-test-3b8970dc013b
Here is a snippet from the post:
pytestmark = pytest.mark.usefixtures("spark_context")
def test_do_word_counts(spark_context):
""" test word couting
Args:
spark_context: test fixture SparkContext
"""
test_input = [
' hello spark ',
' hello again spark spark'
]
input_rdd = spark_context.parallelize(test_input, 1)
results = wordcount.do_word_counts(input_rdd)
expected_results = {'hello':2, 'spark':3, 'again':1}
assert results == expected_results
pyspark has unittest module which can be used as below
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
class MySparkTests(PySparkTestCase):
def spark_session(self):
return pyspark.SQLContext(self.sc)
def createMockDataFrame(self):
self.spark_session().createDataFrame(
[
("t1", "t2"),
("t1", "t2"),
("t1", "t2"),
],
['col1', 'col2']
)