Sklearn SGDClassifier partial fit

六月ゝ 毕业季﹏ 提交于 2019-11-28 14:53:27

问题


I'm trying to use SGD to classify a large dataset. As the data is too large to fit into memory, I'd like to use the partial_fit method to train the classifier. I have selected a sample of the dataset (100,000 rows) that fits into memory to test fit vs. partial_fit:

from sklearn.linear_model import SGDClassifier

def batches(l, n):
    for i in xrange(0, len(l), n):
        yield l[i:i+n]

clf1 = SGDClassifier(shuffle=True, loss='log')
clf1.fit(X, Y)

clf2 = SGDClassifier(shuffle=True, loss='log')
n_iter = 60
for n in range(n_iter):
    for batch in batches(range(len(X)), 10000):
        clf2.partial_fit(X[batch[0]:batch[-1]+1], Y[batch[0]:batch[-1]+1], classes=numpy.unique(Y))

I then test both classifiers with an identical test set. In the first case I get an accuracy of 100%. As I understand it, SGD by default passes 5 times over the training data (n_iter = 5).

In the second case, I have to pass 60 times over the data to reach the same accuracy.

Why this difference (5 vs. 60)? Or am I doing something wrong?


回答1:


I have finally found the answer. You need to shuffle the training data between each iteration, as setting shuffle=True when instantiating the model will NOT shuffle the data when using partial_fit (it only applies to fit). Note: it would have been helpful to find this information on the sklearn.linear_model.SGDClassifier page.

The amended code reads as follows:

from sklearn.linear_model import SGDClassifier
import random
clf2 = SGDClassifier(loss='log') # shuffle=True is useless here
shuffledRange = range(len(X))
n_iter = 5
for n in range(n_iter):
    random.shuffle(shuffledRange)
    shuffledX = [X[i] for i in shuffledRange]
    shuffledY = [Y[i] for i in shuffledRange]
    for batch in batches(range(len(shuffledX)), 10000):
        clf2.partial_fit(shuffledX[batch[0]:batch[-1]+1], shuffledY[batch[0]:batch[-1]+1], classes=numpy.unique(Y))


来源:https://stackoverflow.com/questions/24617356/sklearn-sgdclassifier-partial-fit

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!