Spark|ML|Random Forest|Load trained model from .txt of RandomForestClassificationModel. toDebugString

こ雲淡風輕ζ 提交于 2019-12-11 05:42:59

问题


Using Spark 1.6 and the ML library I am saving the results of a trained RandomForestClassificationModel using toDebugString():

 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
 val stringModel =rfModel.toDebugString
 //save stringModel into a file in the driver in format .txt 

So my idea is that in the future read the file .txt and load the trained randomForest, is it possible?

thanks!


回答1:


That won't work. ToDebugString is merely a debug info to understand how it's got calculated.

If you want to keep this thing for later use, you can do the same we do, which is (although we are in pure java) simply serialise RandomForestModel object. There might be version incompatibilities with default java serialisation, so we use Hessian to do it. It worked through versions update - we started with spark 1.6.1 and it still works with spark 2.0.2.




回答2:


If you're ok with not sticking to ml, juste use mllib's implementation: the RandomForestModel you get with mllib has a save function.




回答3:


At least for Spark 2.1.0 you can do this with the following Java (sorry - no Scala) code. However, it may not be the smartest idea to rely on an undocumented format that may change without notice.

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.net.URL;
import java.util.*;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.nio.charset.StandardCharsets.US_ASCII;

/**
 * RandomForest.
 */
public abstract class RandomForest {

    private static final Logger LOG = LoggerFactory.getLogger(RandomForest.class);

    protected final List<Node> trees = new ArrayList<>();

    /**
     * @param model model file (format is Spark's RandomForestClassificationModel toDebugString())
     * @throws IOException
     */
    public RandomForest(final URL model) throws IOException {
        try (final BufferedReader reader = new BufferedReader(new InputStreamReader(model.openStream(), US_ASCII))) {
            Node node;
            while ((node = load(reader)) != null) {
                trees.add(node);
            }
        }
        if (trees.isEmpty()) throw new IOException("Failed to read trees from " + model);
        if (LOG.isDebugEnabled()) LOG.debug("Found " + trees.size() + " trees.");
    }

    private static Node load(final BufferedReader reader) throws IOException {
        final Pattern ifPattern = Pattern.compile("If \\(feature (\\d+) (in|not in|<=|>) (.*)\\)");
        final Pattern predictPattern = Pattern.compile("Predict: (\\d+\\.\\d+(E-\\d+)?)");
        Node root = null;
        final List<Node> stack = new ArrayList<>();
        String line;
        while ((line = reader.readLine()) != null) {
            final String trimmed = line.trim();
            //System.out.println(trimmed);
            if (trimmed.startsWith("RandomForest")) {
                // skip the "Tree 1" line
                reader.readLine();
            } else if (trimmed.startsWith("Tree")) {
                break;
            } else if (trimmed.startsWith("If")) {
                // extract feature index
                final Matcher m = ifPattern.matcher(trimmed);
                m.matches();
                final int featureIndex = Integer.parseInt(m.group(1));
                final String operator = m.group(2);
                final String operand = m.group(3);
                final Predicate<Float> predicate;
                if ("<=".equals(operator)) {
                    predicate = new LessOrEqual(Float.parseFloat(operand));
                } else if (">".equals(operator)) {
                    predicate = new Greater(Float.parseFloat(operand));
                } else if ("in".equals(operator)) {
                    predicate = new In(parseFloatArray(operand));
                } else if ("not in".equals(operator)) {
                    predicate = new NotIn(parseFloatArray(operand));
                } else {
                    predicate = null;
                }
                final Node node = new Node(featureIndex, predicate);

                if (stack.isEmpty()) {
                    root = node;
                } else {
                    insert(stack, node);
                }
                stack.add(node);
            } else if (trimmed.startsWith("Predict")) {
                final Matcher m = predictPattern.matcher(trimmed);
                m.matches();
                final Object node = Float.parseFloat(m.group(1));
                insert(stack, node);
            }
        }
        return root;
    }

    private static void insert(final List<Node> stack, final Object node) {
        Node parent = stack.get(stack.size() - 1);
        while (parent.getLeftChild() != null && parent.getRightChild() != null) {
            stack.remove(stack.size() - 1);
            parent = stack.get(stack.size() - 1);
        }
        if (parent.getLeftChild() == null) parent.setLeftChild(node);
        else parent.setRightChild(node);
    }

    private static float[] parseFloatArray(final String set) {
        final StringTokenizer st = new StringTokenizer(set, "{,}");
        final float[] floats = new float[st.countTokens()];
        for (int i=0; st.hasMoreTokens(); i++) {
            floats[i] = Float.parseFloat(st.nextToken());
        }
        return floats;
    }

    public abstract float predict(final float[] features);

    public String toDebugString() {
        try {
            final StringWriter sw = new StringWriter();
            for (int i=0; i<trees.size(); i++) {
                sw.write("Tree " + i + ":\n");
                print(sw, "", trees.get(0));
            }
            return sw.toString();
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private static void print(final Writer w, final String indent, final Object object) throws IOException {
        if (object instanceof Number) {
            w.write(indent + "Predict: " + object + "\n");
        } else if (object instanceof Node) {
            final Node node = (Node) object;
            // left node
            w.write(indent + node + "\n");
            print(w, indent + " ", node.getLeftChild());
            w.write(indent + "Else\n");
            print(w, indent + " ", node.getRightChild());
        }
    }

    @Override
    public String toString() {
        return getClass().getSimpleName() + "{numTrees=" + trees.size() + "}";
    }

    /**
     * Node.
     */
    protected static class Node {

        private final int featureIndex;
        private final Predicate<Float> predicate;
        private Object leftChild;
        private Object rightChild;

        public Node(final int featureIndex, final Predicate<Float> predicate) {
            Objects.requireNonNull(predicate);
            this.featureIndex = featureIndex;
            this.predicate = predicate;
        }

        public void setLeftChild(final Object leftChild) {
            this.leftChild = leftChild;
        }

        public void setRightChild(final Object rightChild) {
            this.rightChild = rightChild;
        }

        public Object getLeftChild() {
            return leftChild;
        }

        public Object getRightChild() {
            return rightChild;
        }

        public Object eval(final float[] features) {
            Object result = this;
            do {
                final Node node = (Node)result;
                result = node.predicate.test(features[node.featureIndex]) ? node.leftChild : node.rightChild;
            } while (result instanceof Node);

            return result;
        }

        @Override
        public String toString() {
            return "If (feature " + featureIndex + " " + predicate + ")";
        }

    }

    private static class LessOrEqual implements Predicate<Float> {
        private final float value;

        public LessOrEqual(final float value) {
            this.value = value;
        }

        @Override
        public boolean test(final Float f) {
            return f <= value;
        }

        @Override
        public String toString() {
            return "<= " + value;
        }
    }

    private static class Greater implements Predicate<Float> {
        private final float value;

        public Greater(final float value) {
            this.value = value;
        }

        @Override
        public boolean test(final Float f) {
            return f > value;
        }

        @Override
        public String toString() {
            return "> " + value;
        }
    }

    private static class In implements Predicate<Float> {
        private final float[] array;

        public In(final float[] array) {
            this.array = array;
        }

        @Override
        public boolean test(final Float f) {
            for (int i=0; i<array.length; i++) {
                if (array[i] == f) return true;
            }
            return false;
        }

        @Override
        public String toString() {
            return "in " + Arrays.toString(array);
        }
    }

    private static class NotIn implements Predicate<Float> {
        private final float[] array;

        public NotIn(final float[] array) {
            this.array = array;
        }

        @Override
        public boolean test(final Float f) {
            for (int i=0; i<array.length; i++) {
                if (array[i] == f) return false;
            }
            return true;
        }

        @Override
        public String toString() {
            return "not in " + Arrays.toString(array);
        }
    }
}

To use the class for classification, use:

import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;

/**
 * RandomForestClassifier.
 */
public class RandomForestClassifier extends RandomForest {

    public RandomForestClassifier(final URL model) throws IOException {
        super(model);
    }

    @Override
    public float predict(final float[] features) {
        final Map<Object, Integer> counts = new HashMap<>();
        trees.stream().map(node -> node.eval(features))
                .forEach(result -> {
                    Integer count = counts.get(result);
                    if (count == null) {
                        counts.put(result, 1);
                    } else {
                        counts.put(result, count + 1);
                    }
                });
        return (Float)counts.entrySet()
                .stream()
                .sorted((o1, o2) -> Integer.compare(o2.getValue(), o1.getValue()))
                .map(Map.Entry::getKey)
                .findFirst().get();
    }
}

For regression:

import java.io.IOException;
import java.net.URL;

/**
 * RandomForestRegressor.
 */
public class RandomForestRegressor extends RandomForest {

    public RandomForestRegressor(final URL model) throws IOException {
        super(model);
    }

    @Override
    public float predict(final float[] features) {
        return (float)trees
                .stream()
                .mapToDouble(node -> ((Number)node.eval(features)).doubleValue())
                .average()
                .getAsDouble();
    }
}


来源:https://stackoverflow.com/questions/41177736/sparkmlrandom-forestload-trained-model-from-txt-of-randomforestclassificatio

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!