I have an autoencoder that takes an image as an input and produces a new image as an output.
The input image (1x1024x1024x3) is split into patches (1024x32x32x3) bef
To specifically address the initial question, which is 'Reconstructing an image after using extract_image_patches', I propose using tf.scatter_nd()
and building a stratified image. This will work even in a situation where there is an overlap in the extracted patches or the image is under-sample. Here is my proposed solution.
import cv2
import numpy as np
import tensorflow as tf
# Function to extract patches using 'extract_image_patches'
def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100):
with tf.variable_scope('im2_patches'):
patches = tf.image.extract_image_patches(
images=raw_input,
ksizes=[1, _patch_size[0], _patch_size[1], 1],
strides=[1, _stride, _stride, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
h = tf.shape(patches)[1]
w = tf.shape(patches)[2]
patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3))
return patches, (h, w)
# Function to reconstruct image
def patches_to_img(update, _block_shape, _stride=100):
with tf.variable_scope('patches2im'):
_h = _block_shape[0]
_w = _block_shape[1]
bs = tf.shape(update)[0] # batch size
np = tf.shape(update)[1] # number of patches
ps_h = tf.shape(update)[2] # patch height
ps_w = tf.shape(update)[3] # patch width
col_ch = tf.shape(update)[4] # Colour channel count
wout = (_w - 1) * _stride + ps_w # Recalculate output shape of "extract_image_patches" including padded pixels
hout = (_h - 1) * _stride + ps_h # Recalculate output shape of "extract_image_patches" including padded pixels
x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h))
x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1))
y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1))
xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride),
tf.range(0, (hout - ps_h) + 1, _stride))
bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1)) # batch indices
yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1)) # y indices
xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1)) # x indices
cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1)) # color indices
dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1)) # shift indices
idx = tf.concat([bb, yy, xx, cc, dd], -1)
stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np))
stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3))
stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np))
stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3))
with tf.variable_scope("consolidate"):
sum_stratified_img = tf.reduce_sum(stratified_img, axis=1)
stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1)
reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count)
return reconstructed_img, stratified_img
if __name__ == "__main__":
# load initial image
image_org = cv2.imread('orig_img.jpg')
# Add batch dimension
image = np.expand_dims(image_org, axis=0)
# set parameters
patch_size = (228, 228)
stride = 200
input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img")
# Extract patches using "extract_image_patches()"
extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride)
# block_shape is the number of patches extracted in the x and in the y dimension
# extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3)
reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride) # Reconstruct Image
with tf.Session() as sess:
ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image})
# print(bs)
si = si.astype(np.int32)
# Show reconstructed image
cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255)
cv2.waitKey(0)
# Show stratified images
for i in range(si.shape[1]):
im_1 = si[0, i, :, :, :]
cv2.imshow('sd', im_1.astype(np.float32)/255)
The above solution should work for batched images of arbirary color channel dimensions.