Implement perceptual loss with pretrained VGG using keras

强颜欢笑 提交于 2020-12-29 04:23:18

问题


I am relatively new to DL and Keras.

I am trying to implement perceptual loss using the pretrained VGG16 in Keras but have some troubles. I already found that question but I am still struggling :/

A short explanation of what my network should do:

I have a CNN (subsequent called mainModel) that gets grayscale images as input (#TrainData, 512, 512, 1) and outputs grayscale images with the same size. The network should reduce artifacts in the images - but I think it is not that important for this question. Instead of using e.g. MSE as loss function, I would like to implement the perceptual loss.

What I want to do (I hope I have properly understood the concept of perceptual loss):

I would like to append a lossModel (pretrained VGG16 with fixed params) to my mainModel. Then I would like to pass the output of the mainModel to the lossModel. In addition I pass the label images (Y_train) to the lossModel. Further on I compare the activations at a specific layer (e.g. block1_conv2) of the lossModel using e.g. MSE and use it as loss function.

What I did so far:

Load in data and create the mainModel:

### Load data ###
with h5py.File('.\train_test_val.h5', 'r') as hf:
    X_train = hf['X_train'][:]
    Y_train = hf['Y_train'][:]
    X_test = hf['X_test'][:]
    Y_test = hf['Y_test'][:]
    X_val = hf['X_val'][:]
    Y_val = hf['Y_val'][:]

### Create Main Model ###
input_1 = Input((512,512,9))
conv0 = Conv2D(64, (3,3), strides=(1,1), activation=relu, use_bias=True, padding='same')(input_1)
.
.
.

mainModel = Model(inputs=input_1, outputs=output)

Create lossModel, append it to mainModel and fix params:

### Create Loss Model (VGG16) ###
lossModel = vgg16.VGG16(include_top=False, weights='imagenet', input_tensor=mainModel.output, input_shape=(512,512, 1))
lossModel.trainable=False

for layer in lossModel.layers:
    layer.trainable=False

Create new model including both networks and compile it

### Create new Model ###
fullModel = Model(inputs=mainModel.input, outputs=lossModel.output)

fullModel.compile(loss='mse', optimizer='adam',metrics=['mse','mae'])
fullModel.summary()

Adjust label images by passing them through the lossNetwork:

Y_train_lossModel = lossModel.predict(Y_train)

Fit the fullModel using the perceptual loss:

fullModel.fit(X_train, Y_train_lossModel, batch_size=32, epochs=5, validation_data=[X_val,Y_val])

Problems occurring:

  • VGG16 wants to get inputs of shape (?,?,3) but my mainModel outputs a grayscale image (?,?,1)

  • Some issue with appending the lossModel to the mainModel

RuntimeError: Graph disconnected: cannot obtain value for tensor Tensor("conv2d_2/Relu:0", shape=(?, 512, 512, 3), dtype=float32) at layer "input_2". The following previous layers were accessed without issue: []

  • How can I calculate the MSE at a specific layers activation and not at the output of the lossModel?

Thank you so much for your help and sorry for the extremely long question :)


回答1:


Number of channels

Well, the first problem is significant.

VGG models were made to color images with 3 channels... so, it's quite not the right model for your case. I'm not sure if there are models for black & white images, but you should search for them.

A workaround for that, which I don't know if will work well, is to make 3 copies of mainModel's output.

tripleOut = Concatenate()([mainModel.output,mainModel.output,mainModel.output])

Graph disconnected

This means that nowhere in your code, you created a connection between the input and output of fullModel. You must connect the output of mainModel to the input of lossModel

But first, let's prepare the VGG model for multiple outputs.

Preparing lossModel for multiple outputs

You must select which layers of the VGG model will be used to calculate the loss. If you use only the final output there won't be really a good perceptual loss because the final output is made more of concepts than of features.

So, after you select the layers, make a list of their indices or names:

selectedLayers = [1,2,9,10,17,18] #for instance

Let's make a new model from VGG16, but with multiple outputs:

#a list with the output tensors for each selected layer:
selectedOutputs = [lossModel.layers[i].output for i in selectedLayers]
     #or [lossModel.get_layer(name).output for name in selectedLayers]

#a new model that has multiple outputs:
lossModel = Model(lossModel.inputs,selectedOutputs)

Joining the models

Now, here we create the connection between the two models.

We call the lossModel (as if it were a layer) taking the output of the mainModel as input:

lossModelOutputs = lossModel(tripleOut) #or mainModel.output if not using tripeOut

Now, with the graph entirely connected from the input of mainModel to the output of lossModel, we can create the fullModel:

fullModel = Model(mainModel.input, lossModelOutputs)

#if the line above doesn't work due to a type problem, make a list with lossModelOutputs:
lossModelOutputs = [lossModelOutputs[i] for i in range(len(selectedLayers))]

Training

Take the predictions of this new lossModel, just as you did. But for the workaround, let's make it triple channel as well:

triple_Y_train = np.concatenate((Y_train,Y_train,Y_train),axis=-1)
Y_train_lossModel = lossModel.predict(triple_Y_train)
#the output will be a list of numpy arrays, one for each of the selected layers   

Make sure you make each layer of lossModel non trainable before fullModel.compile().

If you want 'mse' for all outputs, you just do:

fullModel.compile(loss='mse', ...)

If you want a different loss for each layer, pass a list of losses:

fullModel.compile(loss=[loss1,loss2,loss3,...], ...)

Additional considerations

Since VGG is supposed to work with images in the caffe format, you might want to add a few layers after mainModel to make the output suitable. It's not absolutely required, but it would use the best performance from VGG.

See how keras transforms an input image ranging from 0 to 255 into a caffe format here at line 15 or 44



来源:https://stackoverflow.com/questions/47675094/implement-perceptual-loss-with-pretrained-vgg-using-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!