I have a 1D tensor that I wish to partition into overlapping blocks. I\'m thinking of something like:
tensor = tf.constant([1, 2, 3, 4, 5, 6, 7])
You can use tf.nn.conv2d to help. Basically, you take a sliding filter of block_size over the input, stepping by stride. To make all the matrix indexes line up, you have to do some reshaping.
import tensorflow as tf
def overlap(tensor, block_size=3, stride=2):
reshaped = tf.reshape(tensor, [1,1,-1,1])
# Construct diagonal identity matrix for conv2d filters.
ones = tf.ones(block_size, dtype=tf.float32)
ident = tf.diag(ones)
filter_dim = [1, block_size, block_size, 1]
filter_matrix = tf.reshape(ident, filter_dim)
stride_window = [1, 1, stride, 1]
# Save the output tensors of the convolutions
filtered_conv = []
for f in tf.unstack(filter_matrix, axis=1):
reshaped_filter = tf.reshape(f, [1, block_size, 1, 1])
c = tf.nn.conv2d(reshaped, reshaped_filter, stride_window, padding='VALID')
filtered_conv.append(c)
# Put the convolutions into a tensor and squeeze to get rid of extra dimensions.
t = tf.stack(filtered_conv, axis=3)
return tf.squeeze(t)
# Calculate the overlapping strided slice for the input tensor.
tensor = tf.constant([1, 2, 3, 4, 5, 6, 7], dtype=tf.float32)
overlap_tensor = overlap(tensor, block_size=3, stride=2)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
in_t, overlap_t = sess.run([tensor, overlap_tensor])
print 'input tensor:'
print in_t
print 'overlapping strided slice:'
print overlap_t
Should give you the output:
input tensor:
[ 1. 2. 3. 4. 5. 6. 7.]
overlapping strided slice:
[[ 1. 2. 3.]
[ 3. 4. 5.]
[ 5. 6. 7.]]
This is the initial version I got working, which doesn't allow for variable block_size, but I think it's easier to see what's going on with the convolution filters - we take a vector of 3 values, every stride steps.
def overlap(tensor, stride=2):
# Reshape the tensor to allow it to be passed in to conv2d.
reshaped = tf.reshape(tensor, [1,1,-1,1])
# Construct the block_size filters.
filter_dim = [1, -1, 1, 1]
x_filt = tf.reshape(tf.constant([1., 0., 0.]), filter_dim)
y_filt = tf.reshape(tf.constant([0., 1., 0.]), filter_dim)
z_filt = tf.reshape(tf.constant([0., 0., 1.]), filter_dim)
# Stride along the tensor with the above filters.
stride_window = [1, 1, stride, 1]
x = tf.nn.conv2d(reshaped, x_filt, stride_window, padding='VALID')
y = tf.nn.conv2d(reshaped, y_filt, stride_window, padding='VALID')
z = tf.nn.conv2d(reshaped, z_filt, stride_window, padding='VALID')
# Pack the three tensors along 4th dimension.
result = tf.stack([x, y, z], axis=4)
# Squeeze to get rid of the extra dimensions.
result = tf.squeeze(result)
return result