深度学习 neural machine translation with attention 错误解析

匿名 (未验证) 提交于 2019-12-03 00:20:01

在这次的 练习中,在 load 过模型参数后,进行 example预测时,报错。


以下是代码部分

EXAMPLES = ['3 May 1979', '5 April 09', '21th of August 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'March 3 2001', 'March 3rd 2001', '1 March 2001'] for example in EXAMPLES:       source = string_to_int(example, Tx, human_vocab)     source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)     prediction = model.predict([source, s0, c0])     prediction = np.argmax(prediction, axis = -1)     output = [inv_machine_vocab[int(i)] for i in prediction]       print("source:", example)     print("output:", ''.join(output))

以下是输出错误:

ValueError                                Traceback (most recent call last) <ipython-input-31-5f0a9dfb7249> in <module>()       4     source = string_to_int(example, Tx, human_vocab)       5     source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1) ----> 6     prediction = model.predict([source, s0, c0])       7     prediction = np.argmax(prediction, axis = -1)       8     output = [inv_machine_vocab[int(i)] for i in prediction]  E:\Python\lib\site-packages\keras\engine\training.py in predict(self, x, batch_size, verbose, steps)    1815         x = _standardize_input_data(x, self._feed_input_names,    1816                                     self._feed_input_shapes, -> 1817                                     check_batch_axis=False)    1818         if self.stateful:    1819             if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:  E:\Python\lib\site-packages\keras\engine\training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)     111                         ': expected ' + names[i] + ' to have ' +     112                         str(len(shape)) + ' dimensions, but got array ' --> 113                         'with shape ' + str(data_shape))     114                 if not check_batch_axis:     115                     data_shape = data_shape[1:]  ValueError: Error when checking : expected input_1 to have 3 dimensions, but got array with shape (37, 30)
shape (37, 30) 这也不对,应该是 shape(30,37),所以修改后如下(看红色部分)

EXAMPLES = ['3 May 1979', '5 April 09', '21th of August 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'March 3 2001', 'March 3rd 2001', '1 March 2001'] for example in EXAMPLES:          source = string_to_int(example, Tx, human_vocab)     # source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))).swapaxes(0,1)     # prediction = model.predict([source, s0, c0])     source = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), source))) #不能变换  数据维度 ,       ttt=np.expand_dims(source,axis=0) # 在 axis=0的位置 ,增加一个 维度,以适应 输入维度要求     prediction = model.predict([ttt, s0, c0])     prediction = np.argmax(prediction, axis = -1)     output = [inv_machine_vocab[int(i)] for i in prediction]          print("source:", example)     print("output:", ''.join(output))

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!