Export python scikit learn models into pmml

后端 未结 3 750
礼貌的吻别
礼貌的吻别 2021-01-30 23:24

I want to export python scikit-learn models into PMML.

What python package is best suited?

I read about Augustus, but I was not able to find any example using

3条回答
  •  半阙折子戏
    2021-01-31 00:04

    SkLearn2PMML is

    a thin wrapper around the JPMML-SkLearn command-line application. For a list of supported Scikit-Learn Estimator and Transformer types, please refer to the documentation of the JPMML-SkLearn project.

    As @user1808924 notes, it supports Python 2.7 or 3.4+. It also requires Java 1.7+

    Installed via: (requires git)

    pip install git+https://github.com/jpmml/sklearn2pmml.git
    

    Example of how export a classifier tree to PMML. First grow the tree:

    # example tree & viz from http://scikit-learn.org/stable/modules/tree.html
    from sklearn import datasets, tree
    iris = datasets.load_iris()
    clf = tree.DecisionTreeClassifier() 
    clf = clf.fit(iris.data, iris.target)
    

    There are two parts to an SkLearn2PMML conversion, an estimator (our clf) and a mapper (for preprocessing steps such as discretization or PCA). Our mapper is pretty basic, since we are not doing any transformations.

    from sklearn_pandas import DataFrameMapper
    default_mapper = DataFrameMapper([(i, None) for i in iris.feature_names + ['Species']])
    
    from sklearn2pmml import sklearn2pmml
    sklearn2pmml(estimator=clf, 
                 mapper=default_mapper, 
                 pmml="D:/workspace/IrisClassificationTree.pmml")
    

    It is possible (though not documented) to pass mapper=None, but you will see that the predictor names get lost (returning x1 not sepal length etc.).

    Let's look at the .pmml file:

    
    
        
    2016-09-26T19:21:43Z

    The first split (Node 1) is on petal width at 0.8. Node 2 (petal width <= 0.8) captures all of the setosa, with nothing else.

    You can compare the pmml output to the graphviz output:

    from sklearn.externals.six import StringIO
    import pydotplus # this might be pydot for python 2.7
    dot_data = StringIO() 
    tree.export_graphviz(clf, 
                         out_file=dot_data,  
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True) 
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("D:/workspace/iris.pdf") 
    # for in-line display, you can also do:
    # from IPython.display import Image  
    # Image(graph.create_png())  
    

提交回复
热议问题