What does `tf.strided_slice()` do?

前端 未结 5 1885
隐瞒了意图╮
隐瞒了意图╮ 2020-12-29 05:24

I am wondering what tf.strided_slice() operator actually does.
The doc says,

To a first order, this operation extracts a slice of siz

5条回答
  •  心在旅途
    2020-12-29 06:09

    tf.strided_slice() is used to do numpy style slicing of a tensor variable. It has 4 parameters in general: input, begin, end, strides.The slice continues by adding stride to the begin index until all dimensions are not less than the end. For ex: Let us take a tensor constant named "sample" of dimensions: [3,2,3]

    import tensorflow as tf 
    
    sample = tf.constant(
        [[[11, 12, 13], [21, 22, 23]],
        [[31, 32, 33], [41, 42, 43]],
        [[51, 52, 53], [61, 62, 63]]])
    
    slice = tf.strided_slice(sample, begin=[0,0,0], end=[3,2,3], strides=[2,2,2])
    
    with tf.Session() as sess:
        print(sess.run(slice))
    

    Now, the output will be:

    [[[11 13]]
    
     [[51 53]]]
    

    This is because the striding starts from [0,0,0] and goes to [2,1,2] discarding any non-existent data like:

    [[0,0,0], [0,0,2], [0,2,0], [0,2,2],
    [2,0,0], [2,0,2], [2,2,0], [2,2,2]]
    

    If you use [1,1,1] as strides then it will simply print all the values.

提交回复
热议问题