Keras:在预训练的网络上fine-tune

匿名 (未验证) 提交于 2019-12-03 00:22:01

Keras:自建数据集图像分类的模型训练、保存与恢复
Keras:使用预训练网络的bottleneck特征

fine-tune的三个步骤:

  • 搭建vgg-16并载入权重;
  • 将之前定义的全连接网络加载到模型顶部,并载入权重;
  • 冻结vgg16网络的一部分参数.

在之前的Keras:自建数据集图像分类的模型训练、保存与恢复里制作了实验用的数据集并初步进行了训练.然后在Keras:使用预训练网络的bottleneck特征中定义并训练了要使用全连接网络,并将网络权重保存到了bottleneck_fc_model.h5文件中.

根据keras中…/keras/applications/vgg16.py的VGG16模型形式,构造VGG16模型的卷积部分,并载入权重(vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5).然后添加预训练好的模型.训练时冻结最后一个卷积块前的卷基层参数.

示例:

from keras.models import Sequential from keras import optimizers from keras.preprocessing.image import ImageDataGenerator from keras.layers import Flatten, Dense, Dropout, Conv2D, MaxPooling2D from keras import backend as K K.set_image_dim_ordering('th')   # 构造VGG16模型 model = Sequential()  # Block 1 model.add(Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', input_shape=(3, 150, 150))) model.add(Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool'))  # Block 2 model.add(Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')) model.add(Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool'))  # Block 3 model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')) model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')) model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool'))  # Block 4 model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')) model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')) model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool'))  # Block 5 model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')) model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool'))  model.load_weights('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',by_name=True) model.summary()  # 在初始化好的VGG网络上添加预训练好的模型 top_model = Sequential() top_model.add(Flatten(input_shape=model.output_shape[1:])) #  (4,4,512) top_model.add(Dense(256, activation='relu')) top_model.add(Dropout(0.5)) top_model.add(Dense(1, activation='sigmoid'))  top_model.load_weights('bottleneck_fc_model.h5',by_name=True) model.add(top_model)  # 将最后一个卷积块前的卷基层参数冻结,把随后卷积块前的权重设置为不可训练(权重不会更新) for layer in model.layers[:25]:     layer.trainable = False  model.compile(loss='binary_crossentropy',               optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),               metrics=['accuracy'])  # 以低学习率进行训练 train_datagen = ImageDataGenerator(rescale=1./255,                                    shear_range=0.2,                                    zoom_range=0.2,                                    horizontal_flip=True)  test_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory('train',                                                     target_size=(150,150),                                                     batch_size=32,                                                     class_mode='binary')  validation_generator = test_datagen.flow_from_directory('validation',                                                         target_size=(150,150),                                                         batch_size=32,                                                         class_mode='binary')  model.fit_generator(train_generator,                     steps_per_epoch=10,                     epochs=50,                     validation_data=validation_generator,                     validation_steps=10) 

输出:

_________________________________________________________________ Layer (type)                 Output Shape              Param #    ================================================================= block1_conv1 (Conv2D)        (None, 64, 150, 150)      1792       _________________________________________________________________ block1_conv2 (Conv2D)        (None, 64, 150, 150)      36928      _________________________________________________________________ block1_pool (MaxPooling2D)   (None, 64, 75, 75)        0          _________________________________________________________________ block2_conv1 (Conv2D)        (None, 128, 75, 75)       73856      _________________________________________________________________ block2_conv2 (Conv2D)        (None, 128, 75, 75)       147584     _________________________________________________________________ block2_pool (MaxPooling2D)   (None, 128, 37, 37)       0          _________________________________________________________________ block3_conv1 (Conv2D)        (None, 256, 37, 37)       295168     _________________________________________________________________ block3_conv2 (Conv2D)        (None, 256, 37, 37)       590080     _________________________________________________________________ block3_conv3 (Conv2D)        (None, 256, 37, 37)       590080     _________________________________________________________________ block3_pool (MaxPooling2D)   (None, 256, 18, 18)       0          _________________________________________________________________ block4_conv1 (Conv2D)        (None, 512, 18, 18)       1180160    _________________________________________________________________ block4_conv2 (Conv2D)        (None, 512, 18, 18)       2359808    _________________________________________________________________ block4_conv3 (Conv2D)        (None, 512, 18, 18)       2359808    _________________________________________________________________ block4_pool (MaxPooling2D)   (None, 512, 9, 9)         0          _________________________________________________________________ block5_conv1 (Conv2D)        (None, 512, 9, 9)         2359808    _________________________________________________________________ block5_conv2 (Conv2D)        (None, 512, 9, 9)         2359808    _________________________________________________________________ block5_pool (MaxPooling2D)   (None, 512, 4, 4)         0          ================================================================= Total params: 12,354,880 Trainable params: 12,354,880 Non-trainable params: 0 _________________________________________________________________ Found 60 images belonging to 2 classes. Found 60 images belonging to 2 classes. Epoch 1/50   1/10 [==>...........................] - ETA: 6:57 - loss: 0.7880 - acc: 0.3929  2/10 [=====>........................] - ETA: 6:23 - loss: 0.7920 - acc: 0.4152  3/10 [========>.....................] - ETA: 5:25 - loss: 0.8292 - acc: 0.3839  4/10 [===========>..................] - ETA: 4:47 - loss: 0.8184 - acc: 0.3895  5/10 [==============>...............] - ETA: 3:59 - loss: 0.8159 - acc: 0.3929  6/10 [=================>............] - ETA: 3:08 - loss: 0.8001 - acc: 0.4048  7/10 [====================>.........] - ETA: 2:18 - loss: 0.8094 - acc: 0.4184  8/10 [=======================>......] - ETA: 1:32 - loss: 0.8031 - acc: 0.4247  9/10 [==========================>...] - ETA: 46s - loss: 0.8041 - acc: 0.4296  10/10 [==============================] - 899s 90s/step - loss: 0.8125 - acc: 0.4260 - val_loss: 0.8145 - val_acc: 0.4000 Epoch 2/50   1/10 [==>...........................] - ETA: 6:55 - loss: 0.8487 - acc: 0.4062  2/10 [=====>........................] - ETA: 5:50 - loss: 0.8443 - acc: 0.4353  3/10 [========>.....................] - ETA: 5:08 - loss: 0.8430 - acc: 0.4256  4/10 [===========>..................] - ETA: 4:18 - loss: 0.8258 - acc: 0.4263  5/10 [==============>...............] - ETA: 3:32 - loss: 0.8310 - acc: 0.4339  6/10 [=================>............] - ETA: 2:53 - loss: 0.8266 - acc: 0.4397  7/10 [====================>.........] - ETA: 2:11 - loss: 0.8270 - acc: 0.4305  8/10 [=======================>......] - ETA: 1:26 - loss: 0.8220 - acc: 0.4347   9/10 [==========================>...] - ETA: 43s - loss: 0.8311 - acc: 0.4340    ......  ......
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!