How do I unit test PySpark programs?

后端 未结 7 1898
你的背包
你的背包 2020-12-12 17:01

My current Java/Spark Unit Test approach works (detailed here) by instantiating a SparkContext using \"local\" and running unit tests using JUnit.

The code has to be

相关标签:
7条回答
  • 2020-12-12 17:43

    You can test PySpark code by running your code on DataFrames in the test suite and comparing DataFrame column equality or equality of two entire DataFrames.

    The quinn project has several examples.

    Create SparkSession for test suite

    Create a tests/conftest.py file with this fixture, so you can easily access the SparkSession in your tests.

    import pytest
    from pyspark.sql import SparkSession
    
    @pytest.fixture(scope='session')
    def spark():
        return SparkSession.builder \
          .master("local") \
          .appName("chispa") \
          .getOrCreate()
    

    Column equality example

    Suppose you'd like to test the following function that removes all non-word characters from a string.

    def remove_non_word_characters(col):
        return F.regexp_replace(col, "[^\\w\\s]+", "")
    

    You can test this function with the assert_column_equality function that's defined in the chispa library.

    def test_remove_non_word_characters(spark):
        data = [
            ("jo&&se", "jose"),
            ("**li**", "li"),
            ("#::luisa", "luisa"),
            (None, None)
        ]
        df = spark.createDataFrame(data, ["name", "expected_name"])\
            .withColumn("clean_name", remove_non_word_characters(F.col("name")))
        assert_column_equality(df, "clean_name", "expected_name")
    

    DataFrame equality example

    Some functions need to be tested by comparing entire DataFrames. Here's a function that sorts the columns in a DataFrame.

    def sort_columns(df, sort_order):
        sorted_col_names = None
        if sort_order == "asc":
            sorted_col_names = sorted(df.columns)
        elif sort_order == "desc":
            sorted_col_names = sorted(df.columns, reverse=True)
        else:
            raise ValueError("['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format(
                sort_order=sort_order
            ))
        return df.select(*sorted_col_names)
    

    Here's one test you'd write for this function.

    def test_sort_columns_asc(spark):
        source_data = [
            ("jose", "oak", "switch"),
            ("li", "redwood", "xbox"),
            ("luisa", "maple", "ps4"),
        ]
        source_df = spark.createDataFrame(source_data, ["name", "tree", "gaming_system"])
    
        actual_df = T.sort_columns(source_df, "asc")
    
        expected_data = [
            ("switch", "jose", "oak"),
            ("xbox", "li", "redwood"),
            ("ps4", "luisa", "maple"),
        ]
        expected_df = spark.createDataFrame(expected_data, ["gaming_system", "name", "tree"])
    
        assert_df_equality(actual_df, expected_df)
    

    Testing I/O

    It's generally best to abstract code logic from I/O functions, so they're easier to test.

    Suppose you have a function like this.

    def your_big_function:
        df = spark.read.parquet("some_directory")
        df2 = df.withColumn(...).transform(function1).transform(function2)
        df2.write.parquet("other directory")
    

    It's better to refactor the code like this:

    def all_logic(df):
      return df.withColumn(...).transform(function1).transform(function2)
    
    def your_formerly_big_function:
        df = spark.read.parquet("some_directory")
        df2 = df.transform(all_logic)
        df2.write.parquet("other directory")
    

    Designing your code like this lets you easily test the all_logic function with the column equality or DataFrame equality functions mentioned above. You can use mocking to test your_formerly_big_function. It's generally best to avoid I/O in test suites (but sometimes unavoidable).

    This blog post has more details on how to test PySpark code.

    0 讨论(0)
  • 2020-12-12 17:45

    Assuming you have pyspark installed, you can use the class below for unitTest it in unittest:

    import unittest
    import pyspark
    
    
    class PySparkTestCase(unittest.TestCase):
    
        @classmethod
        def setUpClass(cls):
            conf = pyspark.SparkConf().setMaster("local[2]").setAppName("testing")
            cls.sc = pyspark.SparkContext(conf=conf)
            cls.spark = pyspark.SQLContext(cls.sc)
    
        @classmethod
        def tearDownClass(cls):
            cls.sc.stop()
    

    Example:

    class SimpleTestCase(PySparkTestCase):
    
        def test_with_rdd(self):
            test_input = [
                ' hello spark ',
                ' hello again spark spark'
            ]
    
            input_rdd = self.sc.parallelize(test_input, 1)
    
            from operator import add
    
            results = input_rdd.flatMap(lambda x: x.split()).map(lambda x: (x, 1)).reduceByKey(add).collect()
            self.assertEqual(results, [('hello', 2), ('spark', 3), ('again', 1)])
    
        def test_with_df(self):
            df = self.spark.createDataFrame(data=[[1, 'a'], [2, 'b']], 
                                            schema=['c1', 'c2'])
            self.assertEqual(df.count(), 2)
    

    Note that this creates a context per class. Use setUp instead of setUpClass to get a context per test. This typically adds a lot of overhead time on the execution of the tests, as creating a new spark context is currently expensive.

    0 讨论(0)
  • 2020-12-12 17:49

    Sometime ago I've also faced the same issue and after reading through several articles, forums and some StackOverflow answers I've ended with writing a small plugin for pytest: pytest-spark

    I am already using it for few months and the general workflow looks good on Linux:

    1. Install Apache Spark (setup JVM + unpack Spark's distribution to some directory)
    2. Install "pytest" + plugin "pytest-spark"
    3. Create "pytest.ini" in your project directory and specify Spark location there.
    4. Run your tests by pytest as usual.
    5. Optionally you can use fixture "spark_context" in your tests which is provided by plugin - it tries to minimize Spark's logs in the output.
    0 讨论(0)
  • 2020-12-12 17:59

    I use pytest, which allows test fixtures so you can instantiate a pyspark context and inject it into all of your tests that require it. Something along the lines of

    @pytest.fixture(scope="session",
                    params=[pytest.mark.spark_local('local'),
                            pytest.mark.spark_yarn('yarn')])
    def spark_context(request):
        if request.param == 'local':
            conf = (SparkConf()
                    .setMaster("local[2]")
                    .setAppName("pytest-pyspark-local-testing")
                    )
        elif request.param == 'yarn':
            conf = (SparkConf()
                    .setMaster("yarn-client")
                    .setAppName("pytest-pyspark-yarn-testing")
                    .set("spark.executor.memory", "1g")
                    .set("spark.executor.instances", 2)
                    )
        request.addfinalizer(lambda: sc.stop())
    
        sc = SparkContext(conf=conf)
        return sc
    
    def my_test_that_requires_sc(spark_context):
        assert spark_context.textFile('/path/to/a/file').count() == 10
    

    Then you can run the tests in local mode by calling py.test -m spark_local or in YARN with py.test -m spark_yarn. This has worked pretty well for me.

    0 讨论(0)
  • 2020-12-12 18:00

    I'd recommend using py.test as well. py.test makes it easy to create re-usable SparkContext test fixtures and use it to write concise test functions. You can also specialize fixtures (to create a StreamingContext for example) and use one or more of them in your tests.

    I wrote a blog post on Medium on this topic:

    https://engblog.nextdoor.com/unit-testing-apache-spark-with-py-test-3b8970dc013b

    Here is a snippet from the post:

    pytestmark = pytest.mark.usefixtures("spark_context")
    def test_do_word_counts(spark_context):
        """ test word couting
        Args:
           spark_context: test fixture SparkContext
        """
        test_input = [
            ' hello spark ',
            ' hello again spark spark'
        ]
    
        input_rdd = spark_context.parallelize(test_input, 1)
        results = wordcount.do_word_counts(input_rdd)
    
        expected_results = {'hello':2, 'spark':3, 'again':1}  
        assert results == expected_results
    
    0 讨论(0)
  • 2020-12-12 18:00

    pyspark has unittest module which can be used as below

    from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
    
    class MySparkTests(PySparkTestCase):
        def spark_session(self):
            return pyspark.SQLContext(self.sc)
    
        def createMockDataFrame(self):
             self.spark_session().createDataFrame(
                [
                    ("t1", "t2"),
                    ("t1", "t2"),
                    ("t1", "t2"),
                ],
                ['col1', 'col2']
            )
    
    0 讨论(0)
提交回复
热议问题