What does TensorFlow's `conv2d_transpose()` operation do?

后端 未结 6 776
悲&欢浪女
悲&欢浪女 2021-01-30 04:07

The documentation for the conv2d_transpose() operation does not clearly explain what it does:

The transpose of conv2d.

This opera

6条回答
  •  感动是毒
    2021-01-30 04:53

    Here's a simple explanation of what is going on in a special case that is used in U-Net - that's one of the main use cases for transposed convolution.

    We're interested in the following layer:

    Conv2DTranspose(64, (2, 2), strides=(2, 2))
    

    What does this layer do exactly? Can we reproduce its work?

    Here’s the answer:

    • First of all the default padding in this case is valid. This means we have no padding.
    • The size of the output will be 2 times bigger: if input (m, n), output will be (2m, 2n). Why is that? See the next point.
    • Take the first element from the input and multiply by the filter weights with shape (2,2). Put it into the output. Take the next element, multiply and put in the output next to the first result without overlapping. Why is that? We have strides (2, 2).

    Here's an example input and output (see details here and here):

    In [15]: X.reshape(n, m)
    Out[15]:
    array([[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14]])
    In [16]: y_resh
    Out[16]:
    array([[ 0.,  0.,  1.,  1.,  2.,  2.,  3.,  3.,  4.,  4.],
           [ 0.,  0.,  1.,  1.,  2.,  2.,  3.,  3.,  4.,  4.],
           [ 5.,  5.,  6.,  6.,  7.,  7.,  8.,  8.,  9.,  9.],
           [ 5.,  5.,  6.,  6.,  7.,  7.,  8.,  8.,  9.,  9.],
           [10., 10., 11., 11., 12., 12., 13., 13., 14., 14.],
           [10., 10., 11., 11., 12., 12., 13., 13., 14., 14.]], dtype=float32)
    

    This slide from Stanford's cs231n is useful for our question:

提交回复
热议问题