V-Net(2.5D卷积)网络训练
然后,在服务器中训练网络
2.5D网络程序
1 import keras
2 from keras.models import *
3 from keras.layers import Input, Conv3D, Deconvolution3D, Dropout, Concatenate
4 from keras.optimizers import *
5 from keras import layers
6 from keras import backend as K
7
8 from keras.callbacks import ModelCheckpoint
9 from fit_generator import get_path_list, get_train_batch
10 import matplotlib.pyplot as plt
11
12 train_batch_size = 1
13 epoch = 10
14
15
16 data_train_path = "./vnet_3_1_input/train"
17 data_label_path = "./vnet_3_1_input/label"
18 train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
19
20
21 # 写一个LossHistory类,保存loss和acc
22 class LossHistory(keras.callbacks.Callback):
23 def on_train_begin(self, logs={}):
24 self.losses = {'batch': [], 'epoch': []}
25 self.accuracy = {'batch': [], 'epoch': []}
26 self.val_loss = {'batch': [], 'epoch': []}
27 self.val_acc = {'batch': [], 'epoch': []}
28
29 def on_batch_end(self, batch, logs={}):
30 self.losses['batch'].append(logs.get('loss'))
31 self.accuracy['batch'].append(logs.get('dice_coef'))
32 self.val_loss['batch'].append(logs.get('val_loss'))
33 self.val_acc['batch'].append(logs.get('val_acc'))
34
35 def on_epoch_end(self, batch, logs={}):
36 self.losses['epoch'].append(logs.get('loss'))
37 self.accuracy['epoch'].append(logs.get('dice_coef'))
38 self.val_loss['epoch'].append(logs.get('val_loss'))
39 self.val_acc['epoch'].append(logs.get('val_acc'))
40
41 def loss_plot(self, loss_type):
42 iters = range(len(self.losses[loss_type]))
43 plt.figure()
44 # acc
45 plt.plot(iters, self.accuracy[loss_type], 'r', label='train dice')
46 # loss
47 plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
48 if loss_type == 'epoch':
49 # val_acc
50 plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
51 # val_loss
52 plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
53 plt.grid(True)
54 plt.xlabel(loss_type)
55 plt.ylabel('dice-loss')
56 plt.legend(loc="best")
57 plt.show()
58
59 class WeightedBinaryCrossEntropy(object):
60
61 def __init__(self, pos_ratio=0.7):
62 neg_ratio = 1. - pos_ratio
63 self.pos_ratio = tf.constant(pos_ratio, tf.float32)
64 self.weights = tf.constant(neg_ratio / pos_ratio, tf.float32)
65 self.__name__ = "weighted_binary_crossentropy({0})".format(pos_ratio)
66
67 def __call__(self, y_true, y_pred):
68 return self.weighted_binary_crossentropy(y_true, y_pred)
69
70 def weighted_binary_crossentropy(self, y_true, y_pred):
71 # Transform to logits
72 epsilon = tf.convert_to_tensor(K.common._EPSILON, y_pred.dtype.base_dtype)
73 y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)
74 y_pred = tf.log(y_pred / (1 - y_pred))
75
76 cost = tf.nn.weighted_cross_entropy_with_logits(y_true, y_pred, self.weights)
77 return K.mean(cost * self.pos_ratio, axis=-1)
78
79
80 def dice_coef(y_true, y_pred):
81 smooth = 1.
82 y_true_f = K.flatten(y_true)
83 y_pred_f = K.flatten(y_pred)
84 intersection = K.sum(y_true_f * y_pred_f)
85 return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth)
86
87
88 def dice_coef_loss(y_true, y_pred):
89 return 1. - dice_coef(y_true, y_pred)
90
91
92 def mycrossentropy(y_true, y_pred, e=0.1):
93 nb_classes = 10
94 loss1 = K.categorical_crossentropy(y_true, y_pred)
95 loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / nb_classes, y_pred)
96 return (1 - e) * loss1 + e * loss2
97
98
99 class myVnet(object):
100 def __init__(self, img_depth=3, img_rows=400, img_cols=400, img_channel=1, drop=0.5):
101 self.img_depth = img_depth
102 self.img_rows = img_rows
103 self.img_cols = img_cols
104 self.img_channel = img_channel
105 self.drop = drop
106
107 def BN_operation(self, input):
108 output = keras.layers.normalization.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True,
109 scale=True,
110 beta_initializer='zeros', gamma_initializer='ones',
111 moving_mean_initializer='zeros',
112 moving_variance_initializer='ones',
113 beta_regularizer=None,
114 gamma_regularizer=None, beta_constraint=None,
115 gamma_constraint=None)(input)
116 return output
117
118 def encode_layer(self, kernel_num, kernel_size, input):
119 # 第一次卷积,层内卷积
120 conv1 = Conv3D(kernel_num, kernel_size, activation='relu', padding='same',
121 kernel_initializer='he_normal')(input)
122 conv1 = self.BN_operation(conv1)
123 conv1 = Dropout(self.drop)(conv1)
124 # 第二次卷积,层间卷积
125 conv2 = Conv3D(kernel_num, [3, 1, 1], activation='relu', padding='same',
126 kernel_initializer='he_normal')(conv1)
127 conv2 = self.BN_operation(conv2)
128 conv2 = Dropout(self.drop)(conv2)
129 # 残差
130 res = layers.add([conv1, conv2])
131 # res = Conv3D(kernel_num, [3, 1, 1], activation='relu', padding='valid',
132 # kernel_initializer='he_normal')(res)
133 return res
134
135 def down_operation(self, kernel_num, kernel_size, input):
136 down = Conv3D(kernel_num, kernel_size, strides=[1, 2, 2], activation='relu', padding='same',
137 kernel_initializer='he_normal')(input)
138 return down
139
140 def decode_layer(self, kernel_num, kernel_size, input, code_layer):
141 deconv = Deconvolution3D(kernel_num, kernel_size, strides=(1, 2, 2), activation='relu', padding='same',
142 kernel_initializer='he_normal')(input)
143 # deconv = Conv3D(kernel_num, kernel_size, activation='relu', padding='same', kernel_initializer='he_normal')(
144 # UpSampling3D(size=(1, 2, 2))(input))
145
146 merge = Concatenate(axis=4)([deconv, code_layer])
147 conv = Conv3D(kernel_num, kernel_size, activation='relu', padding='same',
148 kernel_initializer='he_normal')(merge)
149 conv = Dropout(self.drop)(conv)
150
151 res = layers.add([deconv, conv])
152 return res
153
154 # V-Net网络
155 def get_vnet(self):
156 inputs = Input((self.img_depth, self.img_rows, self.img_cols, self.img_channel))
157
158 # 卷积层1
159 conv1 = self.encode_layer(32, [1, 3, 3], inputs)
160 # 下采样1
161 down1 = self.down_operation(64, [1, 3, 3], conv1)
162
163 # 卷积层2
164 conv2 = self.encode_layer(64, [1, 3, 3], down1)
165 # 下采样2
166 down2 = self.down_operation(128, [1, 3, 3], conv2)
167
168 # 卷积层3
169 conv3 = self.encode_layer(128, [1, 3, 3], down2)
170 # 下采样3
171 down3 = self.down_operation(256, [1, 3, 3], conv3)
172
173 # 卷积层4
174 conv4 = self.encode_layer(256, [1, 3, 3], down3)
175 # 下采样4
176 down4 = self.down_operation(512, [1, 3, 3], conv4)
177
178 # 卷积层5
179 conv5 = self.encode_layer(512, [1, 3, 3], down4)
180 conv5 = Conv3D(512, [3, 1, 1], activation='relu', padding='valid',
181 kernel_initializer='he_normal')(conv5)
182 #######################################################################################################################
183 # 反卷积6
184 deconv6 = Deconvolution3D(256, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
185 kernel_initializer='he_normal')(conv5)
186 conv4 = Conv3D(256, [3, 1, 1], activation='relu', padding='valid',
187 kernel_initializer='he_normal')(conv4)
188 merge6 = Concatenate(axis=4)([deconv6, conv4])
189 conv6 = Conv3D(256, [1, 3, 3], activation='relu', padding='same',
190 kernel_initializer='he_normal')(merge6)
191 conv6 = Dropout(self.drop)(conv6)
192 res6 = layers.add([deconv6, conv6])
193 #######################################################################################################################
194
195 #######################################################################################################################
196 # 反卷积7
197 deconv7 = Deconvolution3D(128, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
198 kernel_initializer='he_normal')(res6)
199 conv3 = Conv3D(128, [3, 1, 1], activation='relu', padding='valid',
200 kernel_initializer='he_normal')(conv3)
201 # conv3 = Conv3D(128, [3, 1, 1], activation='relu', padding='valid',
202 # kernel_initializer='he_normal')(conv3)
203 merge7 = Concatenate(axis=4)([deconv7, conv3])
204 conv7 = Conv3D(128, [1, 3, 3], activation='relu', padding='same',
205 kernel_initializer='he_normal')(merge7)
206 conv7 = Dropout(self.drop)(conv7)
207 res7 = layers.add([deconv7, conv7])
208 #######################################################################################################################
209
210 #######################################################################################################################
211 # 反卷积8
212 deconv8 = Deconvolution3D(64, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
213 kernel_initializer='he_normal')(res7)
214 conv2 = Conv3D(64, [3, 1, 1], activation='relu', padding='valid',
215 kernel_initializer='he_normal')(conv2)
216 # conv2 = Conv3D(64, [3, 1, 1], activation='relu', padding='valid',
217 # kernel_initializer='he_normal')(conv2)
218 # conv2 = Conv3D(64, [3, 1, 1], activation='relu', padding='valid',
219 # kernel_initializer='he_normal')(conv2)
220 merge8 = Concatenate(axis=4)([deconv8, conv2])
221 conv8 = Conv3D(64, [1, 3, 3], activation='relu', padding='same',
222 kernel_initializer='he_normal')(merge8)
223 conv8 = Dropout(self.drop)(conv8)
224 res8 = layers.add([deconv8, conv8])
225 #######################################################################################################################
226
227 #######################################################################################################################
228 # 反卷积9
229 deconv9 = Deconvolution3D(32, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
230 kernel_initializer='he_normal')(res8)
231 conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
232 kernel_initializer='he_normal')(conv1)
233 # conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
234 # kernel_initializer='he_normal')(conv1)
235 # conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
236 # kernel_initializer='he_normal')(conv1)
237 # conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
238 # kernel_initializer='he_normal')(conv1)
239 merge9 = Concatenate(axis=4)([deconv9, conv1])
240 conv9 = Conv3D(32, [1, 3, 3], activation='relu', padding='same',
241 kernel_initializer='he_normal')(merge9)
242 conv9 = Dropout(self.drop)(conv9)
243 res9 = layers.add([deconv9, conv9])
244 #######################################################################################################################
245
246 conv10 = Conv3D(1, [1, 1, 1], activation='sigmoid')(res9)
247
248 model = Model(inputs=inputs, outputs=conv10)
249
250 # 在这里可以自定义损失函数loss和准确率函数accuracy
251 # model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
252
253 losses = WeightedBinaryCrossEntropy()
254 model.compile(optimizer=Adam(lr=1e-4), loss=losses.weighted_binary_crossentropy, metrics=['accuracy',dice_coef])
255 # model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=['accuracy',dice_coef])
256 # model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy',dice_coef])
257 print('model compile')
258 return model
259
260 def train(self):
261 print("loading data")
262 print("loading data done")
263
264 # model = self.get_vnet()
265 model = load_model('./model_pre_mean/vnet_tumour_3_1_epoch5.hdf5', custom_objects={'dice_coef': dice_coef,'dice_coef_loss': dice_coef_loss})
266 print("got vnet")
267
268 # 保存的是模型和权重
269 model_checkpoint = ModelCheckpoint('./model_pre_mean/vnet_tumour_3_1_epoch6.hdf5', monitor='loss',
270 verbose=1, save_best_only=True)
271 print('Fitting model...')
272
273 # 创建一个实例history
274 history = LossHistory()
275 # 在callbacks中加入history最后才能绘制收敛曲线
276 model.fit_generator(
277 generator=get_train_batch(train_path_list, label_path_list, train_batch_size, 3, 400, 400),
278 epochs=epoch, verbose=1,
279 steps_per_epoch=count//train_batch_size,
280 callbacks=[model_checkpoint, history],
281 workers=1)
282 # 绘制acc-loss曲线
283 history.loss_plot('batch')
284 plt.savefig('./curve_figure_pre_mean/vnet_tumour_dice_loss_curve_3_1_epoch6.png')
285
286
287 if __name__ == '__main__':
288 myvnet = myVnet()
289 myvnet.train()
导入数据的小程序,fit_generator形式
1 import numpy as np
2 import cv2 as cv
3 import os
4
5 data_train_path = "../../Vnet_tf/V_Net_data/train"
6 data_label_path = "../../Vnet_tf/V_Net_data/label"
7
8
9 def get_path_list(data_train_path, data_label_path):
10 dirs = os.listdir(data_train_path)
11 dirs.sort(key=lambda x: int(x))
12 count = 0
13 for dir in dirs:
14 dir_path = os.path.join(data_train_path, dir)
15 count += len(os.listdir(dir_path))
16 print("共有{}组训练数据".format(count))
17
18 train_path_list = []
19 label_path_list = []
20 for dir in dirs:
21 train_dir_path = os.path.join(data_train_path, dir)
22 label_dir_path = os.path.join(data_label_path, dir)
23 trains = os.listdir(train_dir_path)
24 labels = os.listdir(label_dir_path)
25 trains.sort(key=lambda x: int(x.split(".")[0]))
26 labels.sort(key=lambda x: int(x.split(".")[0]))
27 for name in trains:
28 train_path = os.path.join(train_dir_path, name)
29 label_path = os.path.join(label_dir_path, name)
30
31 train_path_list.append(train_path)
32 label_path_list.append(label_path)
33
34 return train_path_list, label_path_list, count
35
36
37 def get_train_img(paths, img_d, img_rows, img_cols):
38 """
39 参数:
40 paths:要读取的图片路径列表
41 img_rows:图片行
42 img_cols:图片列
43 color_type:图片颜色通道
44 返回:
45 imgs: 图片数组
46 """
47 # Load as grayscale
48 datas = []
49 for path in paths:
50 data = np.load(path)
51 # Reduce size
52 resized = np.reshape(data, (img_d, img_rows, img_cols, 1))
53 resized = resized.astype('float32')
54 resized /= 255
55 # 均值
56 # 注意:这里取均值时,要考虑是输入一批train还是单个train
57 # 一批train需要设置axis=0,这样就是对每一张图像求均值
58 # 单个train,就不需要设置,就会直接对图像求均值
59 # mean = resized.mean(axis=0)
60 # 标准差
61 # std = np.std(resized, ddof=1)
62 # 标准化
63 # resized -= mean
64 # resized /= std
65 datas.append(resized)
66 datas = np.array(datas)
67 return datas
68
69
70 def get_label_img(paths, img_rows, img_cols):
71 """
72 参数:
73 paths:要读取的图片路径列表
74 img_rows:图片行
75 img_cols:图片列
76 color_type:图片颜色通道
77 返回:
78 imgs: 图片数组
79 """
80 # Load as grayscale
81 datas = []
82 for path in paths:
83 data = np.load(path)
84 # Reduce size
85 resized = np.reshape(data, (1, img_cols, img_rows, 1))
86 resized = resized.astype('float32')
87 resized /= 255
88 datas.append(resized)
89 datas = np.array(datas)
90 return datas
91
92
93 def get_train_batch(train, label, batch_size, img_d, img_w, img_h):
94 """
95 参数:
96 X_train:所有图片路径列表
97 y_train: 所有图片对应的标签列表
98 batch_size:批次
99 img_w:图片宽
100 img_h:图片高
101 color_type:图片类型
102 is_argumentation:是否需要数据增强
103 返回:
104 一个generator,x: 获取的批次图片 y: 获取的图片对应的标签
105 """
106 while 1:
107 for i in range(0, len(train), batch_size):
108 x = get_train_img(train[i:i+batch_size], img_d, img_w, img_h)
109 y = get_label_img(label[i:i+batch_size], img_w, img_h)
110 # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完
111 yield(np.array(x), np.array(y))
112
113
114 if __name__ == "__main__":
115 train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
116 print(train_path_list)
来源:oschina
链接:https://my.oschina.net/u/4398140/blog/4333675