Looping over a tensor

前端 未结 2 447
野趣味
野趣味 2020-12-31 03:02

I am trying to process a tensor of variable size, in a python way that would be something like:

# X is of shape [m, n]
for x in X:
    process(x)
         


        
2条回答
  •  死守一世寂寞
    2020-12-31 03:38

    To loop over a tensor you could try tf.unstack

    Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.

    So adding 1 to each tensor would look something like:

    import tensorflow as tf
    x = tf.placeholder(tf.float32, shape=(None, 10))
    x_unpacked = tf.unstack(x) # defaults to axis 0, returns a list of tensors
    
    processed = [] # this will be the list of processed tensors
    for t in x_unpacked:
        # do whatever
        result_tensor = t + 1
        processed.append(result_tensor)
    
    output = tf.concat(processed, 0)
    
    with tf.Session() as sess:
        print(sess.run([output], feed_dict={x: np.zeros((5, 10))}))
    

    Obviously you can further unpack each tensor from the list to process it, down to single elements. To avoid lots of nested unpacking though, you could maybe try flattening x with tf.reshape(x, [-1]) first, and then loop over it like

    flattened_unpacked = tf.unstack(tf.reshape(x, [-1])
    for elem in flattened_unpacked:
        process(elem)
    

    In this case elem is a scalar.

提交回复
热议问题