Tensorflow save final state of LSTM in dynamic_rnn for prediction

 ̄綄美尐妖づ 提交于 2019-12-05 18:35:06

In order to save the final state, you can create a separate TF variable, then before saving the graph, run an assign op to assign your latest state to that variable, and then save the graph. The only thing you need to keep in mind is to declare that variable BEFORE you declare the Saver; otherwise it won't be included in the graph.

This is discussed at great detail here, including the working code: TF LSTM: Save State from training session for prediction session later

*** UPDATE: answers to followup questions:

It looks like you are using BasicLSTMCell, with state_is_tuple=True. The prior discussion that I referred you to used GRUCell with state_is_tuple=False. The details between the two are somewhat different, but the overall approach could be similar, so hopefully this should work for you:

During training, you first feed zeros as initial_state into dynamic_rnn and then keep re-feeding its own output back as input as initial_state. So, the LAST output state of our dynamic_rnn call is what you want to save for later. Since it results from a sess.run() call, essentially it's a numpy array (not a tensor and not a placeholder). So the question amounts to "how do I save a numpy array as a Tensorflow variable along with the rest of the variables in the graph." That's why you assign the final state to a variable whose only purpose is that.

So, code is something like this:

    # GRAPH DEFINITIONS:
    state_in = tf.placeholder(tf.float32, [LAYERS, 2, None, CELL_SIZE], name='state_in')
    l = tf.unstack(state_in, axis=0)
    state_tup = tuple(
        [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
        for idx in range(NLAYERS)])
    #multicell = your BasicLSTMCell / MultiRNN definitions
    output, state_out = tf.nn.dynamic_rnn(multicell, X, dtype=tf.float32, initial_state=state_tup)

    savedState = tf.get_variable('savedState', shape=[LAYERS, 2, BATCHSIZE, CELL_SIZE])
    saver = tf.train.Saver(max_to_keep=1)

    in_state = np.zeros((LAYERS, 2, BATCHSIZE, CELL_SIZE))

    # TRAINING LOOP:
    feed_dict = {X: x, Y_: y_, batchsize: BATCHSIZE, state_in:in_state}
    _, out_state = sess.run([training_step, state_out], feed_dict=feed_dict)
    in_state = out_state

    # ONCE TRAINING IS OVER:
    assignOp = tf.assign(savedState, out_state)
    sess.run(assignOp)
    saver.save(sess, pathModel + '/my_model.ckpt')

    # RECOVERING IN A DIFFERENT PROGRAM:

    gInit = tf.global_variables_initializer().run()
    lInit = tf.local_variables_initializer().run()
    new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
    new_saver.restore(sess, pathModel + 'my_model.ckpt')
    # retrieve State and get its LAST batch (latest obervarions)
    savedState = sess.run('savedState:0') # this is FULL state from training
    state = savedState[:,:,-1,:]  # -1 gets only the LAST batch of the state (latest seen observations)
    state = np.reshape(state, [state.shape[0], 2, -1, state.shape[2]]) #[LAYERS, 2, 1 (BATCH), SELL_SIZE]
    #x = .... (YOUR INPUTS)
    feed_dict = {'X:0': x, 'state_in:0':state}
    #PREDICTION LOOP:
    preds, state = sess.run(['preds:0', 'state_out:0'], feed_dict = feed_dict)
    # so now state will be re-fed into feed_dict with the next loop iteration

As mentioned, this is a modified approach of what works well for me with GRUCell, where state_is_tuple = False. I adapted it to try BasicLSTMCell with state_is_tuple=True. It works, but not as accurately as the original approach. I don't know yet whether its just because for me GRU is better than LSTM or for some other reason. See if this works for you...

Also keep in mind that, as you can see with the recovery and prediction code, your predictions will likely be based on a different batch size than your training loop (I guess batch of 1?) So you have to think through how to handle your recovered state -- just take the last batch? Or something else? This code takes the last layer of the saved state only (i.e. the most recent observations from training) because that's what was relevant for me...

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!