问题
I am a beginner with TensorFlow, I am trying to implement a function that takes a batch as input. It has to slice this batch into several ones, apply some operations on them, then concatenate them to build a new tensor to return. Through my readings, I found there are some implemented function like input_slice_producer and batch_join but I didn't get to work with them. I attached what I've found as solution below, but it's kinda slow, not proper and incapable of detecting the current size of batch. Does any know a better way of doing this?
def model(x):
W_1 = tf.Variable(tf.random_normal([6,1]),name="W_1")
x_size = x.get_shape().as_list()[0]
# x is a batch of bigger input of shape [None,6], so I couldn't
# get the proper size of the batch when feeding it
if x_size == None:
x_size= batch_size
#intialize the y_res
dummy_x = tf.slice(x,[0,0],[1,6])
result = tf.reduce_sum(tf.mul(dummy_x,W_1))
y_res = tf.zeros([1], tf.float32)
y_res = result
#go throw all slices and concatenate them to get result
for i in range(1,x_size):
dummy_x = tf.slice(x,[i,0],[1,6])
result = tf.reduce_sum(tf.mul(dummy_x,W_1))
y_res = tf.concat(0, [y_res, result])
return y_res
回答1:
The TensorFlow function tf.map_fn(fn, elems) allows you to apply a function (fn
) to each slice of a tensor (elems
). For example, you could express your program as follows:
def model(x):
W_1 = tf.Variable(tf.random_normal([6, 1]), name="W_1")
def fn(x_slice):
return tf.reduce_sum(x_slice, W_1)
return tf.map_fn(fn, x)
It may also be possible to implement your operation more concisely using broadcasting on the tf.mul() operator, which uses NumPy broadcasting semantics, and the axis
argument to tf.reduce_sum().
来源:https://stackoverflow.com/questions/35575982/how-to-slice-a-batch-and-apply-an-operation-on-each-slice-in-tensorflow