I have a classifier that I trained using Python\'s scikit-learn. How can I use the classifier from a Java program? Can I use Jython? Is there some way to save the classifier
Here is some code for the JPMML solution:
--PYTHON PART--
# helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
def determine_categorical_columns(df):
categorical_columns = []
x = 0
for col in df.dtypes:
if col == 'object':
val = df[df.columns[x]].iloc[0]
if not isinstance(val,Decimal):
categorical_columns.append(df.columns[x])
x += 1
return categorical_columns
categorical_columns = determine_categorical_columns(df)
other_columns = list(set(df.columns).difference(categorical_columns))
#construction of transformators for our example
labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
nones = [(d, None) for d in other_columns]
transformators = labelBinarizers+nones
mapper = DataFrameMapper(transformators,df_out=True)
gbc = GradientBoostingClassifier()
#construction of the pipeline
lm = PMMLPipeline([
("mapper", mapper),
("estimator", gbc)
])
--JAVA PART --
//Initialisation.
String pmmlFile = "ScikitLearnNew.pmml";
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
//Determine which features are required as input
HashMap() inputFieldMap = new HashMap();
for (int i = 0; i < evaluator.getInputFields().size();i++) {
InputField curInputField = evaluator.getInputFields().get(i);
String fieldName = curInputField.getName().getValue();
inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
}
//prediction
HashMap argsMap = new HashMap();
//... fill argsMap with input
Map res;
// here we keep only features that are required by the model
Map args = new HashMap();
Iterator iter = argsMap.keySet().iterator();
while (iter.hasNext()) {
String key = iter.next();
Field f = inputFieldMap.get(key);
if (f != null) {
FieldName name =f.getName();
String value = argsMap.get(key);
args.put(name, value);
}
}
//the model is applied to input, a probability distribution is obtained
res = evaluator.evaluate(args);
SegmentResult segmentResult = (SegmentResult) res;
Object targetValue = segmentResult.getTargetValue();
ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;