How to compare if two sklearn estimators are equals?

社会主义新天地 提交于 2021-01-27 06:35:31

问题


I have two sklearn estimators and want to compare them:

import numpy as np
from sklearn.tree import DecisionTreeClassifier

X, y = np.random.random((100,2)), np.random.choice(2,100)    
dt1 = DecisionTreeClassifier()
dt1.fit(X, y)
dt2 = DecisionTreeClassifier()
dt3 = sklearn.base.copy.deepcopy(dt1)

How can I compare classifiers so that dt1 != dt2, dt1 == dt3?


回答1:


You will want to compare the params assigned to the classifier instance and the .tree_.value of the trained classifiers:

# the trees have the same params
def compare_trees(tree1, tree2):
    if hash(tree1.__dict__.values())==hash(tree2.__dict__.values()):
        # the trees have both been trained
        if tree1.tree_ != None and tree2.tree_ != None: 
            try: # the tree values are matching arrays
                return (tree1.tree_.value==tree2.tree_.value).all()
            except: # they do not match
                return False
        elif tree1.tree_ != None or tree2.tree_ != None: 
            # XOR of the trees is not trained
            return False
        else: # Neither has been trained
            return True
    else: # the params are different
        return False


dt1 = DecisionTreeClassifier()
X, y = np.random.random((100,2)), np.random.choice(2,100)
dt1.fit(X, y)

dt2 = DecisionTreeClassifier() # untrained

dt3 = sklearn.base.copy.deepcopy(dt1) # copy of 1st

dt4 = DecisionTreeClassifier() # trained on different data
X_, y_ = np.random.random((100,2)), np.random.choice(2,100)   
dt4.fit(X_, y_)

print(compare_trees(dt1, dt1)) # True
print(compare_trees(dt1, dt2)) # False
print(compare_trees(dt1, dt3)) # True
print(compare_trees(dt1, dt4)) # False


来源:https://stackoverflow.com/questions/38412526/how-to-compare-if-two-sklearn-estimators-are-equals

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!