V-Net(2.5D卷积)网络训练---Keras

五迷三道 提交于 2020-08-12 02:08:44

V-Net(2.5D卷积)网络训练

然后,在服务器中训练网络

 

2.5D网络程序 

 

  1 import keras
  2 from keras.models import *
  3 from keras.layers import Input, Conv3D, Deconvolution3D, Dropout, Concatenate
  4 from keras.optimizers import *
  5 from keras import layers
  6 from keras import backend as K
  7 
  8 from keras.callbacks import ModelCheckpoint
  9 from fit_generator import get_path_list, get_train_batch
 10 import matplotlib.pyplot as plt
 11 
 12 train_batch_size = 1
 13 epoch = 10
 14 
 15 
 16 data_train_path = "./vnet_3_1_input/train"
 17 data_label_path = "./vnet_3_1_input/label"
 18 train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
 19 
 20 
 21 # 写一个LossHistory类,保存loss和acc
 22 class LossHistory(keras.callbacks.Callback):
 23    def on_train_begin(self, logs={}):
 24        self.losses = {'batch': [], 'epoch': []}
 25        self.accuracy = {'batch': [], 'epoch': []}
 26        self.val_loss = {'batch': [], 'epoch': []}
 27        self.val_acc = {'batch': [], 'epoch': []}
 28 
 29    def on_batch_end(self, batch, logs={}):
 30        self.losses['batch'].append(logs.get('loss'))
 31        self.accuracy['batch'].append(logs.get('dice_coef'))
 32        self.val_loss['batch'].append(logs.get('val_loss'))
 33        self.val_acc['batch'].append(logs.get('val_acc'))
 34 
 35    def on_epoch_end(self, batch, logs={}):
 36        self.losses['epoch'].append(logs.get('loss'))
 37        self.accuracy['epoch'].append(logs.get('dice_coef'))
 38        self.val_loss['epoch'].append(logs.get('val_loss'))
 39        self.val_acc['epoch'].append(logs.get('val_acc'))
 40 
 41    def loss_plot(self, loss_type):
 42        iters = range(len(self.losses[loss_type]))
 43        plt.figure()
 44        # acc
 45        plt.plot(iters, self.accuracy[loss_type], 'r', label='train dice')
 46        # loss
 47        plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
 48        if loss_type == 'epoch':
 49            # val_acc
 50            plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
 51            # val_loss
 52            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
 53        plt.grid(True)
 54        plt.xlabel(loss_type)
 55        plt.ylabel('dice-loss')
 56        plt.legend(loc="best")
 57        plt.show()
 58 
 59 class WeightedBinaryCrossEntropy(object):
 60  
 61     def __init__(self, pos_ratio=0.7):
 62         neg_ratio = 1. - pos_ratio
 63         self.pos_ratio = tf.constant(pos_ratio, tf.float32)
 64         self.weights = tf.constant(neg_ratio / pos_ratio, tf.float32)
 65         self.__name__ = "weighted_binary_crossentropy({0})".format(pos_ratio)
 66  
 67     def __call__(self, y_true, y_pred):
 68         return self.weighted_binary_crossentropy(y_true, y_pred)
 69  
 70     def weighted_binary_crossentropy(self, y_true, y_pred):
 71         # Transform to logits
 72         epsilon = tf.convert_to_tensor(K.common._EPSILON, y_pred.dtype.base_dtype)
 73         y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)
 74         y_pred = tf.log(y_pred / (1 - y_pred))
 75  
 76         cost = tf.nn.weighted_cross_entropy_with_logits(y_true, y_pred, self.weights)
 77         return K.mean(cost * self.pos_ratio, axis=-1)
 78 
 79 
 80 def dice_coef(y_true, y_pred):
 81     smooth = 1.
 82     y_true_f = K.flatten(y_true)
 83     y_pred_f = K.flatten(y_pred)
 84     intersection = K.sum(y_true_f * y_pred_f)
 85     return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth)
 86 
 87 
 88 def dice_coef_loss(y_true, y_pred):
 89     return 1. - dice_coef(y_true, y_pred)
 90 
 91 
 92 def mycrossentropy(y_true, y_pred, e=0.1):
 93     nb_classes = 10
 94     loss1 = K.categorical_crossentropy(y_true, y_pred)
 95     loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / nb_classes, y_pred)
 96     return (1 - e) * loss1 + e * loss2
 97 
 98 
 99 class myVnet(object):
100     def __init__(self, img_depth=3, img_rows=400, img_cols=400, img_channel=1, drop=0.5):
101         self.img_depth = img_depth
102         self.img_rows = img_rows
103         self.img_cols = img_cols
104         self.img_channel = img_channel
105         self.drop = drop
106 
107     def BN_operation(self, input):
108         output = keras.layers.normalization.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True,
109                                                                scale=True,
110                                                                beta_initializer='zeros', gamma_initializer='ones',
111                                                                moving_mean_initializer='zeros',
112                                                                moving_variance_initializer='ones',
113                                                                beta_regularizer=None,
114                                                                gamma_regularizer=None, beta_constraint=None,
115                                                                gamma_constraint=None)(input)
116         return output
117 
118     def encode_layer(self, kernel_num, kernel_size, input):
119         # 第一次卷积,层内卷积
120         conv1 = Conv3D(kernel_num, kernel_size, activation='relu', padding='same',
121                        kernel_initializer='he_normal')(input)
122         conv1 = self.BN_operation(conv1)
123         conv1 = Dropout(self.drop)(conv1)
124         # 第二次卷积,层间卷积
125         conv2 = Conv3D(kernel_num, [3, 1, 1], activation='relu', padding='same',
126                        kernel_initializer='he_normal')(conv1)
127         conv2 = self.BN_operation(conv2)
128         conv2 = Dropout(self.drop)(conv2)
129         # 残差
130         res = layers.add([conv1, conv2])
131         # res = Conv3D(kernel_num, [3, 1, 1], activation='relu', padding='valid',
132         #              kernel_initializer='he_normal')(res)
133         return res
134 
135     def down_operation(self, kernel_num, kernel_size, input):
136         down = Conv3D(kernel_num, kernel_size, strides=[1, 2, 2], activation='relu', padding='same',
137                       kernel_initializer='he_normal')(input)
138         return down
139 
140     def decode_layer(self, kernel_num, kernel_size, input, code_layer):
141         deconv = Deconvolution3D(kernel_num, kernel_size, strides=(1, 2, 2), activation='relu', padding='same',
142                                  kernel_initializer='he_normal')(input)
143         # deconv = Conv3D(kernel_num, kernel_size, activation='relu', padding='same', kernel_initializer='he_normal')(
144         #          UpSampling3D(size=(1, 2, 2))(input))
145 
146         merge = Concatenate(axis=4)([deconv, code_layer])
147         conv = Conv3D(kernel_num, kernel_size, activation='relu', padding='same',
148                       kernel_initializer='he_normal')(merge)
149         conv = Dropout(self.drop)(conv)
150 
151         res = layers.add([deconv, conv])
152         return res
153 
154     # V-Net网络
155     def get_vnet(self):
156         inputs = Input((self.img_depth, self.img_rows, self.img_cols, self.img_channel))
157 
158         # 卷积层1
159         conv1 = self.encode_layer(32, [1, 3, 3], inputs)
160         # 下采样1
161         down1 = self.down_operation(64, [1, 3, 3], conv1)
162 
163         # 卷积层2
164         conv2 = self.encode_layer(64, [1, 3, 3], down1)
165         # 下采样2
166         down2 = self.down_operation(128, [1, 3, 3], conv2)
167 
168         # 卷积层3
169         conv3 = self.encode_layer(128, [1, 3, 3], down2)
170         # 下采样3
171         down3 = self.down_operation(256, [1, 3, 3], conv3)
172 
173         # 卷积层4
174         conv4 = self.encode_layer(256, [1, 3, 3], down3)
175         # 下采样4
176         down4 = self.down_operation(512, [1, 3, 3], conv4)
177 
178         # 卷积层5
179         conv5 = self.encode_layer(512, [1, 3, 3], down4)
180         conv5 = Conv3D(512, [3, 1, 1], activation='relu', padding='valid',
181                        kernel_initializer='he_normal')(conv5)
182 #######################################################################################################################
183         # 反卷积6
184         deconv6 = Deconvolution3D(256, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
185                                   kernel_initializer='he_normal')(conv5)
186         conv4 = Conv3D(256, [3, 1, 1], activation='relu', padding='valid',
187                        kernel_initializer='he_normal')(conv4)
188         merge6 = Concatenate(axis=4)([deconv6, conv4])
189         conv6 = Conv3D(256, [1, 3, 3], activation='relu', padding='same',
190                        kernel_initializer='he_normal')(merge6)
191         conv6 = Dropout(self.drop)(conv6)
192         res6 = layers.add([deconv6, conv6])
193 #######################################################################################################################
194 
195 #######################################################################################################################
196         # 反卷积7
197         deconv7 = Deconvolution3D(128, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
198                                   kernel_initializer='he_normal')(res6)
199         conv3 = Conv3D(128, [3, 1, 1], activation='relu', padding='valid',
200                        kernel_initializer='he_normal')(conv3)
201         # conv3 = Conv3D(128, [3, 1, 1], activation='relu', padding='valid',
202         #                kernel_initializer='he_normal')(conv3)
203         merge7 = Concatenate(axis=4)([deconv7, conv3])
204         conv7 = Conv3D(128, [1, 3, 3], activation='relu', padding='same',
205                        kernel_initializer='he_normal')(merge7)
206         conv7 = Dropout(self.drop)(conv7)
207         res7 = layers.add([deconv7, conv7])
208 #######################################################################################################################
209 
210 #######################################################################################################################
211         # 反卷积8
212         deconv8 = Deconvolution3D(64, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
213                                   kernel_initializer='he_normal')(res7)
214         conv2 = Conv3D(64, [3, 1, 1], activation='relu', padding='valid',
215                        kernel_initializer='he_normal')(conv2)
216         # conv2 = Conv3D(64, [3, 1, 1], activation='relu', padding='valid',
217         #                kernel_initializer='he_normal')(conv2)
218         # conv2 = Conv3D(64, [3, 1, 1], activation='relu', padding='valid',
219         #                kernel_initializer='he_normal')(conv2)
220         merge8 = Concatenate(axis=4)([deconv8, conv2])
221         conv8 = Conv3D(64, [1, 3, 3], activation='relu', padding='same',
222                        kernel_initializer='he_normal')(merge8)
223         conv8 = Dropout(self.drop)(conv8)
224         res8 = layers.add([deconv8, conv8])
225 #######################################################################################################################
226 
227 #######################################################################################################################
228         # 反卷积9
229         deconv9 = Deconvolution3D(32, [1, 3, 3], strides=(1, 2, 2), activation='relu', padding='same',
230                                   kernel_initializer='he_normal')(res8)
231         conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
232                        kernel_initializer='he_normal')(conv1)
233         # conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
234         #                kernel_initializer='he_normal')(conv1)
235         # conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
236         #                kernel_initializer='he_normal')(conv1)
237         # conv1 = Conv3D(32, [3, 1, 1], activation='relu', padding='valid',
238         #                kernel_initializer='he_normal')(conv1)
239         merge9 = Concatenate(axis=4)([deconv9, conv1])
240         conv9 = Conv3D(32, [1, 3, 3], activation='relu', padding='same',
241                        kernel_initializer='he_normal')(merge9)
242         conv9 = Dropout(self.drop)(conv9)
243         res9 = layers.add([deconv9, conv9])
244 #######################################################################################################################
245 
246         conv10 = Conv3D(1, [1, 1, 1], activation='sigmoid')(res9)
247 
248         model = Model(inputs=inputs, outputs=conv10)
249 
250         # 在这里可以自定义损失函数loss和准确率函数accuracy
251         # model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
252         
253         losses = WeightedBinaryCrossEntropy()
254         model.compile(optimizer=Adam(lr=1e-4), loss=losses.weighted_binary_crossentropy, metrics=['accuracy',dice_coef])
255 #        model.compile(optimizer=Adam(lr=1e-4), loss=dice_coef_loss, metrics=['accuracy',dice_coef])        
256 #        model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy',dice_coef])
257         print('model compile')
258         return model
259 
260     def train(self):
261         print("loading data")
262         print("loading data done")
263 
264 #        model = self.get_vnet()
265         model = load_model('./model_pre_mean/vnet_tumour_3_1_epoch5.hdf5', custom_objects={'dice_coef': dice_coef,'dice_coef_loss': dice_coef_loss})
266         print("got vnet")
267 
268         # 保存的是模型和权重
269         model_checkpoint = ModelCheckpoint('./model_pre_mean/vnet_tumour_3_1_epoch6.hdf5', monitor='loss',
270                                            verbose=1, save_best_only=True)
271         print('Fitting model...')
272 
273         # 创建一个实例history
274         history = LossHistory()
275         # 在callbacks中加入history最后才能绘制收敛曲线
276         model.fit_generator(
277             generator=get_train_batch(train_path_list, label_path_list, train_batch_size, 3, 400, 400),
278             epochs=epoch, verbose=1,
279             steps_per_epoch=count//train_batch_size,
280             callbacks=[model_checkpoint, history],
281             workers=1)
282         # 绘制acc-loss曲线
283         history.loss_plot('batch')
284         plt.savefig('./curve_figure_pre_mean/vnet_tumour_dice_loss_curve_3_1_epoch6.png')
285 
286 
287 if __name__ == '__main__':
288     myvnet = myVnet()
289     myvnet.train()

 

导入数据的小程序,fit_generator形式

 

  1 import numpy as np
  2 import cv2 as cv
  3 import os
  4 
  5 data_train_path = "../../Vnet_tf/V_Net_data/train"
  6 data_label_path = "../../Vnet_tf/V_Net_data/label"
  7 
  8 
  9 def get_path_list(data_train_path, data_label_path):
 10     dirs = os.listdir(data_train_path)
 11     dirs.sort(key=lambda x: int(x))
 12     count = 0
 13     for dir in dirs:
 14         dir_path = os.path.join(data_train_path, dir)
 15         count += len(os.listdir(dir_path))
 16     print("共有{}组训练数据".format(count))
 17 
 18     train_path_list = []
 19     label_path_list = []
 20     for dir in dirs:
 21         train_dir_path = os.path.join(data_train_path, dir)
 22         label_dir_path = os.path.join(data_label_path, dir)
 23         trains = os.listdir(train_dir_path)
 24         labels = os.listdir(label_dir_path)
 25         trains.sort(key=lambda x: int(x.split(".")[0]))
 26         labels.sort(key=lambda x: int(x.split(".")[0]))
 27         for name in trains:
 28             train_path = os.path.join(train_dir_path, name)
 29             label_path = os.path.join(label_dir_path, name)
 30 
 31             train_path_list.append(train_path)
 32             label_path_list.append(label_path)
 33 
 34     return train_path_list, label_path_list, count
 35 
 36 
 37 def get_train_img(paths, img_d, img_rows, img_cols):
 38     """
 39     参数:
 40         paths:要读取的图片路径列表
 41         img_rows:图片行
 42         img_cols:图片列
 43         color_type:图片颜色通道
 44     返回:
 45         imgs: 图片数组
 46     """
 47     # Load as grayscale
 48     datas = []
 49     for path in paths:
 50         data = np.load(path)
 51         # Reduce size
 52         resized = np.reshape(data, (img_d, img_rows, img_cols, 1))
 53         resized = resized.astype('float32')
 54         resized /= 255
 55         # 均值
 56         # 注意:这里取均值时,要考虑是输入一批train还是单个train
 57         # 一批train需要设置axis=0,这样就是对每一张图像求均值
 58         # 单个train,就不需要设置,就会直接对图像求均值
 59 #        mean = resized.mean(axis=0)
 60         # 标准差
 61         # std = np.std(resized, ddof=1)
 62         # 标准化
 63 #        resized -= mean
 64         # resized /= std
 65         datas.append(resized)
 66     datas = np.array(datas)
 67     return datas
 68 
 69 
 70 def get_label_img(paths, img_rows, img_cols):
 71     """
 72     参数:
 73         paths:要读取的图片路径列表
 74         img_rows:图片行
 75         img_cols:图片列
 76         color_type:图片颜色通道
 77     返回:
 78         imgs: 图片数组
 79     """
 80     # Load as grayscale
 81     datas = []
 82     for path in paths:
 83         data = np.load(path)
 84         # Reduce size
 85         resized = np.reshape(data, (1, img_cols, img_rows, 1))
 86         resized = resized.astype('float32')
 87         resized /= 255
 88         datas.append(resized)
 89     datas = np.array(datas)
 90     return datas
 91 
 92 
 93 def get_train_batch(train, label, batch_size, img_d, img_w, img_h):
 94     """
 95     参数:
 96         X_train:所有图片路径列表
 97         y_train: 所有图片对应的标签列表
 98         batch_size:批次
 99         img_w:图片宽
100         img_h:图片高
101         color_type:图片类型
102         is_argumentation:是否需要数据增强
103     返回:
104         一个generator,x: 获取的批次图片 y: 获取的图片对应的标签
105     """
106     while 1:
107         for i in range(0, len(train), batch_size):
108             x = get_train_img(train[i:i+batch_size], img_d, img_w, img_h)
109             y = get_label_img(label[i:i+batch_size], img_w, img_h)
110             # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完
111             yield(np.array(x), np.array(y))
112 
113 
114 if __name__ == "__main__":
115     train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
116     print(train_path_list)

 

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!