How to set adaptive learning rate for GradientDescentOptimizer?

前端 未结 5 1118
醉酒成梦
醉酒成梦 2020-11-28 01:08

I am using TensorFlow to train a neural network. This is how I am initializing the GradientDescentOptimizer:

init = tf.initialize_all_variables(         


        
5条回答
  •  南笙
    南笙 (楼主)
    2020-11-28 01:35

    If you want to set specific learning rates for intervals of epochs like 0 < a < b < c < .... Then you can define your learning rate as a conditional tensor, conditional on the global step, and feed this as normal to the optimiser.

    You could achieve this with a bunch of nested tf.cond statements, but its easier to build the tensor recursively:

    def make_learning_rate_tensor(reduction_steps, learning_rates, global_step):
        assert len(reduction_steps) + 1 == len(learning_rates)
        if len(reduction_steps) == 1:
            return tf.cond(
                global_step < reduction_steps[0],
                lambda: learning_rates[0],
                lambda: learning_rates[1]
            )
        else:
            return tf.cond(
                global_step < reduction_steps[0],
                lambda: learning_rates[0],
                lambda: make_learning_rate_tensor(
                    reduction_steps[1:],
                    learning_rates[1:],
                    global_step,)
                )
    

    Then to use it you need to know how many training steps there are in a single epoch, so that we can use the global step to switch at the right time, and finally define the epochs and learning rates you want. So if I want the learning rates [0.1, 0.01, 0.001, 0.0001] during the epoch intervals of [0, 19], [20, 59], [60, 99], [100, \infty] respectively, I would do:

    global_step = tf.train.get_or_create_global_step()
    learning_rates = [0.1, 0.01, 0.001, 0.0001]
    steps_per_epoch = 225
    epochs_to_switch_at = [20, 60, 100]
    epochs_to_switch_at = [x*steps_per_epoch for x in epochs_to_switch_at ]
    learning_rate = make_learning_rate_tensor(epochs_to_switch_at , learning_rates, global_step)
    

提交回复
热议问题