How to count total number of trainable parameters in a tensorflow model?

后端 未结 7 1195
清酒与你
清酒与你 2020-12-04 15:41

Is there a function call or another way to count the total number of parameters in a tensorflow model?

By parameters I mean: an N dim vector of trainable variables h

7条回答
  •  既然无缘
    2020-12-04 15:43

    I'll throw in my equivalent but shorter implementation:

    def count_params():
        "print number of trainable variables"
        size = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list())
        n = sum(size(v) for v in tf.trainable_variables())
        print "Model size: %dK" % (n/1000,)
    

提交回复
热议问题