Serialize a custom transformer using python to be used within a Pyspark ML pipeline

前端 未结 5 566
醉梦人生
醉梦人生 2020-12-01 07:04

I found the same discussion in comments section of Create a custom Transformer in PySpark ML, but there is no clear answer. There is also an unresolved JIRA corresponding to

5条回答
  •  孤街浪徒
    2020-12-01 07:18

    As of Spark 2.3.0 there's a much, much better way to do this.

    Simply extend DefaultParamsWritable and DefaultParamsReadable and your class will automatically have write and read methods that will save your params and will be used by the PipelineModel serialization system.

    The docs were not really clear, and I had to do a bit of source reading to understand this was the way that deserialization worked.

    • PipelineModel.read instantiates a PipelineModelReader
    • PipelineModelReader loads metadata and checks if language is 'Python'. If it's not, then the typical JavaMLReader is used (what most of these answers are designed for)
    • Otherwise, PipelineSharedReadWrite is used, which calls DefaultParamsReader.loadParamsInstance

    loadParamsInstance will find class from the saved metadata. It will instantiate that class and call .load(path) on it. You can extend DefaultParamsReader and get the DefaultParamsReader.load method automatically. If you do have specialized deserialization logic you need to implement, I would look at that load method as a starting place.

    On the opposite side:

    • PipelineModel.write will check if all stages are Java (implement JavaMLWritable). If so, the typical JavaMLWriter is used (what most of these answers are designed for)
    • Otherwise, PipelineWriter is used, which checks that all stages implement MLWritable and calls PipelineSharedReadWrite.saveImpl
    • PipelineSharedReadWrite.saveImpl will call .write().save(path) on each stage.

    You can extend DefaultParamsWriter to get the DefaultParamsWritable.write method that saves metadata for your class and params in the right format. If you have custom serialization logic you need to implement, I would look at that and DefaultParamsWriter as a starting point.

    Ok, so finally, you have a pretty simple transformer that extends Params and all your parameters are stored in the typical Params fashion:

    from pyspark import keyword_only
    from pyspark.ml import Transformer
    from pyspark.ml.param.shared import HasOutputCols, Param, Params
    from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
    from pyspark.sql.functions import lit # for the dummy _transform
    
    class SetValueTransformer(
        Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
    ):
        value = Param(
            Params._dummy(),
            "value",
            "value to fill",
        )
    
        @keyword_only
        def __init__(self, outputCols=None, value=0.0):
            super(SetValueTransformer, self).__init__()
            self._setDefault(value=0.0)
            kwargs = self._input_kwargs
            self._set(**kwargs)
    
        @keyword_only
        def setParams(self, outputCols=None, value=0.0):
            """
            setParams(self, outputCols=None, value=0.0)
            Sets params for this SetValueTransformer.
            """
            kwargs = self._input_kwargs
            return self._set(**kwargs)
    
        def setValue(self, value):
            """
            Sets the value of :py:attr:`value`.
            """
            return self._set(value=value)
    
        def getValue(self):
            """
            Gets the value of :py:attr:`value` or its default value.
            """
            return self.getOrDefault(self.value)
    
        def _transform(self, dataset):
            for col in self.getOutputCols():
                dataset = dataset.withColumn(col, lit(self.getValue()))
            return dataset
    

    Now we can use it:

    from pyspark.ml import Pipeline
    
    svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)
    
    p = Pipeline(stages=[svt])
    df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
    pm = p.fit(df)
    pm.transform(df).show()
    pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
    pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
    print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
    pm2.transform(df).show()
    

    Result:

    +---+-----+-----+-----+
    |key|value|    a|    b|
    +---+-----+-----+-----+
    |  1| null|123.0|123.0|
    |  2|  1.0|123.0|123.0|
    |  3|  0.5|123.0|123.0|
    +---+-----+-----+-----+
    
    matches? True
    +---+-----+-----+-----+
    |key|value|    a|    b|
    +---+-----+-----+-----+
    |  1| null|123.0|123.0|
    |  2|  1.0|123.0|123.0|
    |  3|  0.5|123.0|123.0|
    +---+-----+-----+-----+
    

提交回复
热议问题