how to get learning rate or iteration times when define new layer in caffe

≯℡__Kan透↙ 提交于 2019-12-22 09:12:24

问题


I want to change the loss calculation method in loss layer when the iteration times reach a certain number.
In order to realize it I think I need to get the current learning rate or iteration times, then I use if phrase to choose changing loss calculation method or not.


回答1:


You can add a member variable in Caffe class to save the current learning rate or iteration times and access it in the layer where you want.

For example, to get the current iteration times where you want you need to make 3 key modifications(for simplification):

  1. In common.hpp:

      class Caffe {
        public:
          static Caffe& Get();
    
          ...//Some other public members
    
          //Returns the current iteration times
          inline static int current_iter() { return Get().cur_iter_; }
          //Sets the current iteration times
          inline static void set_cur_iter(int iter) { Get().cur_iter_ = iter; }
    
        protected:
    
          //The variable to save the current itertion times
          int cur_iter_;
    
          ...//Some other protected members
      }
    
  2. In solver.cpp:

      template <typename Dtype>
      void Solver<Dtype>::Step(int iters) {
    
        ...
    
        while (iter_ < stop_iter) {
          Caffe::set_cur_iter(iter_ );
          ...//Left Operations
        }
      }
    
  3. The place where you want to access the current iteration times:

      template <typename Dtype>
      void SomeLayer<Dtype>::some_func() {
        int current_iter = Caffe::current_iter();
        ...//Operations you want
      }
    



回答2:


AFAIK there is no direct access from within a python layer to the solver's iteration count and the learning rate.
However, you can keep a counter of your own

import caffe

class IterCounterLossLayer(caffe.Layer):
def setup(self, bottom, top):
  # do your setup here...
  self.iter_counter = 0  # setup a counter

def reshape(self, bottom, top):
  # reshape code here...
  # loss output is scalar
  top[0].reshape(1)

def forward(self, bottom, top):
  if self.iter_counter < 1000:
    # some way of computing the loss
    # ...
  else:
    # another way
    # ...
  self.iter_counter += 1  # increment, you may consider incrementing by bottom[0].shape[0] the batch size...

def backward(self, top, propagate_down, bottom):
  if self.iter_counter < 1000:
    # gradients need to fit the loss
    # ...
  else:
    # another way
    # ...



回答3:


To get iteration, you can use my count_layer as a bottom layer of your custom layer, with which you can benefit from following aspects:

  1. When you finetune with weights, the iteration number continues from the weights you save.
  2. Have a modular implementation.
  3. No need to change existing caffe codes.

train_val.prototxt

layer {
  name: "iteration"
  top: "iteration"
  type: "Count"
}

count_layer.hpp

#ifndef CAFFE_COUNT_LAYER_HPP_
#define CAFFE_COUNT_LAYER_HPP_

#include <vector>

#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
class CountLayer : public Layer<Dtype> {
 public:
   explicit CountLayer(const LayerParameter& param)
     : Layer<Dtype>(param), delta_(1) {}
  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {

    if (this->blobs_.size() > 0) {
      LOG(INFO) << "Skipping parameter initialization";
    } else {
      this->blobs_.resize(1);
      this->blobs_[0].reset(new Blob<Dtype>());
      if (this->layer_param_.count_param().has_shape()){
        this->blobs_[0]->Reshape(this->layer_param_.count_param().shape());
      } else{
        this->blobs_[0]->Reshape(vector<int>{1, 1});
      }
      shared_ptr<Filler<Dtype> > base_filler(GetFiller<Dtype>(
        this->layer_param_.count_param().base_filler()));
      base_filler->Fill(this->blobs_[0].get());
    }
    top[0]->Reshape(this->blobs_[0]->shape());

    string name = this->layer_param().name();
    if (name == ""){
      name = "Count";
    }
    if (this->layer_param_.param_size() <= 0){
      LOG(INFO) << "Layer " << name << "'s decay_mult has been set to 0";
      this->layer_param_.add_param()->set_decay_mult(Dtype(0));
    } else if (!this->layer_param_.param(0).has_decay_mult()){
      LOG(INFO) << "Layer " << name << "'s decay_mult has been set to 0";
      this->layer_param_.mutable_param(0)->set_decay_mult(0);
    } 

    delta_ = Dtype(this->layer_param_.count_param().delta());
    //this make top start from base and make finetune correct
    caffe_add_scalar(this->blobs_[0]->count(), -delta_, this->blobs_[0]->mutable_cpu_data()); 
  }
  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) { }
  virtual inline const char* type() const { return "Parameter"; }
  virtual inline int ExactNumBottomBlobs() const { return 0; }
  virtual inline int ExactNumTopBlobs() const { return 1; }

 protected:
  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {
    caffe_add_scalar(this->blobs_[0]->count(), delta_, this->blobs_[0]->mutable_cpu_data());
    top[0]->ShareData(*(this->blobs_[0]));
    top[0]->ShareDiff(*(this->blobs_[0]));
  }
  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom)
  { 
  }

  private:
    Dtype delta_;
};

}  // namespace caffe

#endif

caffe_layer.cpp

#include "caffe/layers/count_layer.hpp"

namespace caffe {

INSTANTIATE_CLASS(CountLayer);
REGISTER_LAYER_CLASS(Count);

}  // namespace caffe

caffe.proto

optional CountParameter count_param = 666;
...
message CountParameter {
  optional BlobShape shape = 1;
  optional FillerParameter base_filler = 2; // The filler for the base
  optional float delta = 3 [default = 1];
}


来源:https://stackoverflow.com/questions/38369565/how-to-get-learning-rate-or-iteration-times-when-define-new-layer-in-caffe

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!