UNet的tensorflow2实现

穿精又带淫゛_ 提交于 2020-01-14 02:32:54

  本篇博客主要使用tensorflow2实现UNet。
参考博客
在这里插入图片描述
  从图中可以看到,UNet主要有下采样和上采样部分;在向上采样的过程中,会用到下采样过程的特征图。
图中2842284^2之类的表示图面积。

以下为UNet实现(tensorflow)

import tensorflow as tf
import cv2
from tensorflow import keras 
import numpy as np
from tensorflow.keras.layers import Cropping2D, Concatenate, BatchNormalization, Activation, Softmax, Dropout

input_h = 572
input_w = 572
down_feature_list = []
#save down feature
filter_list = [2**i for i in range(6,10)]

def down_sampling(inputs, filters):
    x = keras.layers.Conv2D(filters, [3,3])(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = keras.layers.Conv2D(filters, [3,3])(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    down_feature_list.append(x)
    return keras.layers.MaxPool2D()(x)

def up_sampling(inputs, down_data, filters):
    now_data = keras.layers.Conv2DTranspose(filters, 2, strides=2, padding='valid')(inputs)
    a_h, a_w = now_data.shape[1:3]
    b_h, b_w = down_data.shape[1:3]
    h_delta = b_h - a_h
    w_delta = b_w - a_w
    cropping = ((h_delta//2, h_delta//2), (w_delta//2, w_delta//2))
    crop_data = Cropping2D(cropping)(down_data)
    concat_data = Concatenate()([crop_data, now_data])

    out_data = keras.layers.Conv2D(filters, [3,3])(concat_data)
    out_data = keras.layers.Conv2D(filters, [3,3])(out_data)
    return out_data
    
def Unet():
    inputs = keras.Input(shape=(input_h, input_w, 3), name="input")
    layer = inputs
    for filters in filter_list:
        layer = down_sampling(layer, filters)

    for filter_num in [1024, 512]:
        layer = keras.layers.Conv2D(filter_num, [3,3])(layer)
        layer = BatchNormalization()(layer)
        layer = Activation('relu')(layer)

    for filters in filter_list[::-1]:
        down_feature = down_feature_list.pop()
        layer = up_sampling(layer, down_feature, filters)
        
    layer = keras.layers.Conv2D(2, 1, padding='valid')(layer)
    probabilities  = Softmax()(layer)
    model = tf.keras.models.Model(inputs, probabilities)
    return model

unet = Unet()
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!