How does tensorflow batch_matmul work?

前端 未结 6 1090
迷失自我
迷失自我 2020-12-31 10:34

Tensorflow has a function called batch_matmul which multiplies higher dimensional tensors. But I\'m having a hard time understanding how it works, perhaps partially because

6条回答
  •  半阙折子戏
    2020-12-31 10:42

    The answer to this particular answer is using tf.scan function.

    If a = [5,3,2] #dimension of 5 batch, with 3X2 mat in each batch
    and b = [2,3] # a constant matrix to be multiplied with each sample

    then let def fn(a,x): return tf.matmul(x,b)

    initializer = tf.Variable(tf.random_number(3,3))

    h = tf.scan(fn,outputs,initializer)

    this h will store all the outputs.

提交回复
热议问题