CN115565021A - 基于可学习特征变换的神经网络知识蒸馏方法 - Google Patents
基于可学习特征变换的神经网络知识蒸馏方法 Download PDFInfo
- Publication number
- CN115565021A CN115565021A CN202211196707.5A CN202211196707A CN115565021A CN 115565021 A CN115565021 A CN 115565021A CN 202211196707 A CN202211196707 A CN 202211196707A CN 115565021 A CN115565021 A CN 115565021A
- Authority
- CN
- China
- Prior art keywords
- knowledge distillation
- loss
- model
- feature map
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/7715—Feature extraction, e.g. by transforming the feature space, e.g. multi-dimensional scaling [MDS]; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
本发明提出了一种基于可学习特征变换的神经网络知识蒸馏方法,属于计算机视觉技术领域。本发明对齐学生模型与教师模型的中间特征和输出结果,无需针对不同任务设计复杂的特征变换模块,不引入复杂的超参数,免去了繁琐的参数调整步骤,可以提高知识蒸馏在多个任务上的通用性,提升知识蒸馏效果的同时免去了手工设计结构的繁琐,在多个计算机视觉任务上(如图片分类、目标检测、语义分割等)实现了性能提升。
Description
技术领域
本发明属于计算机视觉技术领域,涉及计算机视觉、神经网络模型压缩、基于中间特征的神经网络知识蒸馏等深度学习技术。
背景技术
近年来,随着深度学***台上进行部署。为解决这一问题,需要使用到神经网络模型压缩技术。
知识蒸馏是目前神经网络模型压缩技术中一种重要的方法,该方法将大规模神经网络作为教师网络,将小规模神经网络作为学生网络,将教师网络的知识传递到学生网络中,进而获得一个复杂度低、性能好、易于部署的神经网络,达到模型压缩的目的。
目前,主流的知识蒸馏方法分为基于输出响应和基于中间特征的知识蒸馏,基于输出响应的知识蒸馏方法将教师模型尾层的预测结果作为监督信息,指导学生模型对教师模型的行为进行模仿。基于中间特征的知识蒸馏方法则将教师模型中间隐藏层的特征作为监督信号指导学生模型训练。在实际应用中,针对不同的视觉任务衍生出了多种多样的知识蒸馏方法,而这些方法往往有很多手工设计的部分,如损失函数、特征掩膜,而这些手工设计的部分一方面使得蒸馏方法的通用性降低,另一方面带来额外的超参数,使得调参难度增大。
发明内容
为了解决上述问题,本发明提出了一种基于可学习特征变换的知识蒸馏方法,将学生模型的中间特征与输出响应与教师模型进行对齐,提升知识蒸馏效果的同时免去了手工设计结构的繁琐,在多个计算机视觉任务上(如图片分类、目标检测、语义分割等)实现了性能提升。
本发明提供的技术方案是:
一种基于可学习特征变换的知识蒸馏方法,如图1所示,其步骤包括:
1)将输入数据输入教师模型,所述教师模型的中间层输出第一特征图,将所述输入数据输入学生模型,所述学生模型的中间层输出第二特征图;
2)将第二特征图与第一特征图进行空间维度和通道维度上的对齐,对齐后的特征图通过一个多层感知机模块得到第三特征图;同时,对对齐后的特征图的形状展开和转置,再通过另一个多层感知机模块得到变换后的特征图,再将变换后的特征图形状恢复成变换前的形状,得到第四特征图;
3)计算第一特征图和第三特征图间的均方差损失作为空间特征损失,计算第一特征图和第四特征图间的均方差损失作为通道特征损失,将所述空间特征损失和所述通道特征损失加权求和作为教师模型与学生模型间的知识蒸馏损失函数;
4)根据所述知识蒸馏损失函数,对学生模型进行训练实现知识蒸馏。
优选地,所述多层感知机模块为隐藏层数为1,激活函数为ReLU的多层感知机结构。
优选地,通过双线性插值和1x1卷积将所述第二特征图与所述第一特征图进行空间维度和通道维度上的对齐。
进一步,取得所述学生模型的下游任务,根据下游任务类型匹配模型的目标函数,将目标函数和知识蒸馏损失函数组合对学生模型进行训练。
进一步,根据所述教师模型、所述学生模型、所述下游任务调整所述蒸馏损失函数的超参数,将所述目标函数中的回归损失函数、分类损失函数和知识蒸馏损失函数求和获得所述学生模型训练的总损失函数,根据该总损失函数对所述学生模型进行训练。
本发明的有益效果:
本发明提供一种基于可学习特征变换的知识蒸馏方法,对齐教师模型和学生模型的特征,提高蒸馏效果,同时无需针对不同任务设计复杂的特征变换模块,不引入复杂的超参数,免去了繁琐的参数调整步骤,提高了知识蒸馏在多个任务上的通用性,在多种计算机视觉任务上均能取得不错的效果。
附图说明
图1为本发明基于可学习特征变换的知识蒸馏方法流程示意图;
图2为本发明具体实施例学生模型的训练过程架构示意图。
具体实施方式
下面结合附图,通过实例进一步描述本发明,但不以任何方式限制本发明的范围。
以大规模目标检测数据集COCO为例,以在该数据上预训练好的RetinNet-rx101作为教师模型,并选取RetinaNet-R50作为学生模型来说明如何通过可学习变换模块进行目标检测任务上的知识蒸馏,如图2所示。
步骤S1:将输入数据输入教师模型得到所述教师模型的中间层输出的第一特征图,将所述输入数据输入学生模型得到所述学生模型的中间层输出的第二特征图,具体包括:
S11:将任意一批原始的训练图片输入进教师模型RetinNet-rx101中,在所述教师模型的FPN部分得到中间层输出的第一特征图。
S12:将所述训练图片输入进学生模型RetinaNet-R50中,在所述学生模型的FPN部分得到中间层输出的第二特征图。
步骤S2:利用多层感知机模块得到第三特征图和第四特征图,具体包括:
S21:通过双线性插值和1x1卷积将所述第二特征图与所述第一特征图进行空间维度和通道维度上的对齐,得到对齐后的特征图;
S22:将所述对齐后的特征图通过一个隐藏层数为1,激活函数为ReLU的多层感知机模块得到第三特征图。
S23:设所述对齐后的特征图形状为[N,C,H,W],将该特征图的形状通过展开和转置操作调整为[N,(H*W),C],将调整后的特征图通过一个隐藏层数为1,激活函数为ReLU的多层感知机模块得到变换后的特征图,再将变换后的特征图形状调整为[N,C,H,W]得到所述第四特征图。
步骤S3:根据所述第一特征图、第三特征图和第四特征图,计算所述教师模型和所述学生模型间的空间特征损失和通道特征损失,将所述空间特征损失和所述通道特征损失加权求和作为所述教师模型与所述学生模型间的知识蒸馏损失函数,具体包括:
S31:计算所述第一特征图和所述第三特征图间的均方差损失作为所述空间特征损失,其表达式为:
S32:计算所述第一特征图和所述第四特征图间的均方差损失作为所述通道特征损失,其表达式为:
S33:将所述空间特征损失和所述通道特征损失加权求和得到所述知识蒸馏损失函数,其表达式为:
Ldistill=αLossSpatial+βLossChannel
其中α,β为超参数,在本实施例中分别设定为2e-5和1e-6。
步骤S4:根据所述知识蒸馏损失函数,对学生模型进行训练实现知识蒸馏。
进一步,取得所述学生模型的下游任务,在本实施例中,下游任务为目标检测任务。
步骤S5:根据所述下游任务类型匹配模型目标函数,在本实施例中,模型的目标函数分为回归损失函数和分类损失函数,所述回归损失函数表达式为:
在本实施例中,所述分类损失函数采用Focal Loss,其表达式为:
Lcls=-αt(1-pt)γlog(pt)
其中pt为样本被正确分类的概率值,αt,γ为超参数,在本实施例中分别设定为0.25,2.0。
步骤S6:根据教师模型、学生模型、下游任务调整所述蒸馏损失函数的超参数,目标函数、知识蒸馏损失函数和超参数获得所述学生模型训练的总损失函数;根据所述总损失函数对所述学生模型进行训练,其中所述总损失函数的表达式为:
Ltotal=Lreg+Lcls+Ldistill。
对于图像分类任务,在ImageNet数据集上的结果表明,使用ResNet34作为教师模型,ResNet18作为学生模型,采用本发明所提出的蒸馏方法进行知识蒸馏,可以将测试集上的Top-1准确率从69.9%提升到了71.4%;对于目标检测任务,在MSCOCO数据集上的结果表明,使用RetinaNet-RX101作为教师模型,RetinaNet-R50作为学生模型,采用本发明所提的知识蒸馏方法,可以将学生模型的mAP从37.4%提升到41.0%;对于语义分割任务,在CityScapes数据集上的结果表明,使用PSPNet-ResNet34作为教师模型,PSPNet-ResNet18作为学生模型,采用本发明所提的知识蒸馏方法,可以将学生模型的mIoU从69.9%提升到74.2%(注:ImageNet是一个大规模图像分类数据集,Top1-accuracy用于衡量图像分类准确率;MSCOCO是一个大规模数据集,包含目标检测等任务,bbox的mAP是衡量目标检测性能的一个指标;CityScapes是一个语义分割数据集,mIoU是衡量语义分割性能的一个指标。)此外,本发明也可用于实现跨模型的知识蒸馏,并能取得不错的效果。例如,对于图像分类任务,在Cifar-100数据集上,使用基于卷积神经网络架构的ResNet56作为教师模型,基于Transformer架构的ViT-tiny作为学生模型,可以将学生模型的Top1-accuracy由57.8%提升至77.5%(注:Cifar100是一个小规模图像分类数据集)。
以上通过详细实施案例描述了本发明,本领域的研究人员和技术人员可以根据上述的步骤作出形式或内容方面的非实质性的改变而不偏离本发明实质保护的范围。因此,本发明不局限于以上实施例中所公开的内容,本发明的保护范围应以权利要求所述为准。
Claims (5)
1.一种基于可学习特征变换的知识蒸馏方法,其特征在于,其步骤包括:
1)将输入数据输入教师模型,所述教师模型的中间层输出第一特征图,将所述输入数据输入学生模型,所述学生模型的中间层输出第二特征图;
2)将第二特征图与第一特征图进行空间维度和通道维度上的对齐,对齐后的特征图通过一个多层感知机模块得到第三特征图;同时,对对齐后的特征图的形状展开和转置,再通过另一个多层感知机模块得到变换后的特征图,再将变换后的特征图形状恢复成变换前的形状,得到第四特征图;
3)计算第一特征图和第三特征图间的均方差损失作为空间特征损失,计算第一特征图和第四特征图间的均方差损失作为通道特征损失,将所述空间特征损失和所述通道特征损失加权求和作为教师模型与学生模型间的知识蒸馏损失函数;
4)根据所述知识蒸馏损失函数,对学生模型进行训练实现知识蒸馏。
2.如权利要求1所述的基于可学习特征变换的知识蒸馏方法,其特征在于,步骤2)中所述多层感知机模块采用隐藏层数为1,激活函数为ReLU的多层感知机。
3.如权利要求1所述的基于可学习特征变换的知识蒸馏方法,其特征在于,步骤2)中通过双线性插值和1x1卷积将第二特征图与第一特征图进行空间维度和通道维度上的对齐。
4.如权利要求1所述的基于可学习特征变换的知识蒸馏方法,其特征在于,步骤4)中获取学生模型的下游任务,根据下游任务类型匹配模型的目标函数,将目标函数和知识蒸馏损失函数组合对学生模型进行训练。
5.如权利要求4所述的基于可学习特征变换的知识蒸馏方法,其特征在于,步骤4)中将所述目标函数中的回归损失函数、分类损失函数和知识蒸馏损失函数求和获得所述学生模型训练的总损失函数,根据该总损失函数对所述学生模型进行训练。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211196707.5A CN115565021A (zh) | 2022-09-28 | 2022-09-28 | 基于可学习特征变换的神经网络知识蒸馏方法 |
PCT/CN2022/143756 WO2024066111A1 (zh) | 2022-09-28 | 2022-12-30 | 图像处理模型的训练、图像处理方法、装置、设备及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211196707.5A CN115565021A (zh) | 2022-09-28 | 2022-09-28 | 基于可学习特征变换的神经网络知识蒸馏方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115565021A true CN115565021A (zh) | 2023-01-03 |
Family
ID=84743371
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211196707.5A Pending CN115565021A (zh) | 2022-09-28 | 2022-09-28 | 基于可学习特征变换的神经网络知识蒸馏方法 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN115565021A (zh) |
WO (1) | WO2024066111A1 (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118015316B (zh) * | 2024-04-07 | 2024-06-11 | 之江实验室 | 一种图像匹配模型训练的方法、装置、存储介质、设备 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111160409A (zh) * | 2019-12-11 | 2020-05-15 | 浙江大学 | 一种基于共同特征学习的异构神经网络知识重组方法 |
CN112819050B (zh) * | 2021-01-22 | 2023-10-27 | 北京市商汤科技开发有限公司 | 知识蒸馏和图像处理方法、装置、电子设备和存储介质 |
CN114998694A (zh) * | 2022-06-08 | 2022-09-02 | 上海商汤智能科技有限公司 | 图像处理模型的训练方法、装置、设备、介质和程序产品 |
-
2022
- 2022-09-28 CN CN202211196707.5A patent/CN115565021A/zh active Pending
- 2022-12-30 WO PCT/CN2022/143756 patent/WO2024066111A1/zh unknown
Also Published As
Publication number | Publication date |
---|---|
WO2024066111A1 (zh) | 2024-04-04 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113743514B (zh) | 一种基于知识蒸馏的目标检测方法及目标检测终端 | |
CN110349185B (zh) | 一种rgbt目标跟踪模型的训练方法及装置 | |
CN114419449B (zh) | 一种自注意力多尺度特征融合的遥感图像语义分割方法 | |
CN114943963A (zh) | 一种基于双分支融合网络的遥感图像云和云影分割方法 | |
CN114066831B (zh) | 一种基于两阶段训练的遥感图像镶嵌质量无参考评价方法 | |
CN114937204A (zh) | 一种轻量级多特征聚合的神经网络遥感变化检测方法 | |
CN115565021A (zh) | 基于可学习特征变换的神经网络知识蒸馏方法 | |
CN116524307A (zh) | 一种基于扩散模型的自监督预训练方法 | |
Lin et al. | Application of ConvLSTM network in numerical temperature prediction interpretation | |
CN116977872A (zh) | 一种CNN+Transformer遥感图像检测方法 | |
Manzari et al. | A robust network for embedded traffic sign recognition | |
CN111179272A (zh) | 一种面向道路场景的快速语义分割方法 | |
CN112989843B (zh) | 意图识别方法、装置、计算设备及存储介质 | |
CN117830788A (zh) | 一种多源信息融合的图像目标检测方法 | |
Yuan et al. | Multi-branch bounding box regression for object detection | |
Pan et al. | PMT-IQA: progressive multi-task learning for blind image quality assessment | |
Cui et al. | WetlandNet: Semantic segmentation for remote sensing images of coastal wetlands via improved UNet with deconvolution | |
Huan et al. | Remote sensing image reconstruction using an asymmetric multi-scale super-resolution network | |
CN112149496A (zh) | 一种基于卷积神经网络的实时道路场景分割方法 | |
Yang et al. | Amd: Adaptive masked distillation for object detection | |
Wang et al. | Face super-resolution via hierarchical multi-scale residual fusion network | |
CN115601745A (zh) | 一种面向应用端的多视图三维物体识别方法 | |
Chen et al. | Multi-modal feature fusion based on variational autoencoder for visual question answering | |
CN112287989B (zh) | 一种基于自注意力机制的航空影像地物分类方法 | |
CN115546590B (zh) | 一种基于多模态预训练持续学习的目标检测优化方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |