Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
775 views
in Technique[技术] by (71.8m points)

apache spark ml - Serialize a custom transformer using python to be used within a Pyspark ML pipeline

I found the same discussion in comments section of Create a custom Transformer in PySpark ML, but there is no clear answer. There is also an unresolved JIRA corresponding to that: https://issues.apache.org/jira/browse/SPARK-17025.

Given that there is no option provided by Pyspark ML pipeline for saving a custom transformer written in python, what are the other options to get it done? How can I implement the _to_java method in my python class that returns a compatible java object?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

As of Spark 2.3.0 there's a much, much better way to do this.

Simply extend DefaultParamsWritable and DefaultParamsReadable and your class will automatically have write and read methods that will save your params and will be used by the PipelineModel serialization system.

The docs were not really clear, and I had to do a bit of source reading to understand this was the way that deserialization worked.

  • PipelineModel.read instantiates a PipelineModelReader
  • PipelineModelReader loads metadata and checks if language is 'Python'. If it's not, then the typical JavaMLReader is used (what most of these answers are designed for)
  • Otherwise, PipelineSharedReadWrite is used, which calls DefaultParamsReader.loadParamsInstance

loadParamsInstance will find class from the saved metadata. It will instantiate that class and call .load(path) on it. You can extend DefaultParamsReader and get the DefaultParamsReader.load method automatically. If you do have specialized deserialization logic you need to implement, I would look at that load method as a starting place.

On the opposite side:

  • PipelineModel.write will check if all stages are Java (implement JavaMLWritable). If so, the typical JavaMLWriter is used (what most of these answers are designed for)
  • Otherwise, PipelineWriter is used, which checks that all stages implement MLWritable and calls PipelineSharedReadWrite.saveImpl
  • PipelineSharedReadWrite.saveImpl will call .write().save(path) on each stage.

You can extend DefaultParamsWriter to get the DefaultParamsWritable.write method that saves metadata for your class and params in the right format. If you have custom serialization logic you need to implement, I would look at that and DefaultParamsWriter as a starting point.

Ok, so finally, you have a pretty simple transformer that extends Params and all your parameters are stored in the typical Params fashion:

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform

class SetValueTransformer(
    Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
    value = Param(
        Params._dummy(),
        "value",
        "value to fill",
    )

    @keyword_only
    def __init__(self, outputCols=None, value=0.0):
        super(SetValueTransformer, self).__init__()
        self._setDefault(value=0.0)
        kwargs = self._input_kwargs
        self._set(**kwargs)

    @keyword_only
    def setParams(self, outputCols=None, value=0.0):
        """
        setParams(self, outputCols=None, value=0.0)
        Sets params for this SetValueTransformer.
        """
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setValue(self, value):
        """
        Sets the value of :py:attr:`value`.
        """
        return self._set(value=value)

    def getValue(self):
        """
        Gets the value of :py:attr:`value` or its default value.
        """
        return self.getOrDefault(self.value)

    def _transform(self, dataset):
        for col in self.getOutputCols():
            dataset = dataset.withColumn(col, lit(self.getValue()))
        return dataset

Now we can use it:

from pyspark.ml import Pipeline, PipelineModel

svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)

p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()

Result:

+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

matches? True
+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

56.8k users

...