CN114997365A - 图像数据的知识蒸馏方法、装置、终端设备及存储介质 - Google Patents

图像数据的知识蒸馏方法、装置、终端设备及存储介质 Download PDF

Info

Publication number
CN114997365A
CN114997365A CN202210527719.5A CN202210527719A CN114997365A CN 114997365 A CN114997365 A CN 114997365A CN 202210527719 A CN202210527719 A CN 202210527719A CN 114997365 A CN114997365 A CN 114997365A
Authority
CN
China
Prior art keywords
network
output
level
feature
attention
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210527719.5A
Other languages
English (en)
Inventor
陈嘉莉
庞建新
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ubtech Robotics Corp
Original Assignee
Ubtech Robotics Corp
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ubtech Robotics Corp filed Critical Ubtech Robotics Corp
Priority to CN202210527719.5A priority Critical patent/CN114997365A/zh
Publication of CN114997365A publication Critical patent/CN114997365A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/7715Feature extraction, e.g. by transforming the feature space, e.g. multi-dimensional scaling [MDS]; Mappings, e.g. subspace methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • General Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Multimedia (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本申请适用于知识蒸馏技术领域,提供一种图像数据的知识蒸馏方法、装置、终端设备及存储介质,通过在利用图像数据和总损失函数使教师网络引导学生网络训练的过程中,采用基于注意力的特征蒸馏方式,对教师网络和学生网络的低、中层级输出的特征图的注意力图进行误差约束,采用基于相似度的特征约束方式,对教师网络和学生网络的高层级输出的特征图的特征向量进行相似度约束,根据注意力图损失函数、相似度损失函数及学生网络的分类层的分类损失函数更新总损失函数,利用图像数据和更新后的损失函数,使教师网络引导学生网络迭代训练,使得进行知识蒸馏后的学生网络的精度、速度及复杂度都能满足图像分类任务的性能要求。

Description

图像数据的知识蒸馏方法、装置、终端设备及存储介质
技术领域
本申请属于知识蒸馏技术领域,尤其涉及一种图像数据的知识蒸馏方法、装置、终端设备及存储介质。
背景技术
知识蒸馏(Knowledge Distillation)是一种在繁琐的大模型(教师网络)中提炼知识并将其压缩为单个小模型(学生网络)、利用大模型引导小模型提高性能的方法,教师网络的输出被用作训练学生网络的软目标(soft-target)。在实际应用中,由于计算力的限制,需要用到小模型来执行图像分类任务,但小模型的表现效果远低于大模型的效果,所以需要利用知识蒸馏方法来进行模型压缩。
利用传统的知识蒸馏方法进行图像分类时,使用最后的分类层的概率分布进行知识传递,然而分类层的训练数据会达到几十万乃至上千万量级,直接使用分类层进行知识蒸馏的效果较差。另外,传统的知识蒸馏方法的教师网络和学生网络一般采用相似的网络结构,而用于进行图像分类的大模型和小模型的网络结构差异较大,传统知识蒸馏方法对于这种异构网络结构的蒸馏效果较差。
发明内容
有鉴于此,本申请实施例提供了一种图像数据的知识蒸馏方法、装置、终端设备及存储介质,以解决利用传统的知识蒸馏方法进行图像分类,效果差的问题。
本申请实施例的第一方面提供一种图像数据的知识蒸馏方法,包括:
将图像数据输入已训练的教师网络和待训练的学生网络;
基于总损失函数,利用所述教师网络引导所述学生网络训练;
获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图;
对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数;
对所述教师网络的第m层级输出的特征图和所述学生网络的第m层级输出的特征图进行相似度约束,得到相似度损失函数;
根据所述学生网络的分类层的分类损失函数、所述第i注意力图损失函数及所述相似度损失函数,更新所述总损失函数,并返回执行所述基于总损失函数,利用所述教师网络引导所述学生网络训练的步骤;
其中,i=1,2,…,m-1,m为大于1的整数。
本申请实施例的第二方面提供一种图像数据的知识蒸馏装置,包括:
图像数据输入单元,用于将图像数据输入已训练的教师网络和待训练的学生网络;
知识蒸馏单元,用于基于总损失函数,利用所述教师网络引导所述学生网络训练;
注意力图获取单元,用于获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图;
误差约束单元,用于对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数;
相似度约束单元,用于对所述教师网络的第m层级输出的特征图和所述学生网络的第m层级输出的特征图进行相似度约束,得到相似度损失函数;
总损失函数更新单元,用于根据所述学生网络的分类层的分类损失函数、所述第i注意力图损失函数及所述相似度损失函数,更新所述总损失函数,并返回执行所述基于总损失函数,利用所述教师网络引导所述学生网络训练的步骤;
其中,i=1,2,…,m-1,m为大于1的整数。
本申请实施例的第三方面提供了一种终端设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如本申请实施例的第一方面提供的图像数据的知识蒸馏方法的步骤。
本申请实施例的第四方面提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如本申请实施例的第一方面提供的图像数据的知识蒸馏方法的步骤。
本申请实施例的第一方面提供的图像数据的知识蒸馏方法,通过在利用图像数据和总损失函数使教师网络引导学生网络训练的过程中,采用基于注意力的特征蒸馏方式,对教师网络和学生网络的低、中层级输出的特征图的注意力图进行误差约束,获得注意力图损失函数,采用基于相似度的特征约束方式,对教师网络和学生网络的高层级输出的特征图的特征向量进行相似度约束,获得相似度损失函数,根据注意力图损失函数、相似度损失函数及学生网络的分类层的分类损失函数更新总损失函数,利用图像数据和更新后的总损失函数,使教师网络引导学生网络迭代训练,使得进行知识蒸馏后的学生网络的精度、速度及复杂度都能满足图像分类任务的性能要求。
可以理解的是,上述第二方面至第四方面的有益效果可以参见上述第一方面中的相关描述,在此不再赘述。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的图像数据的知识蒸馏方法的第一种流程示意图;
图2是本申请实施例提供的图像数据的知识蒸馏方法的第二种流程示意图;
图3是本申请实施例提供的图像数据的知识蒸馏方法的第三种流程示意图;
图4是本申请实施例提供的图像数据的知识蒸馏方法的第四种流程示意图;
图5是本申请实施例提供的教师网络和学生网络的结构示意图;
图6是本申请实施例提供的知识蒸馏装置的结构示意图;
图7是本申请实施例提供的终端设备的结构示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定***结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的***、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
应当理解,当在本申请说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本申请说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
另外,在本申请说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
在本申请说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本申请的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。
本申请实施例提供的图像数据的知识蒸馏方法可以由终端设备的处理器在运行具有相应功能的计算机程序时执行,通过利用基于多个阶段的特征图的注意力(Attention)的特征蒸馏方式,获取教师网络和学生网络的低、中层级输出的特征图的注意力图(Attention Maps),使得学生网络的低、中层级能够更好地拟合教师网络的低、中层级的特征提取能力,通过对高层级的特征图进行基于相似度的特征蒸馏,让学生网络的高层级输出的特征图尽可能地去拟合教师网络的高层级输出的特征图,使得学生网络获得与教师网络接近的特征表达能力。
在应用中,图像数据是由多张图像组成的训练数据集。图像数据可以包括任意的需要进行分类或识别的多张图像,图像数据可以包含人脸图像、工业产品图像、动物图像、风景图像等中的至少一种。终端设备可以是具有图像分类、图像识别、机器视觉等中的至少一种功能的计算设备,例如,机器人、手机、个人数字助理(Personal Digital Assistant,PDA)、自助服务终端、摄像机、可穿戴设备、车载设备、增强现实(Augmented Reality,AR)/虚拟现实(Virtual Reality,VR)设备、平板电脑、笔记本电脑、个人计算机(PersonalComputer,PC)、上网本、服务器、门禁***、安防设备等,本申请实施例对终端设备的具体类型不作任何限制。
如图1所示,本申请实施例提供的图像数据的知识蒸馏方法,包括如下步骤S101至S106:
步骤S101、将图像数据输入已训练的教师网络和待训练的学生网络,进入步骤S102。
在应用中,图像数据可以从终端设备的本地存储空间,或者,与终端设备通信的任意其他设备(例如,服务器)的存储空间中获取,例如,从开放图库(Open Image)中获取多张图像构成图像数据。
在应用中,教师网络是已经训练完成,能够满足图像数据的分类或识别要求,也即分类层的分类损失函数收敛或满足精度要求的网络。教师网络可以采用重量级的卷积神经网络(Convolutional Neural Networks,CNN),例如,残差网络(Residual Network,ResNet)、循环神经网络(Recurrent Neural Network,RNN)、稠密卷积网络(DenseConvolutional Network,DenseNet)等。ResNet可以是ResNet 50、ResNet101或ResNet152。
在应用中,学生网络是需要利用知识蒸馏方法进行知识迁移学习,在教师网络的引导下学习图像数据的分类或识别功能的待训练的网络。学生网络可以采用轻量级的卷积神经网络,例如,移动端网络(Mobile Network,MobileNet)、洗牌网络(Shuffle Network,ShuffleNet)、轻量型目标检测网络(Thunder Network,ThunderNet)等。MobileNet可以是MobileNetv1、MobileNetv2或MobileNetv3。当终端设备用于进行人脸图像分类或人脸识别时,MobileNet可以是移动端人脸识别网络MobileFaceNet。ShuffleNet可以是ShuffleNetv1或ShuffleNetv2。
步骤S102、基于总损失函数,利用所述教师网络引导所述学生网络训练,进入步骤S103。
在应用中,教师网络引导学生网络进行迭代训练,训练目标为学生网络的分类层的分类损失函数收敛或满足精度要求。在第一个训练周期或第一阶段的训练周期中,损失函数可以仅由学生网络的分类层的分类损失函数构成,第一阶段的训练周期可以包括多个训练周期。学生网络的分类层的分类损失函数可以设置为与教师网络的分类层的分类损失函数相同,由于学生网络是通过迁移学习获得尽量贴近教师网络的性能,但不一定能完全达到与教师网络相同的性能,因此,学生网络的分类层的分类损失函数也可以设置为与教师网络的分类层的分类损失函数不同,相对于教师网络的分类层的分类损失函数的略大。分类损失函数可以是归一化指数函数(Softmax Faction),当终端设备用于进行人脸图像分类或人脸识别时,分类损失函数可以是CosFace损失函数、ArcFace损失函数等。
步骤S103、获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图,进入步骤S104。
在应用中,可以事先将教师网络和学生网络的堆叠的卷积层中划分出至少m个层级(stage),每个层级包括至少一层卷积层,划分出的这些层级都属于骨干网络(Backbone)。骨架网络用于对图像进行特征提取,获得具有高辨识度和鲁棒性的特征图,提取的这些特征图作为后续检测头网络(Detection Head)的输入。骨架网络可以基于普通卷积层、深度可分离卷积(Depthwise Separable Convolution)层、深度可分离空洞卷积(Depthwise Separable Dilated Convolution)层、批归一化(Batch Normalization)层、Mish激活函数(Mish Activation Function)层、平均池化(AveragePooling)层等构建。
在应用中,获取注意力图的具体方法为:
获取教师网络的第1层级输出的特征图的注意力图和学生网络的第1层级输出的特征图的注意力图;
获取教师网络的第2层级输出的特征图的注意力图和学生网络的第2层级输出的特征图的注意力图;
……;
依此类推,直到获取到教师网络的第m-1层级输出的特征图的注意力图和学生网络的第m-1层级输出的特征图的注意力图时为止,也即令i=1,2,…,m-1,m为大于1的整数。
在应用中,m的数值可以根据实际需要进行设置,例如,m=4。
在一个实施例中,所述教师网络和所述学生网络的低、中层级输出的特征图为局部特征图,所述教师网络和所述学生网络的高层级输出的特征图为全局特征图。
在应用中,学生网络的低、中层级(例如,第1至第m-1层级)用于学习提取图像的局部特征,高层级(例如,第m层级)用于学习图像的全局特征。当终端设备用于进行人脸图像分类或人脸识别时,局部特征可以是人脸关键点信息(例如,边角点、五官(左右两个眼睛的中心、鼻子和左右两个嘴角)),全局特征可是整个人脸的综合性特征。
在应用中,特征图可以包括多个维度,例如,批次(Batch)维度、通道(Channel)维度、空间维度。批次维度反应了特征图的批次大小,也即相应层级输出的特征图的数量。通道维度则反映了特征图的通道数,也即颜色通道的数量和每个颜色通道的位数,例如,特征图包含红(Red,R)、绿(Green,G)、蓝(Blue,B)三个颜色通道,每个颜色通道包括8位。空间维度反应了特征图的空间尺寸的大小,例如,高度(Hight)和宽度(Width)。当特征图包含批次维度、通道维度、空间维度这三个维度时,特征图的维度表示格式可以为(C,H,W)或(B,C,H,W);其中,B表示批次大小,C表示通道数,H表示高度,W表示宽度。
在应用中,基于特征图的维度,可以获取特征图在不同维度方向上的注意力图,例如,通道维度方向上的通道维度注意图和空间维度方向上的空间注意力图中的至少一种。
如图2所示,在一个实施例中,步骤S103包括如下步骤S201和S202:
步骤S201、在通道维度方向上,对所述教师网络的第i层级输出的特征图进行池化,得到所述教师网络的第i层级输出的特征图的通道维度注意力图;
步骤S202、在通道维度方向上,对所述学生网络的第i层级输出的特征图进行池化,得到所述学生网络的第i层级输出的特征图的通道维度注意力图。
在应用中,分别在通道维度方向上,对教师网络的第i层级输出的特征图和学生网络的第i层级输出的特征图进行池化,得到教师网络的第i层级输出的特征图的通道维度注意力图和学生网络的第i层级输出的特征图的通道维度注意力图。当特征图的维度表示格式为(C,H,W)时,通道维度注意力图的维度表示格式为(C,1,1),当特征图的维度表示格式为(B,C,H,W)时,通道维度注意力图的维度表示格式为(B,C,1,1)。
如图3所示,在一个实施例中,步骤S103包括如下步骤S301和S302:
步骤S301、在空间维度方向上,对所述教师网络的第i层级输出的特征图进行池化,得到所述教师网络的第i层级输出的特征图的空间维度注意力图;
步骤S302、在空间维度方向上,对所述学生网络的第i层级输出的特征图进行池化,得到所述学生网络的第i层级输出的特征图的空间维度注意力图。
在应用中,分别在空间维度方向上,对教师网络的第i层级输出的特征图和学生网络的第i层级输出的特征图进行池化(Pooling),得到教师网络的第i层级输出的特征图的空间维度注意力图和学生网络的第i层级输出的特征图的空间维度注意力图。当特征图的维度表示格式为(C,H,W)时,空间维度注意力图的维度表示格式为(1,H,W),当特征图的维度表示格式为(B,C,H,W)时,空间维度注意力图的维度表示格式为(B,1,H,W)。
在应用中,池化可以包括最大池化(max pooling)和平均池化(mean Pooling或average pooling)中的至少一种。池化的作用为对特征图进行降采样,以降低特征图的数据量大小,从而可以提高教师网络和学生网络对图像数据进行分类或识别时的处理速度。通道维度方向上的池化是为了降低通道数,空间维度方向上的池化是为了降低空间尺寸的大小。
步骤S104、对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数,进入步骤S105。
在应用中,依次计算教师网络和学生网络的相同层级输出的特征图的注意力图之间的误差,并采用相应的误差约束条件对每个层级的误差进行约束,得到每个层级对应的注意力图损失函数。
在应用中,获取注意力图损失函数的具体方法为:
对教师网络的第1层级输出的特征图的注意力图和学生网络的第1层级输出的特征图的注意力图进行误差约束,得到第1注意力图损失函数;
对教师网络的第2层级输出的特征图的注意力图和学生网络的第2层级输出的特征图的注意力图进行误差约束,得到第2注意力图损失函数;
……;
依此类推,直到完成对教师网络的第m-1层级输出的特征图的注意力图和学生网络的第m-1层级输出的特征图的注意力图进行误差约束,得到第m-1注意力图损失函数时为止。
如图2所示,在一个实施例中,步骤S104包括:
步骤S401、对所述教师网络的第i层级输出的特征图的通道维度注意力图和所述学生网络的第i层级输出的特征图的通道维度注意力图进行误差约束,得到第i通道维度注意力图损失函数。
在应用中,在通道维度方向上,依次获得教师网络的各层级输出的特征图的通道维度注意力图和学生网络的各层级输出的特征图的通道维度注意力图之后,即进一步获取教师网络的相应层级输出的特征图的通道维度注意力图和学生网络的相应层级输出的特征图的通道维度注意力图进行误差约束,得到相应层级对应的通道维度注意力图损失函数。
如图3所示,在一个实施例中,步骤S104包括:
步骤S402、对所述教师网络的第i层级输出的特征图的空间维度注意力图和所述学生网络的第i层级输出的特征图的空间维度注意力图进行误差约束,得到第i空间维度注意力图损失函数。
在应用中,在空间维度方向上,依次获得教师网络的各层级输出的特征图的空间维度注意力图和学生网络的各层级输出的特征图的空间维度注意力图之后,即进一步获取教师网络的相应层级输出的特征图的空间维度注意力图和学生网络的相应层级输出的特征图的空间维度注意力图进行误差约束,得到相应层级对应的空间维度注意力图损失函数。
在应用中,误差可以包括但不限于均方误差(Mean Square Error,MSE)、随机误差(Random Error)、粗大误差(Gross Conceptual Error,GSE)等中的至少一种。
步骤S105、对所述教师网络的第m层级输出的特征图和所述学生网络的第m层级输出的特征图进行相似度约束,得到相似度损失函数,进入步骤S106。
在应用中,在依次完成教师网络和学生网络靠前的第1至第m-1层级的输出的特征图的注意力图之间的误差约束之后,进一步地计算教师网络的第m层级输出的特征图和学生网络的第m层级输出的特征图之间的相似度,并采用相应的相似度约束条件对相似度进行约束,得到对应的相似度损失函数。
在应用中,相似度用于反应教师网络的第m层级输出的特征图和学生网络的第m层级输出的特征图之间的接近程度,相似度具体可以包括但不限于余弦相似度(CosineSimilarity)、欧氏距离(Euclidean Distance)、汉明距离(Hamming Distance)、马氏距离(Mahalanobis Distance)等中的至少一种。
如图4所示,在一个实施例中,步骤S105包括如下步骤S501至S504:
步骤S501、获取所述教师网络的第m层级输出的特征图的特征向量并进行归一化,得到第一归一化结果,进入步骤S503;
步骤S502、获取所述学生网络的第m层级输出的特征图的特征向量并进行归一化,得到第二归一化结果,进入步骤S503;
步骤S503、根据所述第一归一化结果和所述第二归一化结果,获取所述教师网络的第m层级输出的特征图的特征向量和所述学生网络的第m层级输出的特征图的特征向量之间的夹角的余弦值,进入步骤S504;
步骤S504、基于预设相似度约束条件对所述余弦值的大小进行约束,得到余弦相似度损失函数。
在应用中,当相似度为余弦相似度时,计算余弦相似度损失函数的方法为:先分别获取教师网络和学生网络的第m层级输出的特征图的特征向量并进行归一化,得到与教师网络对应的第一归一化结果和与学生网络对应的第二归一化结果;然后将第一归一化结果和第二归一化结果相乘,得到教师网络和学生网络的第m层级输出的特征图的特征向量之间的夹角的余弦值;最后利用预先设置的余弦相似度约束条件,对余弦值进行约束,得到余弦相似度损失函数。余弦相似度约束条件基于使得教师网络和学生网络的第m层级输出的特征图的特征向量之间的夹角越接近于0越好的目的设置,也即余弦值越接近于1越好,因此,余弦相似度条件可以设置为余弦值无限趋近于1。
在一个实施例中,所述余弦相似度损失函数的计算公式如下:
Figure BDA0003645237470000121
其中,Lm()表示所述余弦相似度损失函数,d()表示距离度量函数,
Figure BDA0003645237470000122
表示与所述教师网络的第m层级输出的特征图的特征向量对应的特征向量转换函数,
Figure BDA0003645237470000123
表示所述教师网络的第m层级输出的特征图的特征向量,
Figure BDA0003645237470000124
表示与所述学生网络的第m层级输出的特征图的特征向量对应的特征向量转换函数,
Figure BDA0003645237470000125
表示所述学生网络的第m层级输出的特征图的特征向量。
在应用中,特征向量转换函数用于对特征向量进行归一化处理,可以根据实际需要采用具有归一化功能的函数,例如,1×1卷积。
步骤S106、根据所述学生网络的分类层的分类损失函数、所述第i注意力图损失函数及所述相似度损失函数,更新所述总损失函数,并返回执行步骤S102。
在应用中,在获得第1至第m-1层级对应的注意力图损失函数和第m层级对应的相似度损失函数之后,根据学生网络的分类层的分类损失函数、第1至第m-1层级对应的注意力图损失函数及第m层级对应的相似度损失函数,计算新的损失函数,并将在当前的一个训练周期或当前阶段的训练周期中使用的损失函数更新为新的损失函数,以在下一个训练周期或下一阶段的训练周期中使用更新后的损失函数对学生网络进行训练。新的损失函数可以为学生网络的分类层的分类损失函数、第1至第m-1层级对应的注意力图损失函数及第m层级对应的相似度损失函数之和。
在一个实施例中,更新后的所述总损失函数的计算公式如下:
Ltotal=Lcis+Lattention+Lsimilarity
其中,Ltotal表示所述总损失函数,Lcis表示所述分类损失函数,Lattention表示第1至第m-1注意力图损失函数之和,Lsimilarity表示相似度损失函数。
如图5所示,示例性的示出了教师网络和学生网络的结构示意图;其中,左侧为教师网络,右侧为学生网络,标号1至m分别表示第1至第m层级。
在一个实施例中,步骤S106之后,包括:
在所述总损失函数收敛时,确定所述学生网络训练完成;
将待分类或识别的图像数据输入训练完成的所述学生网络,获取所述学生网络输出的图像识别或分类结果。
本申请实施例还提供一种图像数据的知识蒸馏装置,用于执行上述图像数据的知识蒸馏方法实施例中的步骤。图像数据的知识蒸馏装置可以是终端设备中的虚拟装置(virtual appliance),由终端设备的处理器运行,也可以是终端设备本身。
如图6所示,本申请实施例提供的图像数据的知识蒸馏装置100,包括:
图像数据输入单元101,用于将图像数据输入已训练的教师网络和待训练的学生网络;
知识蒸馏单元102,用于基于总损失函数,利用所述教师网络引导所述学生网络训练;
注意力图获取单元103,用于获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图;
误差约束单元104,用于对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数;
相似度约束单元105,用于对所述教师网络的第m层级输出的特征图和所述学生网络的第m层级输出的特征图进行相似度约束,得到相似度损失函数;
总损失函数更新单元106,用于根据所述学生网络的分类层的分类损失函数、所述第i注意力图损失函数及所述相似度损失函数,更新所述总损失函数,并返回图像数据输入单元101;
其中,i=1,2,…,m-1,m为大于1的整数。
在一个实施例中,所述知识蒸馏装置还包括:
确定单元,用于在所述总损失函数收敛时,确定所述学生网络训练完成;
图像处理单元,用于将待处理的图像数据输入训练完成的所述学生网络,获取所述学生网络输出的图像处理结果,所述待处理的图像数据包括待分类或识别的至少一张图像。
在应用中,图像数据的知识蒸馏装置中的各模块可以为软件程序模块,也可以通过处理器中集成的不同逻辑电路实现,还可以通过多个分布式处理器实现。
如图7所示,本申请实施例还提供一种终端设备200,包括:至少一个处理器201(图7中仅示出一个处理器)、存储器202以及存储在存储器202中并可在至少一个处理器201上运行的计算机程序203,处理器201执行计算机程序203时实现上述各个方法实施例中的步骤。
在应用中,终端设备可包括,但不仅限于,处理器、存储器。本领域技术人员可以理解,图7仅仅是终端设备的举例,并不构成对终端设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如还可以包括输入输出设备、网络接入设备等。
在应用中,处理器可以是中央处理单元(Central Processing Unit,CPU),该处理器还可以是其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
在应用中,存储器在一些实施例中可以是终端设备的内部存储单元,例如终端设备的硬盘或内存。存储器在另一些实施例中也可以是终端设备的外部存储设备,例如,终端设备上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(SecureDigital,SD)卡,闪存卡(Flash Card)等。进一步地,存储器还可以既包括终端设备的内部存储单元也包括外部存储设备。存储器用于存储操作***、应用程序、引导装载程序(BootLoader)、数据以及其他程序等,例如计算机程序的程序代码等。存储器还可以用于暂时地存储已经输出或者将要输出的数据。
需要说明的是,上述装置/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述***中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
本申请实施例还提供了一种网络设备,该网络设备包括:至少一个处理器、存储器以及存储在存储器中并可在至少一个处理器上运行的计算机程序,处理器执行计算机程序时实现上述各个方法实施例中的步骤。
本申请实施例还提供了一种计算机可读存储介质,计算机可读存储介质存储有计算机程序,计算机程序被处理器执行时实现可实现上述各个方法实施例中的步骤。
本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得终端设备执行时实现可实现上述各个方法实施例中的步骤。
集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,可以通过计算机程序来指令相关的硬件来完成,计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,计算机程序包括计算机程序代码,计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。计算机可读介质至少可以包括:能够将计算机程序代码携带到装置/终端设备的任何实体或装置、记录介质、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质。例如U盘、移动硬盘、磁碟或者光盘等。在某些司法管辖区,根据立法和专利实践,计算机可读介质不可以是电载波信号和电信信号。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的实施例中,应该理解到,所揭露的装置/网络设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/网络设备实施例仅仅是示意性的,例如,模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个***,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。

Claims (10)

1.一种图像数据的知识蒸馏方法,其特征在于,包括:
将图像数据输入已训练的教师网络和待训练的学生网络;
基于总损失函数,利用所述教师网络引导所述学生网络训练;
获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图;
对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数;
对所述教师网络的第m层级输出的特征图和所述学生网络的第m层级输出的特征图进行相似度约束,得到相似度损失函数;
根据所述学生网络的分类层的分类损失函数、所述第i注意力图损失函数及所述相似度损失函数,更新所述总损失函数,并返回执行所述基于总损失函数,利用所述教师网络引导所述学生网络训练的步骤;
其中,i=1,2,…,m-1,m为大于1的整数。
2.如权利要求1所述的图像数据的知识蒸馏方法,其特征在于,所述获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图,包括:
在通道维度方向上,对所述教师网络的第i层级输出的特征图进行池化,得到所述教师网络的第i层级输出的特征图的通道维度注意力图;
在通道维度方向上,对所述学生网络的第i层级输出的特征图进行池化,得到所述学生网络的第i层级输出的特征图的通道维度注意力图;
所述对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数,包括:
对所述教师网络的第i层级输出的特征图的通道维度注意力图和所述学生网络的第i层级输出的特征图的通道维度注意力图进行误差约束,得到第i通道维度注意力图损失函数。
3.如权利要求1所述的图像数据的知识蒸馏方法,其特征在于,所述获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图,包括:
在空间维度方向上,对所述教师网络的第i层级输出的特征图进行池化,得到所述教师网络的第i层级输出的特征图的空间维度注意力图;
在空间维度方向上,对所述学生网络的第i层级输出的特征图进行池化,得到所述学生网络的第i层级输出的特征图的空间维度注意力图;
所述对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数,包括:
对所述教师网络的第i层级输出的特征图的空间维度注意力图和所述学生网络的第i层级输出的特征图的空间维度注意力图进行误差约束,得到第i空间维度注意力图损失函数。
4.如权利要求1所述的图像数据的知识蒸馏方法,其特征在于,所述对所述教师网络的第m层级输出的特征图和所述学生网络的第m层级输出的特征图进行相似度约束,得到相似度损失函数,包括:
获取所述教师网络的第m层级输出的特征图的特征向量并进行归一化,得到第一归一化结果;
获取所述学生网络的第m层级输出的特征图的特征向量并进行归一化,得到第二归一化结果;
根据所述第一归一化结果和所述第二归一化结果,获取所述教师网络的第m层级输出的特征图的特征向量和所述学生网络的第m层级输出的特征图的特征向量之间的夹角的余弦值;
基于预设余弦相似度约束条件对所述余弦值的大小进行约束,得到余弦相似度损失函数。
5.如权利要求1至4任一项所述的图像数据的知识蒸馏方法,其特征在于,所述误差约束包括均方误差约束,所述相似度约束包括余弦相似度约束。
6.如权利要求1至4任一项所述的图像数据的知识蒸馏方法,其特征在于,所述教师网络和所述学生网络的低、中层级输出的特征图为局部特征图,所述教师网络和所述学生网络的高层级输出的特征图为全局特征图。
7.如权利要求1至4任一项所述的图像数据的知识蒸馏方法,其特征在于,所述图像数据包含多张人脸图像。
8.一种图像数据的知识蒸馏装置,其特征在于,包括:
图像数据输入单元,用于将图像数据输入已训练的教师网络和待训练的学生网络;
知识蒸馏单元,用于基于总损失函数,利用所述教师网络引导所述学生网络训练;
注意力图获取单元,用于获取所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图;
误差约束单元,用于对所述教师网络的第i层级输出的特征图的注意力图和所述学生网络的第i层级输出的特征图的注意力图进行误差约束,得到第i注意力图损失函数;
相似度约束单元,用于对所述教师网络的第m层级输出的特征图和所述学生网络的第m层级输出的特征图进行相似度约束,得到相似度损失函数;
总损失函数更新单元,用于根据所述学生网络的分类层的分类损失函数、所述第i注意力图损失函数及所述相似度损失函数,更新所述总损失函数,并返回执行所述基于总损失函数,利用所述教师网络引导所述学生网络训练的步骤;
其中,i=1,2,…,m-1,m为大于1的整数。
9.一种终端设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述图像数据的知识蒸馏方法的步骤。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述图像数据的知识蒸馏方法的步骤。
CN202210527719.5A 2022-05-16 2022-05-16 图像数据的知识蒸馏方法、装置、终端设备及存储介质 Pending CN114997365A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210527719.5A CN114997365A (zh) 2022-05-16 2022-05-16 图像数据的知识蒸馏方法、装置、终端设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210527719.5A CN114997365A (zh) 2022-05-16 2022-05-16 图像数据的知识蒸馏方法、装置、终端设备及存储介质

Publications (1)

Publication Number Publication Date
CN114997365A true CN114997365A (zh) 2022-09-02

Family

ID=83027286

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210527719.5A Pending CN114997365A (zh) 2022-05-16 2022-05-16 图像数据的知识蒸馏方法、装置、终端设备及存储介质

Country Status (1)

Country Link
CN (1) CN114997365A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116306868A (zh) * 2023-03-01 2023-06-23 支付宝(杭州)信息技术有限公司 一种模型的处理方法、装置及设备

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116306868A (zh) * 2023-03-01 2023-06-23 支付宝(杭州)信息技术有限公司 一种模型的处理方法、装置及设备
CN116306868B (zh) * 2023-03-01 2024-01-05 支付宝(杭州)信息技术有限公司 一种模型的处理方法、装置及设备

Similar Documents

Publication Publication Date Title
CN111797893B (zh) 一种神经网络的训练方法、图像分类***及相关设备
EP3968179A1 (en) Place recognition method and apparatus, model training method and apparatus for place recognition, and electronic device
EP3635629A1 (en) Fine-grained image recognition
CN110852311A (zh) 一种三维人手关键点定位方法及装置
CN111831844A (zh) 图像检索方法、图像检索装置、图像检索设备及介质
CN112084849A (zh) 图像识别方法和装置
WO2023124040A1 (zh) 一种人脸识别方法及装置
WO2021190433A1 (zh) 更新物体识别模型的方法和装置
CN115050064A (zh) 人脸活体检测方法、装置、设备及介质
CN110717405B (zh) 人脸特征点定位方法、装置、介质及电子设备
WO2022127333A1 (zh) 图像分割模型的训练方法、图像分割方法、装置、设备
CN114997365A (zh) 图像数据的知识蒸馏方法、装置、终端设备及存储介质
CN111368860B (zh) 重定位方法及终端设备
EP4060526A1 (en) Text processing method and device
CN113592015A (zh) 定位以及训练特征匹配网络的方法和装置
CN113159053A (zh) 图像识别方法、装置及计算设备
CN112488054A (zh) 一种人脸识别方法、装置、终端设备及存储介质
CN113191364B (zh) 车辆外观部件识别方法、装置、电子设备和介质
CN114972146A (zh) 基于生成对抗式双通道权重分配的图像融合方法及装置
CN111275183B (zh) 视觉任务的处理方法、装置和电子***
CN113780066A (zh) 行人重识别方法、装置、电子设备及可读存储介质
CN112949672A (zh) 商品识别方法、装置、设备以及计算机可读存储介质
CN112131902A (zh) 闭环检测方法及装置、存储介质和电子设备
CN116912502B (zh) 全局视角辅助下图像关键解剖结构的分割方法及其设备
CN111597375B (zh) 基于相似图片组代表特征向量的图片检索方法及相关设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination