接上文,经过了openpose的原理简单的解析,这一节我们主要进行code的解析。
CODE解析
我们主要参考的代码是https://github.com/tensorboy/pytorch_Realtime_Multi-Person_Pose_Estimation,代码写的很好,我们主要看的是demo/picture_demo.py
首先我们看下效果,
作图表示输入的图片,酷酷的四字弟弟,右图是出来的关键点,我们根据demo中的单张图片的前向过程讲解下模型的inference阶段,在简单的过一下训练的过程。
而在inference的阶段中,我们主要看这几个关键的函数,我们把这几个函数扒出来单独介绍下。其中相对重要的几个函数,我们主要进行了标红,其中权重文件,在这个git里面有相关的下载连接。

咱们主要看这这几部分
1. model = get_model('vgg19')
表示在上面所述的网络的示意图中,那个F使用的是vgg19,提取到的feature maps。此函数在lib/network/rtpose_vgg.py之中。
1 """CPM Pytorch Implementation"""
2
3 from collections import OrderedDict
4
5 import torch
6 import torch.nn as nn
7 import torch.nn.functional as F
8 import torch.utils.data as data
9 import torch.utils.model_zoo as model_zoo
10 from torch.autograd import Variable
11 from torch.nn import init
12
13 def make_stages(cfg_dict):
14 """Builds CPM stages from a dictionary
15 Args:
16 cfg_dict: a dictionary
17 """
18 layers = []
19 for i in range(len(cfg_dict) - 1):
20 one_ = cfg_dict[i]
21 for k, v in one_.items():
22 if 'pool' in k:
23 layers += [nn.MaxPool2d(kernel_size=v[0], stride=v[1],
24 padding=v[2])]
25 else:
26 conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
27 kernel_size=v[2], stride=v[3],
28 padding=v[4])
29 layers += [conv2d, nn.ReLU(inplace=True)]
30 one_ = list(cfg_dict[-1].keys())
31 k = one_[0]
32 v = cfg_dict[-1][k]
33 conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
34 kernel_size=v[2], stride=v[3], padding=v[4])
35 layers += [conv2d]
36 return nn.Sequential(*layers)
37
38
39 def make_vgg19_block(block):
40 """Builds a vgg19 block from a dictionary
41 Args:
42 block: a dictionary
43 """
44 layers = []
45 for i in range(len(block)):
46 one_ = block[i]
47 for k, v in one_.items():
48 if 'pool' in k:
49 layers += [nn.MaxPool2d(kernel_size=v[0], stride=v[1],
50 padding=v[2])]
51 else:
52 conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
53 kernel_size=v[2], stride=v[3],
54 padding=v[4])
55 layers += [conv2d, nn.ReLU(inplace=True)]
56 return nn.Sequential(*layers)
57
58
59
60 def get_model(trunk='vgg19'):
61 """Creates the whole CPM model
62 Args:
63 trunk: string, 'vgg19' or 'mobilenet'
64 Returns: Module, the defined model
65 """
66 blocks = {}
67 # block0 is the preprocessing stage
68 if trunk == 'vgg19':
69 block0 = [{'conv1_1': [3, 64, 3, 1, 1]},
70 {'conv1_2': [64, 64, 3, 1, 1]},
71 {'pool1_stage1': [2, 2, 0]},
72 {'conv2_1': [64, 128, 3, 1, 1]},
73 {'conv2_2': [128, 128, 3, 1, 1]},
74 {'pool2_stage1': [2, 2, 0]},
75 {'conv3_1': [128, 256, 3, 1, 1]},
76 {'conv3_2': [256, 256, 3, 1, 1]},
77 {'conv3_3': [256, 256, 3, 1, 1]},
78 {'conv3_4': [256, 256, 3, 1, 1]},
79 {'pool3_stage1': [2, 2, 0]},
80 {'conv4_1': [256, 512, 3, 1, 1]},
81 {'conv4_2': [512, 512, 3, 1, 1]},
82 {'conv4_3_CPM': [512, 256, 3, 1, 1]},
83 {'conv4_4_CPM': [256, 128, 3, 1, 1]}]
84
85 elif trunk == 'mobilenet':
86 block0 = [{'conv_bn': [3, 32, 2]}, # out: 3, 32, 184, 184
87 {'conv_dw1': [32, 64, 1]}, # out: 32, 64, 184, 184
88 {'conv_dw2': [64, 128, 2]}, # out: 64, 128, 92, 92
89 {'conv_dw3': [128, 128, 1]}, # out: 128, 256, 92, 92
90 {'conv_dw4': [128, 256, 2]}, # out: 256, 256, 46, 46
91 {'conv4_3_CPM': [256, 256, 1, 3, 1]},
92 {'conv4_4_CPM': [256, 128, 1, 3, 1]}]
93
94 # Stage 1
95 blocks['block1_1'] = [{'conv5_1_CPM_L1': [128, 128, 3, 1, 1]},
96 {'conv5_2_CPM_L1': [128, 128, 3, 1, 1]},
97 {'conv5_3_CPM_L1': [128, 128, 3, 1, 1]},
98 {'conv5_4_CPM_L1': [128, 512, 1, 1, 0]},
99 {'conv5_5_CPM_L1': [512, 38, 1, 1, 0]}]
100
101 blocks['block1_2'] = [{'conv5_1_CPM_L2': [128, 128, 3, 1, 1]},
102 {'conv5_2_CPM_L2': [128, 128, 3, 1, 1]},
103 {'conv5_3_CPM_L2': [128, 128, 3, 1, 1]},
104 {'conv5_4_CPM_L2': [128, 512, 1, 1, 0]},
105 {'conv5_5_CPM_L2': [512, 19, 1, 1, 0]}]
106
107 # Stages 2 - 6
108 for i in range(2, 7):
109 blocks['block%d_1' % i] = [
110 {'Mconv1_stage%d_L1' % i: [185, 128, 7, 1, 3]},
111 {'Mconv2_stage%d_L1' % i: [128, 128, 7, 1, 3]},
112 {'Mconv3_stage%d_L1' % i: [128, 128, 7, 1, 3]},
113 {'Mconv4_stage%d_L1' % i: [128, 128, 7, 1, 3]},
114 {'Mconv5_stage%d_L1' % i: [128, 128, 7, 1, 3]},
115 {'Mconv6_stage%d_L1' % i: [128, 128, 1, 1, 0]},
116 {'Mconv7_stage%d_L1' % i: [128, 38, 1, 1, 0]}
117 ]
118
119 blocks['block%d_2' % i] = [
120 {'Mconv1_stage%d_L2' % i: [185, 128, 7, 1, 3]},
121 {'Mconv2_stage%d_L2' % i: [128, 128, 7, 1, 3]},
122 {'Mconv3_stage%d_L2' % i: [128, 128, 7, 1, 3]},
123 {'Mconv4_stage%d_L2' % i: [128, 128, 7, 1, 3]},
124 {'Mconv5_stage%d_L2' % i: [128, 128, 7, 1, 3]},
125 {'Mconv6_stage%d_L2' % i: [128, 128, 1, 1, 0]},
126 {'Mconv7_stage%d_L2' % i: [128, 19, 1, 1, 0]}
127 ]
128
129 models = {}
130
131 if trunk == 'vgg19':
132 print("Bulding VGG19")
133 models['block0'] = make_vgg19_block(block0)
134
135 for k, v in blocks.items():
136 models[k] = make_stages(list(v))
137
138 class rtpose_model(nn.Module):
139 def __init__(self, model_dict):
140 super(rtpose_model, self).__init__()
141 self.model0 = model_dict['block0']
142 self.model1_1 = model_dict['block1_1']
143 self.model2_1 = model_dict['block2_1']
144 self.model3_1 = model_dict['block3_1']
145 self.model4_1 = model_dict['block4_1']
146 self.model5_1 = model_dict['block5_1']
147 self.model6_1 = model_dict['block6_1']
148
149 self.model1_2 = model_dict['block1_2']
150 self.model2_2 = model_dict['block2_2']
151 self.model3_2 = model_dict['block3_2']
152 self.model4_2 = model_dict['block4_2']
153 self.model5_2 = model_dict['block5_2']
154 self.model6_2 = model_dict['block6_2']
155
156 self._initialize_weights_norm()
157
158 def forward(self, x):
159
160 saved_for_loss = []
161 out1 = self.model0(x)
162
163 out1_1 = self.model1_1(out1)
164 out1_2 = self.model1_2(out1)
165 out2 = torch.cat([out1_1, out1_2, out1], 1)
166 saved_for_loss.append(out1_1)
167 saved_for_loss.append(out1_2)
168
169 out2_1 = self.model2_1(out2)
170 out2_2 = self.model2_2(out2)
171 out3 = torch.cat([out2_1, out2_2, out1], 1)
172 saved_for_loss.append(out2_1)
173 saved_for_loss.append(out2_2)
174
175 out3_1 = self.model3_1(out3)
176 out3_2 = self.model3_2(out3)
177 out4 = torch.cat([out3_1, out3_2, out1], 1)
178 saved_for_loss.append(out3_1)
179 saved_for_loss.append(out3_2)
180
181 out4_1 = self.model4_1(out4)
182 out4_2 = self.model4_2(out4)
183 out5 = torch.cat([out4_1, out4_2, out1], 1)
184 saved_for_loss.append(out4_1)
185 saved_for_loss.append(out4_2)
186
187 out5_1 = self.model5_1(out5)
188 out5_2 = self.model5_2(out5)
189 out6 = torch.cat([out5_1, out5_2, out1], 1)
190 saved_for_loss.append(out5_1)
191 saved_for_loss.append(out5_2)
192
193 out6_1 = self.model6_1(out6)
194 out6_2 = self.model6_2(out6)
195 saved_for_loss.append(out6_1)
196 saved_for_loss.append(out6_2)
197 #其中out6_1 为38个feature maps
198 #其中out6_2 为19个feature maps
199 #saved_for_loss表示需要计算loss的层,对于训练的时候有用
200 return (out6_1, out6_2), saved_for_loss
201
202 def _initialize_weights_norm(self):
203
204 for m in self.modules():
205 if isinstance(m, nn.Conv2d):
206 init.normal_(m.weight, std=0.01)
207 if m.bias is not None: # mobilenet conv2d doesn't add bias
208 init.constant_(m.bias, 0.0)
209
210 # last layer of these block don't have Relu
211 init.normal_(self.model1_1[8].weight, std=0.01)
212 init.normal_(self.model1_2[8].weight, std=0.01)
213
214 init.normal_(self.model2_1[12].weight, std=0.01)
215 init.normal_(self.model3_1[12].weight, std=0.01)
216 init.normal_(self.model4_1[12].weight, std=0.01)
217 init.normal_(self.model5_1[12].weight, std=0.01)
218 init.normal_(self.model6_1[12].weight, std=0.01)
219
220 init.normal_(self.model2_2[12].weight, std=0.01)
221 init.normal_(self.model3_2[12].weight, std=0.01)
222 init.normal_(self.model4_2[12].weight, std=0.01)
223 init.normal_(self.model5_2[12].weight, std=0.01)
224 init.normal_(self.model6_2[12].weight, std=0.01)
225
226 model = rtpose_model(models)
227 return model
228
229
230 """Load pretrained model on Imagenet
231 :param model, the PyTorch nn.Module which will train.
232 :param model_path, the directory which load the pretrained model, will download one if not have.
233 :param trunk, the feature extractor network of model.
234 """
235
236
237 def use_vgg(model):
238
239 url = 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
240 vgg_state_dict = model_zoo.load_url(url)
241 vgg_keys = vgg_state_dict.keys()
242
243 # load weights of vgg
244 weights_load = {}
245 # weight+bias,weight+bias.....(repeat 10 times)
246 for i in range(20):
247 weights_load[list(model.state_dict().keys())[i]
248 ] = vgg_state_dict[list(vgg_keys)[i]]
249
250 state = model.state_dict()
251 state.update(weights_load)
252 model.load_state_dict(state)
253 print('load imagenet pretrained model')
254
2. get_outputs函数中,包括了模型的搭建,以及前向的工作,得到了paf以及heatmap图。咱们需要特别看下这个函数。
def get_outputs(img, model, preprocess):
"""Computes the averaged heatmap and paf for the given image
:param multiplier:
:param origImg: numpy array, the image being processed
:param model: pytorch model
:returns: numpy arrays, the averaged paf and heatmap
"""
inp_size = cfg.DATASET.IMAGE_SIZE
#其中inp_size为368
# padding
#其中的DOWNSAMPLE为默认值是8
im_croped, im_scale, real_shape = im_transform.crop_with_factor(
img, inp_size, factor=cfg.MODEL.DOWNSAMPLE, is_ceil=True)
#进行图片的处理
#im_cropped size is => (*, 368, 3)
if preprocess == 'rtpose':
im_data = rtpose_preprocess(im_croped)
elif preprocess == 'vgg':
im_data = vgg_preprocess(im_croped)
elif preprocess == 'inception':
im_data = inception_preprocess(im_croped)
elif preprocess == 'ssd':
im_data = ssd_preprocess(im_croped)
batch_images= np.expand_dims(im_data, 0)
# several scales as a batch
batch_var = torch.from_numpy(batch_images).cuda().float()
#其中predicted_outputs是个tuple,后面是用来计算的loss我们用_来接住,不管它
predicted_outputs, _ = model(batch_var)
output1, output2 = predicted_outputs[-2], predicted_outputs[-1]
heatmap = output2.cpu().data.numpy().transpose(0, 2, 3, 1)[0]
paf = output1.cpu().data.numpy().transpose(0, 2, 3, 1)[0]
#其中经过了8倍的下采样,(h // 8, w // 8, 38)
#(h.//8, w//8, 19)
return paf, heatmap, im_scale
3. paf_to_pose_cpp函数,相对而言有点炸,里面涉及的是一个cpp代码,通过swig来进行的。大师总体的工作,,可以通过另一个代码https://blog.csdn.net/l297969586/article/details/80346254来进行对齐,后续有时间我会补上这个swig的代码,相对而言,就是通过一个采样,计算了关键点之间亲和度的方法以及关键点聚类的操作,但是里面有一些值不是很明白,还是说纯粹工程上作者们试出来的,whatever,通过这个函数,我们可以得到图片中有多少个人,以及这些人的关键点坐标,哪些关键点组成哪个人的哪个肢干等等数据,想要的数据都得到了。
4. draw_humans 函数,进行画图,对于上面得到的结果,可以直接拿出每个人,每个人的关节点的坐标,直接进行画点,连线。。