CN113569882A - 一种基于知识蒸馏的快速行人检测方法 - Google Patents
一种基于知识蒸馏的快速行人检测方法 Download PDFInfo
- Publication number
- CN113569882A CN113569882A CN202010352095.9A CN202010352095A CN113569882A CN 113569882 A CN113569882 A CN 113569882A CN 202010352095 A CN202010352095 A CN 202010352095A CN 113569882 A CN113569882 A CN 113569882A
- Authority
- CN
- China
- Prior art keywords
- model
- loss function
- regression
- detection
- classification
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 82
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 29
- 238000000034 method Methods 0.000 claims abstract description 26
- 238000012549 training Methods 0.000 claims abstract description 23
- 238000004364 calculation method Methods 0.000 claims abstract description 5
- 230000006870 function Effects 0.000 claims description 70
- 238000004088 simulation Methods 0.000 claims description 34
- 238000004422 calculation algorithm Methods 0.000 claims description 24
- 230000008447 perception Effects 0.000 claims description 15
- 230000004927 fusion Effects 0.000 claims description 11
- 238000000926 separation method Methods 0.000 claims description 10
- 238000010586 diagram Methods 0.000 claims description 9
- 230000006978 adaptation Effects 0.000 claims description 6
- 238000013461 design Methods 0.000 claims description 6
- 238000004821 distillation Methods 0.000 claims description 5
- 230000000694 effects Effects 0.000 claims description 5
- 238000005457 optimization Methods 0.000 claims description 4
- 230000009467 reduction Effects 0.000 claims description 4
- 238000009826 distribution Methods 0.000 claims description 3
- 230000007613 environmental effect Effects 0.000 claims description 3
- 230000005764 inhibitory process Effects 0.000 claims description 3
- 238000012360 testing method Methods 0.000 claims description 3
- 238000001914 filtration Methods 0.000 claims description 2
- 230000001629 suppression Effects 0.000 claims 1
- 238000013527 convolutional neural network Methods 0.000 abstract description 7
- 238000011897 real-time detection Methods 0.000 abstract description 4
- 230000001133 acceleration Effects 0.000 abstract description 3
- 230000006835 compression Effects 0.000 abstract description 3
- 238000007906 compression Methods 0.000 abstract description 3
- 238000013135 deep learning Methods 0.000 abstract description 2
- 230000008569 process Effects 0.000 description 9
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 238000005065 mining Methods 0.000 description 2
- 238000012544 monitoring process Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 230000004913 activation Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000005520 cutting process Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 238000005286 illumination Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000003860 storage Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/253—Fusion techniques of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computational Linguistics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Traffic Control Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及计算机视觉、深度学习、目标检测、模型压缩加速等领域,具体是一种在图像或视频中对行人目标进行快速识别和定位的方法。针对行人检测网络参数量计算量较大,无法达到实时检测的问题,本发明提出一种基于知识蒸馏的快速行人检测方法,将常规的基于卷积神经网络的行人检测模型作为教师模型,通过优化的卷积方式降低其参数量和计算量,得到轻量的检测模型作为学生模型;然后针对检测任务,利用教师模型辅助学生模型的训练,提高轻量模型在复杂场景下的检测性能,从而在不牺牲过多检测精度的基础上加快检测速度,达到实时检测的要求。
Description
技术领域
本发明涉及计算机视觉、深度学习、目标检测、模型压缩加速等领域,具体是一种在图像或视频中对行人目标进行快速识别和定位的方法。
背景技术
随着互联网时代数据信息的极大丰富以及计算设备性能的不断提升,基于图像或视频来模拟人类视觉的计算机视觉技术迅猛发展。目标检测是计算机视觉中的基础任务之一,其主要目的是对图像或视频中的物体进行识别和定位。行人检测由于其检测目标的特殊性,一直是通用目标检测的重要分支。鲁棒Robust、快速的行人检测算法在智能交通、自动驾驶、视频监控、人机交互等领域均有着广泛的应用,同时也是目标跟踪、姿态识别、视频分析、场景理解等高级视觉任务的基础。因此,如何在复杂背景的干扰下稳健、快速的检测行人是计算机视觉技术在实际应用中亟待解决的难题。
卷积神经网络提取的特征相比于手工设计特征具备更好的表达能力和鲁棒性,极大改善了传统行人检测算法在复杂交通场景下由于遮挡、多尺度、光照变化等因素带来的性能降低问题,有效提升了算法的性能。目前基于卷积神经网络的检测算法主要有两类:一是以Faster RCNN为代表的双阶段(two stage)方法和以YOLO(You Only Look Once)为代表的单阶段(one stage)方法。前者首先利用区域建议网络生成可能存在目标的前景区域,然后针对每个区域判断是否是待检测物体,同时进一步微调其位置,是一个由粗到精的过程。后者的基本思路则是利用统一的网络直接回归出待检测物体的类别和位置,本质上是一个密集采样的过程。通常two stage检测算法具备更高的检测精度,但其检测速度相对较慢;得益于更加统一的检测框架,one stage检测算法的检测速度具备一定的优势,但其检测精度会略有降低。在目标检测方法发展的过程中,两类算法相互借鉴、相互融合,算法的性能和速度均取得了长足的进步。
目前,基于卷积神经网络的行人检测算法性能提升显著,在充足训练样本支持下,经典的检测网络如RetinaNet对于遮挡、杂乱背景等问题也具有较好鲁棒性。但是,这些检测网络的参数量、计算量很大,对计算资源要求很高。例如,精度较高的RetinaNet,即使在配备RTX2080的服务器上,其处理速度也难以达到实时检测。在智能交通、自动驾驶和智能监控等领域的实际应用中,行人检测算法通常需要运行于前端的嵌入式设备,而这些设备的计算和存储能力是非常有限的。因此,如何在保证一定检测精度的前提下,对检测网络进行压缩加速从而提高算法的实时性,是当前行人检测算法研究与应用的重要难点。
发明内容
本发明为了解决上述技术问题采用以下技术方案:
一种基于知识蒸馏的快速行人检测方法,包括如下步骤:
一种基于知识蒸馏的快速行人检测方法,其特征在于,包括以下步骤:
一种基于知识蒸馏的快速行人检测方法,其特征在于,包括以下步骤:
步骤1:教师模型设计,在RetinaNet网络基础上,引入小尺度检测模块以及尺度感知损失函数,改善复杂环境下的行人检测效果;
步骤2:教师模型训练,对改进后的RetinaNet进行训练,将训练完成的模型作为教师模型,为后续学生模型的训练提供辅助信息;
步骤3,学生模型生成,通过新型的卷积方式替换教师网络中的传统卷积方式,同时降低生成融合特征图的通道数,生成参数量和计算量较低的轻量学生模型。
步骤4,知识蒸馏,综合利用教师模型提供的特征信息,分类置信度以及回归偏置作“软标签”指导步骤3中生成的学生模型的训练,通过知识蒸馏缓解轻量学生模型由于容量较低带来的性能降低问题,包括:
4a)将教师模型的融合特征层经过adaption调整层调节至与学生模型对应特征维度一致,取其中局部重要特征作为标签,与学生模型对应的特征计算特征模拟损失函数Lfeature_imit,进行特征模拟;
4b)将教师模型的分类支路输出的分类置信度作为标签,与学生模型输出的分类置信度计算分类损失函数Lcls_imit,进行分类模拟;
4c)在教师模型输出的回归偏置足够可信的情况下,将其作为标签,与学生模型输出的回归偏置计算回归模拟损失函数Lreg_imit,进行回归模拟;
4d)将学生模型输出的分类置信度和回归偏置与行人的真实标签分别计算分类损失函数Lcls和回归损失函数Lreg,与步骤4a),4b),4c)中得到的Lfeature_imit、Lcls_imit、Lreg_imit加权求和后得到整体损失函数Lall,从而联合地对轻量学生模型进行参数优化,不断迭代得到最终的算法模型。
步骤5,输出测试结果,将待检测图片输入到已经训练好的轻量学生模型中,设定阈值来滤除置信度较低的预测框,对剩下的预测框采用非极大值抑制法去除重叠程度较高的框,进而得到最终的检测结果。
进一步的,所述的步骤2.所述步骤1中的小尺度检测模块,是结合多分支结构以及跳跃连接获取具备不同感受野和深度的特征信息,以环境信息弥补小尺度行人特征不足的问题,强化特征的表示能力,具体为尺度感知回归损失函数,其形式如下所示:
其中:A是所有预设边界框,pn*=1表示仅对判断为正样本的预测框计算回归损失函数,dx,dy,dw,dh分别表示预测框相对于预设边界框的偏移量和缩放比例。Wscalen为尺度感知系数,Wimg,Himg为输入图片尺寸,gwn和ghn分别为第n个预测框对应行人真实框的宽和高。β为控制尺度感知系数影响程度的权值,可根据样本分布调节,本发明中设为1。
进一步的,所述的步骤3中新的卷积方式为非对称深度分离卷积,其将普通3×3卷积拆分为3×1和1×3深度分离卷积,将二者串联后再接1×1卷积,构成新的卷积模块后代替原本的3×3卷积。此外,为了进一步降低参数量,将RetinaNet融合生成的5个特征层p1~p5的通道数从256降低到128。
进一步的,5.所述的步骤4中的知识蒸馏方法,对步骤3中得到的轻量学生模型进行训练,利用知识蒸馏从特征、分类置信度和回归偏置三个维度挖掘教师模型的“知识”,辅助轻量学生模型的训练;
从特征维度提出的特征模拟损失函数Lfeature_imit形式为:
其中:k为5个融合特征层,W、H、C分别为融合特征层的宽、高、通道数。I是分类置信度大于阈值θ的预设边界框的重叠区域掩膜,本发明中θ设为0.5,Np是掩膜区域所有特征点的个数。fadap为适应层。s和t分别为学生模型特征和教师模型特征;
从分类置信度维度提出的分类模拟损失函数Lcls_imit形式为:
其中:A表示所有预设边界框的集合,N为其总数目。ptn和psn分别表示第n个预测框的教师模型预测概率和学生模型预测概率;
从回归偏置维度提出的回归模拟损失函数Lreg_imit形式为:
其中:M是所有计算回归模拟损失函数的预测框数目,A是所有预设边界框的集合,Rs和Rt分别为学生模型和教师模型输出的相对于第n个预设框的回归偏置。boxsn、boxtn、boxgtn分别表示第n个学生模型预测框,教师模型预测框以及对应的行人真实框,
最终通过整体损失函数Lall联合地对整个知识蒸馏框架进行优化,其形式为:
Lall=Lhard+Lsoft
Lhard=Lcls+λ1Lreg
Lsoft=Lfeature_imit+λ2Lcls_imit+λ3Lreg_imit
其中:Lhard为轻量模型预测值与行人真实标签计算的损失,主要包括分类损失函数Lcls和回归损失函数Lreg。分类损失函数采用RetinaNet中的focal loss,回归损失函数采用添加尺度感知权重的新型回归损失函数,具体形式见步骤1,λ1为平衡二者的权重,本发明中设为0.8。Lfeature_imit、Lcls_imit、Lreg_imit分别为上文所述的特征模拟损失函数、分类模拟损失函数和回归模拟损失函数,λ1、λ2为平衡三者的权重,本发明中均设为1。
进一步的,所述的步骤5中滤除置信度较低预测框的阈值θ1设为0.3,非极大值抑制的重叠程度阈值θ2设为0.4。
有益效果
本发明采用以上技术方案与现有技术相比,具有以下技术效果:
1,本发明引入环境信息设计了专门的小尺度行人检测模块,通过多分支结构以及跳跃连接获得更宽、更深的特征,提高小尺度行人特征的表示能力。同时设计了尺度感知回归损失函数,使得算法更加关注小尺度行人的定位过程,从而提高复杂场景下对小尺度行人的处理能力。
2,本发明通过优化的卷积方式极大降低了基于卷积神经网络的行人检测算法RetinaNet的参数量和计算量,使算法能够在实际应用场景中达到实时的检测速度,降低了对硬件设备的性能要求。
3,本发明针对检测任务设计了一种新型的知识蒸馏方法,从特征、分类置信度、回归偏置三个维度充分挖掘教师模型的“知识”,从而提高模型容量较低的轻量网络在复杂场景下的检测性能,最终在检测速度和检测性能两个方面取得较好的平衡。
附图说明
图1是本发明一种基于知识蒸馏的快速行人检测方法的流程示意图;
图2是本发明中为小尺度行人设计的检测模块结构示意图;
图3是本发明中非对称深度分离卷积模块示意图;
图4是本发明针对检测任务设计的知识蒸馏整体结构示意图。
具体实施方式
下面详细描述本发明的实施方式,所述实施方式的示例在附图中示出。下面通过参考附图描述的实施方式是示例性的,仅用于解释本发明,而不能解释为对本发明的限制。
如图1所示,本发明涉及的基于知识蒸馏的快速行人检测方法,具体步骤如下:
步骤1:教师模型设计,首先针对基于卷积神经网络的行人检测算法RetinaNet在复杂背景下对小尺度行人检测效果不佳的问题,通过引入环境信息为小尺度行人设计专门的检测模块,将其添加至RetinaNet融合特征层p1~p5中的p1和p2层。
其具体结构图如图2所示,虚线框中为环境感知模块,其将输入特征图经通道数为256的3×3卷积后分成两个分支,一个分支接3×3卷积,另一分支则串联两个3×3卷积,从而获得更大的感受野,捕捉足够的环境信息。然后采用concat操作将两个分支输出特征融合,再经3×3卷积后通过跳跃连接与原本的分支输入特征对应相加,从而获得更宽、更深的小尺度行人特征,加强对小尺度行人的判别能力。完整的小尺度行人检测模块在环境感知模块的基础上继续通过分支结构和跳跃连接进一步强化特征的表示能力,最终分别通过通道数为2A和4A的3×3卷积输出分类置信度和回归偏置。
为了进一步改善小尺度行人的定位精度,本发明还设计了尺度感知的回归损失函数,其基本思路是:对于相同的偏移距离,小尺度行人预测框与真实框的重叠程度更低,即位置更不准确,而大尺度行人预测框则影响相对较小。因此对小尺度行人位置的精确回归更加困难。故提出形式如下的新型回归损失函数:
其中:A是所有预设边界框,pn*=1表示仅对判断为正样本的预测框计算回归损失函数,dx,dy,dw,dh分别表示预测框相对于预设边界框的偏移量和缩放比例。Wscalen为尺度感知系数,Wimg,Himg为输入图片尺寸,gwn和ghn分别为第n个预测框对应行人真实框的宽和高。对于每个判断为正样本的行人预测框,计算其面积与输入图片尺寸的比值,以2相减后作为尺度感知权重系数。即降低大尺度行人的回归损失,增大小尺度行人的回归损失,从而迫使算法在回归过程中更加关注小尺度行人,提高其定位精度。β为控制尺度感知系数影响程度的权值,可根据样本分布调节,本发明中设为1。
步骤2:教师模型训练,对基于卷积神经网络的行人检测方法RetinaNet改进之后,利用行人数据集对其进行充分的训练,将得到的高性能行人检测模型作为教师模型,从而辅助后续的知识蒸馏过程。主要的训练参数设置为:输入图片尺寸为1200×900,数据增强方式采用随机裁剪、水平翻转、颜色变换,优化器采用adam,学习率初始设为0.0001,若3个epoch内整体不下降,则将学习率降低10倍,总共训练100个epoch,batch_size大小为2。
步骤3:学生模型生成,为了提升ReitnaNet在实际场景应用时的检测速度,本发明中优化的卷积方式—非对称深度分离卷积,从而极大降低算法的参数量和计算量,有效地实现模型的压缩加速。
图3是常规卷积模块与非对称深度分离卷积结构对比。将普通3×3卷积拆解为大小为1×3和3×1,通道数为1的非对称深度分离卷积串联,然后接批归一化层(BN)和激活函数Relu,为了改善深度分离卷积通道之间信息无法流通的问题,接通道数与输入特征图相同的1×1卷积,最后再加一组BN层和Relu。以此优化的卷积方式取代步骤1中改进RetinaNet的主干网络后两层、融合特征层p1~p5、卷积预测模块中的常规3×3卷积,从而降低参数量和计算量,理论分析可得此卷积方式参数量和计算量约为常规3×3卷积的1/9。为了进一步降低模型的参数量,本发明还削减了原始ReitnaNet融合特征层p1~p5的通道数,由256降低到128,因为本发明只对行人一个种类进行检测,所以适当的降低通道数对性能影响不大。经过卷积优化以及通道数降低之后,可以获得一个参数量和计算量大幅减小的轻量学生模型,其检测速度可以达到实时的要求,但由于其模型容量较低,检测性能相比参数量较多的原始RetinaNet会有所下降。
步骤4:知识蒸馏,为了提升步骤3中得到的轻量检测模型在复杂场景下的检测性能,本发明针对检测任务采用新的知识蒸馏方法。其整体框架结构如图4所示,以步骤2中训练好的改进RetinaNet作为教师模型,从三个角度充分挖掘教师模型的“知识”用以指导步骤3中得到的轻量学生模型的训练过程,从而弥补模型容量低带来的性能损失,更加有效地对其参数进行优化。
首先是挖掘教师模型特征维度的“知识”,通过特征模拟损失函数Lfeature_imit对教师模型的局部特征进行模拟,其具体形式为:
其中:k为5个融合特征层,W、H、C分别为融合特征层的宽、高、通道数。s和t分别为学生模型特征和教师模型特征。fadap为适应层,其实际上是通道数为256的1×1卷积,将轻量模型融合特征层的特征经适应层调节至与教师模型融合特征层输出通道数一致。Iij为分类置信度大于阈值的预设边界框的重叠区域,对教师模型与学生模型在此区域内的特征计算L2损失函数,从而迫使轻量学生模型仅在可能存在待检测行人的区域对教师模型进行特征模拟。丰富的特征信息可以改善检测任务类别信息较少对传统知蒸馏方法的影响,同时仅针对重要前景特征的模拟可以有效避免噪声的引入。本发明中设为0.5,Np是掩膜区域所有特征点的个数。
其次是挖掘教师模型分类维度的“知识”,与常规基于分类任务的知识蒸馏方法类似,利用教师模型卷积预测模块中分类分支输出的分类置信度作为“软标签”,与对应轻量学生模型输出的分类置信度计算分类模拟损失函数Lcls_imit,其具体形式为:
其中:A表示所有预设边界框的集合,N为其总数目。ptn和psn分别表示第n个预测框的教师模型预测概率和学生模型预测概率。pi是当前预测框属于第i个类别的概率,zi为模型输出值。T为温度系数,在softmax函数中加入可以“软化”模型预测的类别概率值。对于检测等预测难度更高的任务,温度系数T越高,会引入越多的噪声。故本发明中温度系数设为1,即采用常规的softmax函数获得前后景的类别概率值。
最后是挖掘教师模型定位维度的“知识”,由于检测任务相比于分类任务更加复杂,其包含分类和定位两个任务。如何通过回归分支准确的回归出物体边界框的位置和大小对于检测任务是至关重要的。故本发明考虑将教师模型回归分支输出的定位信息同样作为“软标签”用来指导轻量模型回归分支的学习。与分类任务预测的离散的类别概率不同,回归分支预测的是相对于预设边界框的位置偏移和缩放比例,是连续值。故知识蒸馏的过程中对教师模型的输出更为敏感,教师模型的错误预测值可能与真实标签相去甚远,若把教师模型的误检作为“软标签”可能会导致损失值巨大,影响轻量模型的正常收敛。因此本发明提出一种受限的回归模拟损失函数Lreg_imit,其具体形式为:
其中:M是所有计算回归模拟损失函数的预测框数目,A是所有预设边界框的集合,Rs和Rt分别为学生模型和教师模型输出的相对于第n个预设框的回归偏置。boxsn、boxtn、boxgtn分别表示第n个学生模型预测框,教师模型预测框以及对应的行人真实框。以轻量模型输出的预测框与对应真实框的重叠程度作为限制,仅当教师模型的预测框比学生模型更加接近真实框的时候才将其作为“软标签”,否则不计算回归模拟损失函数。此时对于轻量模型来说只有真实框作为标签指导其回归分支的学习,避免了教师模型的误检带来的影响,从而更加鲁棒、稳定地指导轻量模型的训练过程。
此外,轻量模型预测的分类置信度和回归偏置同样与真实标签计算常规的分类损失函数Lcls和回归损失函数Lreg。最终将五部分损失函数加权后作为整个蒸馏模型的总损失,联合地对各部分进行优化,整体损失函数形式为:
Lall=Lhard+Lsoft
Lhard=Lcls+λ1Lreg
Lsoft=Lfeature_imit+λ2Lcls_imit+λ3Lreg_imit
其中:Lhard为轻量模型预测值与行人真实标签计算的损失,主要包括分类损失函数Lcls和回归损失函数Lreg。分类损失函数采用RetinaNet中的focal loss,回归损失函数采用添加尺度感知权重的新型回归损失函数,具体形式见步骤1,λ1为平衡二者的权重,本发明中设为0.8。Lfeature_imit、Lcls_imit、Lreg_imit分别为上文所述的特征模拟损失函数、分类模拟损失函数和回归模拟损失函数,λ1、λ2为平衡三者的权重,本发明中均设为1。
图4中虚线表示梯度反向传播的方向,即添加知识蒸馏联合训练时固定复杂教师模型参数不更新,仅更新轻量模型参数。训练时采用随机水平翻转,裁剪实现数据增强。采用Adam作为权值优化算法,初始学习率为0.0001。因为添加知识蒸馏联合训练收敛较慢,故学习率更新策略为:如果6个epoch内损失函数不下降,学习率缩小10倍,共训练120个epoch,batch_size大小设为2。
步骤5:输出测试结果,通过步骤4中知识蒸馏方法对轻量模型进行有效训练后,得到最终的行人检测算法模型。将原始待检测图片送入模型中,设定阈值θ1来滤除置信度较低的预测框,本发明中设为0.3,若预测框的分类置信度低于阈值,则舍弃此结果。对剩下的预测框采用非极大值抑制法去除重叠程度较高的框,重叠程度阈值θ2设为0.4,最后保留的预测框就是算法的检测结果。
以上实施例仅为说明本发明的技术思想,不能以此限定本发明的保护范围,凡是按照本发明提出的技术思想,在技术方案基础上所做的任何改动,均落入本发明保护范围之内。
Claims (6)
1.一种基于知识蒸馏的快速行人检测方法,其特征在于,包括以下步骤:
步骤1:教师模型设计,在RetinaNet网络基础上,引入小尺度检测模块以及尺度感知损失函数,改善复杂环境下的行人检测效果;
步骤2:教师模型训练,对改进后的RetinaNet进行训练,将训练完成的模型作为教师模型,为后续学生模型的训练提供辅助信息;
步骤3,学生模型生成,通过新型的卷积方式替换教师网络中的传统卷积方式,同时降低生成融合特征图的通道数,生成参数量和计算量较低的轻量学生模型。
步骤4,知识蒸馏,综合利用教师模型提供的特征信息,分类置信度以及回归偏置作“软标签”指导步骤3中生成的学生模型的训练,通过知识蒸馏缓解轻量学生模型由于容量较低带来的性能降低问题,包括:
4a)将教师模型的融合特征层经过adaption调整层调节至与学生模型对应特征维度一致,取其中局部重要特征作为标签,与学生模型对应的特征计算特征模拟损失函数Lfeature_imit,进行特征模拟;
4b)将教师模型的分类支路输出的分类置信度作为标签,与学生模型输出的分类置信度计算分类损失函数Lcls_imit,进行分类模拟;
4c)在教师模型输出的回归偏置足够可信的情况下,将其作为标签,与学生模型输出的回归偏置计算回归模拟损失函数Lreg_imit,进行回归模拟;
4d)将学生模型输出的分类置信度和回归偏置与行人的真实标签分别计算分类损失函数Lcls和回归损失函数Lreg,与步骤4a),4b),4c)中得到的Lfeature_imit、Lcls_imit、Lreg_imit加权求和后得到整体损失函数Lall,从而联合地对轻量学生模型进行参数优化,不断迭代得到最终的算法模型。
步骤5,输出测试结果,将待检测图片输入到已经训练好的轻量学生模型中,设定阈值来滤除置信度较低的预测框,对剩下的预测框采用非极大值抑制法去除重叠程度较高的框,进而得到最终的检测结果。
2.根据权利要求1所述的一种基于知识蒸馏的快速行人检测方法,其特征在于,所述步骤1中的小尺度检测模块,是结合多分支结构以及跳跃连接获取具备不同感受野和深度的特征信息,以环境信息弥补小尺度行人特征不足的问题,强化特征的表示能力,具体为尺度感知回归损失函数,其形式如下所示:
其中:A是所有预设边界框,pn*=1表示仅对判断为正样本的预测框计算回归损失函数,dx,dy,dw,dh分别表示预测框相对于预设边界框的偏移量和缩放比例。Wscalen为尺度感知系数,Wimg,Himg为输入图片尺寸,gwn和ghn分别为第n个预测框对应行人真实框的宽和高。β为控制尺度感知系数影响程度的权值,可根据样本分布调节,本发明中设为1。
3.根据权利要求1所述的一种基于知识蒸馏的快速行人检测方法,其特征在于,所述的步骤2的对改进后的RetinaNet进行训练,主要训练设置包括batch_size,优化器选择,学习率衰减策略。
4.根据权利要求1所述的一种基于知识蒸馏的快速行人检测方法,其特征在于,所述的步骤3中新的卷积方式为非对称深度分离卷积,其将普通3×3卷积拆分为3×1和1×3深度分离卷积,将二者串联后再接1×1卷积,构成新的卷积模块后代替原本的3×3卷积。此外,为了进一步降低参数量,将RetinaNet融合生成的5个特征层p1~p5的通道数从256降低到128。
5.根据权利要求1所述的一种基于知识蒸馏的快速行人检测方法,其特征在于,所述的步骤4中的知识蒸馏方法,对步骤3中得到的轻量学生模型进行训练,利用知识蒸馏从特征、分类置信度和回归偏置三个维度挖掘教师模型的“知识”,辅助轻量学生模型的训练;
从特征维度提出的特征模拟损失函数Lfeature_imit形式为:
其中:k为5个融合特征层,W、H、C分别为融合特征层的宽、高、通道数。I是分类置信度大于阈值θ的预设边界框的重叠区域掩膜,本发明中θ设为0.5,Np是掩膜区域所有特征点的个数。fadap为适应层。s和t分别为学生模型特征和教师模型特征;
从分类置信度维度提出的分类模拟损失函数Lcls_imit形式为:
其中:A表示所有预设边界框的集合,N为其总数目。ptn和psn分别表示第n个预测框的教师模型预测概率和学生模型预测概率;
从回归偏置维度提出的回归模拟损失函数Lreg_imit形式为:
其中:M是所有计算回归模拟损失函数的预测框数目,A是所有预设边界框的集合,Rs和Rt分别为学生模型和教师模型输出的相对于第n个预设框的回归偏置。boxsn、boxtn、boxgtn分别表示第n个学生模型预测框,教师模型预测框以及对应的行人真实框,
最终通过整体损失函数Lall联合地对整个知识蒸馏框架进行优化,其形式为:
Lall=Lhard+Lsoft
Lhard=Lcls+λ1Lreg
Lsoft=Lfeature_imit+λ2Lcls_imit+λ3Lreg_imit
其中:Lhard为轻量模型预测值与行人真实标签计算的损失,主要包括分类损失函数Lcls和回归损失函数Lreg。分类损失函数采用RetinaNet中的focal loss,回归损失函数采用添加尺度感知权重的新型回归损失函数,具体形式见步骤1,λ1为平衡二者的权重,本发明中设为0.8。Lfeature_imit、Lcls_imit、Lreg_imit分别为上文所述的特征模拟损失函数、分类模拟损失函数和回归模拟损失函数,λ1、λ2为平衡三者的权重,本发明中均设为1。
6.根据权利要求1所述的一种基于知识蒸馏的快速行人检测方法,其特征在于,所述的步骤5中滤除置信度较低预测框的阈值θ1设为0.3,非极大值抑制的重叠程度阈值θ2设为0.4。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010352095.9A CN113569882A (zh) | 2020-04-28 | 2020-04-28 | 一种基于知识蒸馏的快速行人检测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010352095.9A CN113569882A (zh) | 2020-04-28 | 2020-04-28 | 一种基于知识蒸馏的快速行人检测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113569882A true CN113569882A (zh) | 2021-10-29 |
Family
ID=78158191
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010352095.9A Pending CN113569882A (zh) | 2020-04-28 | 2020-04-28 | 一种基于知识蒸馏的快速行人检测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113569882A (zh) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114358206A (zh) * | 2022-01-12 | 2022-04-15 | 合肥工业大学 | 二值神经网络模型训练方法及***、图像处理方法及*** |
CN114972931A (zh) * | 2022-08-03 | 2022-08-30 | 国连科技(浙江)有限公司 | 一种基于知识蒸馏的货物存放方法及装置 |
CN115223117A (zh) * | 2022-05-30 | 2022-10-21 | 九识智行(北京)科技有限公司 | 三维目标检测模型的训练和使用方法、装置、介质及设备 |
CN116206275A (zh) * | 2023-02-23 | 2023-06-02 | 南通探维光电科技有限公司 | 基于知识蒸馏的识别模型训练方法及装置 |
CN117372819A (zh) * | 2023-12-07 | 2024-01-09 | 神思电子技术股份有限公司 | 用于有限模型空间的目标检测增量学习方法、设备及介质 |
CN117496509A (zh) * | 2023-12-25 | 2024-02-02 | 江西农业大学 | 一种融合多教师知识蒸馏的Yolov7柚子计数方法 |
WO2024012607A3 (zh) * | 2022-07-14 | 2024-04-04 | 顺丰科技有限公司 | 人员检测方法、装置、设备和存储介质 |
-
2020
- 2020-04-28 CN CN202010352095.9A patent/CN113569882A/zh active Pending
Cited By (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114358206B (zh) * | 2022-01-12 | 2022-11-01 | 合肥工业大学 | 二值神经网络模型训练方法及***、图像处理方法及*** |
CN114358206A (zh) * | 2022-01-12 | 2022-04-15 | 合肥工业大学 | 二值神经网络模型训练方法及***、图像处理方法及*** |
CN115223117A (zh) * | 2022-05-30 | 2022-10-21 | 九识智行(北京)科技有限公司 | 三维目标检测模型的训练和使用方法、装置、介质及设备 |
CN115223117B (zh) * | 2022-05-30 | 2023-05-30 | 九识智行(北京)科技有限公司 | 三维目标检测模型的训练和使用方法、装置、介质及设备 |
WO2024012607A3 (zh) * | 2022-07-14 | 2024-04-04 | 顺丰科技有限公司 | 人员检测方法、装置、设备和存储介质 |
CN114972931A (zh) * | 2022-08-03 | 2022-08-30 | 国连科技(浙江)有限公司 | 一种基于知识蒸馏的货物存放方法及装置 |
CN114972931B (zh) * | 2022-08-03 | 2022-12-30 | 国连科技(浙江)有限公司 | 一种基于知识蒸馏的货物存放方法及装置 |
CN116206275A (zh) * | 2023-02-23 | 2023-06-02 | 南通探维光电科技有限公司 | 基于知识蒸馏的识别模型训练方法及装置 |
CN116206275B (zh) * | 2023-02-23 | 2024-03-01 | 南通探维光电科技有限公司 | 基于知识蒸馏的识别模型训练方法及装置 |
CN117372819B (zh) * | 2023-12-07 | 2024-02-20 | 神思电子技术股份有限公司 | 用于有限模型空间的目标检测增量学习方法、设备及介质 |
CN117372819A (zh) * | 2023-12-07 | 2024-01-09 | 神思电子技术股份有限公司 | 用于有限模型空间的目标检测增量学习方法、设备及介质 |
CN117496509A (zh) * | 2023-12-25 | 2024-02-02 | 江西农业大学 | 一种融合多教师知识蒸馏的Yolov7柚子计数方法 |
CN117496509B (zh) * | 2023-12-25 | 2024-03-19 | 江西农业大学 | 一种融合多教师知识蒸馏的Yolov7柚子计数方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113569882A (zh) | 一种基于知识蒸馏的快速行人检测方法 | |
CN105787458A (zh) | 基于人工设计特征和深度学习特征自适应融合的红外行为识别方法 | |
CN109801297B (zh) | 一种基于卷积实现的图像全景分割预测优化方法 | |
CN111723693A (zh) | 一种基于小样本学习的人群计数方法 | |
CN110705412A (zh) | 一种基于运动历史图像的视频目标检测方法 | |
CN111368634B (zh) | 基于神经网络的人头检测方法、***及存储介质 | |
CN110852199A (zh) | 一种基于双帧编码解码模型的前景提取方法 | |
CN116343185A (zh) | 一种面向助盲领域的指示牌语义信息提取方法 | |
CN115019039A (zh) | 一种结合自监督和全局信息增强的实例分割方法及*** | |
CN103500456A (zh) | 一种基于动态贝叶斯模型网络的对象跟踪方法和设备 | |
CN113780187A (zh) | 交通标志识别模型训练方法、交通标志识别方法和装置 | |
CN113361475A (zh) | 一种基于多阶段特征融合信息复用的多光谱行人检测方法 | |
CN116452472A (zh) | 基于语义知识引导的低照度图像增强方法 | |
CN112069997B (zh) | 一种基于DenseHR-Net的无人机自主着陆目标提取方法及装置 | |
CN114694090A (zh) | 一种基于改进PBAS算法与YOLOv5的校园异常行为检测方法 | |
CN113963021A (zh) | 一种基于时空特征和位置变化的单目标跟踪方法及*** | |
CN113642498A (zh) | 一种基于多层次时空特征融合的视频目标检测***及方法 | |
CN112347962A (zh) | 一种基于感受野的卷积神经网络目标检测***与方法 | |
Ni et al. | Fusion learning model for mobile face safe detection and facial gesture analysis | |
CN116152699B (zh) | 用于水电厂视频监控***的实时运动目标检测方法 | |
CN116778277B (zh) | 基于渐进式信息解耦的跨域模型训练方法 | |
Li et al. | Single image defogging method based on improved generative adversarial network | |
CN114359336A (zh) | 基于光流和动态级联rpn的目标跟踪算法 | |
Wang | Research on classroom behavior recognition based on convolutional neural network | |
Li et al. | Building Recognition of Aerial Images Based on Improved Unet Network |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
WD01 | Invention patent application deemed withdrawn after publication | ||
WD01 | Invention patent application deemed withdrawn after publication |
Application publication date: 20211029 |