I am looking for a TensorFlow way of implementing something similar to Python\'s list.index() function.
Given a matrix and a value to find, I want to know the first
Here is another solution to the problem, assuming there is a hit on every row.
import tensorflow as tf
val = 3
m = tf.constant([
[0 , 0, val, 0, val],
[val, 0, val, val, 0],
[0 , val, 0, 0, 0]])
# replace all entries in the matrix either with its column index, or out-of-index-number
match_indices = tf.where( # [[5, 5, 2, 5, 4],
tf.equal(val, m), # [0, 5, 2, 3, 5],
x=tf.range(tf.shape(m)[1]) * tf.ones_like(m), # [5, 1, 5, 5, 5]]
y=(tf.shape(m)[1])*tf.ones_like(m))
result = tf.reduce_min(match_indices, axis=1)
with tf.Session() as sess:
print(sess.run(result)) # [2, 0, 1]