CN115908955A - 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置 - Google Patents

基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置 Download PDF

Info

Publication number
CN115908955A
CN115908955A CN202310202396.7A CN202310202396A CN115908955A CN 115908955 A CN115908955 A CN 115908955A CN 202310202396 A CN202310202396 A CN 202310202396A CN 115908955 A CN115908955 A CN 115908955A
Authority
CN
China
Prior art keywords
network
bird
teacher
gradient
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202310202396.7A
Other languages
English (en)
Other versions
CN115908955B (zh
Inventor
苏慧
鲍虎军
杨非
程乐超
宋明黎
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202310202396.7A priority Critical patent/CN115908955B/zh
Publication of CN115908955A publication Critical patent/CN115908955A/zh
Application granted granted Critical
Publication of CN115908955B publication Critical patent/CN115908955B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Landscapes

  • Image Analysis (AREA)

Abstract

本发明公开了基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置,通过构建鸟类图像分类数据集;在训练阶段,从鸟类图像分类数据集抽取支撑集s和预测集q,经教师网络后,分别输出的特征向量进行匹配,得到预测集q的类别预测结果,并利用所述预测结果与预测集q的类别真值构建教师网络交叉熵损失函数,训练教师网络;获取鸟类图像经过教师网络、学生网络各个网络层的特征,并利用各层特征的和,作为损失值反向传播,得到输入的鸟类图像基于损失值的梯度信息,构建梯度损失函数,使教师网络和学生网络输入的鸟类图像的梯度信息相匹配;梯度损失函数叠加学生网络交叉熵损失函数,训练学生网络,用于鸟类图像分类。

Description

基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置
技术领域
本发明涉及少样本的图像分类识别,尤其是涉及基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置。
背景技术
对于鸟类图像的分类识别,在机器学习的训练阶段,通常要对数以万计的鸟类样本图像进行标注,来保证深度学习模型的泛化能力,然而,有些鸟类之间差异很小,需要大量高质量的鸟类图像,并且鸟类种类繁多,需要保证每种鸟类一定数量的样本图像作为训练集,然而,标注如此大量高质量的鸟类样本需要很高的成本,因此,仅通过少量鸟类样本图像,便能进行深度学习,同时还能保证鸟类分类识别的准确率,成为了鸟类学家通过分类***对自然保护区中鸟类种类、数量进行统计、分析的关键。目前主流的少样本学习方法有孪生网络(Siamese Net)、原生网络(ProtoNet)、关联网络(RelationNet)等等。
在少样本鸟类图像学习中,更深更复杂的网络,能够减少鸟类的类内差异,而减少鸟类的类内差异,能够提升少样本鸟类图像的分类识别能力,但更深更复杂的网络,在实际应用中会提交计算成本、网络复杂度,从而增加能耗和计算时间的消耗,由此,如何有效地使简单的鸟类图像分类模型拥有和复杂模型一致的类内差异和鸟类识别能力,是一个值得探讨的问题。
知识蒸馏是一种“教师-学生网络思想”的训练方法,目前主流的蒸馏方法有KD(Distilling the Knowledge in a Neural Network,在神经网络中提取知识)、PKT(Probabilistic Knowledge Transfer for deep representation learning面向深度表征的概率知识迁移)、RKD(《Relational Knowledge Disitllation,关系知识词典)、CC(Correlation Congruence for Knowledge Distillation,相关一致性知识蒸馏)等等,主要可以划分为基于类间关系蒸馏、基于特征蒸馏,基于样本结构蒸馏。这些方法主要针对神经网络的低阶信息进行蒸馏,未在蒸馏过程中考虑梯度信息对神经网络精度的影响,也较少分析蒸馏方法在少样本图像分类学习中的提升情况。
发明内容
为解决现有技术的不足,实现基于少量鸟类图像样本,实现精确且高效的鸟类分类的目的,本发明采用如下的技术方案:
基于梯度蒸馏的少样本学习的鸟类分类训练方法,包括如下步骤:
步骤S1:构建鸟类图像分类数据集;
表示鸟类图像分类数据集,xi表示第i张鸟类图像,yi∈{1,2,…,}表示分类标签,N表示鸟类图像数量,表示鸟类类别数。
步骤S2:利用原生网络(Prototypical Networks)作为训练框架,训练复杂教师网络,在训练阶段,从鸟类图像分类数据集抽取C个类别,每个类别K个样本(C-way K-shot),作为分类***的支撑集s;再从类别对应的剩余数据中,抽取样本作为分类***的预测集q;训练过程中,将支撑集s和预测集q经教师网络后,分别输出的特征向量进行匹配,得到预测集q的类别预测结果,并利用所述预测结果与预测集q的类别真值构建教师网络交叉熵损失函数,训练教师网络;
步骤S3:获取鸟类图像经过教师网络、学生网络各个网络层的特征,并对各层特征求和,利用各层特征的和,作为损失值反向传播,得到输入的鸟类图像基于损失值的梯度信息,共同训练所述教师网络和学生网络;
步骤S4:构建梯度损失函数,使教师网络和学生网络输入的鸟类图像的梯度信息相匹配;
步骤S5:梯度损失函数叠加学生网络交叉熵损失函数,训练学生网络。
进一步地,所述步骤S2中,教师网络预测结果为,其中表示教师网络的网络映射函数,xi表示第i张鸟类图像,表示教师网络的参数,下标t表示教师网络,支持集s、预测集q经教师网络后,分别得到特征向量如下:
其中表示分类***支撑集s中类别为c的鸟类样本图像,c∈C,表示 经过教师网络后得到的特征向量的均值,即支持集类别为c的特征中心,kc表示抽取的类别为c的样本量,表示预测集q中的鸟类样本图像,表示经过教师网络后得到的特征向量,上标t表示教师网络。
进一步地,所述步骤S2中,利用平方欧式距离匹配支撑集s和预测集q经过教师网络后得到的特征向量,获得预测集s的类别预测结果:
其中表示的平方欧式距离,exp(·)表示指数函数。
进一步地,所述步骤S2中,利用预测集的类别预测结果和预测集类别真值构建教师网络交叉熵损失函数,训练教师网络:
其中表示预测集的类别预测结果,表示鸟类图像分类数据集的分类标签的独热编码形式,T表示矩阵的转置。
进一步地,所述步骤S3中,单张鸟类图像传入教师网络,每一层的特征结果为,其中表示教师网络前j层的特征映射函数,表示教师网络前j层参数,下标t表示教师网络;学生网络每一层特征结果为,其中表示学生网络前j层的特征映射函数,表示学生网络前j层参数,下标s表示学生网络;鸟类图像经过教师网络、学生网络后得到各个网络层的特征,将各层特征的和作为损失函数值,并利用损失函数值反向传播,得到输入的鸟类样本图像基于该损失函数值的梯度信息:
其中表示教师或学生网络t/s的第j层特征,表示第j层特征的像素数,表示网络总共计算的特征层数,表示基于教师或学生网络特征的损失函数值,表示利用损失函数值进行梯度反向后,得到的教师或学生网络基于鸟类图像的梯度信息,表示求导操作。
进一步地,所述步骤S4中,构建梯度损失函数,使教师网络和学生网络输入的鸟类样本图像的梯度信息相匹配:
其中表示取绝对值操作,表示L1正则化操作,表示教师或学生网络基于鸟类图像的梯度信息。
进一步地,所述步骤S5中,结合梯度损失函数和学生网络交叉熵损失函数,训练学生网路:
其中表示训练学生网络所用的总损失函数,表示学生网络交叉熵损失函数,表示权重参数。
基于梯度蒸馏的少样本学习的鸟类分类方法,根据所述的基于梯度蒸馏的少样本学习的鸟类分类训练方法,还包括步骤S6:将待分类的鸟类图像,输入训练好的学生网络,进行图像分类。
基于梯度蒸馏的少样本学习的鸟类分类装置,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现所述的基于梯度蒸馏的少样本学习的鸟类分类方法。
基于梯度蒸馏的少样本学习的鸟类分类***,包括鸟类图像分类数据集、教师网络和学生网络;
所述教师网络,是以原生网络(Prototypical Networks)作为训练框架,在训练阶段,从鸟类图像分类数据集抽取C个类别,每个类别K个样本(C-way K-shot),作为分类***的支撑集s;再从类别对应的剩余数据中,抽取样本作为分类***的预测集q;训练过程中,将支撑集s和预测集q经教师网络后,分别输出的特征向量进行匹配,得到预测集q的类别预测结果,并利用所述预测结果与预测集q的类别真值构建教师网络交叉熵损失函数,训练教师网络;
所述教师网络和学生网络,将鸟类图像经各个网络层进行特征提取并求和,利用各层特征的和,作为损失值反向传播,得到输入的鸟类图像基于损失值的梯度信息,再通过构建梯度损失函数,使教师网络和学生网络输入的鸟类图像的梯度信息相匹配;
所述学生网络,利用梯度损失函数叠加学生网络交叉熵损失函数进行训练,训练好后对鸟类图像进行分类。
本发明的优势和有益效果在于:
本发明的基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置,结合“教师-学生网络思想”的训练方法,在蒸馏过程中综合考虑梯度信息对神经网络精度的影响,在少样本鸟类图像学习中,通过更深更复杂的网络,既减少鸟类的类内差异,提升少样本鸟类图像的分类识别能力,又保持较低的能耗和计算时间,从而基于少量鸟类图像样本,实现精确且高效的鸟类分类的目的。
附图说明
图1是本发明实施例中分类方法的流程图。
图2是本发明实施例中分类***的结构示意图。
图3是本发明实施例中装置的结构示意图。
具体实施方式
以下结合附图对本发明的具体实施方式进行详细说明。应当理解的是,此处所描述的具体实施方式仅用于说明和解释本发明,并不用于限制本发明。
本发明的实施例使用Pytorch框架,在CUB(Caltech-UCSDBirds 加州理工学院鸟类数据库)上采用5类1样本(5-way 1-shot)进行少样本学习训练。少样本学习训练框架使用原生网络(Prototypical Networks),教师网络和学生网络的基础结构分别采用Conv6、Conv4结构和ResNet18、ResNet10结构。使用初始学习率为0.001的Adam优化器,鸟类样本图像尺寸为84*84,训练600次epoch,训练过程中支持集(support set)每类1个(5-Way),查询集(query set)每类16个(5-Way)。本发明实施例中均从零开始训练。通过将教师网络和学生网络的梯度信息做匹配,引导学生网络关注的注意力区域和教师网络关注的注意力区域尽可能一致。首先进行网络前向得到的各个网络层的鸟类样本图像特征,对各层鸟类样本图像特征求和作为损失函数值,利用损失函数值反向传播,获得输入的鸟类样本图像基于各层图像特征损失函数值的梯度信息。将教师网络和学生网络的输入数据梯度信息做匹配,使学生网络关注的注意力区域更有利于识别,从而提升学生网络的识别效果。
如图1所示,基于梯度蒸馏的少样本学习的鸟类分类方法,具体包括如下步骤:
步骤S1:构建鸟类图像分类数据集共1.2万张,鸟类分类数据集共200个类别,每个类别60个样本。(CUB公开数据集)
表示鸟类图像分类数据集,xi表示第i张鸟类样本图像,yi∈{1,2,…,  }表示分类标签,N=12000表示鸟类样本图像数量, =200表示鸟类类别数。
步骤S2:利用原生网络(Prototypical Networks)作为训练框架,训练复杂教师网络,在训练阶段,对鸟类样本图像随机抽取C=5个类别,每个类别K=1个样本(C-way K-shot),作为分类***的支撑集s(support set);再从总数量N的类别对应剩余的数据中,随机抽取批量样本作为分类***的预测集q(query set);训练过程中,将支撑集和预测集的网络结果进行匹配,交叉熵损失(entropy loss)实现少样本分类训练。
教师网络测试结果为,其中表示教师网络的网络映射函数,表示教师网络的参数,下标t表示教师网络。分类***的支持集、预测集经过教师网络后,分别得到特征向量如下:
其中表示分类***支撑集s中类别为c的鸟类样本图像,c∈C,表示经过教师网络后得到的特征向量的均值,即支持集类别为的特征中心。kc表示抽取的类别为c的样本量,表示分类***预测集q中的鸟类样本图像,表示经过教师网络后得到的特征向量。上标t表示教师网络。
利用平方欧式距离匹配支撑集和预测集经过教师网络后得到的特征向量,获得预测集的类别预测结果。并利用预测集的类别预测结果和预测集类别真值构建交叉熵损失函数,训练教师网络。
其中表示的平方欧式距离,exp(·)表示指数函数,表示预测集的类别预测结果,的独热编码形式,T表示矩阵的转置。
步骤S3:获得鸟类样本图像经过教师网络、学生网络后各个网络层的特征,并对各层特征求和。利用各层特征的和,作为损失值反向传播,得到输入的鸟类样本图像基于损失值的梯度信息,教师网络和学生网络一起训练收敛,由于教师网络比学生网络收敛的更快更好,由此引导学生网络。
单个鸟类样本图像传入教师网络,每一层的特征结果为,其中表示教师网络前j层的特征映射函数,表示教师网络前j层参数,下标t表示教师网络;学生网络每一层特征结果为,其中表示学生网络前j层的特征映射函数,表示学生网络前j层参数,下标s表示学生网络。鸟类样本图像经过教师网络、学生网络后得到各个网络层的特征,将各层特征的和作为损失函数值,并利用损失函数值反向传播,得到输入的鸟类样本图像基于该损失函数值的梯度信息:
其中表示教师或学生网络t/s的第j层特征,表示第j层特征的像素数,表示网络总共计算的特征层数,表示基于特征的损失函数值。表示利用损失函数值进行梯度反向后,得到的鸟类样本图像的梯度信息,表示求导操作。
步骤S4:构建梯度损失函数,使教师网络和学生网络输入的鸟类样本图像的梯度信息相匹配。
构建梯度损失函数,使教师网络和学生网络输入的鸟类样本图像的梯度信息相匹配:
其中表示取绝对值操作,表示L1正则化操作。
步骤S5:梯度损失函数叠加entropy loss交叉熵损失函数,训练学生网络。
结合梯度损失函数和交叉熵损失函数,训练学生网路:
其中表示训练学生网络所用的总损失函数,表示交叉熵损失函数,表示权重参数(本发明实施例中,)。
步骤S6:将待分类的鸟类图像,输入训练好的学生网络,如图1所示,学生模型输出的logits即分类结果。
如表1所示,本方法在CUB鸟类数据集上,利用原生网络作为训练框架,5-Way 1-Shot训练的测试结果。本发明方法分别展示传统方法的学生网络Conv4、教师网络Conv6和本发明方法的学生网络Conv4的鸟类分类识别的准确率;以及传统的学生网络resnet10训练、教师网络resnet18训练和本发明方法学生网络Resnet10的鸟类分类识别的准确率。
表1各方法在CUB数据集上5-Way 1-Shot训练的测试结果
如图2所示,基于梯度蒸馏的少样本学习的鸟类分类***,包括鸟类图像分类数据集、教师网络和学生网络。
教师网络是以原生网络作为训练框架,在训练阶段,从鸟类图像分类数据集抽取C个类别,每个类别K个样本,作为分类***的支撑集s;再从类别对应的剩余数据中,抽取样本作为分类***的预测集q;训练过程中,将支撑集s和预测集q经教师网络后,分别输出的特征向量进行匹配,得到预测集q的类别预测结果,并利用所述预测结果与预测集q的类别真值构建教师网络交叉熵损失函数,训练教师网络;
教师网络和学生网络,将鸟类图像经各个网络层进行特征提取并求和,利用各层特征的和,作为损失值反向传播,得到输入的鸟类图像基于损失值的梯度信息,再通过构建梯度损失函数,使教师网络和学生网络输入的鸟类图像的梯度信息相匹配;
学生网络利用梯度损失函数叠加学生网络交叉熵损失函数进行训练,训练好后对鸟类图像进行分类。
与前述基于梯度蒸馏的少样本学习的鸟类分类方法的实施例相对应,本发明还提供了基于梯度蒸馏的少样本学习的鸟类分类装置的实施例。
参见图3,本发明实施例提供的基于梯度蒸馏的少样本学习的鸟类分类装置,包括存储器和一个或多个处理器,存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现上述实施例中的基于梯度蒸馏的少样本学习的鸟类分类方法。
本发明基于梯度蒸馏的少样本学习的鸟类分类装置的实施例可以应用在任意具备数据处理能力的设备上,该任意具备数据处理能力的设备可以为诸如计算机等设备或装置。装置实施例可以通过软件实现,也可以通过硬件或者软硬件结合的方式实现。以软件实现为例,作为一个逻辑意义上的装置,是通过其所在任意具备数据处理能力的设备的处理器将非易失性存储器中对应的计算机程序指令读取到内存中运行形成的。从硬件层面而言,如图3所示,为本发明基于梯度蒸馏的少样本学习的鸟类分类装置所在任意具备数据处理能力的设备的一种硬件结构图,除了图3所示的处理器、内存、网络接口、以及非易失性存储器之外,实施例中装置所在的任意具备数据处理能力的设备通常根据该任意具备数据处理能力的设备的实际功能,还可以包括其他硬件,对此不再赘述。
上述装置中各个单元的功能和作用的实现过程具体详见上述方法中对应步骤的实现过程,在此不再赘述。
对于装置实施例而言,由于其基本对应于方法实施例,所以相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本发明方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
本发明实施例还提供一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述实施例中的基于梯度蒸馏的少样本学习的鸟类分类方法。
所述计算机可读存储介质可以是前述任一实施例所述的任意具备数据处理能力的设备的内部存储单元,例如硬盘或内存。所述计算机可读存储介质也可以是任意具备数据处理能力的设备的外部存储设备,例如所述设备上配备的插接式硬盘、智能存储卡(Smart Media Card,SMC)、SD卡、闪存卡(Flash Card)等。进一步的,所述计算机可读存储介质还可以既包括任意具备数据处理能力的设备的内部存储单元也包括外部存储设备。所述计算机可读存储介质用于存储所述计算机程序以及所述任意具备数据处理能力的设备所需的其他程序和数据,还可以用于暂时地存储已经输出或者将要输出的数据。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的范围。

Claims (10)

1.基于梯度蒸馏的少样本学习的鸟类分类训练方法,其特征在于包括如下步骤:
步骤S1:构建鸟类图像分类数据集;
步骤S2:利用原生网络作为训练框架,训练复杂教师网络,在训练阶段,从鸟类图像分类数据集抽取C个类别,每个类别K个样本,作为分类***的支撑集s;再从类别对应的剩余数据中,抽取样本作为分类***的预测集q;训练过程中,将支撑集s和预测集q经教师网络后,分别输出的特征向量进行匹配,得到预测集q的类别预测结果,并利用所述预测结果与预测集q的类别真值构建教师网络交叉熵损失函数,训练教师网络;
步骤S3:获取鸟类图像经过教师网络、学生网络各个网络层的特征,并对各层特征求和,利用各层特征的和,作为损失值反向传播,得到输入的鸟类图像基于损失值的梯度信息,共同训练所述教师网络和学生网络;
步骤S4:构建梯度损失函数,使教师网络和学生网络输入的鸟类图像的梯度信息相匹配;
步骤S5:梯度损失函数叠加学生网络交叉熵损失函数,训练学生网络。
2.根据权利要求1所述的基于梯度蒸馏的少样本学习的鸟类分类训练方法,其特征在于:所述步骤S2中,教师网络预测结果为ht (xi | θt),其中ht(·)表示教师网络的网络映射函数,xi表示第i张鸟类图像,θt表示教师网络的参数,下标t表示教师网络,支持集s、预测集q经教师网络后,分别得到特征向量如下:
其中xsc表示分类***支撑集s中类别为c的鸟类样本图像,c∈C,表示xsc经过教师网络后得到的特征向量的均值,即支持集类别为c的特征中心,kc表示抽取的类别为c的样本量,xqi表示预测集q中的鸟类样本图像,表示xqi经过教师网络后得到的特征向量,上标t表示教师网络。
3.根据权利要求1所述的基于梯度蒸馏的少样本学***方欧式距离匹配支撑集s和预测集q经过教师网络后得到的特征向量,获得预测集s的类别预测结果:
其中表示的平方欧式距离,exp(·)表示指数函数。
4.根据权利要求1所述的基于梯度蒸馏的少样本学习的鸟类分类训练方法,其特征在于:所述步骤S2中,利用预测集的类别预测结果和预测集类别真值构建教师网络交叉熵损失函数,训练教师网络:
其中表示预测集的类别预测结果,表示鸟类图像分类数据集的分类标签yi的独热编码形式,T表示矩阵的转置。
5.根据权利要求1所述的基于梯度蒸馏的少样本学习的鸟类分类训练方法,其特征在于:所述步骤S3中,单张鸟类图像传入教师网络,每一层的特征结果为,其中表示教师网络前j层的特征映射函数,表示教师网络前j层参数,下标t表示教师网络;学生网络每一层特征结果为,其中表示学生网络前j层的特征映射函数,表示学生网络前j层参数,下标s表示学生网络;鸟类图像经过教师网络、学生网络后得到各个网络层的特征,将各层特征的和作为损失函数值,并利用损失函数值反向传播,得到输入的鸟类样本图像基于该损失函数值的梯度信息:
其中表示教师或学生网络t/s的第j层特征,nj表示第j层特征的像素数,表示网络总共计算的特征层数,表示基于教师或学生网络特征的损失函数值,表示利用损失函数值进行梯度反向后,得到的教师或学生网络基于鸟类图像xi的梯度信息,表示求导操作。
6.根据权利要求1所述的基于梯度蒸馏的少样本学习的鸟类分类训练方法,其特征在于:所述步骤S4中,构建梯度损失函数,使教师网络和学生网络输入的鸟类样本图像的梯度信息相匹配:
其中表示取绝对值操作,||·||表示L1正则化操作,表示教师或学生网络基于鸟类图像的梯度信息。
7.根据权利要求1所述的基于梯度蒸馏的少样本学习的鸟类分类训练方法,其特征在于:所述步骤S5中,结合梯度损失函数和学生网络交叉熵损失函数,训练学生网路:
其中ls表示训练学生网络所用的总损失函数,表示学生网络交叉熵损失函数,α表示权重参数。
8.基于梯度蒸馏的少样本学习的鸟类分类方法,其特征在于:根据权利要求1所述的基于梯度蒸馏的少样本学习的鸟类分类训练方法,还包括步骤S6:将待分类的鸟类图像,输入训练好的学生网络,进行图像分类。
9.基于梯度蒸馏的少样本学习的鸟类分类装置,其特征在于,包括存储器和一个或多个处理器,所述存储器中存储有可执行代码,所述一个或多个处理器执行所述可执行代码时,用于实现权利要求8所述的基于梯度蒸馏的少样本学习的鸟类分类方法。
10.基于梯度蒸馏的少样本学习的鸟类分类***,包括鸟类图像分类数据集、教师网络和学生网络,其特征在于:
所述教师网络,是以原生网络作为训练框架,在训练阶段,从鸟类图像分类数据集抽取C个类别,每个类别K个样本,作为分类***的支撑集s;再从类别对应的剩余数据中,抽取样本作为分类***的预测集q;训练过程中,将支撑集s和预测集q经教师网络后,分别输出的特征向量进行匹配,得到预测集q的类别预测结果,并利用所述预测结果与预测集q的类别真值构建教师网络交叉熵损失函数,训练教师网络;
所述教师网络和学生网络,将鸟类图像经各个网络层进行特征提取并求和,利用各层特征的和,作为损失值反向传播,得到输入的鸟类图像基于损失值的梯度信息,再通过构建梯度损失函数,使教师网络和学生网络输入的鸟类图像的梯度信息相匹配;
所述学生网络,利用梯度损失函数叠加学生网络交叉熵损失函数进行训练,训练好后对鸟类图像进行分类。
CN202310202396.7A 2023-03-06 2023-03-06 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置 Active CN115908955B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310202396.7A CN115908955B (zh) 2023-03-06 2023-03-06 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310202396.7A CN115908955B (zh) 2023-03-06 2023-03-06 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置

Publications (2)

Publication Number Publication Date
CN115908955A true CN115908955A (zh) 2023-04-04
CN115908955B CN115908955B (zh) 2023-06-20

Family

ID=86496470

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310202396.7A Active CN115908955B (zh) 2023-03-06 2023-03-06 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置

Country Status (1)

Country Link
CN (1) CN115908955B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116152575A (zh) * 2023-04-18 2023-05-23 之江实验室 基于类激活采样引导的弱监督目标定位方法、装置和介质

Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
CN112766413A (zh) * 2021-02-05 2021-05-07 浙江农林大学 一种基于加权融合模型的鸟类分类方法及***
CN112784964A (zh) * 2021-01-27 2021-05-11 西安电子科技大学 基于桥接知识蒸馏卷积神经网络的图像分类方法
CN113222035A (zh) * 2021-05-20 2021-08-06 浙江大学 基于强化学***衡故障分类方法
WO2021197223A1 (zh) * 2020-11-13 2021-10-07 平安科技(深圳)有限公司 一种模型压缩方法、***、终端及存储介质
CN113705646A (zh) * 2021-08-18 2021-11-26 西安交通大学 基于半监督元学习的射频细微特征信息提取方法及***
US11200497B1 (en) * 2021-03-16 2021-12-14 Moffett Technologies Co., Limited System and method for knowledge-preserving neural network pruning
CN114330580A (zh) * 2021-12-31 2022-04-12 之江实验室 基于歧义指导互标签更新的鲁棒知识蒸馏方法
CN114863181A (zh) * 2022-05-19 2022-08-05 杭州登虹科技有限公司 一种基于预测概率知识蒸馏的性别分类方法和***
CN114912612A (zh) * 2021-06-25 2022-08-16 江苏大学 鸟类识别方法、装置、计算机设备及存储介质
CN115424288A (zh) * 2022-06-09 2022-12-02 南开大学 一种基于多维度关系建模的视觉Transformer自监督学习方法及***

Patent Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
WO2021197223A1 (zh) * 2020-11-13 2021-10-07 平安科技(深圳)有限公司 一种模型压缩方法、***、终端及存储介质
CN112784964A (zh) * 2021-01-27 2021-05-11 西安电子科技大学 基于桥接知识蒸馏卷积神经网络的图像分类方法
CN112766413A (zh) * 2021-02-05 2021-05-07 浙江农林大学 一种基于加权融合模型的鸟类分类方法及***
US11200497B1 (en) * 2021-03-16 2021-12-14 Moffett Technologies Co., Limited System and method for knowledge-preserving neural network pruning
CN113222035A (zh) * 2021-05-20 2021-08-06 浙江大学 基于强化学***衡故障分类方法
CN114912612A (zh) * 2021-06-25 2022-08-16 江苏大学 鸟类识别方法、装置、计算机设备及存储介质
CN113705646A (zh) * 2021-08-18 2021-11-26 西安交通大学 基于半监督元学习的射频细微特征信息提取方法及***
CN114330580A (zh) * 2021-12-31 2022-04-12 之江实验室 基于歧义指导互标签更新的鲁棒知识蒸馏方法
CN114863181A (zh) * 2022-05-19 2022-08-05 杭州登虹科技有限公司 一种基于预测概率知识蒸馏的性别分类方法和***
CN115424288A (zh) * 2022-06-09 2022-12-02 南开大学 一种基于多维度关系建模的视觉Transformer自监督学习方法及***

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
DMITRY MEDVEDEV: "Learning to Generate Synthetic Training Data using Gradient Matching and Implicit Differentiation", 《HTTP:ARXIV:2203.08559V1》 *
ZUNLEI FENG ETAL.: "Model Doctor: A Simple Gradient Aggregation Strategy for Diagnosing and Treating CNN Classifiers", 《HTTP:ARXIV:2112.04934V1》 *
李 东等: "基于知识蒸馏的轻量化残差网络在塑料颗粒高速检测中的运用", 《机电工程技术》 *

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116152575A (zh) * 2023-04-18 2023-05-23 之江实验室 基于类激活采样引导的弱监督目标定位方法、装置和介质

Also Published As

Publication number Publication date
CN115908955B (zh) 2023-06-20

Similar Documents

Publication Publication Date Title
Tingting et al. Three‐stage network for age estimation
Mou et al. Vehicle instance segmentation from aerial image and video using a multitask learning residual fully convolutional network
CN111353542B (zh) 图像分类模型的训练方法、装置、计算机设备和存储介质
Sahasrabudhe et al. Self-supervised nuclei segmentation in histopathological images using attention
US20220222918A1 (en) Image retrieval method and apparatus, storage medium, and device
CN111832440B (zh) 人脸特征提取模型的构建方法、计算机存储介质及设备
CN111898703B (zh) 多标签视频分类方法、模型训练方法、装置及介质
CN113378938B (zh) 一种基于边Transformer图神经网络的小样本图像分类方法及***
US20240054760A1 (en) Image detection method and apparatus
CN113392866A (zh) 一种基于人工智能的图像处理方法、装置及存储介质
CN109325513A (zh) 一种基于海量单类单幅图像的图像分类网络训练方法
CN115908955A (zh) 基于梯度蒸馏的少样本学习的鸟类分类***、方法与装置
CN115359074A (zh) 基于超体素聚类及原型优化的图像分割、训练方法及装置
Tan et al. Style interleaved learning for generalizable person re-identification
Munir et al. Multi branch siamese network for person re-identification
CN116777006A (zh) 基于样本缺失标签增强的多标签学习方法、装置和设备
Li et al. Lightweight automatic identification and location detection model of farmland pests
Mi et al. Principal component analysis based on block-norm minimization
CN117611838A (zh) 一种基于自适应超图卷积网络的多标签图像分类方法
Liang et al. Adaptive Cycle-consistent Adversarial Network for Malaria Blood Cell Image Synthetization
Li et al. Dynamic information enhancement for video classification
Lee et al. Learning non-homogenous textures and the unlearning problem with application to drusen detection in retinal images
CN116311504A (zh) 一种小样本行为识别方法、***及设备
Cao et al. Multi-label image recognition with two-stream dynamic graph convolution networks
CN114612961A (zh) 一种多源跨域表情识别方法、装置及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant