Tensorflow: Slicing a Tensor into overlapping blocks

前端 未结 6 1477
我在风中等你
我在风中等你 2021-01-02 06:26

I have a 1D tensor that I wish to partition into overlapping blocks. I\'m thinking of something like: tensor = tf.constant([1, 2, 3, 4, 5, 6, 7])



        
6条回答
  •  萌比男神i
    2021-01-02 07:09

    Here is a relatively straight forward approach using your example:

    def overlapping_blocker(tensor,block_size,stride):
        blocks = []
        n = tensor.get_shape().as_list()[0]
        ilo = range(0, n, stride)
        ihi = range(block_size, n+1, stride)
        ilohi = zip(ilo, ihi).
        for ilo, ihi in ilohi:
            blocks.append(tensor[ilo:ihi])
        return(tf.pack(blocks, 0))
    
    with tf.Session() as sess:
        tensor = tf.constant([1., 2., 3., 4., 5., 6., 7.])
        block_tensor = overlapping_blocker(tensor, 3, 2)
        print(sess.run(block_tensor))
    

    Output:

    [[ 1.  2.  3.]
     [ 3.  4.  5.]
     [ 5.  6.  7.]]
    

提交回复
热议问题