CN114330580A - 基于歧义指导互标签更新的鲁棒知识蒸馏方法 - Google Patents
基于歧义指导互标签更新的鲁棒知识蒸馏方法 Download PDFInfo
- Publication number
- CN114330580A CN114330580A CN202111676330.9A CN202111676330A CN114330580A CN 114330580 A CN114330580 A CN 114330580A CN 202111676330 A CN202111676330 A CN 202111676330A CN 114330580 A CN114330580 A CN 114330580A
- Authority
- CN
- China
- Prior art keywords
- network
- data set
- ambiguity
- sample
- teacher
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了基于歧义指导互标签更新的鲁棒知识蒸馏方法,包括:A、构建训练数据集;B、构建教师网络和学生网络;C、通过鲁棒学习方法对教师网络进行训练;D、对数据集中的每个样本进行歧义感知权重估计和权重分配;E、根据小损失标准对数据集中的样本进行标签重新标注,结合教师网络特征的标签传播算法更新标签,再计算损失和更新网络参数;F、在学生网络和教师网络之间进行互标签传播算法,并更新样本标签、计算损失和更新网络参数;G、将测试图像数据导入学生网络,由其得到预测结果且用于图像分类;本方案可以有效地提升知识蒸馏对噪声标签的鲁棒性,从而可以在噪声标签的环境下获取一个高性能的轻量级网络,其更能适用于实际情况。
Description
技术领域
本发明涉及计算机视觉技术领域,尤其涉及基于歧义指导互标签更新的鲁棒知识蒸馏方法。
背景技术
近些年,许多模型压缩方法被提出用于减少卷积神经网络的参数量从而实现模型加速的目的。在这些方法当中,知识蒸馏扮演了一个重要的角色。知识蒸馏通常包含一个教师网络和一个学生网络,学生网络通过学习教师网络的输出中蕴含的“暗知识“,泛化能力得到了明显增强。但是在实际情况中,训练数据集中往往含有大量的标签噪声,对这些标签噪声的过拟合会显著影响知识蒸馏的性能。
发明内容
有鉴于此,本发明的目的在于提出一种基于歧义指导互标签更新的鲁棒知识蒸馏方法,该方案基于标签更新策略,在知识蒸馏的过程中动态地更新标签,从而大大降低了噪声标签对知识蒸馏的影响,提升了模型的鲁棒性。
为了实现上述的技术目的,本发明所采用的技术方案为:
一种基于歧义指导互标签更新的鲁棒知识蒸馏方法,其包括:
A、构建训练数据集,并按预设条件对其进行预处理;
B、构建教师网络和学生网络;
C、通过鲁棒学习方法将训练数据集导入教师网络中进行训练,获得预设性能的教师模型;
D、将训练数据集导入学生网络中,然后对训练数据集中的每个样本进行歧义感知权重估计和权重分配;
E、学生网络按预设条件根据小损失标准对训练数据集中的样本进行标签重新标注,再结合教师网络特征的标签传播算法更新标签,然后计算损失,更新网络参数;
F、在学生网络和教师网络之间进行互标签传播算法,并更新样本标签,然后计算损失,更新网络参数;
G、将测试图像数据导入学生网络,由学生网络的学生模型前向传播得到预测结果且将其用于图像分类。
作为一种可能的实施方式,进一步,步骤A中,所述训练数据集包括具有噪声标签数据的噪声数据集和/或无噪声标签数据的无噪声数据集;其中,无噪声标签数据的无噪声数据集经注入噪声处理后,生成合成噪声数据集。
作为其中一种较优的无噪声数据集处理方法,优选的,所述无噪声数据集被等分为两份,其中一份经注入噪声处理后,用于学生网络训练,另一份不作处理,用于教师网络训练,,该噪声被设为C-N噪声,所述教师模型使用标准交叉熵损失训练。
作为另一种较优的无噪声数据集处理方法,优选的,所述注入噪声处理的方法为在无噪声数据集中加入对称和/或非对称噪声。
作为一种较优的选择实施方式,优选的,所述噪声数据集包括ANIMAL-10N数据集、Clothing1M数据集中的一种以上,所述无噪声数据集包括CIFAR-100数据集。
作为一种较优的选择实施方式,优选的,步骤B中,所述教师网络、学生网络为如下之一:
(1)所述教师网络为层数为40,宽度系数为2的宽残差网络,所述学生网络为层数为16,宽度系数为2的宽残差网络;
(2)所述教师网络为层数为40,宽度系数为2的宽残差网络,所述学生网络为层数为40,宽度系数为1的宽残差网络;
(3)所述教师网络为层数为56的残差网络,所述学生网络为层数为20的残差网络。
作为一种较优的选择实施方式,优选的,步骤D中,通过歧义感知权重估计模块对训练数据集中的每个样本进行歧义感知权重估计和权重分配,该歧义感知权重估计模块包括两个全连接层,且两个全连接层之间还设有PRelu层,步骤D具体包括:
将训练数据集中的所有样本导入学生网络,得到它们的特征,然后计算每个类别的原型特征,其公式如下:
按如下公式计算每个样本的特征分布得分:
将在第t轮的标签和特征分布得分拼接起来得到歧义特征向量,其公式如下:
将权重写为矩阵形式,其公式如下:
作为一种较优的选择实施方式,优选的,步骤E具体包括:
使用教师网络的特征构建k-nn图G=<V,E>,其中V和E分别表示顶点集合和边集合,顶点之间的相似度矩阵被描述如下:
其中代表样本xi在教师网络下的特征,NNk(xi)表示样本xi的k近邻,然后,可得到一个对称邻接矩阵继而进行归一化Wt得到其中,D为对角度矩阵;同时,根据小损失标准,训练数据集原始的标注将被根据学生网络小损失标准重新标注,其公式如下:
使用mixup算法得到混合样本数据,所述混合样本数据为虚拟样本,其公式如下:
最后,定义了如下损失使学生网络模仿教师网络的样本间的相似度,其公式如下:
作为一种较优的选择实施方式,优选的,步骤F包括:
最后,通过公式(4)计算损失更新网络参数。
基于上述方案,本发明还提供一种计算机可读的存储介质,所述的存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述的至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行实现上述所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法。
采用上述的技术方案,本发明与现有技术相比,其具有的有益效果为:本方案基于教师-学生网络,提出了一种包括小损失选择的标签传播和互标签传播两个阶段的二阶段标签更新方法,基于该设计的标签更新策略,可以有效地提升知识蒸馏对噪声标签的鲁棒性,从而可以在噪声标签的环境下获取一个高性能的轻量级网络,相比于传统的知识蒸馏方法,本方案考虑到了更为实际的噪声标签问题,使得知识蒸馏算法更能适用于实际情况。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明方法的简要实施流程示意图。
图2是本发明方法的简要原理流程示意图。
具体实施方式
下面结合附图和实施例,对本发明作进一步的详细描述。特别指出的是,以下实施例仅用于说明本发明,但不对本发明的范围进行限定。同样的,以下实施例仅为本发明的部分实施例而非全部实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
如图1或图2所示,本方案一种基于歧义指导互标签更新的鲁棒知识蒸馏方法,其包括:
A、构建训练数据集,并按预设条件对其进行预处理;
本步骤中,使用的训练数据集为常见的三个图像分类数据集,其分别为CIFAR-100数据集、ANIMAL-10N数据集、Clothing1M数据集,其中,CIFAR-100数据集为不含有噪声标签的无噪声数据集,其可以通过添加对称、非对称噪声使其成为合成噪声数据集;除此之外,本方案还提供一种C-N噪声,即将无噪声数据集等分为两份,一份注入噪声(对称、非对称噪声)用于学生网络训练,另一份不做处理用于教师网络训练,该C-N噪声的标准交叉熵损失用于训练教师网络;另外,ANIMAL-10N数据集,Clothing1M数据集为真实场景下的数据集,其分别含有约8%和38%的噪声标签数据。此外,在训练过程中,还可以采用图像旋转,翻转等方式用于数据增强。为简化描述,下述使用表示训练学生网络所使用的数据集。
B、构建教师网络和学生网络;
本步骤中,教师网络相比于学生网络一般具有更复杂的模型结构,本方案可以采取如下三对知识蒸馏中常用的网络结构,其分别是:
(1)所述教师网络为层数为40,宽度系数为2的宽残差网络,所述学生网络为层数为16,宽度系数为2的宽残差网络,即WRN_40_2-WRN_16_2;
(2)所述教师网络为层数为40,宽度系数为2的宽残差网络,所述学生网络为层数为40,宽度系数为1的宽残差网络,即WRN_40_2-WRN_40_1;
(3)所述教师网络为层数为56的残差网络,所述学生网络为层数为20的残差网络,即resent56-resnet20。
C、通过鲁棒学习方法将训练数据集导入教师网络中进行训练,获得预设性能的教师模型;
本步骤中,所导入的训练数据集为C-N噪声时,该C-N噪声的标准交叉熵损失用于训练教师网络,对于其余的噪声类型(对称、非对称,真实噪声),本方案使用经典鲁棒学习算法DivideMix(J.Li,R.Socher,and S.C.Hoi,“Dividemix:Learning with noisy labelsas semi-supervised learning,”in Int.Conf.Learn.Represent.,2019.)预训练一个教师模型。
D、将训练数据集导入学生网络中,然后对训练数据集中的每个样本进行歧义感知权重估计和权重分配;
本步骤中,通过歧义感知权重估计模块对训练数据集中的每个样本进行歧义感知权重估计和权重分配,该歧义感知权重估计模块包括两个全连接层,且两个全连接层之间还设有PRelu层,本步骤具体包括:
将训练数据集中的所有样本导入学生网络,得到它们的特征,然后计算每个类别的原型(prototype)特征,其公式如下:
接下来,按如下公式计算每个样本的特征分布得分:
然后,将在第t轮的标签和特征分布得分拼接起来得到歧义特征向量,其公式如下:
其中,为两个全连接层,σ表示PreLU操作,该双层感知机网络为前述歧义感知权重估计模块,其包含两层全连接层,输出为一个标量(权重)。具体来说,本方案将标签和相似度得分拼接起来得到歧义特征然后将其送入歧义感知权重估计模块计算该样本的权重;
将权重写为矩阵形式,其公式如下:
E、学生网络按预设条件根据小损失标准对训练数据集中的样本进行标签重新标注,再结合教师网络特征的标签传播算法更新标签,然后计算损失,
更新网络参数;
本步骤具体包括:
首先,使用教师网络的特征构建k-nn图G=<V,E>,其中,V和E分别表示顶点集合和边集合,顶点之间的相似度矩阵被描述如下:
其中代表样本xi在教师网络下的特征,NNk(xi)表示样本xi的k近邻,然后,可得到一个对称邻接矩阵继而进行归一化Wt得到其中,D为对角度矩阵;同时,根据小损失标准,训练数据集原始的标注将被根据学生网络小损失标准重新标注,其公式如下:
为了进一步提高蒸馏的鲁棒性,使用mixup算法(Hongyi Zhang,MoustaphaCisse,Yann N Dauphin,and David Lopez-Paz.mixup:Beyond empirical riskminimization.In ICLR,2018)得到混合样本数据,所述混合样本数据为虚拟样本,其公式如下:
最后,为了使得学生网络具有更好的特征表示能力,定义了如下损失使学生网络模仿教师网络的样本间的相似度,其公式如下:
F、在学生网络和教师网络之间进行互标签传播算法,并更新样本标签,然后计算损失,更新网络参数;
本步骤具体包括:
最后,通过公式(4)计算损失更新网络参数。
G、将测试图像数据导入学生网络,由学生网络的学生模型前向传播得到预测结果且将其用于图像分类。
本步骤作为本方案的推理阶段,其只用到训练好的学生网络,将测试图像送入学生网络,学生网络得到输出完成推理。
另外,在本发明各个实施方式中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本发明各个实施方式方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述仅为本发明的部分实施例,并非因此限制本发明的保护范围,凡是利用本发明说明书及附图内容所作的等效装置或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。
Claims (10)
1.一种基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,其包括:
A、构建训练数据集,并按预设条件对其进行预处理;
B、构建教师网络和学生网络;
C、通过鲁棒学习方法将训练数据集导入教师网络中进行训练,获得预设性能的教师模型;
D、将训练数据集导入学生网络中,然后对训练数据集中的每个样本进行歧义感知权重估计和权重分配;
E、学生网络按预设条件根据小损失标准对训练数据集中的样本进行标签重新标注,再结合教师网络特征的标签传播算法更新标签,然后计算损失,更新网络参数;
F、在学生网络和教师网络之间进行互标签传播算法,并更新样本标签,然后计算损失,更新网络参数;
G、将测试图像数据导入学生网络,由学生网络的学生模型前向传播得到预测结果且将其用于图像分类。
2.如权利要求1所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,步骤A中,所述训练数据集包括具有噪声标签数据的噪声数据集和/或无噪声标签数据的无噪声数据集;其中,无噪声标签数据的无噪声数据集经注入噪声处理后,生成合成噪声数据集。
3.如权利要求2所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,所述无噪声数据集被等分为两份,其中一份经注入噪声处理后,用于学生网络训练,另一份不作处理,用于教师网络训练,该噪声被设为C-N噪声,所述教师模型使用标准交叉熵损失训练。
4.如权利要求2所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,所述注入噪声处理的方法为在无噪声数据集中加入对称和/或非对称噪声。
5.如权利要求4所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,所述噪声数据集包括ANIMAL-10N数据集、Clothing1M数据集中的一种以上,所述无噪声数据集包括CIFAR-100数据集。
6.如权利要求1所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,步骤B中,所述教师网络、学生网络为如下之一:
(1)所述教师网络为层数为40,宽度系数为2的宽残差网络,所述学生网络为层数为16,宽度系数为2的宽残差网络;
(2)所述教师网络为层数为40,宽度系数为2的宽残差网络,所述学生网络为层数为40,宽度系数为1的宽残差网络;
(3)所述教师网络为层数为56的残差网络,所述学生网络为层数为20的残差网络。
7.如权利要求1所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,步骤D中,通过歧义感知权重估计模块对训练数据集中的每个样本进行歧义感知权重估计和权重分配,该歧义感知权重估计模块包括两个全连接层,且两个全连接层之间还设有PRelu层,步骤D具体包括:
将训练数据集中的所有样本导入学生网络,得到它们的特征,然后计算每个类别的原型特征,其公式如下:
按如下公式计算每个样本的特征分布得分:
将在第t轮的标签和特征分布得分拼接起来得到歧义特征向量,其公式如下:
将权重写为矩阵形式,其公式如下:
8.如权利要求7所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法,其特征在于,步骤E具体包括:
使用教师网络的特征构建k-nn图G=<V,E>,其中V和E分别表示顶点集合和边集合,顶点之间的相似度矩阵被描述如下:
其中,代表样本xi在教师网络下的特征,NNk(xi)表示样本xi的k近邻,然后,可得到一个对称邻接矩阵继而进行归一化Wt得到其中,D为对角度矩阵;同时,根据小损失标准,训练数据集原始的标注将被根据学生网络小损失标准重新标注,其公式如下:
使用mixup算法得到混合样本数据,所述混合样本数据为虚拟样本,其公式如下:
最后,定义了如下损失使学生网络模仿教师网络的样本间的相似度,其公式如下:
10.一种计算机可读的存储介质,其特征在于:所述的存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述的至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行实现如权利要求1至9之一所述的基于歧义指导互标签更新的鲁棒知识蒸馏方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111676330.9A CN114330580A (zh) | 2021-12-31 | 2021-12-31 | 基于歧义指导互标签更新的鲁棒知识蒸馏方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111676330.9A CN114330580A (zh) | 2021-12-31 | 2021-12-31 | 基于歧义指导互标签更新的鲁棒知识蒸馏方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114330580A true CN114330580A (zh) | 2022-04-12 |
Family
ID=81022006
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111676330.9A Pending CN114330580A (zh) | 2021-12-31 | 2021-12-31 | 基于歧义指导互标签更新的鲁棒知识蒸馏方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114330580A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114299349A (zh) * | 2022-03-04 | 2022-04-08 | 南京航空航天大学 | 一种基于多专家***和知识蒸馏的众包图像学习方法 |
CN115908955A (zh) * | 2023-03-06 | 2023-04-04 | 之江实验室 | 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置 |
CN117058437A (zh) * | 2023-06-16 | 2023-11-14 | 江苏大学 | 一种基于知识蒸馏的花卉分类方法、***、设备及介质 |
CN117237984A (zh) * | 2023-08-31 | 2023-12-15 | 江南大学 | 基于标签一致性的mt腿部识别方法、***、介质和设备 |
-
2021
- 2021-12-31 CN CN202111676330.9A patent/CN114330580A/zh active Pending
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114299349A (zh) * | 2022-03-04 | 2022-04-08 | 南京航空航天大学 | 一种基于多专家***和知识蒸馏的众包图像学习方法 |
CN115908955A (zh) * | 2023-03-06 | 2023-04-04 | 之江实验室 | 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置 |
CN117058437A (zh) * | 2023-06-16 | 2023-11-14 | 江苏大学 | 一种基于知识蒸馏的花卉分类方法、***、设备及介质 |
CN117058437B (zh) * | 2023-06-16 | 2024-03-08 | 江苏大学 | 一种基于知识蒸馏的花卉分类方法、***、设备及介质 |
CN117237984A (zh) * | 2023-08-31 | 2023-12-15 | 江南大学 | 基于标签一致性的mt腿部识别方法、***、介质和设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111368996B (zh) | 可传递自然语言表示的重新训练投影网络 | |
US10803591B2 (en) | 3D segmentation with exponential logarithmic loss for highly unbalanced object sizes | |
CN114330580A (zh) | 基于歧义指导互标签更新的鲁棒知识蒸馏方法 | |
CN109583501B (zh) | 图片分类、分类识别模型的生成方法、装置、设备及介质 | |
CN110188358B (zh) | 自然语言处理模型的训练方法及装置 | |
Sau et al. | Deep model compression: Distilling knowledge from noisy teachers | |
KR20210029785A (ko) | 활성화 희소화를 포함하는 신경 네트워크 가속 및 임베딩 압축 시스템 및 방법 | |
CN110874439B (zh) | 一种基于评论信息的推荐方法 | |
CN114398961A (zh) | 一种基于多模态深度特征融合的视觉问答方法及其模型 | |
Kumar | Machine Learning Quick Reference: Quick and essential machine learning hacks for training smart data models | |
KR102203253B1 (ko) | 생성적 적대 신경망에 기반한 평점 증강 및 아이템 추천 방법 및 시스템 | |
CN113392317A (zh) | 一种标签配置方法、装置、设备及存储介质 | |
CN115311506B (zh) | 基于阻变存储器的量化因子优化的图像分类方法及装置 | |
CN111382619B (zh) | 图片推荐模型的生成、图片推荐方法、装置、设备及介质 | |
CN115204301A (zh) | 视频文本匹配模型训练、视频文本匹配方法和装置 | |
CN112801092B (zh) | 一种自然场景图像中字符元素检测方法 | |
CN112396091B (zh) | 社交媒体图像流行度预测方法、***、存储介质及应用 | |
CN114282528A (zh) | 一种关键词提取方法、装置、设备及存储介质 | |
CN117994570A (zh) | 基于模型无关适配器提高复杂多样数据分布的识别方法 | |
CN116486285B (zh) | 一种基于类别掩码蒸馏的航拍图像目标检测方法 | |
CN111611796A (zh) | 下位词的上位词确定方法、装置、电子设备及存储介质 | |
CN114861671A (zh) | 模型训练方法、装置、计算机设备及存储介质 | |
CN113010772B (zh) | 一种数据处理方法、相关设备及计算机可读存储介质 | |
CN114065840A (zh) | 一种基于集成学习的机器学习模型调整方法及设备 | |
CN113268657A (zh) | 基于评论和物品描述的深度学习推荐方法及*** |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |