How to find an index of the first matching element in TensorFlow

前端 未结 4 1083
予麋鹿
予麋鹿 2021-01-01 21:44

I am looking for a TensorFlow way of implementing something similar to Python\'s list.index() function.

Given a matrix and a value to find, I want to know the first

4条回答
  •  暗喜
    暗喜 (楼主)
    2021-01-01 22:10

    It seems that tf.argmax works like np.argmax (according to the test), which will return the first index when there are multiple occurrences of the max value. You can use tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1) to get what you want. However, currently the behavior of tf.argmax is undefined in case of multiple occurrences of the max value.

    If you are worried about undefined behavior, you can apply tf.argmin on the return value of tf.where as @Igor Tsvetkov suggested. For example,

    # test with tensorflow r1.0
    import tensorflow as tf
    
    val = 3
    m = tf.placeholder(tf.int32)
    m_feed = [[0  ,   0, val,   0, val],
              [val,   0, val, val,   0],
              [0  , val,   0,   0,   0]]
    
    tmp_indices = tf.where(tf.equal(m, val))
    result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])
    
    with tf.Session() as sess:
        print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]
    

    Note that tf.segment_min will raise InvalidArgumentError when there is some row containing no val. In your code row_elems.index(val) will raise exception too when row_elems don't contain val.

提交回复
热议问题