How to count total number of trainable parameters in a tensorflow model?

后端 未结 7 1224
清酒与你
清酒与你 2020-12-04 15:41

Is there a function call or another way to count the total number of parameters in a tensorflow model?

By parameters I mean: an N dim vector of trainable variables h

7条回答
  •  庸人自扰
    2020-12-04 16:09

    If one prefers to avoid numpy (it can be left out for many projects), then:

    all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])
    

    This is a TF translation of the previous answer by Julius Kunze.

    As any TF operation, it requires a session run to evaluate:

    print(sess.run(all_trainable_vars))
    

提交回复
热议问题