Tensorflow, how to multiply a 2D tensor (matrix) by corresponding elements in a 1D vector

后端 未结 2 427
执笔经年
执笔经年 2021-01-03 12:17

I have a 2D matrix M of shape [batch x dim], I have a vector V of shape [batch]. How can I multiply each of the columns i

2条回答
  •  半阙折子戏
    2021-01-03 13:00

    In NumPy, we would need to make V 2D and then let broadcasting do the element-wise multiplication (i.e. Hadamard product). I am guessing, it should be the same on tensorflow. So, for expanding dims on tensorflow, we can use tf.newaxis (on newer versions) or tf.expand_dims or a reshape with tf.reshape -

    tf.multiply(M, V[:,tf.newaxis])
    tf.multiply(M, tf.expand_dims(V,1))
    tf.multiply(M, tf.reshape(V, (-1, 1)))
    

提交回复
热议问题