I视线
关注上方公号,获取更多干货!
近两年井喷了很多anchor free的目标检测方法,性能也让one-stage和two-stage的距离又缩小了一点,尤其是Centernet,在我心中是真正的anchor free + nms free方法,较YOLO3都涨了4个点。作者也在工程中对Centernet进行了实用和改进,效果确实提升客观。本文对该方法做个浅析。
不太重要的摘要
Centernet的与众不同
CenterNet没有anchor这个概念,只负责预测物体的中心点,所以也没有所谓的box overlap大于多少多少的算positive anchor,小于多少算negative anchor这一说,也不需要区分这个anchor是物体还是背景 - 因为每个目标只对应一个中心点,这个中心点是通过heatmap中预测出来的,所以不需要NMS再进行来筛选。
CenterNet的输出分辨率的下采样因子是4,比起其他的目标检测框架算是比较小的(Mask-Rcnn最小为16、SSD为最小为16)。之所以设置为4是因为centernet没有采用FPN结构,因此所有中心点要在一个Feature map上出,因此分辨率不能太低。
相对于其他基于关键点的目标检测方法相比(比如CornerNet,ExtremeNet),无需grouping过程以及后处理操作(这些操作会拖慢速度);

Centernet的原理实现
-
预处理得到热图
输出热点图Y∈[0,1]W/R∗H/R∗C,其中R是输出stride大小,默认为4,C是图像中的中心点种类数。在目标检测问题中C就是目标种类数,在人体动作检测的时候C就是人的关节数(一般是17)。 也就是说我们在这一步中需要得到各个种类中心点的热点图 。
热点图的计算方法采用二维高斯分布。假设实际的中心点(也就是物体边框的几何中心点)为p∈R2∈R2,由于输出图像Yxyz∈[0,1]W/R∗H/R∗C,所以p映射到大小W/R, H/R的图像中的对应点~p = p/R。 (这一步向下取整会造成误差) 然后以~p为中心,并利用二维高斯分布计算周围点在[0,1]范围内的值,从而生成一个中心点的热点图。这个热点图会用于之后的误差分析,由于论文中最开始讲的就是这一部分,所以我也在一开始讲解这个处理过程。
目标检测的预测结果
我们的目标是通过输入图像I,预测出图像的中心点图Yxyc∈[0,1]W/R∗H/R∗C。Yxyc在0,1之间,0表示(x,y,c)处为背景,1表示(x,y,c)处为中心点(根据c的值确定属于哪一类中心点),Yxyc值的大小表示像素点(x,y)是类别c的中心点的一个置信度大小。
而且由于我们对图像进行了下采样,也就是因为输出中心点图大小比原始图像长宽都缩小了R倍,而像素点位置p~=p/R这一步是向下取整,所以也造成了很少量的中心点偏移,所以我们还需要预测一个中心点的偏移值。将上一步中预测中心点位置加上偏移值才是实际预测中心点位置。
进行目标检测除了需要知道目标的中心点位置,我们还需要知道目标所处的边框大小,也就是需要预测一个边框的长宽值。在已知边框中心的情况下,通过边框长宽就可以完全定位目标边框的位置。
综上所述,对于每个像素点我们需要使用网络预测出C+4个值。
-
C个是否是中心点的信息 -
w(预测边框的宽) -
h(预测边框的高) -
δx(中心点在x方向上的偏移) -
δy(中心点在y方向上的偏移)
结合以上C+4个参数我们就可以得到目标边框的预测结果:(结果用边框的左上角和右下角位置表示)
由于每个像素点都可以计算出一个目标边框,所以我们还需要对预测结果进行一步筛选。在前面我们说过,Yxyc值的大小表示像素点(x,y)是类别c中心的置信度大小,所以我们只需要分别选择出每个类别,也就是每个通道上的极值点位置。极值点的判断可以简单的将像素点与其八邻域像素点值进行比较,如果像素点(x,y)在通道c上比它的八邻域像素点在通道c上的值大,那么就保留(x,y)作为类别c的一个中心点预测结果。
损失函数
明确模型的输入和输出之后,在开始训练之前还需要明确模型的损失函数,也就是需要明确如何度量预测结果的好坏。预测误差根据预测结果的类型也可以分为三类:
中心点误差Lk
首先第一个误差就是由于中心点估计产生的误差,误差Lk的计算方式如下:
可以看到中心点误差的计算中使用到了在第一部分中的热点图,由于热点图是由实际中心点求二维高斯分布得到的,所以由以上公式可以体现出预测中心点和实际中心点之间的偏差程度。在论文中α=2 β=4,N是图像I中的中心点个数,在最后结果除以N是为了使结果归一化。。
中心点偏移误差Loff
由于图像下采样,使得最后预测的中心点和原图中实际中心点的位置存在着微小的误差,这是由于向下取整的问题造成的。所以我们还需要一个偏移预测量
OpOp来表示中心点由于下采样造成的偏移,
OpOp由x,y这两个方向上偏移构成。Loff就是将所有中心点的偏移预测量和实际偏移量做差(L1 loss)再求平均的结果。
边框尺寸误差Lsize
网络预测结果中包含了边框长宽的预测,为了表示对边框的预测误差,这里直接计算预测边框的总面积与实际边框总面积的误差。
总损失函数
将以上三个方面的误差加权求和就可以得到总的误差值Ldet,在论文中人size=0.1,人off=1。
目标检测的训练网络
明确前面几个问题之后就需要选择训练网络。在论文中,作者使用了四种网络结构:ResNet-18,ResNet-101,DLA-34,和Hourglass-104。其中两个ResNet和DLA都使用了可变形卷积层。
hm head
整个网络主要有bone net特征提取网络和输出部分组成,网络结果如上图所示,特征提取网络就不细讲了。用高斯分布来表示目标,网络第三个分支/hm输出部分网络如下:
nn.Sequential( nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))
其实就是一个conv2d(64,256),relu(),conv2d(256,1),最后的输出为n_category×128×128,一个类别一个通道,其中每个点的值表示:是目标的概率有多大。
reg head
网络定义为:
nn.Sequential( nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 2, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))
分支最后输出:2×128×128,所有类别用共同的预测宽度w和高度h。
wh head
网络定义为:
nn.Sequential( nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(256, 2, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True)
分支输出:2×128×128,每个点的两个值表示,当前index为目标时hm输出位置的偏差,所有类别用共同的w,h预测值。
来张实验结果图测测效果:
Centernet的优缺点
优点如下:
1.设计模型的结构比较简单,一般人也可以轻松看明白,不仅对于two-stage,对于one-stage的目标检测算法来说该网络的模型设计也是优雅简单的。
2.该模型的思想不仅可以用于目标检测,还可以用于3D检测和人体姿态识别,虽然论文中没有是深入探讨这个,但是可以说明这个网络的设计还是很好的,我们可以借助这个框架去做一些其他的任务。
3. 虽然目前尚未尝试轻量级的模型,但是可以猜到这个模型对于嵌入式端这种算力比较小的平台还是很有优势的。
缺点如下:
1.在实际训练中,如果在图像中,同一个类别中的某些物体的GT中心点,在下采样时会挤到一块,也就是两个物体在GT中的中心点重叠了,CenterNet对于这种情况也是无能为力的,也就是将这两个物体的当成一个物体来训练(因为只有一个中心点)。同理,在预测过程中,如果两个同类的物体在下采样后的中心点也重叠了,那么CenterNet也是只能检测出一个中心点,不过CenterNet对于这种情况的处理要比faster-rcnn强一些的,具体指标可以查看论文相关部分。
2. 有一个需要注意的点,CenterNet在训练过程中,如果同一个类的不同物体的高斯分布点互相有重叠,那么则在重叠的范围内选取较大的高斯点。
文章参考:
[1]https://blog.csdn.net/Eric3778/article/details/101098013?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-49.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-49.nonecase
[2]https://blog.csdn.net/u011622208/article/details/103072220?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-47.nonecase&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-47.nonecase
如有侵权,请联系删除。
最新人工智能、深度学习、SLAM干货奉上!
本文分享自微信公众号 - AI深度学习视线(AI_DeepSight)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。
来源:oschina
链接:https://my.oschina.net/u/4590228/blog/4410209