How does tensorflow batch_matmul work?

前端 未结 6 1068
迷失自我
迷失自我 2020-12-31 10:34

Tensorflow has a function called batch_matmul which multiplies higher dimensional tensors. But I\'m having a hard time understanding how it works, perhaps partially because

6条回答
  •  北海茫月
    2020-12-31 10:39

    First of all tf.batch_matmul() was removed and no longer available. Now you suppose to use tf.matmul():

    The inputs must be matrices (or tensors of rank > 2, representing batches of matrices), with matching inner dimensions, possibly after transposition.

    So let's assume you have the following code:

    import tensorflow as tf
    batch_size, n, m, k = 10, 3, 5, 2
    A = tf.Variable(tf.random_normal(shape=(batch_size, n, m)))
    B = tf.Variable(tf.random_normal(shape=(batch_size, m, k)))
    tf.matmul(A, B)
    

    Now you will receive a tensor of the shape (batch_size, n, k). Here is what is going on here. Assume you have batch_size of matrices nxm and batch_size of matrices mxk. Now for each pair of them you calculate nxm X mxk which gives you an nxk matrix. You will have batch_size of them.

    Notice that something like this is also valid:

    A = tf.Variable(tf.random_normal(shape=(a, b, n, m)))
    B = tf.Variable(tf.random_normal(shape=(a, b, m, k)))
    tf.matmul(A, B)
    

    and will give you a shape (a, b, n, k)

提交回复
热议问题