“freeze” some variables/scopes in tensorflow: stop_gradient vs passing variables to minimize

前端 未结 4 2000
无人共我
无人共我 2020-11-30 18:31

I am trying to implement Adversarial NN, which requires to \'freeze\' one or the other part of the graph during alternating training minibatches. I.e. there two sub-networks

4条回答
  •  孤城傲影
    2020-11-30 19:00

    @mrry's answer is completely right and perhaps more general than what I'm about to suggest. But I think a simpler way to accomplish it is to just pass the python reference directly to var_list:

    W = tf.Variable(...)
    C = tf.Variable(...)
    Y_est = tf.matmul(W,C)
    loss = tf.reduce_sum((data-Y_est)**2)
    optimizer = tf.train.AdamOptimizer(0.001)
    
    # You can pass the python object directly
    train_W = optimizer.minimize(loss, var_list=[W])
    train_C = optimizer.minimize(loss, var_list=[C])
    

    I have a self-contained example here: https://gist.github.com/ahwillia/8cedc710352eb919b684d8848bc2df3a

提交回复
热议问题