Tensorflow - matmul of input matrix with batch data

后端 未结 5 464
独厮守ぢ
独厮守ぢ 2020-11-30 22:55

I have some data represented by input_x. It is a tensor of unknown size (should be inputted by batch) and each item there is of size n. input

5条回答
  •  情书的邮戳
    2020-11-30 23:29

    The matmul operation only works on matrices (2D tensors). Here are two main approaches to do this, both assume that U is a 2D tensor.

    1. Slice embed into 2D tensors and multiply each of them with U individually. This is probably easiest to do using tf.scan() like this:

      h = tf.scan(lambda a, x: tf.matmul(x, U), embed)
      
    2. On the other hand if efficiency is important it may be better to reshape embed to be a 2D tensor so the multiplication can be done with a single matmul like this:

      embed = tf.reshape(embed, [-1, m])
      h = tf.matmul(embed, U)
      h = tf.reshape(h, [-1, n, c])
      

      where c is the number of columns in U. The last reshape will make sure that h is a 3D tensor where the 0th dimension corresponds to the batch just like the original x_input and embed.

提交回复
热议问题