Using pretrained glove word embedding with scikit-learn

懵懂的女人 提交于 2020-07-19 04:49:25

问题


I have used keras to use pre-trained word embeddings but I am not quite sure how to do it on scikit-learn model.

I need to do this in sklearn as well because I am using vecstack to ensemble both keras sequential model and sklearn model.

This is what I have done for keras model:

glove_dir = '/home/Documents/Glove'
embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.200d.txt'), 'r', encoding='utf-8')
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

embedding_dim = 200


embedding_matrix = np.zeros((max_words, embedding_dim))
for word, i in word_index.items():
    if i < max_words:
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector

model = Sequential()
model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
.
.
model.layers[0].set_weights([embedding_matrix])
model.layers[0].trainable = False
model.compile(----)
model.fit(-----)

I am very new to scikit-learn, from what I have seen to make an model in sklearn you do:

lr = LogisticRegression()
lr.fit(X_train, y_train)
lr.predict(x_test)

So, my question is how do I use pre-trained Glove with this model? where do I pass the pre-trained glove embedding_matrix

Thank you very much and I really appreciate your help.


回答1:


You can simply use the Zeugma library.

You can install it with pip install zeugma, then create and train your model with the following lines of code (assuming corpus_train and corpus_test are lists of strings):

from sklearn.linear_model import LogisticRegresion
from zeugma.embeddings import EmbeddingTransformer

glove = EmbeddingTransformer('glove')
x_train = glove.transform(corpus_train)

model = LogisticRegression()
model.fit(x_train, y_train)

x_test = glove.transform(corpus_test)
model.predict(x_test)

You can also use different pre-trained embeddings (complete list here) or train your own (see Zeugma's documentation for how to do this).



来源:https://stackoverflow.com/questions/55198750/using-pretrained-glove-word-embedding-with-scikit-learn

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!