How to average summaries over multiple batches?

后端 未结 9 1410
刺人心
刺人心 2020-12-13 09:18

Assuming I have a bunch of summaries defined like:

loss = ...
tf.scalar_summary(\"loss\", loss)
# ...
summaries = tf.m         


        
9条回答
  •  一整个雨季
    2020-12-13 10:06

    For future reference, the TensorFlow metrics API now supports this by default. For example, take a look at tf.mean_squared_error:

    For estimation of the metric over a stream of data, the function creates an update_op operation that updates these variables and returns the mean_squared_error. Internally, a squared_error operation computes the element-wise square of the difference between predictions and labels. Then update_op increments total with the reduced sum of the product of weights and squared_error, and it increments count with the reduced sum of weights.

    These total and count variables are added to the set of metric variables, so in practice what you would do is something like:

    x_batch = tf.placeholder(...)
    y_batch = tf.placeholder(...)
    model_output = ...
    mse, mse_update = tf.metrics.mean_squared_error(y_batch, model_output)
    # This operation resets the metric internal variables to zero
    metrics_init = tf.variables_initializer(
        tf.get_default_graph().get_collection(tf.GraphKeys.METRIC_VARIABLES))
    with tf.Session() as sess:
        # Train...
        # On evaluation step
        sess.run(metrics_init)
        for x_eval_batch, y_eval_batch in ...:
            mse = sess.run(mse_update, feed_dict={x_batch: x_eval_batch, y_batch: y_eval_batch})
        print('Evaluation MSE:', mse)
    

提交回复
热议问题