How to seed the random number generator for scikit-learn?

六眼飞鱼酱① 提交于 2019-12-01 09:25:55

问题


I'm trying to write a unit test for some of my code that uses scikit-learn. However, my unit tests seem to be non-deterministic.

AFAIK, the only places in my code where scikit-learn uses any randomness are in its LogisticRegression model and its train_test_split, so I have the following:

RANDOM_SEED = 5
self.lr = LogisticRegression(random_state=RANDOM_SEED)
X_train, X_test, y_train, test_labels = train_test_split(docs, labels, test_size=TEST_SET_PROPORTION, random_state=RANDOM_SEED)

But this doesn't seem to work -- even when I pass a fixed docs and a fixed labels, the prediction probabilities on a fixed validation set vary from run to run.

I also tried adding a numpy.random.seed(RANDOM_SEED) call at the top of my code, but that didn't seem to work either.

Is there anything I'm missing? Is there a way to pass a seed to scikit-learn in a single place, so that seed is used throughout all of scikit-learn's invocations?


回答1:


from sklearn import datasets, linear_model
iris = datasets.load_iris()
(X, y) = iris.data, iris.target
RANDOM_SEED = 5
lr = linear_model.LogisticRegression(random_state=RANDOM_SEED)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=RANDOM_SEED)
lr.fit(X_train, y_train)
lr.score(X_test, y_test)

produced 0.93333333333333335 several times now. The way you did it seems ok. Another way is to set np.random.seed() or use Sacred for documented randomness. Using random_state is what the docs describe:

If your code relies on a random number generator, it should never use functions like numpy.random.random or numpy.random.normal. This approach can lead to repeatability issues in unit tests. Instead, a numpy.random.RandomState object should be used, which is built from a random_state argument passed to the class or function.



来源:https://stackoverflow.com/questions/40750394/how-to-seed-the-random-number-generator-for-scikit-learn

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!