CN117011640A - 基于伪标签滤波的模型蒸馏实时目标检测方法及装置 - Google Patents
基于伪标签滤波的模型蒸馏实时目标检测方法及装置 Download PDFInfo
- Publication number
- CN117011640A CN117011640A CN202310815686.9A CN202310815686A CN117011640A CN 117011640 A CN117011640 A CN 117011640A CN 202310815686 A CN202310815686 A CN 202310815686A CN 117011640 A CN117011640 A CN 117011640A
- Authority
- CN
- China
- Prior art keywords
- model
- loss
- real
- data set
- student model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 80
- 238000004821 distillation Methods 0.000 title claims abstract description 31
- 238000001914 filtration Methods 0.000 title claims abstract description 27
- 238000012549 training Methods 0.000 claims abstract description 30
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 12
- 238000000034 method Methods 0.000 claims abstract description 12
- PXFBZOLANLWPMH-UHFFFAOYSA-N 16-Epiaffinine Natural products C1C(C2=CC=CC=C2N2)=C2C(=O)CC2C(=CC)CN(C)C1C2CO PXFBZOLANLWPMH-UHFFFAOYSA-N 0.000 claims description 21
- 230000009466 transformation Effects 0.000 claims description 18
- 238000000605 extraction Methods 0.000 claims description 9
- 238000004590 computer program Methods 0.000 claims description 6
- 230000006870 function Effects 0.000 claims description 6
- 238000013508 migration Methods 0.000 claims description 3
- 230000005012 migration Effects 0.000 claims description 3
- 238000013528 artificial neural network Methods 0.000 claims 1
- 238000013527 convolutional neural network Methods 0.000 description 12
- 239000011159 matrix material Substances 0.000 description 8
- 238000004364 calculation method Methods 0.000 description 5
- 230000000007 visual effect Effects 0.000 description 5
- 230000008859 change Effects 0.000 description 4
- 238000013519 translation Methods 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000005286 illumination Methods 0.000 description 3
- 230000000903 blocking effect Effects 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000004927 fusion Effects 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 238000010008 shearing Methods 0.000 description 2
- 239000013598 vector Substances 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000014509 gene expression Effects 0.000 description 1
- RVRCFVVLDHTFFA-UHFFFAOYSA-N heptasodium;tungsten;nonatriacontahydrate Chemical compound O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.O.[Na+].[Na+].[Na+].[Na+].[Na+].[Na+].[Na+].[W].[W].[W].[W].[W].[W].[W].[W].[W].[W].[W] RVRCFVVLDHTFFA-UHFFFAOYSA-N 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/778—Active pattern-learning, e.g. online learning of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Multimedia (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种基于伪标签滤波的模型蒸馏实时目标检测方法及装置,基于教师模型对学生模型进行训练,然后将待测数据输入训练好的学生模型获得实时目标检测结果。基于教师模型对学生模型进行训练先获取扩充数据集,然后将扩充数据集输入训练完成的教师模型,再将生成的伪标签数据集输入质量分类器生成新的伪标签数据集,然后将新的伪标签数据集和原始数据的合集分别输入学生模型和训练好的教师模型,然后基于检测结果计算学生模型的原始损失和两个模型之间的知识蒸馏损失,最后根据原始损失和知识蒸馏损失计算整体损失反向更新学生模型参数。本发明目标检测模型实时性高且泛化能力强,方法检测精度高。
Description
技术领域
本发明涉及目标检测技术领域,特别是涉及一种基于伪标签滤波的模型蒸馏实时目标检测方法及装置。
背景技术
实时目标检测的目的是实现在实时性要求下对图像或视频中的物体进行检测和识别,实时目标检测在自动驾驶、安防监控、智能家居、医疗影像等领域都有着广泛的应用。
近年来,深度学习技术的发展为实时目标检测的研究提供了强有力的支持。目前,深度学习在实时目标检测中的应用主要分为以下两类:(1)单阶段检测方法:单阶段检测方法通常采用卷积神经网络(CNN)结构进行特征提取和分类,通过回归框的位置和大小来检测物体。典型的单阶段检测算法包括YOLO和SSD等;(2)两阶段检测方法:两阶段检测方法通常先通过卷积神经网络进行候选框的生成,然后再对候选框进行分类和定位。典型的两阶段检测算法包括RCNN、Fast RCNN和Faster-RCNN等。总的来说,实时目标检测技术在近年来得到了快速的发展和改进,不断有新的方法和算法被提出,这使得实时目标检测在实际应用中具有了更加广泛的应用前景。
但是,现有的目标检测方法的训练通常需要大量的标注数据来学习目标的特征和上下文信息。由于目标检测任务的复杂性和多样性,为了获得较好的性能,需要使用深度网络模型,例如基于卷积神经网络(CNN)或基于Transformer的模型。这些模型通常具有数百万到数十亿个参数,需要大量的计算资源和存储空间,在模型推理阶段实时性较差。且现有的实时目标检测方法通常对光照和视角的变化比较敏感。不同的视角、角度、遮挡等因素都会导致图像中目标的外观变化,这些变化会对目标的大小和形状造成影响,从而使现有的目标检测方法难以准确地检测和定位目标。当相机变焦、视角变化、以及目标被遮挡时,现有的目标检测方法泛化能力变差。
发明内容
基于此,有必要针对上述技术问题,提供一种模型实时性高且泛化能力强的基于伪标签滤波的模型蒸馏实时目标检测方法及装置。
第一方面,本发明提供了一种基于伪标签滤波的模型蒸馏实时目标检测方法,基于教师模型对学生模型进行训练,然后将待测数据输入训练好的学生模型获得实时目标检测结果,教师模型和学生模型均为目标检测模型,教师模型的层数比学生模型的层数多,教师模型的深度比学生模型的深度大;
基于教师模型对学生模型进行训练具体包括以下步骤:
获取扩充数据集;
基于扩充数据集对教师模型进行训练,将扩充数据集输入训练完成的教师模型,生成伪标签数据集;
将伪标签数据集输入质量分类器生成新的伪标签数据集;
将新的伪标签数据集和原始数据的合集分别输入学生模型和训练好的教师模型分别获得检测结果和预训练结果;
基于检测结果计算学生模型的原始损失,基于检测结果和预训练结果计算知识蒸馏损失;
根据学生模型的原始损失和知识蒸馏损失计算整体损失;
基于整体损失调整学生模型的参数,获得训练好的学生模型。
在其中一个实施例中,获取扩充数据集包括:
获取原始数据集;
对原始数据集进行随机仿射变换。
在其中一个实施例中,教师模型为YOLOv5-l模型,基于扩充数据集对教师模型进行训练时训练300epoches;
质量分类器为正负样本质量分离器。
在其中一个实施例中,学生模型的原始损失包括置信度、类别损失和边框回归损失;
置信度为
LCE_obj=-αlog(β)-(1-α)log(1-β) (1);
式中,元素β表示样本属于前景或者背景的概率,即边界框的置信度值,α=是真实标签中是否包含目标的标志(1表示包含目标,0表示不包含目标);类别损失为
式中,p(x)是一个实际得到的概率分布,每个元素pi表示样本属于第i类的概率,当样本属于第类别i时yi=1,其它均为0,nc是样本总类别数;
边框回归损失为
式中,c为同时包含预测框和真实框的最小矩形,bgt为真实框,b为预测框,ρ(bgt,b)表示真实框和预测框中心点的欧式距离,β是用于平衡函数权重的参数,ν是用来衡量两个框之间长宽比的一致性的参数,IoU项和α表示附加功率正则化项;
学生模型的原始损失为
LSTU=λ1×LCE_cls+λ2×LCE_obj+λ3×Lα-CIoU (4);
式中,λ1为0.3,λ2为0.4,λ3为0.3。
在其中一个实施例中,知识蒸馏损失为
式中,m,n表示输出结果的张量的行列,outputT,outputS分别教师模型和学生模型的输出结果。
在其中一个实施例中,整体损失为
Ltotal=α1×LSTU+α2×LDistill (6);
式中,α1为0.8,α2为0.2。
在其中一个实施例中,基于整体损失调整学生模型的参数,获得训练好的学生模型是将整体损失反向传播至学生模型,调整学生模型参数,获得训练好的学生模型。
在其中一个实施例中,使用深度可分离卷积模块取代学生模型中特征提取部分传统卷积神经网络模块。
在其中一个实施例中,基于教师模型对学生模型进行训练还包括在获得训练好的学生模型后,使用原始数据对训练好的学生模型进行迁移训练。
第二方面,本发明还提供了一种基于伪标签滤波的模型蒸馏实时目标检测装置,包括存储器和处理器,存储器存储有计算机程序,处理器执行计算机程序时实现基于伪标签滤波的模型蒸馏实时目标检测的步骤。
本发明的有益效果:
(1)本发明采用随机仿射变换操作来模拟不同的视角和光照条件,包括对图像平移、缩放、旋转、剪切,来模拟目标位置变化、目标在不同距离下的大小变化、不同视角下的目标、目标部分被遮挡,对原始标签进行转换,扩充数据集,提升对多视角目标、遮挡目标以及多尺度目标的检测效果,提高了目标检测模型即学生模型的泛化能力,减少学生模型的过拟合。
(2)本发明同时进行教师模型和学生模型两个目标检测模型的训练,通过深度可分离卷积模块取代学生模型中特征提取部分传统卷积神经网络模块,实现深度学习神经网络模型的结构优化与设计,然后结合模型蒸馏的方法进行模型学习与融合,将一个复杂的教师模型的知识转移到学生模型中,学生模型层数少且深度小,达到轻量化的效果,使学生模型具有更高的检测精度的同时进一步提高了学生模型的实时性和泛化能力。
附图说明
图1是本发明实施例提供的基于伪标签滤波的模型蒸馏实时目标检测方法的流程示意图之一;
图2是本发明实施例提供的基于伪标签滤波的模型蒸馏实时目标检测方法的流程示意图之一;
图3是本发明实施例提供的基于伪标签滤波的模型蒸馏实时目标检测方法的流程示意图之一;
图4本发明实施例提供的使用深度可分离卷积模块取代特征提取部分传统卷积神经网络模块后的学生模型结构示意图。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
在一个实施例中,如图1所示,图1是本发明实施例提供的基于伪标签滤波的模型蒸馏实时目标检测方法的流程示意图之一,以该方法应用于计算机设备,包括以下步骤:
S101、基于教师模型对学生模型进行训练。
具体的,使用教师模型对学生模型进行训练即将教师模型的知识转移到学生模型中。
S102、将待测数据输入训练好的学生模型获得实时目标检测结果。
具体的,教师模型和学生模型均为目标检测模型,教师模型的层数比学生模型的层数多,教师模型的深度比学生模型的深度大。
本实施例中,如图2所示,图2是本发明实施例提供的基于伪标签滤波的模型蒸馏实时目标检测方法的流程示意图之一,基于教师模型对学生模型进行训练具体包括以下步骤:
S201、获取扩充数据集。
S202、基于扩充数据集对教师模型进行训练,将扩充数据集输入训练完成的教师模型,生成伪标签数据集。
具体的,教师模型的输出结果即为伪标签,我们将数据集图像表示为XU,和教师模型FT:XU→YU,在训练教师模型时,并引入一个温度因子T,调整Softmax概率分布,以生成伪标签。当T值较小时,负标签会变得更小,后续用于训练的学生模型对于负标签的关注也会减少,将伪标签定义为:
soft_labels=softmax(YU/T)
S203、将伪标签数据集输入质量分类器生成新的伪标签数据集。
经过随机仿射变换之后,有许多变形严重的原始数据图像被用于训练学生模型。依据教师模型的检测结果,采用质量分类器选取合适的伪标签样本,过滤后的样本占总体样本的5%~20%。
S204、将新的伪标签数据集和原始数据的合集分别输入学生模型和训练好的教师模型分别获得检测结果和预训练结果。
需要说明的是,学生模型通常选择与教师模型同系列算法的模型。然后将新的伪标签数据集和原始数据集一起输入到学生模型中进行训练,使得学生模型能够学习到教师模型中的知识。
将新的伪标签数据集和原始数据集重构数据集图像表示为XC,标签表示为YC,以及学生模型学生模型的预测结果表示为:
S205、基于检测结果计算学生模型的原始损失,基于检测结果和预训练结果计算知识蒸馏损失。教师模型的输出被视为伪标签,学生模型需要尽可能地模仿教师模型的输出。
S206、根据学生模型的原始损失和知识蒸馏损失计算整体损失。
S207、基于整体损失调整学生模型的参数,获得训练好的学生模型。
本实施例中,通过模型学习与融合,将一个复杂的教师模型的知识转移到学生模型中,使学生模型具有更高的检测精度。
在其中一个实施例中,如图3所示,图3是本发明实施例提供的基于伪标签滤波的模型蒸馏实时目标检测方法的流程示意图之一,本实施例涉及的是如何获取扩充数据集,在上述实施例的基础上,步骤S201包括:
S301、获取原始数据集。
S302、对原始数据集进行随机仿射变换。
具体的,随机仿射变换是一种基于随机采样的图像扭曲方法。以下是实现随机仿射变换的基本步骤:
随机生成仿射矩阵参数:仿射矩阵参数包括随机生成旋转角度、随机生成缩放比例、随机生成平移距离、随机生成错切参数等,可以通过以上随机生成矩阵参数来实现随机仿射变换。
构建仿射变换矩阵:根据生成的仿射矩阵参数,构建仿射变换矩阵,实现图像的仿射变换。
对图像进行仿射变换:将构建好的仿射变换矩阵应用于原始数据集,实现随机仿射变换。需要说明的是,原始数据集为原始采集的图像集合。
如表1所示,详细给出了3×3仿射变换矩阵的参数:
表1仿射变换矩阵的参数
sx×cos(θ) | -sy×sin(θ+hx) | tx |
-sx×sin(θ+hy) | sy×cos(θ) | ty |
0 | 0 | 1 |
其中,sx和sy表示将图像沿着x轴和y轴的缩放比例,tx和ty分别为图像沿x轴和y轴的平移距离,θ表示为旋转角度,hx和hy分别为图像沿x轴和y轴的错切参数。随机仿射变换可以增加数据集的多样性和数量,通过缩放、旋转、错切模拟不同的距离、角度、遮挡下的目标从而提高模型的检测精度和泛化能力。
采用随机仿射变换操作来模拟不同的视角和光照条件,包括对图像平移、缩放、旋转、剪切,来模拟目标位置变化、目标在不同距离下的大小变化、不同视角下的目标、目标部分被遮挡,并且对原始数据集进行转换,扩充数据集,能够提升对多视角目标、遮挡目标以及多尺度目标的检测效果,进而可以提高目标检测模型的泛化能力,减少目标检测模型的过拟合。
在一个可选的实施例中,教师模型为YOLOv5-l模型,基于扩充数据集对教师模型进行训练时训练300epoches。质量分类器为正负样本质量分离器。
在其中一个实施例中,学生模型的原始损失包括置信度、类别损失和边框回归损失;
置信度为
LCE_obj=-αlog(β)-(1-α)log(1-β) (1);
式中,元素β表示样本属于前景或者背景的概率,即边界框的置信度值,α=是真实标签中是否包含目标的标志(1表示包含目标,0表示不包含目标)。
需要说明的是,本实施例中的样本指的是新的伪标签数据集和原始数据集组合的数据集中的数据。
需要说明的是,置信度损失和类别损失使用交叉熵进行计算。熵是一种衡量信息不确定性的量,广泛应用于通信与信息领域。对于概率分布为的随机变量X,熵如式(9)所示,为f(x):
f(x)=-∫p(x)log p(x)dx (9);
使用交叉熵衡量预测的类别和置信度的不确定性,熵的值越大,不确定性越大,预测的结果越差;熵的值越小,不确定性越小,预测的结果越准确。
对于每个目标,学生模型会输出一个类别概率分布,表示该目标属于每个类别的概率。而对于每个目标,真实标签只有一个类别。因此,交叉熵损失函数可以用来评估模型预测的类别与真实标签之间的差距。
类别损失为
式中,p(x)是一个实际得到的概率分布,每个元素pi表示样本属于第i类的概率,当样本属于第类别i时yi=1,其它均为0,nc是样本总类别数。
边框回归损失为
式中,c为同时包含预测框和真实框的最小矩形,bgt为真实框,b为预测框,ρ(bgt,b)表示真实框和预测框中心点的欧式距离,β是用于平衡函数权重的参数,ν是用来衡量两个框之间长宽比的一致性的参数,IoU项和α表示附加功率正则化项。
目标检测模型中,边界框回归最常用的损失函数是IoU系列。原始用于计算损失的IoU系列表达式可用式(10)、式(11)表示:
LIoU=1-IoU
(11);
其中,bgt为真实框,b为预测框。本实施例采用一个新的功率IoU损耗系列来度量知识蒸馏损失,该系列具有一个IoU项和一个参数α表示附加功率正则化项。α-IoU是用于边界框回归损失的幂交集族。通过调节α,自适应的增加高IoU对象的损失和梯度的权重,以提高边界框回归精度。将上述α-IoU损失扩展为更一般的形式,如式(12)所示:
其中,一般取α1=α2=3,表示基于b和bgt计算的任何惩罚项。上述IoU可以取代为任意的GIoU,DIoU,CIoU,/>为其公式中对应的惩罚项。本专利采用α-CIoU用于边界框回归的损失函数计算公式如式(3)所示,ν和β的计算公式如式(13)、式(14)所示。
学生模型的原始损失为
LSTU=λ1×LCE_cls+λ2×LCE_obj+λ3×Lα-CIoU (4);
式中,λ1为0.3,λ2为0.4,λ3为0.3。
在其中一个实施例中,知识蒸馏损失为
式中,m,n表示输出结果的张量的行列,outputT,outputS分别教师模型和学生模型的输出结果。
在其中一个实施例中整体损失为
Ltotal=α1×LSTU+α2×LDistill (6);
式中,α1为0.8,α2为0.2。
在其中一个实施例中,基于整体损失调整学生模型的参数,获得训练好的学生模型是将整体损失反向传播至学生模型,调整学生模型参数,获得的学生模型为训练好的学生模型。
常用的目标检测器均采用深度卷积神经网络。在一个实施例中,使用深度可分离卷积模块取代学生模型中特征提取部分传统卷积神经网络模块。如图4所示,图4本发明实施例提供的使用深度可分离卷积模块取代特征提取部分传统卷积神经网络模块后的学生模型结构示意图。将深度可分离卷积应用于目标检测模型的特征提取(Backbone)网络的后三层卷积层,可以在减少模型参数数量和计算量的同时保持模型精度。深度可分离卷积模块可以分解卷积操作,将传统的卷积操作分解为深度卷积和逐点卷积两个操作。深度可分离卷积可以用于加速卷积神经网络的计算,并在保持模型精度的情况下减少模型的参数数量和计算量。
本实施例中采用深度可分离卷积模块代替传统的卷积操作,设计轻量级的网络特征提取结构,优化目标检测模型结构,进一步提高模型的实时性。
深度卷积是指将标准卷积的通道卷积(Channel-wise Convolution)和空间卷积(Spatial Convolution)拆分开来进行。假设输入特征图的形状为[H,W,C],卷积核的形状为[k,k,C,D],其中k表示卷积核的大小,C表示输入特征图的通道数,D表示输出特征图的通道数。则深度卷积的计算过程可以表示为:
(1)对于每一个输出通道d,使用大小为[k,k,C]的卷积核对输入特征图的每一个通道进行卷积,得到一个[H,W]的二维特征图。
(2)将得到的所有二维特征图沿着通道维度进行拼接,得到一个形状为[H,W,D]的输出特征图。
逐点卷积是指使用大小为[1,1,D,D']的卷积核对深度卷积得到的特征图进行卷积。逐点卷积的作用是对每个通道之间的信息进行交互,同时将深度卷积得到的特征图中的低级特征与高级特征进行融合。假设深度卷积得到的输出特征图形状为[H,W,D],逐点卷积的卷积核形状为[1,1,D,D'],则逐点卷积的计算过程可以表示为:
(1)对于输出特征图中的每一个位置[i,j],使用大小为[1,1,D]的卷积核对深度卷积得到的特征图的每一个通道进行加权求和,得到一个长度为D的向量。
(2)将得到的所有向量沿着通道维度进行拼接,得到一个形状为[H,W,D']的输出特征图。
综合上述两个操作,深度可分离卷积的计算过程即为:先进行深度卷积,得到一个形状为[H,W,D]的特征图,再进行逐点卷积,得到一个形状为[H,W,D']的输出特征图。
在其中一个实施例中,基于教师模型对学生模型进行训练还包括在获得训练好的学生模型后,使用原始数据对训练好的学生模型进行迁移训练。
优选的,本发明使用测试数据集评估最终生成的学生模型的整体性能。使用准确率(Precision)和召回率(Recall)、[email protected]等指标来评估模型的检测精度。采用帧率(FPS)或推理时间(ms)来衡量模型推理速度,较高的帧率或较短的推理时间表示模型具有更快的实时性。
基于同样的发明构思,本申请实施例还提供了一种用于实现上述所涉及的基于伪标签滤波的模型蒸馏实时目标检测方法的基于伪标签滤波的模型蒸馏实时目标检测装置。该装置所提供的解决问题的实现方案与上述方法中所记载的实现方案相似,故下面所提供的一个或多个点基于伪标签滤波的模型蒸馏实时目标检测装置实施例中的具体限定可以参见上文中对于基于伪标签滤波的模型蒸馏实时目标检测方法的限定,在此不再赘述。
在一个实施例中,基于伪标签滤波的模型蒸馏实时目标检测装置,包括存储器和处理器,存储器存储有计算机程序,处理器执行计算机程序时实现基于伪标签滤波的模型蒸馏实时目标检测的步骤。
应该理解的是,虽然如上所述的各实施例所涉及的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,如上所述的各实施例所涉及的流程图中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
以上所述实施例仅表达了本发明的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对本发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干变形和改进,这些都属于本发明的保护范围。因此,本发明的保护范围应以所附权利要求为准。
Claims (10)
1.一种基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,基于教师模型对学生模型进行训练,然后将待测数据输入训练好的学生模型获得实时目标检测结果,所述教师模型和学生模型均为目标检测模型,所述教师模型的层数比学生模型的层数多,所述教师模型的深度比学生模型的深度大;
所述基于教师模型对学生模型进行训练具体包括以下步骤:
获取扩充数据集;
基于扩充数据集对教师模型进行训练,将扩充数据集输入训练完成的教师模型,生成伪标签数据集;
将所述伪标签数据集输入质量分类器生成新的伪标签数据集;
将新的伪标签数据集和原始数据的合集分别输入学生模型和训练好的教师模型分别获得检测结果和预训练结果;
基于检测结果计算学生模型的原始损失,基于检测结果和预训练结果计算知识蒸馏损失;
根据学生模型的原始损失和知识蒸馏损失计算整体损失;
基于所述整体损失调整学生模型的参数,获得训练好的学生模型。
2.根据权利要求1所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,获取扩充数据集包括:
获取原始数据集;
对所述原始数据集进行随机仿射变换。
3.根据权利要求2所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,所述教师模型为YOLOv5-l模型,基于扩充数据集对教师模型进行训练时训练300epoches;
所述质量分类器为正负样本质量分离器。
4.根据权利要求1所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,所述学生模型的原始损失包括置信度、类别损失和边框回归损失;
所述置信度为
LCE_obj=-αlog(β)-(1-α)log(1-β) (1);
式中,元素β表示样本属于前景或者背景的概率,即边界框的置信度值,α=是真实标签中是否包含目标的标志(1表示包含目标,0表示不包含目标);
所述类别损失为
式中,p(x)是一个实际得到的概率分布,每个元素pi表示样本属于第i类的概率,当样本属于第类别i时yi=1,其它均为0,nc是样本总类别数;
所述边框回归损失为
式中,c为同时包含预测框和真实框的最小矩形,bgt为真实框,b为预测框,ρ(bgt,b)表示真实框和预测框中心点的欧式距离,β是用于平衡函数权重的参数,ν是用来衡量两个框之间长宽比的一致性的参数,IoU项和α表示附加功率正则化项;
所述学生模型的原始损失为
LSTU=λ1×LCE_cls+λ2×LCE_obj+λ3×Lα-CIoU (4);
式中,λ1为0.3,λ2为0.4,λ3为0.3。
5.根据权利要求4所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,所述知识蒸馏损失为
式中,m,n表示输出结果的张量的行列,outputT,outputS分别教师模型和学生模型的输出结果。
6.根据权利要求5所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,所述整体损失为
Ltotal=α1×LSTU+α2×LDistill (6);
式中,α1为0.8,α2为0.2。
7.根据权利要求6所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,基于所述整体损失调整学生模型的参数,获得训练好的学生模型是将整体损失反向传播至学生模型,调整学生模型参数,获得训练好的学生模型。
8.根据权利要求3所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,使用深度可分离卷积模块取代学生模型中特征提取部分传统卷积神经网络模块。
9.根据权利要求2至8任意一项所述的基于伪标签滤波的模型蒸馏实时目标检测方法,其特征在于,所述基于教师模型对学生模型进行训练还包括在获得训练好的学生模型后,使用原始数据对训练好的学生模型进行迁移训练。
10.一种基于伪标签滤波的模型蒸馏实时目标检测装置,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至9中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310815686.9A CN117011640A (zh) | 2023-07-04 | 2023-07-04 | 基于伪标签滤波的模型蒸馏实时目标检测方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310815686.9A CN117011640A (zh) | 2023-07-04 | 2023-07-04 | 基于伪标签滤波的模型蒸馏实时目标检测方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117011640A true CN117011640A (zh) | 2023-11-07 |
Family
ID=88566434
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310815686.9A Pending CN117011640A (zh) | 2023-07-04 | 2023-07-04 | 基于伪标签滤波的模型蒸馏实时目标检测方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117011640A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117372819A (zh) * | 2023-12-07 | 2024-01-09 | 神思电子技术股份有限公司 | 用于有限模型空间的目标检测增量学习方法、设备及介质 |
-
2023
- 2023-07-04 CN CN202310815686.9A patent/CN117011640A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117372819A (zh) * | 2023-12-07 | 2024-01-09 | 神思电子技术股份有限公司 | 用于有限模型空间的目标检测增量学习方法、设备及介质 |
CN117372819B (zh) * | 2023-12-07 | 2024-02-20 | 神思电子技术股份有限公司 | 用于有限模型空间的目标检测增量学习方法、设备及介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111242208B (zh) | 一种点云分类方法、分割方法及相关设备 | |
Ye et al. | Inverted pyramid multi-task transformer for dense scene understanding | |
Sahu et al. | A survey on deep learning: convolution neural network (CNN) | |
Yang et al. | A dual attention network based on efficientNet-B2 for short-term fish school feeding behavior analysis in aquaculture | |
Li et al. | Robust tensor subspace learning for anomaly detection | |
WO2022001805A1 (zh) | 一种神经网络蒸馏方法及装置 | |
Yin et al. | End-to-end face parsing via interlinked convolutional neural networks | |
CN114049381A (zh) | 一种融合多层语义信息的孪生交叉目标跟踪方法 | |
Wang et al. | Face mask extraction in video sequence | |
CN113592060A (zh) | 一种神经网络优化方法以及装置 | |
CN116310850B (zh) | 基于改进型RetinaNet的遥感图像目标检测方法 | |
Ettaouil | Generalization Ability Augmentation and Regularization of Deep Convolutional Neural Networks Using l1/2 Pooling | |
CN117011640A (zh) | 基于伪标签滤波的模型蒸馏实时目标检测方法及装置 | |
Zhang et al. | Unsupervised remote sensing image segmentation based on a dual autoencoder | |
Zhang et al. | Crop pest recognition based on a modified capsule network | |
Wu et al. | Dynamic activation and enhanced image contour features for object detection | |
Hua et al. | Real-time object detection in remote sensing images based on visual perception and memory reasoning | |
Cao et al. | QuasiVSD: efficient dual-frame smoke detection | |
CN116740362A (zh) | 一种基于注意力的轻量化非对称场景语义分割方法及*** | |
Pengcheng et al. | Lightweight detection method of coal gangue based on multispectral and improved YOLOv5s | |
CN115222896B (zh) | 三维重建方法、装置、电子设备及计算机可读存储介质 | |
Pei et al. | FGO-Net: Feature and Gaussian Optimization Network for visual saliency prediction | |
Tang et al. | Fast semantic segmentation network with attention gate and multi-layer fusion | |
Shang et al. | Recognition of coal and gangue under low illumination based on SG-YOLO model | |
Jain et al. | Flynet–neural network model for automatic building detection from satellite images |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |