问题
After much effort, I managed to build a tensorflow 2 implementation of an existing pytorch style-transfer project. Then I wanted to get all the nice extra features that are available through Keras standard learning, e.g. model.fit().
But the same model fails when learning through model.fit(). The model seems to learn the content features, but is unable to learn style features. This is the diagram of the model in quesion:
def vgg_layers19(content_layers, style_layers, input_shape=(256,256,3)):
""" creates a VGG model that returns output values for the given layers
see: https://keras.io/applications/#extract-features-from-an-arbitrary-intermediate-layer-with-vgg19
Returns:
function(x, preprocess=True):
Args:
x: image tuple/ndarray h,w,c(RGB), domain=(0.,255.)
Returns:
a tuple of lists, ([content_features], [style_features])
usage:
(content_features, style_features) = vgg_layers16(content_layers, style_layers)(x_train)
"""
preprocessingFn = tf.keras.applications.vgg19.preprocess_input
base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
base_model.trainable = False
content_features = [base_model.get_layer(name).output for name in content_layers]
style_features = [base_model.get_layer(name).output for name in style_layers]
output_features = content_features + style_features
model = Model( inputs=base_model.input, outputs=output_features, name="vgg_layers")
model.trainable = False
def _get_features(x, preprocess=True):
"""
Args:
x: expecting tensor, domain=255. hwcRGB
"""
if preprocess and callable(preprocessingFn):
x = preprocessingFn(x)
output = model(x) # call as tf.keras.Layer()
return ( output[:len(content_layers)], output[len(content_layers):] )
return _get_features
class VGG_Features():
""" get content and style features from VGG model """
def __init__(self, loss_model, style_image=None, target_style_gram=None):
self.loss_model = loss_model
if style_image is not None:
assert style_image.shape == (256,256,3), "ERROR: loss_model expecting input_shape=(256,256,3), got {}".format(style_image.shape)
self.style_image = style_image
self.target_style_gram = VGG_Features.get_style_gram(self.loss_model, self.style_image)
if target_style_gram is not None:
self.target_style_gram = target_style_gram
@staticmethod
def get_style_gram(vgg_features_model, style_image):
style_batch = tf.repeat( style_image[tf.newaxis,...], repeats=_batch_size, axis=0)
# show([style_image], w=128, domain=(0.,255.) )
# B, H, W, C = style_batch.shape
(_, style_features) = vgg_features_model( style_batch , preprocess=True ) # hwcRGB
target_style_gram = [ fnstf_utils.gram(value) for value in style_features ] # list
return target_style_gram
def __call__(self, input_batch):
content_features, style_features = self.loss_model( input_batch, preprocess=True )
style_gram = tuple(fnstf_utils.gram(value) for value in style_features) # tuple(<generator>)
return (content_features[0],) + style_gram # tuple = tuple + tuple
class TransformerNetwork_VGG(tf.keras.Model):
def __init__(self, transformer=transformer, vgg_features=vgg_features):
super(TransformerNetwork_VGG, self).__init__()
self.transformer = transformer
# type: tf.keras.models.Model
# input_shapes: (None, 256,256,3)
# output_shapes: (None, 256,256,3)
style_model = {
'content_layers':['block5_conv2'],
'style_layers': ['block1_conv1',
'block2_conv1',
'block3_conv1',
'block4_conv1',
'block5_conv1']
}
vgg_model = vgg_layers19( style_model['content_layers'], style_model['style_layers'] )
self.vgg_features = VGG_Features(vgg_model, style_image=style_image, batch_size=batch_size)
# input_shapes: (None, 256,256,3)
# output_shapes: [(None, 16, 16, 512), (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
# [ content_loss, style_loss_1, style_loss_2, style_loss_3, style_loss_4, style_loss_5 ]
def call(self, inputs):
x = inputs # shape=(None, 256,256,3)
# shape=(None, 256,256,3)
generated_image = self.transformer(x)
# shape=[(None, 16, 16, 512), (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
vgg_feature_losses = self.vgg(generated_image)
return vgg_feature_losses # tuple(content1, style1, style2, style3, style4, style5)
Style Image
FEATURE_WEIGHTS= [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
GradientTape learning
With the tf.GradientTape() loop, I'm manually handling the multiple outputs, e.g. tuple of 6 tensors, from TransformerNetwork_VGG(x_train). This method learns correctly.
@tf.function()
def train_step(x_train, y_true, loss_weights=None, log_freq=10):
with tf.GradientTape() as tape:
y_pred = TransformerNetwork_VGG(x_train)
generated_content_features = y_pred[:1]
generated_style_gram = y_pred[1:]
y_true = TransformerNetwork_VGG.vgg(x_train)
target_content_features = y_true[:1]
target_style_gram = TransformerNetwork_VGG.vgg.target_style_gram
content_loss = get_MEAN_mse_loss(target_content_features, generated_content_features, weights)
style_loss = tuple(get_MEAN_mse_loss(x,y)*w for x,y,w in zip(target_style_gram, generated_style_gram, weights))
total_loss = content_loss + = tf.reduce_sum(style_loss)
TransformerNetwork = TransformerNetwork_VGG.transformer
grads = tape.gradient(total_loss, TransformerNetwork.trainable_weights)
optimizer.apply_gradients(zip(grads, TransformerNetwork.trainable_weights))
# GradientTape epoch=5:
# losses: [ 6078.71 70.23 4495.13 13817.65 88217.99 48.36]
model.fit() learning
With tf.keras.models.Model.fit(), the multiple outputs, e.g. tuple of 6 tensors, are fed to the loss function individually as loss(y_pred, y_true) and then multipled by the correct weight on reduction. This method does learn to approximate the content_image, but does not learn to minimize the style losses! II cannot figure out why.
history = TransformerNetwork_VGG.fit(
x=train_dataset.repeat(NUM_EPOCHS),
epochs=NUM_EPOCHS,
steps_per_epoch=NUM_BATCHES,
callbacks=callbacks,
)
# model.fit() epoch=5:
# losses: [ 4661.08 219.95 6959.01 4897.39 209201.16 84.68]]
50 epochs, with boosted style_weights, FEATURE_WEIGHTS= [ 0.1854, 1605.23, 25.08, 8.16, 1.28, 2330.79] # boost style loss x100
step=50, losses=[269899.45 337.5 69617.7 38424.96 9192.36 85903.44 66423.51]
check mse losses * weights
I tested my model with losses and weights fixed as follows
* FEATURE_WEIGHTS = SEQ = [1.,2.,3.,4.,5.,6.,]
* MSELoss(y_true, y_pred) == tf.ones() of equal shape
and confirmed that model.fit() is handling multiple output losses * weights correctly
I've checked everything I can think of, but I cannot figure out how to make the model learn correctly with model.fit(). What am I missing??
The full notebook is available here: https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_FastStyleTransfer.ipynb
来源:https://stackoverflow.com/questions/60545104/why-does-my-model-work-with-tf-gradienttape-but-fail-when-using-keras-model