What does `tf.strided_slice()` do?

前端 未结 5 1889
隐瞒了意图╮
隐瞒了意图╮ 2020-12-29 05:24

I am wondering what tf.strided_slice() operator actually does.
The doc says,

To a first order, this operation extracts a slice of siz

5条回答
  •  感动是毒
    2020-12-29 06:27

    I experimented a bit with this method, which gave me some insights, which I think might be of some use. let's say we have a tensor.

    a = np.array([[[1, 1.2, 1.3], [2, 2.2, 2.3], [7, 7.2, 7.3]],
                  [[3, 3.2, 3.3], [4, 4.2, 4.3], [8, 8.2, 8.3]],
                  [[5, 5.2, 5.3], [6, 6.2, 6.3], [9, 9.2, 9.3]]]) 
    # a.shape = (3, 3, 3)
    

    strided_slice() requires 4 required arguments input_, begin, end, strides in which we are giving our a as input_ argument. As the case with tf.slice() method, the begin argument is zero-based and rest of args shape-based. However in the docs begin and end both are zero-based.

    The functionality of method is quite simple:
    It works like iterating over a loop, where begin is the location of element in the tensor from where the loop initiates and end is where it stops.

    tf.strided_slice(a, [0, 0, 0], [3, 3, 3], [1, 1, 1])
    
    # output =  the tensor itself
    
    tf.strided_slice(a, [0, 0, 0], [3, 3, 3], [2, 2, 2])
    
    # output = [[[ 1.   1.3]
    #            [ 7.   7.3]]
    #           [[ 5.   5.3]
    #            [ 9.   9.3]]]
    

    strides are like steps over which the loop iterates, here the [2,2,2] makes method to produce values starting at (0,0,0), (0,0,2), (0,2,0), (0,2,2), (2,0,0), (2,0,2) ..... in the a tensor.

    tf.strided_slice(input3, [1, 1, 0], [2, -1, 3], [1, 1, 1]) 
    

    will produce output similar to tf.strided_slice(input3, [1, 1, 0], [2, 2, 3], [1, 1, 1]) as the tensora has shape = (3,3,3).

提交回复
热议问题