Get the value of some weights in a model trained by TensorFlow

前端 未结 3 487
温柔的废话
温柔的废话 2020-11-27 12:12

I have trained a ConvNet model with TensorFlow, and I want to get a particular weight in layer. For example in torch7 I would simply access model.modules[2].weights

3条回答
  •  伪装坚强ぢ
    2020-11-27 12:51

    In TensorFlow, trained weights are represented by tf.Variable objects. If you created a tf.Variable—e.g. called v—yourself, you can get its value as a NumPy array by calling sess.run(v) (where sess is a tf.Session).

    If you do not currently have a pointer to the tf.Variable, you can get a list of the trainable variables in the current graph by calling tf.trainable_variables(). This function returns a list of all trainable tf.Variable objects in the current graph, and you can select the one that you want by matching the v.name property. For example:

    # Desired variable is called "tower_2/filter:0".
    var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]
    

提交回复
热议问题