Tensorflow has a function called batch_matmul which multiplies higher dimensional tensors. But I\'m having a hard time understanding how it works, perhaps partially because
First of all tf.batch_matmul()
was removed and no longer available. Now you suppose to use tf.matmul():
The inputs must be matrices (or tensors of rank > 2, representing batches of matrices), with matching inner dimensions, possibly after transposition.
So let's assume you have the following code:
import tensorflow as tf
batch_size, n, m, k = 10, 3, 5, 2
A = tf.Variable(tf.random_normal(shape=(batch_size, n, m)))
B = tf.Variable(tf.random_normal(shape=(batch_size, m, k)))
tf.matmul(A, B)
Now you will receive a tensor of the shape (batch_size, n, k)
. Here is what is going on here. Assume you have batch_size
of matrices nxm
and batch_size
of matrices mxk
. Now for each pair of them you calculate nxm X mxk
which gives you an nxk
matrix. You will have batch_size
of them.
Notice that something like this is also valid:
A = tf.Variable(tf.random_normal(shape=(a, b, n, m)))
B = tf.Variable(tf.random_normal(shape=(a, b, m, k)))
tf.matmul(A, B)
and will give you a shape (a, b, n, k)