问题
To build up a capsule network training script, I need to compute many small matrix-vector multiplications.
The size of each weight matrix is at most 20 by 20.
The number of weight matrices is more more than 900.
I'm curious tf.matmul
or tf.linalg.matvec
is the best option for this.
Could anybody give me a hint to optimize the training script?
回答1:
EDIT:
Looking at the notebook that you are referring to, it seems you have the following parameters:
batch_size = 50
caps1_n_caps = 1152
caps1_n_dims = 8
caps2_n_caps = 10
caps2_n_dims = 16
And then you have a tensor w
with shape (caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims)
(in the notebook it has an initial dimension with size 1
that I am skipping) and another tensor caps1_output
with shape (batch_size, caps1_n_caps, caps1_n_dims)
. And you need to combine them to produce caps2_predicted
with shape (batch_size, caps1_n_caps, caps1_n_dims, caps2_n_dims)
.
In the notebook they tile the tensors in order to operate them with tf.linalg.matmul, but actually you can compute the same result without any tiling just using tf.einsum:
import tensorflow as tf
batch_size = 50
caps1_n_caps = 1152
caps1_n_dims = 8
caps2_n_caps = 10
caps2_n_dims = 16
w = tf.zeros((caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims), dtype=tf.float32)
caps1_output = tf.zeros((batch_size, caps1_n_caps, caps1_n_dims), dtype=tf.float32)
caps2_predicted = tf.einsum('ijkl,bil->bilk', w, caps1_output)
print(caps2_predicted.shape)
# (50, 1152, 8, 16)
I'm not sure if I have understood exactly what you want, but you say you want to compute something like:
ûij = Wij × ui
For a collection of several matrices W and vectors u. Assuming you have 900 matrices and vectors, matrices have size 20×20 and vectors have size 20, you can represent them as two tensors, ws
, with shape (900, 20, 20)
, and us
, with shape (900, 20)
. If you do that, you result us_hat
, with shape (900, 20, 20)
, would be computed simply as:
us_hat = ws * tf.expand_dims(us, axis=-1)
来源:https://stackoverflow.com/questions/64336531/optimizing-tensorflow-for-many-small-matrix-vector-multiplications