Indexing tensor with index matrix in theano?

独自空忆成欢 提交于 2019-12-06 06:17:31

You can do the following in numpy:

import numpy as np

A = np.arange(4 * 2 * 5).reshape(4, 2, 5)
B = np.arange(4 * 2).reshape(4, 2) % 5

C = A[np.arange(A.shape[0])[:, np.newaxis], np.arange(A.shape[1]), B]

So you can do the same thing in theano:

import theano
import theano.tensor as T

AA = T.tensor3()
BB = T.imatrix()

CC = AA[T.arange(AA.shape[0]).reshape((-1, 1)), T.arange(AA.shape[1]), BB]

f = theano.function([AA, BB], CC)

f(A.astype(theano.config.floatX), B)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!