I am looking for a TensorFlow way of implementing something similar to Python\'s list.index() function.
Given a matrix and a value to find, I want to know the first
Here is a solution which also considers the case the element is not included by the matrix (solution from github repository of DeepMind)
def get_first_occurrence_indices(sequence, eos_idx):
'''
args:
sequence: [batch, length]
eos_idx: scalar
'''
batch_size, maxlen = sequence.get_shape().as_list()
eos_idx = tf.convert_to_tensor(eos_idx)
tensor = tf.concat(
[sequence, tf.tile(eos_idx[None, None], [batch_size, 1])], axis = -1)
index_all_occurrences = tf.where(tf.equal(tensor, eos_idx))
index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1],
index_all_occurrences[:, 0])
index_first_occurrences.set_shape([batch_size])
index_first_occurrences = tf.minimum(index_first_occurrences + 1, maxlen)
return index_first_occurrences
And:
import tensorflow as tf
mat = tf.Variable([[1,2,3,4,5], [2,3,4,5,6], [3,4,5,6,7], [0,0,0,0,0]], dtype = tf.int32)
idx = 3
first_occurrences = get_first_occurrence_indices(mat, idx)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(first_occurrence) # [3, 2, 1, 5]