How to find an index of the first matching element in TensorFlow

前端 未结 4 1071
予麋鹿
予麋鹿 2021-01-01 21:44

I am looking for a TensorFlow way of implementing something similar to Python\'s list.index() function.

Given a matrix and a value to find, I want to know the first

4条回答
  •  轻奢々
    轻奢々 (楼主)
    2021-01-01 22:15

    Here is another solution to the problem, assuming there is a hit on every row.

    import tensorflow as tf
    
    val = 3
    m = tf.constant([
        [0  ,   0,   val,   0, val],
        [val,   0,   val, val,   0],
        [0  , val,     0,   0,   0]])
    
    # replace all entries in the matrix either with its column index, or out-of-index-number
    match_indices = tf.where(                          # [[5, 5, 2, 5, 4],
        tf.equal(val, m),                              #  [0, 5, 2, 3, 5],
        x=tf.range(tf.shape(m)[1]) * tf.ones_like(m),  #  [5, 1, 5, 5, 5]]
        y=(tf.shape(m)[1])*tf.ones_like(m))
    
    result = tf.reduce_min(match_indices, axis=1)
    
    with tf.Session() as sess:
        print(sess.run(result)) # [2, 0, 1]
    

提交回复
热议问题