Keras tensors - Get values with indices coming from another tensor

别说谁变了你拦得住时间么 提交于 2019-12-12 13:12:15

问题


Suppose I have these two tensors:

  • valueMatrix, shaped as (?, 3), where ? is the batch size
  • indexMatrix, shaped as (?, 1)

I want to retrieve values from valueMatrix at the indices contained in indexMatrix.

Example (pseudocode):

valueMatrix = [[7,15,5],[4,6,8]] -- shape=(2,3) -- type=float 
indexMatrix = [[1],[0]] -- shape = (2,1) -- type=int

I want from this example to do something like:

valueMatrix[indexMatrix] --> returns --> [[15],[4]]

I prefer Tensorflow over other backends, but the answer must be compatible with a Keras model using Lambda layers or other suitable layers for the task.


回答1:


import tensorflow as tf
valueMatrix = tf.constant([[7,15,5],[4,6,8]])
indexMatrix = tf.constant([[1],[0]])

# create the row index with tf.range
row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
# stack with column index
idx = tf.stack([row_idx, indexMatrix], axis=-1)
# extract the elements with gather_nd
values = tf.gather_nd(valueMatrix, idx)

with tf.Session() as sess:
    print(sess.run(values))
#[[15]
# [ 4]]


来源:https://stackoverflow.com/questions/46526869/keras-tensors-get-values-with-indices-coming-from-another-tensor

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!