Tensorflow, how to multiply a 2D tensor (matrix) by corresponding elements in a 1D vector

后端 未结 2 429
执笔经年
执笔经年 2021-01-03 12:17

I have a 2D matrix M of shape [batch x dim], I have a vector V of shape [batch]. How can I multiply each of the columns i

2条回答
  •  感情败类
    2021-01-03 13:02

    In addition to @Divakar's answer, I would like to make a note that the order of M and V don't matter. It seems that tf.multiply also does broadcasting during multiplication.

    Example:

    In [55]: M.eval()
    Out[55]: 
    array([[1, 2, 3, 4],
           [2, 3, 4, 5],
           [3, 4, 5, 6]], dtype=int32)
    
    In [56]: V.eval()
    Out[56]: array([10, 20, 30], dtype=int32)
    
    In [57]: tf.multiply(M, V[:,tf.newaxis]).eval()
    Out[57]: 
    array([[ 10,  20,  30,  40],
           [ 40,  60,  80, 100],
           [ 90, 120, 150, 180]], dtype=int32)
    
    In [58]: tf.multiply(V[:, tf.newaxis], M).eval()
    Out[58]: 
    array([[ 10,  20,  30,  40],
           [ 40,  60,  80, 100],
           [ 90, 120, 150, 180]], dtype=int32)
    

提交回复
热议问题