How to implement LSTM network with vector input in each time step?

浪子不回头ぞ 提交于 2019-12-10 14:53:03

问题


I am trying to create a generative LSTM network in Tensorflow. I have input vectors like this:

[[0 0 1 0 ... 1 0]
 [0 0 1 0 ... 1 0]
 ...
 [0 0 0 1 ... 0 1]]

Each vector in this matrix is be one time step, or in other words, each vector should be one input to LSTM. Outputs would be the same, except they would be shifted by one time step to right (I am trying to predict next output). Then I have list of these matrices, say five of them - that is one batch. And finally I have list of those batches, which essentially are my training data. So basically I have 4D tensor.

I have tried to do something like this, but obviously it doesn't work and I am not quite sure I understand how would I solve it:

def LSTM(x_, weights, biases):
    cell = tf.contrib.rnn.BasicLSTMCell(RNN_HIDDEN)

    # initial state
    batch_size = tf.shape(x_)[1]
    initial_state = cell.zero_state(batch_size, tf.float32)

    rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell,
                                                x_,
                                                initial_state=initial_state,
                                                time_major=False)
    return tf.matmul(rnn_outputs[-1], weights['out']) + biases['out']

So, how should I represent data so that network would be able to process it?

Also, I am not quite sure how to define loss in this case. My vectors are 88-dimensional, where each index represents one tone. '1' means tone is played, '0' - tone is off. Also, when specific tone is played, and then played again, I mark that with '2' (vectors truncated for brevity):

[0 0 1 0]
[0 0 1 0]
[0 0 2 0]
[0 0 2 0]

If here would be only ones, I would not be able to distinguish if it is one long tone, or two (or three, or four) shorter ones. This way I alternate between 1s and 2s and each alternation means a tone is pressed again.

Do I need to manually calculate cross entropy here?

来源:https://stackoverflow.com/questions/44042272/how-to-implement-lstm-network-with-vector-input-in-each-time-step

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!