I am looking for a TensorFlow way of implementing something similar to Python\'s list.index() function.
Given a matrix and a value to find, I want to know the first
Looks a little ugly but works (assuming m
and val
are both tensors):
idx = list()
for t in tf.unpack(m, axis=0):
idx.append(tf.reduce_min(tf.where(tf.equal(t, val))))
idx = tf.pack(idx, axis=0)
EDIT:
As Yaroslav Bulatov mentioned, you could achieve the same result with tf.map_fn
:
def index1d(t):
return tf.reduce_min(tf.where(tf.equal(t, val)))
idx = tf.map_fn(index1d, m, dtype=tf.int64)