theano - use tensordot compute dot product of two tensor

醉酒当歌 提交于 2019-11-28 11:35:33

Well, do you want this in numpy or in Theano? In the case, where, as you state, you would like to contract axis 3 of A against axis 2 of B, both are straightforward:

import numpy as np

a = np.arange(3 * 4 * 5).reshape(3, 4, 5).astype('float32')
b = np.arange(3 * 5).reshape(3, 5).astype('float32')

result = a.dot(b.T)

in Theano this writes as

import theano.tensor as T

A = T.ftensor3()
B = T.fmatrix()

out = A.dot(B.T)

out.eval({A: a, B: b})

however, the output then is of shape (3, 4, 3). Since you seem to want an output of shape (3, 4), the numpy alternative uses einsum, like so

einsum_out = np.einsum('ijk, ik -> ij', a, b)

However, einsum does not exist in Theano. So the specific case here can be emulated as follows

out = (a * b[:, np.newaxis]).sum(2)

which can also be written in Theano

out = (A * B.dimshuffle(0, 'x', 1)).sum(2)
out.eval({A: a, B: b})
Joe Kington

In this specific case, einsum is probably easier to understand than tensordot. For example:

c = np.einsum('ijk,ik->ij', a, b)

I'm going to over-simplify the explanation a bit to make things more immediately understandable. We have two input arrays (separated by the comma) and this yields our output array (to the right of the ->).

  • a has shape 3, 4, 5 and we'll refer to it as ijk
  • b has shape 3, 5 (ik)
  • We want the output c to have shape 3, 4 (ij)

Seems a bit magical, right? Let's break that down a bit.

  • The letters we "lose" as we cross the -> are axes that will be summed over. That's what dot is doing, as well.
  • We want output with shape 3, 4, so we're eliminating k
  • Therefore, the output c should be ij
  • This means we'll refer to b as ik.

As a full example:

import numpy as np

a = np.random.random((3, 4, 5))
b = np.random.random((3, 5))

# Looping through things
c1 = []
for i in range(3):
    c1.append(a[i].dot(b[i]))
c1 = np.array(c1)

# Using einsum instead
c2 = np.einsum('ijk,ik->ij', a, b)

assert np.allclose(c1, c2)

You can do this with tensordot as well. I'll add an example of that as soon as I have a bit more time. (Of course, if anyone else would like to add a tensordot example as another answer in the meantime, feel free!)

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!