问题
I am trying to create a convnet using pytorch to work on an input of 2d matrices. I am using a 3x5 filter and I want it to have a custom stride as follows - on even line numbers I want the filter to start from the element at position 0 (red in the image), on odd line numbers I want it to start on the element of position 1 (blue in the image), and in both cases have a stride of 2 on the x direction. That means that if I have a matrix as in the image as my input, I want the filter to have only 0s in its center. I know this is very unusual in convnets but this is actually a problem in physics so the exact stride is important.
回答1:
The following custom conv2d layer implements convolutions in a checkerboard stride as indicated in the original question. The difficulty here lies with the fact that pytorch doesn't really support inconsistent strides like this. That said we can break this operation into two separate strided convolutions, one for the even rows, and one for the odd rows. After that we can just interleave the results back together. There are some details in the code below which ensure we pad correctly (if desired). Also, this layer fully supports back-propagation.
import torch.nn as nn
import torch.nn.functional as F
class AMNI_Conv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=True):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, padding=padding)
self.crow = self.conv.kernel_size[0] // 2
self.ccol = self.conv.kernel_size[1] // 2
# this module only works with odd sized kernels
assert self.conv.kernel_size[0] % 2 == 1 and self.conv.kernel_size[1] % 2 == 1
def forward(self, x):
# currently only padding with zeros is supported
if self.conv.padding[0] != 0 or self.conv.padding[1] != 0:
x = F.pad(x, pad=(self.conv.padding[1], self.conv.padding[1], self.conv.padding[0], self.conv.padding[0]))
# center filters on the "zeros" according to the diagram by AMNI, starting column for even/odd rows may need to change depending on padding/kernel size
if (self.crow + self.ccol + self.conv.padding[0] + self.conv.padding[1]) % 2 == 0:
x_even = F.conv2d(x[:, :, :-1, 1:], self.conv.weight, self.conv.bias, stride=2)
x_odd = F.conv2d(x[:, :, 1:, :-1], self.conv.weight, self.conv.bias, stride=2)
else:
x_even = F.conv2d(x[:, :, :-1, :-1], self.conv.weight, self.conv.bias, stride=2)
x_odd = F.conv2d(x[:, :, 1:, 1:], self.conv.weight, self.conv.bias, stride=2)
b, c, h, w = x_even.shape
# interleave even and odd rows back together
return torch.stack((x_even, x_odd), dim=3).contiguous().view(b, c, -1, w)
Example
This layer basically acts like a normal Conv2d but with the checkerboard stride.
>>> x = torch.arange(64).view(1, 1, 8, 8).float()
tensor([[[[ 0., 1., 2., 3., 4., 5., 6., 7.],
[ 8., 9., 10., 11., 12., 13., 14., 15.],
[16., 17., 18., 19., 20., 21., 22., 23.],
[24., 25., 26., 27., 28., 29., 30., 31.],
[32., 33., 34., 35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44., 45., 46., 47.],
[48., 49., 50., 51., 52., 53., 54., 55.],
[56., 57., 58., 59., 60., 61., 62., 63.]]]])
>>> layer = AMNI_Conv2d(1, 1, (3, 5), bias=False)
# set kernels to delta functions to demonstrate kernel centers
>>> with torch.no_grad():
... layer.conv.weight.zero_()
... layer.conv.weight[:,:,1,2] = 1
>>> result = layer(x)
tensor([[[[10., 12.],
[19., 21.],
[26., 28.],
[35., 37.],
[42., 44.],
[51., 53.]]]], grad_fn=<ViewBackward>)
You could also do this with padding to get every "zero" in the original diagram
>>> layer = AMNI_Conv2d(1, 1, (3, 5), padding=(1, 2), bias=False)
# set kernels to delta functions to demonstrate kernel centers
>>> with torch.no_grad():
... layer.conv.weight.zero_()
... layer.conv.weight[:,:,1,2] = 1
>>> result = layer(x)
tensor([[[[ 1., 3., 5., 7.],
[ 8., 10., 12., 14.],
[17., 19., 21., 23.],
[24., 26., 28., 30.],
[33., 35., 37., 39.],
[40., 42., 44., 46.],
[49., 51., 53., 55.],
[56., 58., 60., 62.]]]], grad_fn=<ViewBackward>)
来源:https://stackoverflow.com/questions/59005379/how-can-i-implement-a-checkerboard-stride-for-conv2d-in-pytorch