问题
Now, I am working on a work about registration using deep learning with the Keras backends. The state of task is that finish the registration between two images fixed
and moving
. Finally I get a deformation field D(200,200,2)
where 200
is image size and 2
represents the offset of each pixel dx, dy, dz
.I should apply D
on moving
and calculate loss with fixed
.
The problem is that is there a way that I can arrange the pixels in moving
according to D
in Keras model?
回答1:
You should be able to implement the deformation using tf.contrib.resampler.resampler. It should just be something like tf.contrib.resampler.resampler(moving, D)
, although you should note that it expects moving
to be in the format (batch_size, height, width, num_channels)
but then D[..., 0]
is expected to contain width coordinates and D[..., 1]
height coordinates. The operation implements gradients for both inputs, so it should work fine for training in any case.
If you don't want to use tf.contrib
because it is going to be removed from TensorFlow, you may roll your own implementation of the bilinear interpolation. This is how it may look like:
import tensorflow as tf
def deform(moving, deformation):
# Performs bilinear interpolation
s = tf.shape(moving)
b, h, w = s[0], s[1], s[2]
grid_b, grid_h, grid_w = tf.meshgrid(
tf.range(b), tf.range(h), tf.range(w), indexing='ij')
idx = tf.cast(tf.stack([grid_h, grid_w], axis=-1), deformation.dtype) + deformation
idx_floor = tf.floor(idx)
clip_low = tf.zeros([2], dtype=tf.int32)
clip_high = tf.cast([h, w], dtype=tf.int32)
# 0 0
idx_00 = tf.clip_by_value(tf.cast(idx_floor, tf.int32), clip_low, clip_high)
idx_batch = tf.expand_dims(grid_b, -1)
idx_batch_00 = tf.concat([idx_batch, idx_00], axis=-1)
moved_00 = tf.gather_nd(moving, idx_batch_00)
# 0 1
idx_01 = tf.clip_by_value(idx_00 + [0, 1], clip_low, clip_high)
idx_batch_01 = tf.concat([idx_batch, idx_01], axis=-1)
moved_01 = tf.gather_nd(moving, idx_batch_01)
# 1 0
idx_10 = tf.clip_by_value(idx_00 + [1, 0], clip_low, clip_high)
idx_batch_10 = tf.concat([idx_batch, idx_10], axis=-1)
moved_10 = tf.gather_nd(moving, idx_batch_10)
# 1 1
idx_11 = tf.clip_by_value(idx_00 + 1, clip_low, clip_high)
idx_batch_11 = tf.concat([idx_batch, idx_11], axis=-1)
moved_11 = tf.gather_nd(moving, idx_batch_11)
# Interpolate
alpha = idx - idx_floor
alpha_h = alpha[..., 0]
alpha_h_1 = 1 - alpha_h
alpha_w = alpha[..., 1]
alpha_w_1 = 1 - alpha_w
moved_0 = moved_00 * alpha_w_1 + moved_01 * alpha_w
moved_1 = moved_10 * alpha_w_1 + moved_11 * alpha_w
moved = moved_0 * alpha_h_1 + moved_1 * alpha_h
return moved
Interestingly, this shouldn't actually work, yet it probably does. The gradients will be estimated from the pixel values of the interpolated coordinates, meaning it will be more precise when deformation values are closer to a midpoint between two pixels than to the exact position of a pixel. However, for most images the difference is probably negligible.
If you want a more principled approach, you can use tf.custom_gradient to interpolate better pixel-wise gradient estimations:
import tensorflow as tf
@tf.custom_gradient
def deform(moving, deformation):
# Same code as before...
# Gradient function
def grad(dy):
moving_pad = tf.pad(moving, [[0, 0], [1, 1], [1, 1], [0, 0]], 'SYMMETRIC')
# Diff H
moving_dh_ref = moving_pad[:, 1:, 1:-1] - moving_pad[:, :-1, 1:-1]
moving_dh_ref = (moving_dh_ref[:, :-1] + moving_dh_ref[:, 1:]) / 2
moving_dh_0 = tf.gather_nd(moving_dh_ref, idx_batch_00)
moving_dh_1 = tf.gather_nd(moving_dh_ref, idx_batch_10)
moving_dh = moving_dh_1 * alpha_h_1 + moving_dh_1 * alpha_h
# Diff W
moving_dw_ref = moving_pad[:, 1:-1, 1:] - moving_pad[:, 1:-1, :-1]
moving_dw_ref = (moving_dw_ref[:, :, :-1] + moving_dw_ref[:, :, 1:]) / 2
moving_dw_0 = tf.gather_nd(moving_dw_ref, idx_batch_00)
moving_dw_1 = tf.gather_nd(moving_dw_ref, idx_batch_01)
moving_dw = moving_dw_1 * alpha_w_1 + moving_dw_1 * alpha_w
# Gradient of deformation
deformation_grad = tf.stack([tf.reduce_sum(dy * moving_dh, axis=-1),
tf.reduce_sum(dy * moving_dw, axis=-1)], axis=-1)
# Gradient of moving would be computed by applying the inverse deformation to dy
# or just resorting to standard TensorFlow gradient, if needed
return None, deformation_grad
return moved, grad
来源:https://stackoverflow.com/questions/55099884/arrange-each-pixel-of-a-tensor-according-to-another-tensor