How do I unit test PySpark programs?

后端 未结 7 1915
你的背包
你的背包 2020-12-12 17:01

My current Java/Spark Unit Test approach works (detailed here) by instantiating a SparkContext using \"local\" and running unit tests using JUnit.

The code has to be

7条回答
  •  孤街浪徒
    2020-12-12 17:43

    You can test PySpark code by running your code on DataFrames in the test suite and comparing DataFrame column equality or equality of two entire DataFrames.

    The quinn project has several examples.

    Create SparkSession for test suite

    Create a tests/conftest.py file with this fixture, so you can easily access the SparkSession in your tests.

    import pytest
    from pyspark.sql import SparkSession
    
    @pytest.fixture(scope='session')
    def spark():
        return SparkSession.builder \
          .master("local") \
          .appName("chispa") \
          .getOrCreate()
    

    Column equality example

    Suppose you'd like to test the following function that removes all non-word characters from a string.

    def remove_non_word_characters(col):
        return F.regexp_replace(col, "[^\\w\\s]+", "")
    

    You can test this function with the assert_column_equality function that's defined in the chispa library.

    def test_remove_non_word_characters(spark):
        data = [
            ("jo&&se", "jose"),
            ("**li**", "li"),
            ("#::luisa", "luisa"),
            (None, None)
        ]
        df = spark.createDataFrame(data, ["name", "expected_name"])\
            .withColumn("clean_name", remove_non_word_characters(F.col("name")))
        assert_column_equality(df, "clean_name", "expected_name")
    

    DataFrame equality example

    Some functions need to be tested by comparing entire DataFrames. Here's a function that sorts the columns in a DataFrame.

    def sort_columns(df, sort_order):
        sorted_col_names = None
        if sort_order == "asc":
            sorted_col_names = sorted(df.columns)
        elif sort_order == "desc":
            sorted_col_names = sorted(df.columns, reverse=True)
        else:
            raise ValueError("['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format(
                sort_order=sort_order
            ))
        return df.select(*sorted_col_names)
    

    Here's one test you'd write for this function.

    def test_sort_columns_asc(spark):
        source_data = [
            ("jose", "oak", "switch"),
            ("li", "redwood", "xbox"),
            ("luisa", "maple", "ps4"),
        ]
        source_df = spark.createDataFrame(source_data, ["name", "tree", "gaming_system"])
    
        actual_df = T.sort_columns(source_df, "asc")
    
        expected_data = [
            ("switch", "jose", "oak"),
            ("xbox", "li", "redwood"),
            ("ps4", "luisa", "maple"),
        ]
        expected_df = spark.createDataFrame(expected_data, ["gaming_system", "name", "tree"])
    
        assert_df_equality(actual_df, expected_df)
    

    Testing I/O

    It's generally best to abstract code logic from I/O functions, so they're easier to test.

    Suppose you have a function like this.

    def your_big_function:
        df = spark.read.parquet("some_directory")
        df2 = df.withColumn(...).transform(function1).transform(function2)
        df2.write.parquet("other directory")
    

    It's better to refactor the code like this:

    def all_logic(df):
      return df.withColumn(...).transform(function1).transform(function2)
    
    def your_formerly_big_function:
        df = spark.read.parquet("some_directory")
        df2 = df.transform(all_logic)
        df2.write.parquet("other directory")
    

    Designing your code like this lets you easily test the all_logic function with the column equality or DataFrame equality functions mentioned above. You can use mocking to test your_formerly_big_function. It's generally best to avoid I/O in test suites (but sometimes unavoidable).

    This blog post has more details on how to test PySpark code.

提交回复
热议问题