Tensorflow - matmul of input matrix with batch data

后端 未结 5 462
独厮守ぢ
独厮守ぢ 2020-11-30 22:55

I have some data represented by input_x. It is a tensor of unknown size (should be inputted by batch) and each item there is of size n. input

5条回答
  •  误落风尘
    2020-11-30 23:44

    As answered by @Stryke, there are two ways to achieve this: 1. Scanning, and 2. Reshaping

    1. tf.scan requires lambda functions and is generally used for recursive operations. Some examples for the same are here: https://rdipietro.github.io/tensorflow-scan-examples/

    2. I personally prefer reshaping, since it is more intuitive. If you are trying to matrix multiply each matrix in the 3D tensor by the matrix that is the 2D tensor, like Cijl = Aijk * Bkl, you can do it with a simple reshape.

      A' = tf.reshape(Aijk,[i*j,k])
      C' = tf.matmul(A',Bkl)
      C = tf.reshape(C',[i,j,l])
      

提交回复
热议问题