replicate a row tensor using tf.tile?

前端 未结 6 1391
故里飘歌
故里飘歌 2021-01-01 17:19

I have a tensor which is simply a vector, vector = [0.5 0.4] and tf.shape indicates that it has shape=(1,), I would like to replicate the vector m times and hav

6条回答
  •  爱一瞬间的悲伤
    2021-01-01 17:31

    I assume that the main use case of such replication is to match the dimensionality of two tensors (that you want to multiply?).

    In that case, there is a much simpler solution. Let the tensorflow do the work of dimensionality matching for you:

    import tensorflow as tf
    tf.enable_eager_execution()
    
    a = tf.constant([1, 2, 3])  # shape=(3)
    b = tf.constant([[[1, 3], [1, 3], [1, 3]], [[2, 0], [2, 0], [2, 0]]])  # shape=(2, 3, 2)
    
    print(tf.einsum('ijk,j->ijk', b, a))
    
    # OUTPUT:
    # tf.Tensor(
    # [[[1 3]
    #   [2 6]
    #   [3 9]]
    # 
    #  [[2 0]
    #   [4 0]
    #   [6 0]]], shape=(2, 3, 2), dtype=int32)
    

    As you can see it can work for much more complex situations: when you need to replicate on both first and last dimensions, when you are working with more complex shapes, etc. All you need to do is match the indices in the string description (above we match the dimension of a, labeled j with the second dimension of b (ijk).

    Another example use case: I have a state per neuron, and since we simulate in batches, this state has dimensionality (n_batch, n_neuron). I need to use this state to modulate connections between neurons (weights of synapses), which in my case had additional dimension so they have the dimensionality (n_neuron, n_neuron, n_X).

    Instead of making a mess with tiling, reshaping, etc. I can just write it in a single line like so:

    W_modulated = tf.einsum('ijk,bi->bijk', self.W, ux)
    

提交回复
热议问题