How to window or reset streaming operations in tensorflow?

后端 未结 2 429
你的背包
你的背包 2021-01-26 18:04

Tensorflow provides all sorts of nice streaming operations to aggregate statistics along batches, such as tf.metrics.mean.

However I find that accumulating

2条回答
  •  南方客
    南方客 (楼主)
    2021-01-26 18:18

    One way to do it is to call the initializer of the relevant variables in the streaming op. For example,

    import tensorflow as tf
    
    x = tf.random_normal(())
    mean_x, update_op = tf.metrics.mean(x, name='mean_x')
    # get the initializers of the local variables (total and count)
    my_metric_variables = [v for v in tf.local_variables() if v.name.startswith('mean_x/')]
    # or maybe just
    # my_metric_variables = tf.get_collection('metric_variables')
    reset_ops = [v.initializer for v in my_metric_variables]
    
    with tf.Session() as sess:
      tf.local_variables_initializer().run()
      for _ in range(100):
        for _ in range(100):
          sess.run(update_op)
        print(sess.run(mean_x))
        # if you comment the following out, the estimate of the mean converges to 0
        sess.run(reset_ops)
    

提交回复
热议问题