Reconstructing an image after using extract_image_patches

后端 未结 7 1377
一生所求
一生所求 2020-12-16 14:37

I have an autoencoder that takes an image as an input and produces a new image as an output.

The input image (1x1024x1024x3) is split into patches (1024x32x32x3) bef

7条回答
  •  -上瘾入骨i
    2020-12-16 15:26

    To specifically address the initial question, which is 'Reconstructing an image after using extract_image_patches', I propose using tf.scatter_nd() and building a stratified image. This will work even in a situation where there is an overlap in the extracted patches or the image is under-sample. Here is my proposed solution.

    import cv2
    import numpy as np
    import tensorflow as tf
    
    # Function to extract patches using 'extract_image_patches'
    def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100):
    
        with tf.variable_scope('im2_patches'):
            patches = tf.image.extract_image_patches(
                images=raw_input,
                ksizes=[1, _patch_size[0], _patch_size[1], 1],
                strides=[1, _stride, _stride, 1],
                rates=[1, 1, 1, 1],
                padding='SAME'
            )
    
            h = tf.shape(patches)[1]
            w = tf.shape(patches)[2]
            patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3))
        return patches, (h, w)
    
    
    # Function to reconstruct image
    def patches_to_img(update, _block_shape, _stride=100):
        with tf.variable_scope('patches2im'):
            _h = _block_shape[0]
            _w = _block_shape[1]
    
            bs = tf.shape(update)[0]  # batch size
            np = tf.shape(update)[1]  # number of patches
            ps_h = tf.shape(update)[2]  # patch height
            ps_w = tf.shape(update)[3]  # patch width
            col_ch = tf.shape(update)[4]  # Colour channel count
    
            wout = (_w - 1) * _stride + ps_w  # Recalculate output shape of "extract_image_patches" including padded pixels
            hout = (_h - 1) * _stride + ps_h  # Recalculate output shape of "extract_image_patches" including padded pixels
    
            x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h))
            x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1))
            y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1))
            xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride),
                                         tf.range(0, (hout - ps_h) + 1, _stride))
    
            bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1))  #  batch indices
            yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1))  # y indices
            xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1))  # x indices
            cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1))  # color indices
            dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1))  # shift indices
    
            idx = tf.concat([bb, yy, xx, cc, dd], -1)
    
            stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np))
            stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3))
    
            stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np))
            stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3))
    
            with tf.variable_scope("consolidate"):
                sum_stratified_img = tf.reduce_sum(stratified_img, axis=1)
                stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1)
                reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count)
    
            return reconstructed_img, stratified_img
    
    
    
    if __name__ == "__main__":
    
        # load initial image
        image_org = cv2.imread('orig_img.jpg')
        # Add batch dimension
        image = np.expand_dims(image_org, axis=0)
    
        # set parameters
        patch_size = (228, 228)
        stride = 200
    
        input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img")
    
        # Extract patches using "extract_image_patches()"
        extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride)
        # block_shape is the number of patches extracted in the x and in the y dimension
        # extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3)
    
        reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride)  # Reconstruct Image
    
    
        with tf.Session() as sess:
            ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image})
            # print(bs)
        si = si.astype(np.int32)
    
        # Show reconstructed image
        cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255)
        cv2.waitKey(0)
    
        # Show stratified images
        for i in range(si.shape[1]):
    
            im_1 = si[0, i, :, :, :]
            cv2.imshow('sd', im_1.astype(np.float32)/255)
    

    The above solution should work for batched images of arbirary color channel dimensions.

提交回复
热议问题