CN113610069B - 基于知识蒸馏的目标检测模型训练方法 - Google Patents
基于知识蒸馏的目标检测模型训练方法 Download PDFInfo
- Publication number
- CN113610069B CN113610069B CN202111179182.XA CN202111179182A CN113610069B CN 113610069 B CN113610069 B CN 113610069B CN 202111179182 A CN202111179182 A CN 202111179182A CN 113610069 B CN113610069 B CN 113610069B
- Authority
- CN
- China
- Prior art keywords
- target detection
- detection frame
- pixel position
- label
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 388
- 238000012549 training Methods 0.000 title claims abstract description 97
- 238000000034 method Methods 0.000 title claims abstract description 37
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 32
- 239000011159 matrix material Substances 0.000 claims abstract description 61
- 238000010586 diagram Methods 0.000 claims abstract description 27
- 238000004821 distillation Methods 0.000 claims description 7
- 230000009466 transformation Effects 0.000 claims description 3
- 239000000523 sample Substances 0.000 description 44
- 238000004364 calculation method Methods 0.000 description 5
- 230000008569 process Effects 0.000 description 4
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000004891 communication Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 241001415288 Coccidae Species 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000011426 transformation method Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明提供了一种基于知识蒸馏的目标检测模型训练方法,包括:利用训练样本图像集训练目标检测教师模型,训练样本图像具有:第一标签:目标检测框中心点像素位置硬标签概率矩阵;第二标签:目标检测框的宽和高;第三标签:目标检测框中心点像素位置偏移量;目标检测教师模型的预测输出结果包括:目标检测框中心点像素位置概率热力图、目标检测框的宽和高、目标检测框中心点像素位置偏移量;以知识蒸馏的方式改进目标检测学生模型的损失函数后,训练生成目标检测学生模型。本发明的解决了利用现有的知识蒸馏方法训练获取的目标检测模型无法同时保证网络结构简单而满足终端设备使用需求,以及目标检测模型的识别率优良以确保模型检测精度的问题。
Description
技术领域
本发明涉及人工智能模型训练技术领域,具体而言,涉及一种基于知识蒸馏的目标检测模型训练方法。
背景技术
知识蒸馏是通过引入教师模型的网络结构指导学生模型的网络结构的训练,进而实现知识迁移。具体的方法步骤是先训练教师模型,然后利用此教师模型的输出和数据的真实标签去训练学生模型,从而将教师模型的网络结构的知识转移到学生模型的网络结构中,在保证了学生模型的网络结构能够获得接近于教师模型的网络结构的性能的同时,还使得学生模型的网络结构尽可能小,参数量更少,从而更有利于降低对部署模型的算力需求,提升模型的推理效率。
实施目标检测任务的终端设备通常为摄像机、照相机或监控探头等小型设备,其所搭载的芯片的算力有限,因此,目标检测模型的网络结构的大小受到了严格的限制。利用传统的知识蒸馏方法训练得到的目标检测模型虽然在网络结构大小上能够匹配终端设备的算力要求;但却无法保证获得的目标检测模型实施目标检测任务时的精度。
这是因为,传统的知识蒸馏方法常用于实施单一的分类任务的模型的训练,而采用CenterNet网络结构的目标检测模型实施的目标检测任务同时包括分类任务和回归任务,这样目标检测模型的网络结构本身便较为复杂,传统的知识蒸馏方法直接将学生模型的损失函数中的真实标签部分替换成教师模型的输出,而没有做到对目标检测模型的损失函数分级分类指导优化,因此最终训练出的目标检测模型存在识别效果差、检测精度低的问题。
由此可知,如何利用知识蒸馏的方法使得训练出目标检测模型同时兼顾网络结构简单而满足终端设备使用需求,以及同时确保目标检测模型的识别率优良以保证模型检测精度,便成了现有技术中亟待解决的问题。
发明内容
本发明的主要目的在于提供一种基于知识蒸馏的目标检测模型训练方法,以解决利用现有技术中的知识蒸馏方法训练获取的目标检测模型无法同时保证网络结构简单而满足终端设备使用需求,以及目标检测模型的识别率优良以确保模型检测精度的问题。
为了实现上述目的,本发明提供了一种基于知识蒸馏的目标检测模型训练方法,包括:步骤S1,利用训练样本图像集训练生成目标检测教师模型,训练样本图像集中的各训练样本图像具有:第一标签:目标检测框中心点像素位置硬标签概率矩阵;第二标签:目标检测框的宽和高;第三标签:目标检测框中心点像素位置偏移量;目标检测教师模型的对应于三类标签的预测输出结果包括:目标检测框中心点像素位置概率热力图、目标检测框的宽和高、目标检测框中心点像素位置偏移量;步骤S2,以知识蒸馏的方式通过目标检测教师模型改进目标检测学生模型的损失函数后,利用训练样本图像集以及预测输出结果,训练生成目标检测学生模型。
进一步地,目标检测学生模型的损失函数Losstotal定义为:
其中,Losshm为目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图对应的损失函数部分;Losswh为目标检测学生模型预测输出的目标检测框的宽和高对应的损失函数部分;Lossreg为目标检测学生模型预测输出的目标检测框中心点像素位置偏移量对应的损失函数部分;λwh为目标检测框的宽和高对应的损失函数部分的权重比例系数;λreg为目标检测框中心点像素位置偏移量的损失函数部分的权重比例系数。
进一步地,目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图对应的损失函数部分Losshm定义为:
其中,为第一标签对应的目标检测框中心点像素位置硬标签概率矩阵转化后得到的目标检测框中心点像素位置软标签概率矩阵指导的子损失函数;为目标检测教师模型和第一标签对应的目标检测框中心点像素位置软标签概率矩阵共同指导的子损失函数;λhm为目标检测教师模型和第一标签对应的目标检测框中心点像素位置软标签概率矩阵共同指导的子损失函数的权重比例系数。
其中,N为目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图中像素点的个数;为将目标检测框中心点像素位置硬标签概率矩阵进行坐标变换后得到的目标检测框中心点像素位置软标签概率矩阵中的数位坐标点(x,y)的概率值;为目标检测教师模型预测输出的目标检测框中心点像素位置概率热力图中像素点(x,y)的概率值;为目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图中像素点(x,y)的概率值;和均为指数常数。
进一步地,目标检测框中心点像素位置硬标签概率矩阵通过高斯核函数坐标变换后得到目标检测框中心点像素位置软标签概率矩阵;目标检测框中心点像素位置软标签概率矩阵的数位坐标点(x,y)的概率值为高斯核函数的结果值G;高斯核函数为:
…………………………(5)其中,m,n分别为目标检测框中心点像素位置硬标签概率矩阵中概率值为1的数位坐标点的横坐标和纵坐标;x,y分别为目标检测框中心点像素位置软标签概率矩阵中任意一个数位坐标点的横坐标和纵坐标;为对应于目标检测框的尺度常数。
进一步地,目标检测学生模型预测输出的目标检测框的宽和高对应的损失函数部分Losswh为;
……………………………………………(6),其中, 为第二标签对应的目标检测框的宽和高指导的子损失函数;为目标检测教师模型以及第二标签对应的目标检测框的宽和高共同指导的子损失函数;为目标检测教师模型以及第二标签对应的目标检测框的宽和高共同指导的子损失函数的权重比例系数。
其中,K为训练样本图像中第二标签对应的目标检测框的宽和高的个数;k指代训练样本图像中任意一个第二标签;为训练样本图像中第二标签对应的目标检测框的宽和高的乘积;为目标检测学生模型预测输出的目标检测框的宽和高的乘积;为目标检测教师模型预测输出的目标检测框的宽和高的乘积;为和之间的L1距离;为和之间的L2距离;为和之间的L2距离;为第一间隔常数。
进一步地,目标检测学生模型预测输出的目标检测框中心点像素位置偏移量对应的损失函数部分Lossreg 为:
其中,为第三标签对应的目标检测框中心点像素位置偏移量指导的子损失函数;为目标检测教师模型以及第三标签对应的目标检测框中心点像素位置偏移量共同指导的子损失函数;为目标检测教师模型以及第三标签对应的目标检测框中心点像素位置偏移量共同指导的子损失函数的权重比例系数。
其中,Z为训练样本图像中第三标签对应的目标检测框中心点像素位置偏移量的个数;z指代训练样本图像中任意一个第三标签;为训练样本图像中第三标签对应的目标检测框中心点像素位置偏移量的横轴偏移量与纵轴偏移量乘积; 为目标检测学生模型预测输出的目标检测框中心点像素位置偏移量的横轴偏移量与纵轴偏移量的乘积;为目标检测教师模型预测输出的目标检测框中心点像素位置偏移量的的横轴偏移量与纵轴偏移量乘积;为和之间的L1距离;为和之间的L2距离;为和之间的L2距离;为第二间隔常数。
应用本发明的技术方案,由于对训练样本图像集的训练样本图像进行了标签分类,训练完成的目标检测教师模型的目标检测任务根据分类后的标签能够得到清晰的区分,具体地,目标检测教师模型的预测输出结果中,获得目标检测框中心点像素位置概率热力图属于分类任务,获得目标检测框的宽和高以及获得目标检测框中心点像素位置偏移量均为回归任务。这样,在使用目标检测教师模型指导训练目标检测学生模型的过程中,目标检测学生模型的损失函数能够针根据目标检测任务的任务类型进行针对性地得到分级分类改进优化,从而在依托知识蒸馏确保得到的目标检测学生模型的网络结构足够简单以满足终端设备使用需求的同时,更能够确保目标检测学生模型更好地迁移获取目标检测教师模型的知识,遗传目标检测教师模型的性能,使得目标检测学生模型具有优良的识别效果和检测精度,具有良好的实用性。
附图说明
构成本申请的一部分的说明书附图用来提供对本发明的进一步理解,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1示出了根据本发明的基于知识蒸馏的目标检测模型训练方法的步骤流程图;
图2示出实施本发明的基于知识蒸馏的目标检测模型训练方法时,训练样本图像集中的一张可选实施例的训练样本图像的示意图,示意图中有一个目标行人,该目标行人的头部作为检测目标,使用目标检测框框选出;
图3示出了图2中的训练样本图像的第一标签,即目标检测框中心点像素位置硬标签概率矩阵;
图4示出了图3中的目标检测框中心点像素位置硬标签概率矩阵转化后的目标检测框中心点像素位置软标签概率矩阵。
其中,上述附图包括以下附图标记:
A、目标行人;B、目标行人的头部;C、目标检测框。
具体实施方式
需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本发明。
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例。此外,术语“包括”、“和”、“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、***、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
为了解决利用现有技术中的知识蒸馏方法训练获取的目标检测模型无法同时保证网络结构简单而满足终端设备使用需求,以及目标检测模型的识别率优良以确保模型检测精度的问题,本发明提供了一种基于知识蒸馏的目标检测模型训练方法。
图1是根据本发明一种可选实施例的基于知识蒸馏的目标检测模型训练方法的步骤流程图。如图1所示,该的目标检测模型训练方法包括:步骤S1,利用训练样本图像集训练生成目标检测教师模型,训练样本图像集中的各训练样本图像具有:第一标签:目标检测框中心点像素位置硬标签概率矩阵;第二标签:目标检测框的宽和高;第三标签:目标检测框中心点像素位置偏移量;目标检测教师模型的对应于三类标签的预测输出结果包括:目标检测框中心点像素位置概率热力图、目标检测框的宽和高、目标检测框中心点像素位置偏移量;步骤S2,以知识蒸馏的方式通过目标检测教师模型改进目标检测学生模型的损失函数后,利用训练样本图像集以及预测输出结果,训练生成目标检测学生模型。
由于对训练样本图像集的训练样本图像进行了标签分类,训练完成的目标检测教师模型的目标检测任务根据分类后的标签能够得到清晰的区分,具体地,目标检测教师模型的预测输出结果中,获得目标检测框中心点像素位置概率热力图属于分类任务,获得目标检测框的宽和高以及获得目标检测框中心点像素位置偏移量均为回归任务。这样,在使用目标检测教师模型指导训练目标检测学生模型的过程中,目标检测学生模型的损失函数能够针根据目标检测任务的任务类型进行针对性地得到分级分类改进优化,从而在依托知识蒸馏确保训练得到的目标检测学生模型的网络结构足够简单以满足终端设备使用需求的同时,更能够确保目标检测学生模型更好地迁移获取目标检测教师模型的知识,遗传目标检测教师模型的性能,使得目标检测学生模型具有优良的识别效果和检测精度,具有良好的实用性。
可选地,目标检测任务的获得目标检测框中心点像素位置概率热力图属于二分类任务。
需要解释说明的是,在使用训练样本图像集中的训练样本图像对目标检测教师模型或目标检测学生模型训练之前,需要对所有的训练样本图像进行三类标签的标注,以一张训练样本图像为例,如图2所示,该训练样本图像中仅有一个目标行人A,采用人工标注的方式使用目标检测框C框选出目标行人的头部B。
之后使用预设程序对该训练样本图像进行标签标注,标注的第一标签为目标检测框中心点像素位置硬标签概率矩阵(如图3所示),目标检测框中心点像素位置硬标签概率矩阵的各数位概率值一一对应于训练样本图像的各像素点为目标检测框中心点的概率值,其值为0或1,其中,数位概率值为1的数位坐标点即为目标检测框C框的几何中心点,其余数位概率值为0。当然,当训练样本图像中有多个目标行人时,数位概率值为1的数位坐标点也为对应的多个。为了保证目标检测教师模型以及目标检测学生模型可以更好地学习到练样本图像中第一标签的特征信息,以提高模型检测精度,需要将目标检测框中心点像素位置硬标签概率矩阵转化得到目标检测框中心点像素位置软标签概率矩阵;这是因为,虽然在训练样本图像中每个目标检测框的只有一个中心点,但是在该中心点附近周围的像素点仍然会表征目标行人的头部的特征,应该与头部之外的像素点以真实的区别,因此,采用目标检测框中心点像素位置软标签概率矩阵能够使得目标检测教师模型以及目标检测学生模型学习到训练样本图像中更真实的特征信息。图4为图3中的目标检测框中心点像素位置硬标签概率矩阵经过转化后获得的目标检测框中心点像素位置软标签概率矩阵;在该图中,与数位概率值为1的数位坐标点邻近的数位坐标点的数位概率值会更接近1(图未示),而远离数位概率值为1的数位坐标点邻近的数位坐标点的数位概率值更接近0。
在本实施例中,两者的转化方法为:目标检测框中心点像素位置硬标签概率矩阵通过高斯核函数坐标变换后得到目标检测框中心点像素位置软标签概率矩阵;目标检测框中心点像素位置软标签概率矩阵的数位坐标点(x,y)的概率值为高斯核函数的结果值G;高斯核函数为:
其中,m,n分别为目标检测框中心点像素位置硬标签概率矩阵中概率值为1的数位坐标点的横坐标和纵坐标;即目标检测框中心点像素位置硬标签概率矩阵的第m列第n行;x,y分别为目标检测框中心点像素位置软标签概率矩阵中任意一个数位坐标点的横坐标和纵坐标;即目标检测框中心点像素位置软标签概率矩阵的第x列第y行;为对应于目标检测框的尺度常数。可选地,目标检测框的尺度常数的取值范围在10像素至80像素之间。
当然,当目标检测框中心点像素位置硬标签概率矩阵中概率值为1的数位坐标点为多个时,即当图2中的目标检测框C为多个时,目标检测框中心点像素位置软标签概率矩阵中各数位坐标点(x,y)的概率值取多个高斯核函数结果值G中的最大者。
对训练样本图像标注的第二标签为目标检测框的宽和高(未图示),对训练样本图像标注的第三标签为目标检测框中心点像素位置偏移量(未图示)。
在本实施例中,目标检测学生模型的损失函数Losstotal定义为:
Losshm为目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图对应的损失函数部分;Losswh为目标检测学生模型预测输出的目标检测框的宽和高对应的损失函数部分;Lossreg为目标检测学生模型预测输出的目标检测框中心点像素位置偏移量对应的损失函数部分;λwh为目标检测框的宽和高对应的损失函数部分的权重比例系数;λreg为目标检测框中心点像素位置偏移量的损失函数部分的权重比例系数。
可选地,目标检测框的宽和高对应的损失函数部分的权重比例系数λwh和目标检测框中心点像素位置偏移量的损失函数部分的权重比例系数λreg的取值范围均为[0.5,1),这说明目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图所占权重最大,是影响目标检测学生模型后期检测精度的最关键因素。
可选地,目标检测框的宽和高对应的损失函数部分的权重比例系数λwh大于目标检测框中心点像素位置偏移量的损失函数部分的权重比例系数λreg。这是因为相比于目标检测框中心点像素位置偏移量,目标检测学生模型后期检测精度受到目标检测框的宽和高的影响更重。
具体而言,目标检测学生模型的损失函数Losstotal分级的第一部分为目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图对应的损失函数部分Losshm,通过知识蒸馏对这部分分类任务的损失函数进行优化改进,其对应的损失函数部分Losshm定义为:
其中,为第一标签对应的目标检测框中心点像素位置硬标签概率矩阵转化后得到的目标检测框中心点像素位置软标签概率矩阵指导的子损失函数;为目标检测教师模型和第一标签对应的目标检测框中心点像素位置软标签概率矩阵共同指导的子损失函数;λhm为目标检测教师模型和第一标签对应的目标检测框中心点像素位置软标签概率矩阵共同指导的子损失函数的权重比例系数。
需要说明的是,本实施例没有给出目标检测教师模型预测输出的以及目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图的图示,但是模型的理想训练状态是希望两者预测输出的目标检测框中心点像素位置概率热力图所对应的目标检测框中心点像素位置概率矩阵都学习靠近图4中的目标检测框中心点像素位置软标签概率矩阵,从而确保目标检测教师模型以及目标检测学生模型均具备良好的检测精度。
目标检测框中心点像素位置软标签概率矩阵指导的子损失函数用于评价目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图所对应的目标检测框中心点像素位置概率矩阵与目标检测框中心点像素位置软标签概率矩阵之间的差异。
为基于知识蒸馏的损失函数,用来评价目标检测学生模型的预测输出和目标检测教师模型的预测输出之间的分布差异,相比较于目标检测框中心点像素位置软标签概率矩阵指导的子损失函数,子损失函数增加了和,用于指导标检测学生模型的网络结构学习标检测教师模型的网络结构后的输出分布,子损失函数的计算公式定义为:
其中,N为目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图中像素点的个数;为将目标检测框中心点像素位置硬标签概率矩阵进行坐标变换后得到的目标检测框中心点像素位置软标签概率矩阵中的数位坐标点(x,y)的概率值;为目标检测教师模型预测输出的目标检测框中心点像素位置概率热力图中像素点(x,y)的概率值;为目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图中像素点(x,y)的概率值;和均为指数常数。
再上述公式(3)和公式(4)中,和是为了增加困难样本的权重系数,目标检测学生模型的预测输出偏差越大,两权重系数越大。是用来调节负样本损失占比的权重系数,负样本越偏离目标,该权重系数越大。可选地,和的取值范围为[2,4]。
目标检测学生模型的损失函数Losstotal分级的第二部分为目标检测学生模型预测输出的目标检测框的宽和高对应的损失函数部分Losswh,通过知识蒸馏对这部分回归任务的损失函数进行优化改进,其对应的损失函数部分Losswh用L1损失函数和L2损失函数相结合的方式定义为:
其中, 为第二标签对应的目标检测框的宽和高指导的子损失函数;为目标检测教师模型以及第二标签对应的目标检测框的宽和高共同指导的子损失函数;为目标检测教师模型以及第二标签对应的目标检测框的宽和高共同指导的子损失函数的权重比例系数。
其中,K为训练样本图像中第二标签对应的目标检测框的宽和高的个数;k指代训练样本图像中任意一个第二标签;为训练样本图像中第二标签对应的目标检测框的宽和高的乘积;为目标检测学生模型预测输出的目标检测框的宽和高的乘积;为目标检测教师模型预测输出的目标检测框的宽和高的乘积;为和之间的L1距离;为和之间的L2距离;为和之间的L2距离;为第一间隔常数。
通过判断目标检测学生模型的预测输出与原始输入的训练样本图像的第二标签的差距大于目标检测学生模型的预测输出与目标检测教师模型的预测输出的差距,并且超第一间隔常数时,会给添加目标检测学生模型添加第二标签的L2损失。
目标检测学生模型的损失函数Losstotal分级的第三部分为目标检测学生模型预测输出的目标检测框中心点像素位置偏移量对应的损失函数部分Lossreg,通过知识蒸馏对这部分回归任务的损失函数进行优化改进,其对应的损失函数部分Lossreg用L1损失函数和L2损失函数相结合的方式定义为:
其中,为第三标签对应的目标检测框中心点像素位置偏移量指导的子损失函数;为目标检测教师模型以及第三标签对应的目标检测框中心点像素位置偏移量共同指导的子损失函数;为目标检测教师模型以及第三标签对应的目标检测框中心点像素位置偏移量共同指导的子损失函数的权重比例系数。
可选地,子损失函数的权重比例系数的取值范围为[0.5,1),确保其不超过第三标签对应的目标检测框中心点像素位置偏移量指导的子损失函数的权重。需要说明的是,目标检测框中心点像素位置偏移量即为目标检测学生模型预测输出的目标检测框中心点的像素坐标位置与训练样本图像中的实际位置的差值。
其中,Z为训练样本图像中第三标签对应的目标检测框中心点像素位置偏移量的个数;z指代训练样本图像中任意一个第三标签;为训练样本图像中第三标签对应的目标检测框中心点像素位置偏移量的横轴偏移量与纵轴偏移量乘积;为目标检测学生模型预测输出的目标检测框中心点像素位置偏移量的横轴偏移量与纵轴偏移量的乘积;为目标检测教师模型预测输出的目标检测框中心点像素位置偏移量的的横轴偏移量与纵轴偏移量乘积;为和之间的L1距离;为和之间的L2距离;为和之间的L2距离;为第二间隔常数。
通过判断目标检测学生模型预测输出与原始输入的训练样本图像的第三标签的差距的差距大于目标检测学生模型预测输出与目标检测教师模型的预测输出的差距,并且超第二间隔常数时,会给添加目标检测学生模型添加第三标签的L2损失。
需要说明的是,本发明提供的,目标检测教师模型的网络结构和目标检测学生模型的网络结构均采用沙漏网络结构,区别在于目标检测教师模型的网络结构的网络深度和宽度都要比目标检测学生模型的网络结构的深度和宽度大,目标检测教师模型的网络结构的参数量是目标检测学生模型的网络结构的参数量的5-10倍。通过本发明提供的基于知识蒸馏的目标检测模型训练方法训练出来的目标检测学生模型的召回率和检测精度都要优于一般的知识蒸馏方式训练方式训练出来的目标检测学生模型。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
上述实施例中的集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在上述计算机可读取的存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在存储介质中,包括若干指令用以使得一台或多台计算机设备(可为个人计算机、服务器或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。
在本发明的上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见其他实施例的相关描述。
在本申请所提供的几个实施例中,应该理解到,所揭露的客户端,可通过其它的方式实现。其中,以上所描述的装置实施例仅仅是示意性的,例如所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个***,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,单元或模块的间接耦合或通信连接,可以是电性或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (8)
1.一种基于知识蒸馏的目标检测模型训练方法,其特征在于,包括:
步骤S1,利用训练样本图像集训练生成目标检测教师模型,所述训练样本图像集中的各训练样本图像具有:第一标签:目标检测框中心点像素位置硬标签概率矩阵;第二标签:目标检测框的宽和高;第三标签:目标检测框中心点像素位置偏移量;所述目标检测教师模型的对应于三类标签的预测输出结果包括:目标检测框中心点像素位置概率热力图、目标检测框的宽和高、目标检测框中心点像素位置偏移量;
步骤S2,以知识蒸馏的方式通过所述目标检测教师模型改进所述目标检测学生模型的损失函数后,利用所述训练样本图像集以及所述预测输出结果,训练生成目标检测学生模型;
所述目标检测学生模型的损失函数Losstotal定义为:
Losstotal=Losshm+λwhLosswh+λregLossreg..................................(1),其中,
Losshm为所述目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图对应的损失函数部分;
Losswh为所述目标检测学生模型预测输出的目标检测框的宽和高对应的损失函数部分;
Lossreg为所述目标检测学生模型预测输出的目标检测框中心点像素位置偏移量对应的损失函数部分;
λwh为所述目标检测框的宽和高对应的损失函数部分的权重比例系数;
λreg为所述目标检测框中心点像素位置偏移量的损失函数部分的权重比例系数;
所述目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图对应的损失函数部分Losshm定义为:
λhm为目标检测教师模型和第一标签对应的目标检测框中心点像素位置软标签概率矩阵共同指导的子损失函数的权重比例系数。
2.根据权利要求1所述的基于知识蒸馏的目标检测模型训练方法,其特征在于,
其中,N为所述目标检测学生模型预测输出的目标检测框中心点像素位置概率热力图中像素点的个数;
Hxy为将所述目标检测框中心点像素位置硬标签概率矩阵进行坐标变换后得到的目标检测框中心点像素位置软标签概率矩阵中的数位坐标点(x,y)的概率值;
α和β均为指数常数。
4.根据权利要求3所述的基于知识蒸馏的目标检测模型训练方法,其特征在于,
当所述目标检测框中心点像素位置硬标签概率矩阵中概率值为1的数位坐标点为多个时,所述目标检测框中心点像素位置软标签概率矩阵中各数位坐标点(x,y)的概率值Hxy取多个高斯核函数结果值G中的最大者。
8.根据权利要求7所述的基于知识蒸馏的目标检测模型训练方法,其特征在于,
其中,Z为所述训练样本图像中第三标签对应的目标检测框中心点像素位置偏移量的个数;
z指代训练样本图像中任意一个所述第三标签;
Tz为所述训练样本图像中第三标签对应的目标检测框中心点像素位置偏移量的横轴偏移量与纵轴偏移量乘积;
ω为第二间隔常数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111179182.XA CN113610069B (zh) | 2021-10-11 | 2021-10-11 | 基于知识蒸馏的目标检测模型训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111179182.XA CN113610069B (zh) | 2021-10-11 | 2021-10-11 | 基于知识蒸馏的目标检测模型训练方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113610069A CN113610069A (zh) | 2021-11-05 |
CN113610069B true CN113610069B (zh) | 2022-02-08 |
Family
ID=78343524
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111179182.XA Active CN113610069B (zh) | 2021-10-11 | 2021-10-11 | 基于知识蒸馏的目标检测模型训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113610069B (zh) |
Families Citing this family (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114119959A (zh) * | 2021-11-09 | 2022-03-01 | 盛视科技股份有限公司 | 一种基于视觉的垃圾桶满溢检测方法及装置 |
CN115512131B (zh) * | 2022-10-11 | 2024-02-13 | 北京百度网讯科技有限公司 | 图像检测方法和图像检测模型的训练方法 |
CN115496666A (zh) * | 2022-11-02 | 2022-12-20 | 清智汽车科技(苏州)有限公司 | 用于目标检测的热图生成方法和装置 |
CN115984640B (zh) * | 2022-11-28 | 2023-06-23 | 北京数美时代科技有限公司 | 一种基于组合蒸馏技术的目标检测方法、***和存储介质 |
CN118154992B (zh) * | 2024-05-09 | 2024-07-23 | 中国科学技术大学 | 基于知识蒸馏的医学图像分类方法、设备及存储介质 |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021189912A1 (zh) * | 2020-09-25 | 2021-09-30 | 平安科技(深圳)有限公司 | 图像中目标物的检测方法、装置、电子设备及存储介质 |
Family Cites Families (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180268292A1 (en) * | 2017-03-17 | 2018-09-20 | Nec Laboratories America, Inc. | Learning efficient object detection models with knowledge distillation |
CN110674688B (zh) * | 2019-08-19 | 2023-10-31 | 深圳力维智联技术有限公司 | 用于视频监控场景的人脸识别模型获取方法、***和介质 |
CN110991556B (zh) * | 2019-12-16 | 2023-08-15 | 浙江大学 | 一种基于多学生合作蒸馏的高效图像分类方法、装置、设备及介质 |
CN112418268B (zh) * | 2020-10-22 | 2024-07-12 | 北京迈格威科技有限公司 | 目标检测方法、装置及电子设备 |
CN112367273B (zh) * | 2020-10-30 | 2023-10-31 | 上海瀚讯信息技术股份有限公司 | 基于知识蒸馏的深度神经网络模型的流量分类方法及装置 |
CN112508169A (zh) * | 2020-11-13 | 2021-03-16 | 华为技术有限公司 | 知识蒸馏方法和*** |
CN112257815A (zh) * | 2020-12-03 | 2021-01-22 | 北京沃东天骏信息技术有限公司 | 模型生成方法、目标检测方法、装置、电子设备及介质 |
CN112990198B (zh) * | 2021-03-22 | 2023-04-07 | 华南理工大学 | 一种用于水表读数的检测识别方法、***及存储介质 |
CN113011356A (zh) * | 2021-03-26 | 2021-06-22 | 杭州朗和科技有限公司 | 人脸特征检测方法、装置、介质及电子设备 |
CN113139500B (zh) * | 2021-05-10 | 2023-10-20 | 重庆中科云从科技有限公司 | 烟雾检测方法、***、介质及设备 |
CN113361384A (zh) * | 2021-06-03 | 2021-09-07 | 深圳前海微众银行股份有限公司 | 人脸识别模型压缩方法、设备、介质及计算机程序产品 |
CN113326852A (zh) * | 2021-06-11 | 2021-08-31 | 北京百度网讯科技有限公司 | 模型训练方法、装置、设备、存储介质及程序产品 |
-
2021
- 2021-10-11 CN CN202111179182.XA patent/CN113610069B/zh active Active
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021189912A1 (zh) * | 2020-09-25 | 2021-09-30 | 平安科技(深圳)有限公司 | 图像中目标物的检测方法、装置、电子设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN113610069A (zh) | 2021-11-05 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113610069B (zh) | 基于知识蒸馏的目标检测模型训练方法 | |
CN108388927B (zh) | 基于深度卷积孪生网络的小样本极化sar地物分类方法 | |
CN114241282A (zh) | 一种基于知识蒸馏的边缘设备场景识别方法及装置 | |
CN109086811B (zh) | 多标签图像分类方法、装置及电子设备 | |
CN109978893A (zh) | 图像语义分割网络的训练方法、装置、设备及存储介质 | |
CN105446988B (zh) | 预测类别的方法和装置 | |
CN108133172A (zh) | 视频中运动对象分类的方法、车流量的分析方法及装置 | |
CN103942749B (zh) | 一种基于修正聚类假设和半监督极速学习机的高光谱地物分类方法 | |
CN112180471B (zh) | 天气预报方法、装置、设备及存储介质 | |
CN111368634B (zh) | 基于神经网络的人头检测方法、***及存储介质 | |
CN110175657B (zh) | 一种图像多标签标记方法、装置、设备及可读存储介质 | |
CN113128478A (zh) | 模型训练方法、行人分析方法、装置、设备及存储介质 | |
CN110969200A (zh) | 基于一致性负样本的图像目标检测模型训练方法及装置 | |
CN112541639A (zh) | 基于图神经网络和注意力机制的推荐***评分预测方法 | |
CN114782752B (zh) | 基于自训练的小样本图像集成分类方法及装置 | |
CN110263808B (zh) | 一种基于lstm网络和注意力机制的图像情感分类方法 | |
CN115439192A (zh) | 医疗商品信息的推送方法及装置、存储介质、计算机设备 | |
CN113065533B (zh) | 一种特征提取模型生成方法、装置、电子设备和存储介质 | |
CN115063664A (zh) | 用于工业视觉检测的模型学习方法、训练方法及*** | |
CN115017970A (zh) | 一种基于迁移学习的用气行为异常检测方法及*** | |
CN114332457A (zh) | 图像实例分割模型训练、图像实例分割方法和装置 | |
CN116151479B (zh) | 一种航班延误预测方法及预测*** | |
CN116894593A (zh) | 光伏发电功率预测方法、装置、电子设备及存储介质 | |
CN117371511A (zh) | 图像分类模型的训练方法、装置、设备及存储介质 | |
CN113010687B (zh) | 一种习题标签预测方法、装置、存储介质以及计算机设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |