Implement transfer learning on niftynet

拜拜、爱过 提交于 2020-02-02 13:10:14

问题


I want to implement transfer learning using the Dense V-Net architecture. As I was searching on how to do this, I found that this feature is currently being worked on (How do I implement transfer learning in NiftyNet?).

Although from that answer it is quite clear that there is not a straight way to implement it, I was trying to:

1) Create the Dense V-Net

2) Restore weigths from the .ckpt file

3) Implement transfer learning on my own

To perform step 1, I thought I could use the niftynet.network.dense_vnet module. Therefore, I tried the following:

checkpoint = '/path_to_ckpt/model.ckpt-3000.index'
x = tf.placeholder(dtype=tf.float32, shape=[None,1,144,144,144])
architecture_parameters = dict(
    use_bdo=False,
    use_prior=False,
    use_dense_connections=True,
    use_coords=False)

hyperparameters = dict(
    prior_size=12,
    n_dense_channels=(4, 8, 16),
    n_seg_channels=(12, 24, 24),
    n_input_channels=(24, 24, 24),
    dilation_rates=([1] * 5, [1] * 10, [1] * 10),
    final_kernel=3,
    augmentation_scale=0)
model_instance = DenseVNet(num_classes=9,hyperparameters=hyperparameters,
                             architecture_parameters=architecture_parameters)

model_net = DenseVNet.layer_op(model_instance, x)

However, I get the following error:

TypeError: Failed to convert object of type <type 'list'> to Tensor. Contents: [None, 1, 72, 72, 24]. Consider casting elements to a supported type.

So, the question is:

Is there any way to implement this?


回答1:


Transfer learning has been added been added to NiftyNet.

You can select which variables you want to restore through the vars_to_restore config parameter and which variables to freeze through the vars_to_freeze config parameter.

See here for more information.




回答2:


A simple transfer learning can be achieved with restoring weights from existing model in the way that you set the parameter starting_iter in [TRAINING] section of your config file to the number of pretrained model. In your example starting_iter=3000.

This will restore the weights from your model and new iterations will start with this initialisation.

Here the architecture of your model has to be exactly the same, otherwise you will get an error.

For more sophisticated transfer learning or maybe also fine tunning where you can restore only a part of weights, there is a great implementation here. It will be probably merged with official niftynet repository very soon, but you can already use it.



来源:https://stackoverflow.com/questions/53050961/implement-transfer-learning-on-niftynet

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!