Tensorflow - matmul of input matrix with batch data

后端 未结 5 452
独厮守ぢ
独厮守ぢ 2020-11-30 22:55

I have some data represented by input_x. It is a tensor of unknown size (should be inputted by batch) and each item there is of size n. input

5条回答
  •  旧巷少年郎
    2020-11-30 23:43

    Previous answers are obsolete. Currently tf.matmul() support tensors with rank > 2:

    The inputs must be matrices (or tensors of rank > 2, representing batches of matrices), with matching inner dimensions, possibly after transposition.

    Also tf.batch_matmul() was removed and tf.matmul() is the right way to do batch multiplication. The main idea can be understood from the following code:

    import tensorflow as tf
    batch_size, n, m, k = 10, 3, 5, 2
    A = tf.Variable(tf.random_normal(shape=(batch_size, n, m)))
    B = tf.Variable(tf.random_normal(shape=(batch_size, m, k)))
    tf.matmul(A, B)
    

    Now you will receive a tensor of the shape (batch_size, n, k). Here is what is going on here. Assume you have batch_size of matrices nxm and batch_size of matrices mxk. Now for each pair of them you calculate nxm X mxk which gives you an nxk matrix. You will have batch_size of them.

    Notice that something like this is also valid:

    A = tf.Variable(tf.random_normal(shape=(a, b, n, m)))
    B = tf.Variable(tf.random_normal(shape=(a, b, m, k)))
    tf.matmul(A, B)
    

    and will give you a shape (a, b, n, k)

提交回复
热议问题