Keras TypeError: can't pickle _thread.RLock objects

巧了我就是萌 提交于 2021-02-19 01:48:23

问题


from keras.layers import Embedding, Dense, Input, Dropout, Reshape
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPool2D
from keras.layers import Concatenate, Lambda
from keras.backend import expand_dims
from keras.models import Model
from keras.initializers import constant, random_uniform, TruncatedNormal


class TextCNN(object):
    def __init__(
      self, sequence_length, num_classes, vocab_size,
      embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):

        # input layer
        input_x = Input(shape=(sequence_length, ), dtype='int32')

        # embedding layer
        embedding_layer = Embedding(vocab_size,
                                    embedding_size,
                                    embeddings_initializer=random_uniform(minval=-1.0, maxval=1.0))(input_x)
        embedded_sequences = Lambda(lambda x: expand_dims(embedding_layer, -1))(embedding_layer)

        # Create a convolution + maxpool layer for each filter size
        pooled_outputs = []
        for filter_size in filter_sizes:
            conv = Conv2D(filters=num_filters,
                          kernel_size=[filter_size, embedding_size],
                          strides=1,
                          padding="valid",
                          activation='relu',
                          kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.1),
                          bias_initializer=constant(value=0.1),
                          name=('conv_%d' % filter_size))(embedded_sequences)

            max_pool = MaxPool2D(pool_size=[sequence_length - filter_size + 1, 1],
                                 strides=(1, 1),
                                 padding='valid',
                                 name=('max_pool_%d' % filter_size))(conv)

            pooled_outputs.append(max_pool)

        # combine all the pooled features
        num_filters_total = num_filters * len(filter_sizes)
        h_pool = Concatenate(axis=3)(pooled_outputs)
        h_pool_flat = Reshape([num_filters_total])(h_pool)

        # add dropout
        dropout = Dropout(0.8)(h_pool_flat)

        # output layer
        output = Dense(num_classes,
                       kernel_initializer='glorot_normal',
                       bias_initializer=constant(0.1),
                       activation='softmax',
                       name='scores')(dropout)

        self.model = Model(inputs=input_x, output=output)

# model saver callback
class Saver(Callback):
    def __init__(self, num):
        self.num = num
        self.epoch = 0

    def on_epoch _end(self, epoch, logs={}):
        if self.epoch % self.num == 0:
            name = './model/model.h5'
            self.model.save(name)
        self.epoch += 1


# evaluation callback
class Evaluation(Callback):
    def __init__(self, num):
        self.num = num
        self.epoch = 0

    def on_epoch_end(self, epoch, logs={}):
        if self.epoch % self.num == 0:
            score = model.evaluate(x_train, y_train, verbose=0)
            print('train score:', score[0])
            print('train accuracy:', score[1])
            score = model.evaluate(x_dev, y_dev, verbose=0)
            print('Test score:', score[0])
            print('Test accuracy:', score[1])
        self.epoch += 1


model.fit(x_train, y_train,
          epochs=num_epochs,
          batch_size=batch_size,
          callbacks=[Saver(save_every), Evaluation(evaluate_every)])

Traceback (most recent call last):
  File "D:/Projects/Python Program Design/sentiment-analysis-Keras/train.py", line 107, in <module>
    callbacks=[Saver(save_every), Evaluation(evaluate_every)])
  File "D:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1039, in fit
    validation_steps=validation_steps)
  File "D:\Anaconda3\lib\site-packages\keras\engine\training_arrays.py", line 204, in fit_loop
    callbacks.on_batch_end(batch_index, batch_logs)
  File "D:\Anaconda3\lib\site-packages\keras\callbacks.py", line 115, in on_batch_end
    callback.on_batch_end(batch, logs)
  File "D:/Projects/Python Program Design/sentiment-analysis-Keras/train.py", line 83, in on_batch_end
    self.model.save(name)
  File "D:\Anaconda3\lib\site-packages\keras\engine\network.py", line 1090, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "D:\Anaconda3\lib\site-packages\keras\engine\saving.py", line 382, in save_model
    _serialize_model(model, f, include_optimizer)
  File "D:\Anaconda3\lib\site-packages\keras\engine\saving.py", line 83, in _serialize_model
    model_config['config'] = model.get_config()
  File "D:\Anaconda3\lib\site-packages\keras\engine\network.py", line 931, in get_config
    return copy.deepcopy(config)
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 215, in _deepcopy_list
    append(deepcopy(a, memo))
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 220, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
  File "D:\Anaconda3\lib\copy.py", line 220, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 220, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
  File "D:\Anaconda3\lib\copy.py", line 220, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
  File "D:\Anaconda3\lib\copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "D:\Anaconda3\lib\copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "D:\Anaconda3\lib\copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "D:\Anaconda3\lib\copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "D:\Anaconda3\lib\copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "D:\Anaconda3\lib\copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "D:\Anaconda3\lib\copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "D:\Anaconda3\lib\copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "D:\Anaconda3\lib\copy.py", line 169, in deepcopy
    rv = reductor(4)
TypeError: can't pickle _thread.RLock objects

When I tried to use model.save to save my model, it happened. I have read some questions in StackOverflow or GitHub issues, most people think "This exception is raised mainly because you're trying to serialize an unserializable object. In the context, the "unserializable" object is the tf.tensor.So remember this: Don't let raw tf.tensors wandering in your model."However, I can't find any "raw tf.tensor". I'll appreciate if you could give me some help, thanks!


回答1:


It might be due to this layer:

embedded_sequences = Lambda(lambda x: expand_dims(embedding_layer, -1))(embedding_layer)

You should replace this with

embedded_sequences = Lambda(lambda x: expand_dims(x, -1))(embedding_layer)


来源:https://stackoverflow.com/questions/55280201/keras-typeerror-cant-pickle-thread-rlock-objects

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!