Getting labels from StringIndexer stages within pipeline in Spark (pyspark)

匿名 (未验证) 提交于 2019-12-03 02:03:01

问题:

I am using Spark and pyspark and I have a pipeline set up with a bunch of StringIndexer objects, that I use to encode the string columns to columns of indices:

indexers = [StringIndexer(inputCol=column, outputCol=column + '_index').setHandleInvalid('skip')             for column in list(set(data_frame.columns) - ignore_columns)] pipeline = Pipeline(stages=indexers) new_data_frame = pipeline.fit(data_frame).transform(data_frame) 

The problem is, that I need to get the list of labels for each StringIndexer object after it gets fitted. For a single column and a single StringIndexer without a pipeline, it's an easy task. I can just access the labels attribute after fitting the indexer on the DataFrame:

indexer = StringIndexer(inputCol="name", outputCol="name_index") indexer_fitted = indexer.fit(data_frame) labels = indexer_fitted.labels new_data_frame = indexer_fitted.transform(data_frame) 

However when I use the pipeline, this doesn't seem possible, or at least I don't know how to do this.

So I guess my question comes down to: Is there a way to access the labels that were used during the indexing process for each individual column?

Or will I have to ditch the pipeline in this use-case, and for example loop through the list of StringIndexer objects and do it manually? (I'm sure that would possible. However using the pipeline would just be a lot nicer)

回答1:

Example data and Pipeline:

from pyspark.ml.feature import StringIndexer, StringIndexerModel  df = spark.createDataFrame([("a", "foo"), ("b", "bar")], ("x1", "x2"))  pipeline = Pipeline(stages=[     StringIndexer(inputCol=c, outputCol='{}_index'.format(c))     for c in df.columns ])  model = pipeline.fit(df) 

Extract from stages:

# Accessing _java_obj shouldn't be necessary in Spark 2.3+ {x._java_obj.getOutputCol(): x.labels  for x in model.stages if isinstance(x, StringIndexerModel)} 
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']} 

From metadata of the transformed DataFrame:

indexed = model.transform(df)  {c.name: c.metadata["ml_attr"]["vals"] for c in indexed.schema.fields if c.name.endswith("_index")} 
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']} 


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!