How to find an index of the first matching element in TensorFlow

前端 未结 4 1070
予麋鹿
予麋鹿 2021-01-01 21:44

I am looking for a TensorFlow way of implementing something similar to Python\'s list.index() function.

Given a matrix and a value to find, I want to know the first

4条回答
  •  慢半拍i
    慢半拍i (楼主)
    2021-01-01 22:00

    Looks a little ugly but works (assuming m and val are both tensors):

    idx = list()
    for t in tf.unpack(m, axis=0):
        idx.append(tf.reduce_min(tf.where(tf.equal(t, val))))
    idx = tf.pack(idx, axis=0)
    

    EDIT: As Yaroslav Bulatov mentioned, you could achieve the same result with tf.map_fn:

    def index1d(t):
        return tf.reduce_min(tf.where(tf.equal(t, val)))
    
    idx = tf.map_fn(index1d, m, dtype=tf.int64)
    

提交回复
热议问题