How to find an index of the first matching element in TensorFlow

前端 未结 4 1080
予麋鹿
予麋鹿 2021-01-01 21:44

I am looking for a TensorFlow way of implementing something similar to Python\'s list.index() function.

Given a matrix and a value to find, I want to know the first

4条回答
  •  醉酒成梦
    2021-01-01 21:56

    Here is a solution which also considers the case the element is not included by the matrix (solution from github repository of DeepMind)

    def get_first_occurrence_indices(sequence, eos_idx):
        '''
        args:
            sequence: [batch, length]
            eos_idx: scalar
        '''
        batch_size, maxlen = sequence.get_shape().as_list()
        eos_idx = tf.convert_to_tensor(eos_idx)
        tensor = tf.concat(
                [sequence, tf.tile(eos_idx[None, None], [batch_size, 1])], axis = -1)
        index_all_occurrences = tf.where(tf.equal(tensor, eos_idx))
        index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
        index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1], 
    index_all_occurrences[:, 0])
        index_first_occurrences.set_shape([batch_size])
        index_first_occurrences = tf.minimum(index_first_occurrences + 1, maxlen)
        
        return index_first_occurrences
    

    And:

    import tensorflow as tf
    mat = tf.Variable([[1,2,3,4,5], [2,3,4,5,6], [3,4,5,6,7], [0,0,0,0,0]], dtype = tf.int32)
    idx = 3
    first_occurrences = get_first_occurrence_indices(mat, idx)
    
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    sess.run(first_occurrence) # [3, 2, 1, 5]
    

提交回复
热议问题