How can I call scikit-learn classifiers from Java?

前端 未结 6 1423
悲&欢浪女
悲&欢浪女 2020-12-07 17:00

I have a classifier that I trained using Python\'s scikit-learn. How can I use the classifier from a Java program? Can I use Jython? Is there some way to save the classifier

6条回答
  •  攒了一身酷
    2020-12-07 17:55

    Here is some code for the JPMML solution:

    --PYTHON PART--

    # helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
    def determine_categorical_columns(df):
        categorical_columns = []
        x = 0
        for col in df.dtypes:
            if col == 'object':
                val = df[df.columns[x]].iloc[0]
                if not isinstance(val,Decimal):
                    categorical_columns.append(df.columns[x])
            x += 1
        return categorical_columns
    
    categorical_columns = determine_categorical_columns(df)
    other_columns = list(set(df.columns).difference(categorical_columns))
    
    
    #construction of transformators for our example
    labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
    nones = [(d, None) for d in other_columns]
    transformators = labelBinarizers+nones
    
    mapper = DataFrameMapper(transformators,df_out=True)
    gbc = GradientBoostingClassifier()
    
    #construction of the pipeline
    lm = PMMLPipeline([
        ("mapper", mapper),
        ("estimator", gbc)
    ])
    

    --JAVA PART --

    //Initialisation.
    String pmmlFile = "ScikitLearnNew.pmml";
    PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
    ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
    MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
    
    //Determine which features are required as input
    HashMap() inputFieldMap = new HashMap();
    for (int i = 0; i < evaluator.getInputFields().size();i++) {
      InputField curInputField = evaluator.getInputFields().get(i);
      String fieldName = curInputField.getName().getValue();
      inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
    }
    
    
    //prediction
    
    HashMap argsMap = new HashMap();
    //... fill argsMap with input
    
    Map res;
    // here we keep only features that are required by the model
    Map args = new HashMap();
    Iterator iter = argsMap.keySet().iterator();
    while (iter.hasNext()) {
      String key = iter.next();
      Field f = inputFieldMap.get(key);
      if (f != null) {
        FieldName name =f.getName();
        String value = argsMap.get(key);
        args.put(name, value);
      }
    }
    //the model is applied to input, a probability distribution is obtained
    res = evaluator.evaluate(args);
    SegmentResult segmentResult = (SegmentResult) res;
    Object targetValue = segmentResult.getTargetValue();
    ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;
    

提交回复
热议问题