
该模块融合了4种不同金字塔尺度的特征,第一行红色是最粗糙的特征–全局池化生成单个bin输出,后面三行是不同尺度的池化特征。
为了保证全局特征的权重,如果金字塔共有N个级别,则在每个级别后使用1×1 1×11×1的卷积将对于级别通道降为原本的1/N。再通过双线性插值获得未池化前的大小,最终concat到一起。
1
import torch
2 import torch.nn.functional as F
3 from torch import nn
4 from torchvision import models
5
6 from utils import initialize_weights
7 from utils.misc import Conv2dDeformable
8 from .config import res101_path
9
10 //金字塔模块,将从前面卷积结构提取的特征分别进行不同的池化操作,得到不同感受野以及全局语境信息(或者叫做不同层级的信息)
11 class _PyramidPoolingModule(nn.Module):
12 def __init__(self, in_dim, reduction_dim, setting):
13 super(_PyramidPoolingModule, self).__init__()
14 self.features = []
15 for s in setting: //对应不同的池化操作,单个bin,多个bin
16 self.features.append(nn.Sequential(
17 nn.AdaptiveAvgPool2d(s),
18 nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
19 nn.BatchNorm2d(reduction_dim, momentum=.95),
20 nn.ReLU(inplace=True)
21 ))
22 self.features = nn.ModuleList(self.features)
23
24 def forward(self, x):
25 x_size = x.size()
26 out = [x]
27 for f in self.features:
28 out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
29 out = torch.cat(out, 1)
30 return out
31
32 //整个pspnet网络的结构
33 class PSPNet(nn.Module):
34 def __init__(self, num_classes, pretrained=True, use_aux=True):
35 super(PSPNet, self).__init__()
36 self.use_aux = use_aux
37 resnet = models.resnet101() //采用resnet101作为骨干模型,提取特征
38 if pretrained:
39 resnet.load_state_dict(torch.load(res101_path))
40 self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
41 self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
42 //设置带洞卷积的参数(dilation),以及卷积的参数
43 for n, m in self.layer3.named_modules():
44 if 'conv2' in n:
45 m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
46 elif 'downsample.0' in n:
47 m.stride = (1, 1)
48 for n, m in self.layer4.named_modules():
49 if 'conv2' in n:
50 m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
51 elif 'downsample.0' in n:
52 m.stride = (1, 1)
53 //加入ppm模块,以及最后的连接层(卷积)
54 self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
55 self.final = nn.Sequential(
56 nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
57 nn.BatchNorm2d(512, momentum=.95),
58 nn.ReLU(inplace=True),
59 nn.Dropout(0.1),
60 nn.Conv2d(512, num_classes, kernel_size=1)
61 )
62
63 if use_aux:
64 self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
65 initialize_weights(self.aux_logits)
66 # 初始化权重
67 initialize_weights(self.ppm, self.final)
68
69 def forward(self, x):
70 x_size = x.size()
71 x = self.layer0(x)
72 x = self.layer1(x)
73 x = self.layer2(x)
74 x = self.layer3(x)
75 if self.training and self.use_aux:
76 aux = self.aux_logits(x)
77 x = self.layer4(x)
78 x = self.ppm(x)
79 x = self.final(x)
80 if self.training and self.use_aux:
81 return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
82 return F.upsample(x, x_size[2:], mode='bilinear')
来源:oschina
链接:https://my.oschina.net/u/4257871/blog/3274905