Why can't I load a PySpark RandomForestClassifier model?

痞子三分冷 提交于 2019-12-22 10:50:04

问题


I can't load a RandomForestClassificationModel saved by Spark.

Environment: Apache Spark 2.0.1, standalone mode running on a small (4 machine) cluster. No HDFS - everything is saved to local disks.

Build and save model:

classifier = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=50)
model = classifier.fit(train)
result = model.transform(test)
model.write().save("/tmp/models/20161030-RF-topics-cats.model")

Later, in a separate program:

model = RandomForestClassificationModel.load("/tmp/models/20161029-RF-topics-cats.model")

gives:

Py4JJavaError: An error occurred while calling o81.load.
: org.apache.spark.sql.AnalysisException: Unable to infer schema for ParquetFormat at /tmp/models/20161029-RF-topics-cats.model/treesMetadata. It must be specified manually;
    at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$16.apply(DataSource.scala:411)
    at org.apache.spark.sql.execution.datasources.DataSource$$anonfun$16.apply(DataSource.scala:411)
    at scala.Option.getOrElse(Option.scala:121)
    at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:410)
    at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:149)
    at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:439)
    at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:423)
    at org.apache.spark.ml.tree.EnsembleModelReadWrite$.loadImpl(treeModels.scala:441)
    at org.apache.spark.ml.classification.RandomForestClassificationModel$RandomForestClassificationModelReader.load(RandomForestClassifier.scala:301

I'd note that the same code works when I use a Naive Bayes classifier.


回答1:


Saving the model to HDFS, and later reading the model from HDFS might solve your problem.

You have 4 nodes, each node has its own local-disk. You are using model.write().save("/temp/xxx")

Later, in a separate program: You are using load("/temp/xxx")

Since there are 4 nodes, with 4 different local disks, it isn't clear to me what exactly is being saved (and to which local-disk) during the write.save() operation, and what exactly is being load() and from which local-disk.



来源:https://stackoverflow.com/questions/40327379/why-cant-i-load-a-pyspark-randomforestclassifier-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!