CN114492755A - 基于知识蒸馏的目标检测模型压缩方法 - Google Patents
基于知识蒸馏的目标检测模型压缩方法 Download PDFInfo
- Publication number
- CN114492755A CN114492755A CN202210106356.8A CN202210106356A CN114492755A CN 114492755 A CN114492755 A CN 114492755A CN 202210106356 A CN202210106356 A CN 202210106356A CN 114492755 A CN114492755 A CN 114492755A
- Authority
- CN
- China
- Prior art keywords
- network model
- teacher
- student
- model
- target detection
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
- G06N5/022—Knowledge engineering; Knowledge acquisition
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computational Linguistics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本发明提供一种基于知识蒸馏的目标检测模型压缩方法,通过FPN分别提取教师网络模型和学生网络模型的特征图,通过计算两者的对应的特征图的Gram矩阵之间的差异,并通过反向传播,能够使学生网络模型向教师网络模型学习不同通道之间的相程度,进而提高学生网络模型的检测精度,从而能够对目标检测模型进行有效压缩,并且在压缩的同时保证检测精度。其中,教师网络模型为以ResNet101为骨干网络的Faster RCNN,学生网络模型为以ResNet50为骨干网络的Faster RCNN,因此,减少了约一半的中间层的层数,实现了有效的模型压缩,并且通过应用Gram矩阵,保证了压缩后模型的检测精度。
Description
技术领域
本发明属于图像识别技术领域,具体涉及一种基于知识蒸馏的目标检测模型压缩方法。
背景技术
目标检测技术主要用于检测图像或视频中的目标,检测内容包括目标的类别以及目标的坐标。近几年来,基于深度学习目标检测算法取得了很大的进展,比较流行的算法可以分为两类,一类是基于Region Proposal的R-CNN系算法(包括R-CNN、FastR-CNN、FasterR-CNN等),它们是Two-stage算法,需要预先产生目标候选框,然后再对候选框做分类与回归。而另一类是YOLO、SSD这类One-stage算法,可以使用一个卷积神经网络CNN直接预测得到不同目标的类别与位置。相对而言,Two-stage方法准确度高一些,但是速度慢,而One-stage方法虽然速度快,但准确性要比Two-stage低一些。
然而,目前的目标检测模型通常很大,虽然其性能较好,但速度很慢,为了使目标检测技术能够应用于诸多对实时处理能力要求较高的应用场景,需要对目标检测模型进行压缩,并需要保证在压缩后,模型仍有较好的性能。现有技术中,还缺乏对目标检测模型进行压缩的有效方法。
知识蒸馏是一种常见的模型压缩方法,在图像识别,自然语言处理,推荐***等多个领域都有应用,主要是通过构建教师网络与学生网络的训练框架,在训练过程中,固定教师网络的权重,通过使学生网络在反向传播更新参数的过程中向教师网络进行学习,进而提升学生网络的性能,有望能够解决上述问题。然而,现有技术中还缺乏相应的方法,对相关技术人员而言,如何利用知识蒸馏对目标检测模型进行压缩,仍存在诸多技术上的困难。
发明内容
本发明是为解决上述问题而进行的,目的在于提供一种利用知识蒸馏方法对目标检测模型进行压缩的方法,本发明采用了如下技术方案:
本发明提供了一种基于知识蒸馏的目标检测模型压缩方法,其特征在于,包括:步骤S1,对教师网络模型进行训练;步骤S2,将训练好的所述教师网络模型的权重进行固定,通过FPN提取出所述教师网络模型的多层教师特征图;步骤S3,通过所述FPN提取出所述学生网络模型的多层学生特征图,并对所述学生特征图进行卷积,使所述学生特征图的通道数与所述教师特征图的通道数相同;步骤S4,对每一层的所述教师特征图和对应的所述学生特征图,计算两者的Gram矩阵;步骤S5,计算各个所述Gram矩阵的损失值,并对多个所述损失值求和;步骤S6,将所述损失值的和与所述教师网络模型的权重系数相乘,将相乘结果添加到最终的损失函数中,通过反射传播对所述学生网络模型进行优化。
本发明提供的基于知识蒸馏的目标检测模型压缩方法,还可以具有这样的技术特征,其中,所述教师网络模型为以ResNet101为骨干网络的目标检测模型Faster RCNN,所述学生网络模型为以ResNet50为骨干网络的目标检测模型Faster RCNN。
本发明提供的基于知识蒸馏的目标检测模型压缩方法,还可以具有这样的技术特征,其中,步骤S4中,通过以下公式计算所述Gram矩阵:
本发明提供的基于知识蒸馏的目标检测模型压缩方法,还可以具有这样的技术特征,其中,步骤S5中,对所述Gram矩阵进行逐元素L2差值计算,计算得到的差值即所述损失值。
本发明提供的基于知识蒸馏的目标检测模型压缩方法,还可以具有这样的技术特征,其中,所述FPN为四层结构的FPN,经所述FPN处理后,分别得到四种尺寸的所述教师特征图以及四种尺寸的所述学生特征图。
本发明提供的基于知识蒸馏的目标检测模型压缩方法,还可以具有这样的技术特征,其中,步骤S1中,在Pascal VOC数据集上对所述教师网络模型进行训练。
发明作用与效果
根据本发明的基于知识蒸馏的目标检测模型压缩方法,通过FPN分别提取教师网络模型和学生网络模型的特征图,通过计算两者的对应的特征图的Gram矩阵之间的差异,并通过反向传播,能够使学生网络模型向教师网络模型学习不同通道之间的相似程度,进而提高学生网络模型的检测精度,从而能够对目标检测模型进行有效压缩,并且在进行压缩的同时保证检测精度。
附图说明
图1是本发明实施例中基于知识蒸馏的目标检测模型压缩方法的流程图;
图2是本发明实施例中基于知识蒸馏的目标检测模型压缩方法的流程简图;
图3是本发明实施例中目标检测模型Faster RCNN的原理示意图;
图4是本发明实施例中计算Gram矩阵的示意图。
具体实施方式
为了使本发明实现的技术手段、创作特征、达成目的与功效易于明白了解,以下结合实施例及附图对本发明的基于知识蒸馏的目标检测模型压缩方法作具体阐述。
<实施例>
图1是本实施例中基于知识蒸馏的目标检测模型压缩方法的流程图。图2是本实施例中基于知识蒸馏的目标检测模型压缩方法的流程简图。
如图1和图2所示,基于知识蒸馏的目标检测模型压缩方法具体包括如下步骤:
步骤S1,对教师网络模型进行训练。
本实施例中,教师网络模型为以ResNet101为骨干网络(backbone)的目标检测模型Faster RCNN,在Pascal VOC数据集上对教师网络模型进行训练。教师网络模型也即待压缩的模型。
图3是本实施例中目标检测模型Faster RCNN的原理示意图。
如图3所示,目标检测模型Faster RCNN的工作包括如下步骤:
首先,将训练集中的图片用骨干网络backbone进行特征提取,然后,通过FPN进行特征融合,接着,通过Region proposal Network提取region proposal,使用生成的regionproposal对特征图截出proposal,将截出的proposal使用roi pooling处理后,送入到后续的检测网络进行定位与分类。Faster RCNN的结构以及原理为现有技术,因此不再赘述。
步骤S2,将训练好的教师网络模型的权重进行固定,通过FPN提取出教师网络模型的多层特征图(为叙述方便,记作教师特征图),用于供学生网络模型进行学习。
本实施例中,学生网络模型为以ResNet50为骨干网络的Faster RCNN,因此,其中间层的层数显著少于教师网络模型,其运行速度更快。教师网络模型的骨干网络ResNet101与学生网络模型的骨干网络ResNet50结构相同,均由四个stage拼接而成,区别之处在于,每个stage的深度不同,因此拼接而成的网络模型的整体深度存在差异。
FPN的全名为Feature Pyramid Network,FPN结构的主要目的是为了更好地融合特征,低分辨率的特征图具有较好的语义信息,高分辨率的特征图具有较好的位置信息,通过FPN结构,将高层的低分辨率特征图信息融合到低层高分辨率的特征图上,从而使得网络能够获得更好的检测效果。
步骤S3,通过FPN提取出学生网络模型的多层特征图(记作学生特征图),并对学生特征图进行卷积处理,使学生特征图的通道数与教师特征图的通道数相同。
本实施例中,教师网络模型和学生网络模型均会将其骨干网络得到的特征图送入FPN结构进行处理,经FPN处理后会得到与对应的骨干网络相同的四种尺寸的特征图,其中,学生网络模型得到特征图尺寸与教师网络模型得到的特征图尺寸相同,因此可以对教师网络模型与学生网络模型得到的特征图处理后计算差值,使得学生网络模型能够向教师网络模型进行学习。
对于学生网络模型经FPN处理后的特征图,为了将其与教师网络模型对应的特征图进行适配,会对每种分辨率的特征图进行一次卷积处理,防止教师网络模型与学生网络模型得到的特征图通道数不同,进而无法进行后续计算。同时,由过去的论文研究表明,即使教师网络模型与学生网络模型的通道数相同,进行该卷积层的适配处理仍然能够提升检测效果。
步骤S4,对每一层的教师特征图和对应的学生特征图,分别计算Gram矩阵。
为便于进行说明,将一个FPN层的维度记作[N,C,H,W],其中,N为batch size,C为特征图的通道数,H为特征图的高度,W为特征图的宽度。
图4是本实施例中计算Gram矩阵的示意图。
如图4所示,本实施例用到的Gram矩阵计算方法用公式可表示为:
输入的特征图的部分维度为[c,h,w],经过flatten(即将维度中的h×w平铺成一维向量)以及矩阵转置操作,可以变形为[c,h×w]矩阵以及[h×w,ch]矩阵。再对这两个矩阵作内积,得到Gram矩阵,最后得到[c×ch]的Gram矩阵。
步骤S5,计算教师特征图的Gram矩阵和对应的学生特征图的Gram矩阵的损失值,并对多个损失值求和。
本实施例中,对各个Gram矩阵进行Element-wiseL2Loss的计算,也即对Gram矩阵进行逐元素L2差值计算,并对求得的多个差值进行求和。由于差值的和过大,因此会进行归一化处理,通过计算特征图的尺寸,将尺寸求和作为归一化值,将得到的差值和与尺寸和相除,即可得到教师网络模型与学生网络模型对应层Gram矩阵的最终差值。
通过计算Gram矩阵,使学生网络模型向教师网络模型进行学习的主要思路借鉴自图像风格迁移。Gram矩阵可以看做特征之间的偏心协方差矩阵,在feature map中,每个数字都来自于一个特定滤波器在特定位置的卷积,因此每个数字代表一个特征的强度,Gram矩阵计算的实际上是两两特征之间的相关性,哪两个特征是同时出现的,哪两个是此消彼长的。
Gram矩阵的含义为n维欧式空间中任意k个向量之间两两的内积所组成的矩阵,Gram矩阵用于度量各个维度自己的特性以及各个维度之间的关系,内积之后得到的多尺度矩阵中,对角线元素提供了不同特征图各自的信息,其余元素提供了不同特征图之间的相关信息。这样一个矩阵,既能体现出有哪些特征,又能体现出不同特征间的紧密程度。通过计算教师网络模型与学生网络模型的Gram矩阵之间的差异,能够使学生网络模型向教师网络模型学习特征之间的上述特性。
步骤S6,将损失值的和与教师网络模型的权重系数相乘,将相乘得到的结果添加到最终的损失函数中,通过反射传播对学生网络模型进行优化。
通过上述步骤,得到了训练好的学生网络模型,该学生网络模型能够实现与教师网络模型相近的检测效果,且中间层的层数显著少于教师网络模型,因此,实现了对目标检测模型的压缩。
本实施例中,将训练好的学生网络模型在Pascal VOC数据集上进行测试,检测模型性能,并将原始的学生网络模型(也即未经本实施例的方法优化的模型)同样在PascalVOC数据集上进行测试,进行性能对比。经实验比较,经过知识蒸馏提升过的学生网络模型相比原始的学生网络模型,在检测精度上有明显的性能提升。
实施例作用与效果
根据本实施例提供的基于知识蒸馏的目标检测模型压缩方法,通过FPN分别提取教师网络模型和学生网络模型的特征图,通过计算两者的对应的特征图的Gram矩阵之间的差异,并通过反向传播,能够使学生网络模型向教师网络模型学习不同通道之间的相程度,进而提高学生网络模型的检测精度,从而能够对目标检测模型进行有效压缩,并且在压缩的同时保证检测精度。
实施例中,教师网络模型为以ResNet101为骨干网络的Faster RCNN,学生网络模型为以ResNet50为骨干网络的Faster RCNN,因此,减少了约一半的中间层的层数,实现了有效的模型压缩,并且通过应用Gram矩阵,保证了压缩后模型的检测精度。
上述实施例仅用于举例说明本发明的具体实施方式,而本发明不限于上述实施例的描述范围。
在上述实施例中,目标检测模型压缩方法应用于two-stage模型,在替代方案中,该方法同样可用于one-stage模型,也能实现本发明的技术效果。
Claims (6)
1.一种基于知识蒸馏的目标检测模型压缩方法,其特征在于,包括:
步骤S1,对教师网络模型进行训练;
步骤S2,将训练好的所述教师网络模型的权重进行固定,通过FPN提取出所述教师网络模型的多层教师特征图;
步骤S3,通过所述FPN提取出所述学生网络模型的多层学生特征图,并对所述学生特征图进行卷积,使所述学生特征图的通道数与所述教师特征图的通道数相同;
步骤S4,对每一层的所述教师特征图和对应的所述学生特征图,分别计算Gram矩阵;
步骤S5,计算所述教师特征图的所述Gram矩阵和对应的所述学生特征图的所述Gram矩阵的损失值,并对多个所述损失值求和;
步骤S6,将所述损失值的和与所述教师网络模型的权重系数相乘,将相乘结果添加到最终的损失函数中,通过反射传播对所述学生网络模型进行优化。
2.根据权利要求1所述的基于知识蒸馏的目标检测模型压缩方法,其特征在于:
其中,所述教师网络模型为以ResNet101为骨干网络的目标检测模型Faster RCNN,
所述学生网络模型为以ResNet50为骨干网络的目标检测模型Faster RCNN。
4.根据权利要求1所述的基于知识蒸馏的目标检测模型压缩方法,其特征在于:
其中,步骤S5中,对所述Gram矩阵进行逐元素L2差值计算,计算得到的差值即所述损失值。
5.根据权利要求1所述的基于知识蒸馏的目标检测模型压缩方法,其特征在于:
其中,所述FPN为四层结构的FPN,
经所述FPN处理后,分别得到四种尺寸的所述教师特征图以及四种尺寸的所述学生特征图。
6.根据权利要求1所述的基于知识蒸馏的目标检测模型压缩方法,其特征在于:
其中,步骤S1中,在Pascal VOC数据集上对所述教师网络模型进行训练。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210106356.8A CN114492755A (zh) | 2022-01-28 | 2022-01-28 | 基于知识蒸馏的目标检测模型压缩方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210106356.8A CN114492755A (zh) | 2022-01-28 | 2022-01-28 | 基于知识蒸馏的目标检测模型压缩方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114492755A true CN114492755A (zh) | 2022-05-13 |
Family
ID=81477414
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210106356.8A Pending CN114492755A (zh) | 2022-01-28 | 2022-01-28 | 基于知识蒸馏的目标检测模型压缩方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114492755A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114998570A (zh) * | 2022-07-19 | 2022-09-02 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
CN116502706A (zh) * | 2023-06-26 | 2023-07-28 | 中科领航智能科技(苏州)有限公司 | 一种面向车道线检测的知识蒸馏方法 |
-
2022
- 2022-01-28 CN CN202210106356.8A patent/CN114492755A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114998570A (zh) * | 2022-07-19 | 2022-09-02 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
CN116502706A (zh) * | 2023-06-26 | 2023-07-28 | 中科领航智能科技(苏州)有限公司 | 一种面向车道线检测的知识蒸馏方法 |
CN116502706B (zh) * | 2023-06-26 | 2023-10-10 | 中科领航智能科技(苏州)有限公司 | 一种面向车道线检测的知识蒸馏方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
EP3968179A1 (en) | Place recognition method and apparatus, model training method and apparatus for place recognition, and electronic device | |
WO2020228446A1 (zh) | 模型训练方法、装置、终端及存储介质 | |
CN111259940B (zh) | 一种基于空间注意力地图的目标检测方法 | |
CN111898432B (zh) | 一种基于改进YOLOv3算法的行人检测***及方法 | |
CN111968150B (zh) | 一种基于全卷积神经网络的弱监督视频目标分割方法 | |
CN115171165A (zh) | 全局特征与阶梯型局部特征融合的行人重识别方法及装置 | |
CN111709313B (zh) | 基于局部和通道组合特征的行人重识别方法 | |
CN114492755A (zh) | 基于知识蒸馏的目标检测模型压缩方法 | |
CN110321805B (zh) | 一种基于时序关系推理的动态表情识别方法 | |
CN114821390B (zh) | 基于注意力和关系检测的孪生网络目标跟踪方法及*** | |
CN112381763A (zh) | 一种表面缺陷检测方法 | |
CN112507920B (zh) | 一种基于时间位移和注意力机制的考试异常行为识别方法 | |
CN110705600A (zh) | 一种基于互相关熵的多深度学习模型融合方法、终端设备及可读存储介质 | |
CN114330499A (zh) | 分类模型的训练方法、装置、设备、存储介质及程序产品 | |
CN111639230B (zh) | 一种相似视频的筛选方法、装置、设备和存储介质 | |
CN114140623A (zh) | 一种图像特征点提取方法及*** | |
CN111179270A (zh) | 基于注意力机制的图像共分割方法和装置 | |
CN115761888A (zh) | 基于nl-c3d模型的塔吊操作人员异常行为检测方法 | |
CN116468919A (zh) | 图像局部特征匹配方法及*** | |
CN116580322A (zh) | 一种地面背景下无人机红外小目标检测方法 | |
CN113609904B (zh) | 一种基于动态全局信息建模和孪生网络的单目标跟踪算法 | |
CN117409244A (zh) | 一种SCKConv多尺度特征融合增强的低照度小目标检测方法 | |
CN115471718A (zh) | 基于多尺度学习的轻量级显著性目标检测模型的构建和检测方法 | |
CN116012903A (zh) | 一种人脸表情自动标注的方法及*** | |
CN112487927B (zh) | 一种基于物体关联注意力的室内场景识别实现方法及*** |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |