Convert between NHWC and NCHW in TensorFlow

匿名 (未验证) 提交于 2019-12-03 02:16:02

问题:

What is the best way to convert a tensor from NHWC format to NCHW format, and vice versa?

Is there an op specifically that does this, or will I need to use some combination of the split/concat type operations?

回答1:

All you need to do is a permutation of the dimensions from NHWC to NCHW (or the contrary).

The meaning of each letter might help understand:

  • N: number of images in the batch
  • H: height of the image
  • W: width of the image
  • C: number of channels of the image (ex: 3 for RGB, 1 for grayscale...)

From NHWC to NCHW

The image shape is (N, H, W, C) and we want the output to have shape (N, C, H, W). Therefore we need to apply tf.transpose with a well chosen permutation perm.

The returned tensor's dimension i will correspond to the input dimension perm[i]

perm[0] = 0  # output dimension 0 will be 'N', which was dimension 0 in the input perm[1] = 3  # output dimension 1 will be 'C', which was dimension 3 in the input perm[2] = 1  # output dimension 2 will be 'H', which was dimension 1 in the input perm[3] = 2  # output dimension 3 will be 'W', which was dimension 2 in the input 

In practice:

images_nhwc = tf.placeholder(tf.float32, [None, 200, 300, 3])  # input batch out = tf.transpose(x, [0, 3, 1, 2]) print(out.get_shape())  # the shape of out is [None, 3, 200, 300] 

From NCHW to NHWC

The image shape is (N, C, H, W) and we want the output to have shape (N, H, W, C). Therefore we need to apply tf.transpose with a well chosen permutation perm.

The returned tensor's dimension i will correspond to the input dimension perm[i]

perm[0] = 0  # output dimension 0 will be 'N', which was dimension 0 in the input perm[1] = 2  # output dimension 1 will be 'H', which was dimension 2 in the input perm[2] = 3  # output dimension 2 will be 'W', which was dimension 3 in the input perm[3] = 1  # output dimension 3 will be 'C', which was dimension 1 in the input 

In practice:

images_nchw = tf.placeholder(tf.float32, [None, 3, 200, 300])  # input batch out = tf.transpose(x, [0, 2, 3, 1]) print(out.get_shape())  # the shape of out is [None, 200, 300, 3] 


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!