How to explicitly broadcast a tensor to match another's shape in tensorflow?

前端 未结 4 1914
爱一瞬间的悲伤
爱一瞬间的悲伤 2021-02-13 09:17

I have three tensors, A, B and C in tensorflow, A and B are both of shape (m, n, r), C is a binary tensor of sha

4条回答
  •  情深已故
    2021-02-13 09:56

    In the newest tensorflow version(2.0), you can use tf.broadcast_to as below:

    import tensorflow as tf
    
    A = tf.random_normal([20, 100, 10])
    B = tf.random_normal([20, 100, 10])
    C = tf.random_normal([20, 100, 1])
    C = tf.greater_equal(C, tf.zeros_like(C))
    C = tf.broadcast_to(C, A.shape)
    
    D = tf.where(C,A,B)
    

提交回复
热议问题