CN110674880B - 用于知识蒸馏的网络训练方法、装置、介质与电子设备 - Google Patents

用于知识蒸馏的网络训练方法、装置、介质与电子设备 Download PDF

Info

Publication number
CN110674880B
CN110674880B CN201910923038.9A CN201910923038A CN110674880B CN 110674880 B CN110674880 B CN 110674880B CN 201910923038 A CN201910923038 A CN 201910923038A CN 110674880 B CN110674880 B CN 110674880B
Authority
CN
China
Prior art keywords
data
network
student
teacher
sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN201910923038.9A
Other languages
English (en)
Other versions
CN110674880A (zh
Inventor
田野
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Megvii Technology Co Ltd
Original Assignee
Beijing Megvii Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Megvii Technology Co Ltd filed Critical Beijing Megvii Technology Co Ltd
Priority to CN201910923038.9A priority Critical patent/CN110674880B/zh
Publication of CN110674880A publication Critical patent/CN110674880A/zh
Application granted granted Critical
Publication of CN110674880B publication Critical patent/CN110674880B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本公开提供了一种用于知识蒸馏的网络训练方法、装置、存储介质与电子设备,涉及人工智能技术领域。该方法包括:将样本数据输入教师网络,获得所述样本数据对应的软标签数据,将所述样本数据输入学生网络,获得所述样本数据对应的预测数据;基于所述预测数据、所述软标签数据和所述样本数据对应的硬标签数据,构建损失函数;根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数。本公开可以对教师网络和学生网络同步训练,降低学生网络对于教师网络的依赖度,提高学生网络的训练效果,并且可以加速训练过程,提高效率。

Description

用于知识蒸馏的网络训练方法、装置、介质与电子设备
技术领域
本公开涉及人工智能技术领域,尤其涉及一种用于知识蒸馏的网络训练方法、用于知识蒸馏的网络训练装置、计算机可读存储介质与电子设备。
背景技术
深度学习作为人工智能领域的一个重要分支,近年来得到了快速的发展,出现了很多改进的深度学习方法,其中就包括知识蒸馏(Knowledge Distillation)。
知识蒸馏是模型压缩技术的一种具体实现方法,引入教师网络和学生网络,教师网络是相对复杂的网络模型,学生网络是相对精简的网络模型,利用样本数据训练教师网络,再以教师网络的输出训练学生网络,从而在学生网络上实现教师网络的处理功能,达到网络模型的精简等目的。
然而,在现有的知识蒸馏方法中,学生网络的训练极大地依赖于教师网络的质量,且由于网络设计、参数初值等因素的影响,学生网络可能无法很好的适应教师网络,这些问题都不利于学生网络训练的进行,导致无法得到高质量的网络模型。
需要说明的是,在上述背景技术部分公开的信息仅用于加强对本公开的背景的理解,因此可以包括不构成对本领域普通技术人员已知的现有技术的信息。
发明内容
本公开提供了一种用于知识蒸馏的网络训练方法、用于知识蒸馏的网络训练装置、计算机可读存储介质与电子设备,进而至少在一定程度上改善现有技术中存在的学生网络训练效果较差的问题。
本公开的其他特性和优点将通过下面的详细描述变得显然,或部分地通过本公开的实践而习得。
根据本公开的第一方面,提供一种用于知识蒸馏的网络训练方法,包括:将样本数据输入教师网络,获得所述样本数据对应的软标签数据,将所述样本数据输入学生网络,获得所述样本数据对应的预测数据;基于所述预测数据、所述软标签数据和所述硬标签数据之间的误差,构建损失函数;根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数。
可选的,所述软标签数据包括通过所述教师网络对所述样本数据进行分类得到的第一概率数据,所述预测数据包括通过所述学生网络对所述样本数据进行分类得到的第二概率数据。
可选的,所述基于所述预测数据、所述软标签数据和所述样本数据对应的硬标签数据,构建损失函数,包括:根据所述预测数据和所述硬标签数据,构建第一子损失;根据所述预测数据和所述软标签数据,构建第二子损失;根据所述第一子损失和所述第二子损失,确定所述损失函数。
可选的,所述样本数据包括正样本,所述根据所述预测数据和所述软标签数据,构建第二子损失,包括:根据所述正样本对应的预测数据和所述正样本对应的软标签数据,构建所述第二子损失。
可选的,所述根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数,包括:根据所述损失函数和所述正样本对应的预测数据,更新所述学生网络中的参数;根据所述损失函数和所述正样本对应的软标签数据,更新所述教师网络中的参数。
可选的,所述正样本对应的预测数据包括对所述正样本的学生预测值和所述学生预测值对应的概率,所述根据所述损失函数和所述正样本对应的预测数据,更新所述学生网络中的参数,包括:根据所述损失函数对所述学生预测值的梯度,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1。
可选的,所述正样本对应的软标签数据包括对所述正样本的教师预测值和所述教师预测值对应的概率,所述根据所述损失函数对所述学生预测值的梯度,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1,包括:根据所述损失函数对所述学生预测值的梯度,以及所述学生预测值和所述教师预测值之间的误差,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1和所述教师预测值对应的概率。
可选的,所述根据所述损失函数和所述正样本对应的软标签数据,更新所述教师网络中的参数,包括:根据所述损失函数对所述教师预测值的梯度,更新所述教师网络中的参数,使所述教师预测值对应的概率趋近于1。
可选的,所述损失函数为:
Figure BDA0002218172530000031
Figure BDA0002218172530000032
其中,L为所述损失函数,i表示所述硬标签数据的类别,yi为第i类硬标签数据,
Figure BDA0002218172530000033
为第i类硬标签数据对应的预测数据,
Figure BDA0002218172530000034
为第i类硬标签数据对应的软标签数据;∈为经验参数,min(yi)<∈<max(yi);a、b、c均为非负的权重参数,b不为0,且a和c中至少一个不为0。
根据本公开的第二方面,提供一种用于知识蒸馏的网络训练装置,包括:处理模块,用于将样本数据输入教师网络,获得所述样本数据对应的软标签数据,将所述样本数据输入学生网络,获得所述样本数据对应的预测数据;构建模块,用于基于所述预测数据、所述软标签数据和所述硬标签数据之间的误差,构建损失函数;训练模块,用于根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数。
可选的,所述软标签数据包括通过所述教师网络对所述样本数据进行分类得到的第一概率数据,所述预测数据包括通过所述学生网络对所述样本数据进行分类得到的第二概率数据。
可选的,所述构建模块包括:第一子损失单元,用于根据所述预测数据和所述硬标签数据,构建第一子损失;第二子损失单元,用于根据所述预测数据和所述软标签数据,构建第二子损失;损失函数确定单元,用于根据所述第一子损失和所述第二子损失,确定所述损失函数。
可选的,所述样本数据包括正样本,所述第二子损失单元,还用于根据所述正样本对应的预测数据和所述正样本对应的软标签数据,构建所述第二子损失。
可选的,所述训练模块包括:学生网络训练单元,用于根据所述损失函数和所述正样本对应的预测数据,更新所述学生网络中的参数;教师网络训练单元,用于根据所述损失函数和所述正样本对应的软标签数据,更新所述教师网络中的参数。
可选的,所述正样本对应的预测数据包括对所述正样本的学生预测值和所述学生预测值对应的概率,所述学生网络训练单元,还用于根据所述损失函数对所述学生预测值的梯度,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1。
可选的,所述正样本对应的软标签数据包括对所述正样本的教师预测值和所述教师预测值对应的概率,所述学生网络训练单元,还用于根据所述损失函数对所述学生预测值的梯度,以及所述学生预测值和所述教师预测值之间的误差,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1和所述教师预测值对应的概率。
可选的,所述教师网络训练单元,还用于根据所述损失函数对所述教师预测值的梯度,更新所述教师网络中的参数,使所述教师预测值对应的概率趋近于1。
可选的,所述损失函数为:
Figure BDA0002218172530000041
Figure BDA0002218172530000042
其中,L为所述损失函数,i表示所述硬标签数据的类别,yi为第i类硬标签数据,
Figure BDA0002218172530000043
为第i类硬标签数据对应的预测数据,
Figure BDA0002218172530000044
为第i类硬标签数据对应的软标签数据;∈为经验参数,min(yi)<∈<max(yi);a、b、c均为非负的权重参数,b不为0,且a和c中至少一个不为0。
根据本公开的第三方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述任意一种网络训练方法。
根据本公开的第四方面,提供一种电子设备,包括:处理器;以及存储器,用于存储所述处理器的可执行指令;其中,所述处理器配置为经由执行所述可执行指令来执行上述任意一种网络训练方法。
本公开的技术方案具有以下有益效果:
根据上述用于知识蒸馏的网络训练方法、装置、存储介质和电子设备,在获取样本数据和硬标签数据后,通过教师网络处理样本数据得到对应的软标签数据,通过学生网络处理样本数据得到对应的预测数据,再基于预测数据、软标签数据和硬标签数据构建损失函数,通过损失函数更新教师网络和学生网络中的参数。一方面,在知识蒸馏的模型中,对教师网络和学生网络同步训练,使得教师网络可以在训练过程中进一步优化,降低学生网络对于教师网络的依赖度,特别是教师网络的初始质量较差,或者学生网络对于教师网络的适应性较差的情况,可以通过同步训练而逐渐改善,提高学生网络的训练效果。另一方面,在训练过程中,学生网络和教师网络的拟合是双向进行的,不仅学生网络会逐渐拟合教师网络,教师网络也会不断的调整以利于学生网络的拟合,从而加速训练过程,提高效率。再一方面,通过上述方式有利于得到高质量的学生网络,可以部署在客户端等轻量级应用的场景中,提供图像分类、图像识别等服务,且有利于提高服务质量。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施方式,并与说明书一起用于解释本公开的原理。显而易见地,下面描述中的附图仅仅是本公开的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出本示例性实施方式中一种用于知识蒸馏的网络训练方法的流程图;
图2示出本示例性实施方式中一种用于知识蒸馏的网络训练方法的子流程图;
图3示出本示例性实施方式中一种数据处理流程的示意图;
图4示出本示例性实施方式中一种用于知识蒸馏的网络训练装置的结构框图;
图5示出本示例性实施方式中一种用于实现上述方法的计算机可读存储介质;
图6示出本示例性实施方式中一种用于实现上述方法的电子设备。
具体实施方式
现在将参考附图更全面地描述示例实施方式。然而,示例实施方式能够以多种形式实施,且不应被理解为限于在此阐述的范例;相反,提供这些实施方式使得本公开将更加全面和完整,并将示例实施方式的构思全面地传达给本领域的技术人员。所描述的特征、结构或特性可以以任何合适的方式结合在一个或更多实施方式中。在下面的描述中,提供许多具体细节从而给出对本公开的实施方式的充分理解。然而,本领域技术人员将意识到,可以实践本公开的技术方案而省略所述特定细节中的一个或更多,或者可以采用其它的方法、组元、装置、步骤等。在其它情况下,不详细示出或描述公知技术方案以避免喧宾夺主而使得本公开的各方面变得模糊。
此外,附图仅为本公开的示意性图解,并非一定是按比例绘制。图中相同的附图标记表示相同或类似的部分,因而将省略对它们的重复描述。附图中所示的一些方框图是功能实体,不一定必须与物理或逻辑上独立的实体相对应。可以采用软件形式来实现这些功能实体,或在一个或多个硬件模块或集成电路中实现这些功能实体,或在不同网络和/或处理器装置和/或微控制器装置中实现这些功能实体。
本公开的示例性实施方式首先提供了一种用于知识蒸馏的网络训练方法。图1示出了该方法的一种流程,可以包括以下步骤S110~S140:
步骤S110,将样本数据输入教师网络,获得样本数据对应的软标签数据,将样本数据输入学生网络,获得样本数据对应的预测数据。
其中,教师网络和学生网络是知识蒸馏中的一组神经网络模型,是待训练的网络,可以是程序人员初始设计的网络,也可以是任意现有的网络;或者教师网络可以是经过预训练的网络,学生网络是经过重置的网络。举例来说:将知识蒸馏应用于性别识别的场景中,可以采用flop数(表示网络模型的计算量)为54M的ResNet(Residual Neural Network,残差神经网络)作为教师网络,该网络包括5个stage(表示层),每个stage分别包括2,4,4,4,2个bottleneck(瓶颈结构,表示跨越多层的直连)形式的block(残差块),其输出通道数分别为16,32,64,128,256;每个block内部残差分支包括3个conv layer(卷积层),其中前两个conv layer的输出通道数与输入通道数相同,最后一个conv layer的输出通道数与所在block的输出通道数相同;对应的,采用flop数为20M的ResNet作为学生网络,与教师网络的stage数,每个stage包括的block数,以及每个stage的输出通道数相同;学生网络每个block内部残差分支也包括3个conv layer,将前两个conv layer的输出通道数设置为输入通道数的一半,最后一个conv layer的输出通道数与所在block的输出通道数相同;对教师网络和学生网络随机初始化参数,获得上述待训练的教师网络和学生网络。
样本数据用于训练教师网络和学生网络,本示例性实施方式中,样本数据是训练中所用的输入数据,硬标签数据即通常所说的标签数据,是真实标签。样本数据可以根据教师网络和学生网络中输入层的数据格式进行预处理,得到规则化的样本数据,可以通过人工对样本数据打标签的方式获得硬标签数据,也可以从现有的数据集中获取样本数据和硬标签数据。样本数据和硬标签数据的具体内容与教师网络和学生网络的具体应用场景相关,例如:在对象分类的应用场景中,样本数据可以是预先选取的样本对象的特征数据,硬标签数据可以是样本对象的分类标签;更具体的,在图像分类的应用场景中,样本数据可以是样本图片,硬标签数据可以是样本图片的分类标签。
本示例性实施方式中,软标签数据包括教师网络对输入的样本数据进行处理后得到的概率数据,实质上是教师网络根据样本数据得到的预测数据,在知识蒸馏中,该预测数据可以用于学生网络的训练,其性质接近于标签数据,但与真实标签存在不同,因此称为软标签数据,与上述硬标签数据相对应。预测数据包括学生网络对输入的样本数据进行处理后得到的概率数据。以上述对象分类的应用场景为例,软标签数据可以是教师网络对样本对象进行分类的第一概率数据,表示教师网络判定样本对象属于每个分类的概率,预测数据可以是学生网络对样本对象进行分类的第二概率数据,表示学生网络判定样本对象属于每个分类的概率;更具体的,在图像分类的应用场景中,软标签数据可以是教师网络识别样本图片中是否存在目标对象(如猫、狗或汽车等)的概率,预测数据可以是学生网络识别样本图片中是否存在目标对象的概率。软标签数据和预测数据可以是教师网络和学生网络的全连接层或Softmax(归一化指数函数)层输出的概率值等。
需要说明的是,若教师网络中设置Softmax层,则可以引入知识蒸馏中的温度参数,在进行Softmax计算前,将前一层(通常是全连接层)输出的中间数据除以温度参数,再进行Softmax计算,得到软标签数据。
需要补充的是,若样本数据包括多对样本数据和硬标签数据,则可以将多组样本数据输入教师网络和学生网络,例如一个batch(表示一次训练迭代中所用的样本数据,通常是32、64或128组样本数据等),分别得到软标签数据的数组和预测数据的数组。
步骤S120,基于预测数据、软标签数据和硬标签数据之间的误差,构建损失函数。
理想的学生网络和教师网络下,其处理样本数据得到的预测数据、软标签数据和硬标签数据一致或接近一致,即在样本数据上达到较高的准确率。因此,可以基于预测数据、软标签数据和硬标签数据之间的误差,构建损失函数,具体而言,可以构建如下三部分误差:
第一子损失,为预测数据和硬标签数据之间的误差;
第二子损失,为预测数据和软标签数据之间的误差;
第三子损失,为软标签数据和硬标签数据之间的误差。
本示例性实施方式中,为了对教师网络和学生网络进行同步训练,可以从上述三个误差项中选取至少两个,相加或加权相加后得到损失函数。例如,采用平方损失函数计算误差,可以得到损失函数如下:
Figure BDA0002218172530000081
其中,L为损失函数,j表示样本数据的序数,即第j个样本数据,
Figure BDA0002218172530000082
为第j个样本数据对应的预测数据,
Figure BDA0002218172530000083
为第j个样本数据对应的软标签数据,yj为第j个样本数据对应的硬标签数据;
Figure BDA0002218172530000084
为第一子损失,
Figure BDA0002218172530000085
为第二子损失,
Figure BDA0002218172530000086
为第三子损失;a、b、c均为非负的权重参数,分别对应于第一、第二和第三子损失的权重,可以根据经验或实际需要设定其值,a、b、c中最多只有一个为0,在训练过程中也可以根据实际训练情况调整权重参数。
在一种可选的实施方式中,步骤S120可以包括:
根据预测数据和硬标签数据,构建第一子损失;
根据预测数据和软标签数据,构建第二子损失;
根据第一子损失和第二子损失,确定损失函数。
即利用上述第一子损失和第二子损失构建损失函数,以体现出学生网络、教师网络和硬标签数据之间的误差。
步骤S130,根据损失函数更新教师网络中的参数和学生网络中的参数。
本示例性实施方式中,损失函数直接或间接地体现预测数据和硬标签数据之间的误差、以及软标签数据和硬标签数据之间的误差,因此根据损失函数可以更新优化教师网络和学生网络中的参数,可以同时减小上述两部分误差,即可以对教师网络和学生网络进行同步训练。具体而言,在训练过程中的每次迭代中,通过最小化损失函数或以其他方式调整损失函数的值,对应更新教师网络和学生网络的参数,通过对教师网络和/或学生网络进行多次迭代训练,以逐步调整教师网络和学生网络的参数值,趋向于拟合,该训练过程即监督学习的过程。
基于上述内容,本示例性实施方式中,获取样本数据和硬标签数据后,通过教师网络处理样本数据得到对应的软标签数据,通过学生网络处理样本数据得到对应的预测数据,再基于预测数据、软标签数据和硬标签数据构建损失函数,通过损失函数更新教师网络和学生网络中的参数。一方面,在知识蒸馏的模型中,对教师网络和学生网络同步训练,使得教师网络可以在训练过程中进一步优化,降低学生网络对于教师网络的依赖度,特别是教师网络的初始质量较差,或者学生网络对于教师网络的适应性较差的情况,可以通过同步训练而逐渐改善,提高学生网络的训练效果。另一方面,在训练过程中,学生网络和教师网络的拟合是双向进行的,不仅学生网络会逐渐拟合教师网络,教师网络也会不断的调整以利于学生网络的拟合,从而加速训练过程,提高效率。再一方面,通过上述方式有利于得到高质量的学生网络,可以部署在客户端等轻量级应用的场景中,提供图像分类、图像识别等服务,且有利于提高服务质量。
在一种可选的实施方式中,考虑到教师网络和学生网络中采用Softmax层对输出数据进行处理,为了提高预测的准确度,可以采用交叉熵计算误差。同时为了简化损失函数,仅保留上述第一和第二子损失,则损失函数可以由第一子损失和第二子损失组成,第一子损失可以是预测数据和硬标签数据之间的交叉熵,第二子损失可以是预测数据和软标签数据之间的交叉熵。以二分类问题为例,硬标签数据为0或1,则损失函数可以表示如下:
L1=-(1*logpA+0*log(1-pA))=-logpA; (2)
L2=-(pT·logpA+(1-pT)·log(1-pA)); (3)
L=L1+L2=-(logpA+pT·logpA+(1-pT)·log(1-pA)); (4)
其中,L1和L2分别为上述第一子损失和第二子损失,pA为预测数据,pT为软标签数据。该损失函数可以更加准确地体现出教师网络和学生网络的综合误差状况,且更加易于计算。
在一种可选的实施方式中,在将样本数据输入教师网络和学生网络后,若教师网络的输出、学生网络的输出和该样本数据对应的硬标签数据一致,则将该样本数据确定为正样本数据;由此可以从样本数据中获取多组正样本数据及其硬标签数据,并根据正样本对应的预测数据和正样本对应的软标签数据,构建第二子损失,使第二子损失为正样本数据对应的预测数据和正样本数据对应的软标签数据之间的误差(如交叉熵)。在正样本数据下表征两个网络之间的误差,有利于后续加速网络模型的拟合,提高训练的效率。
在一种可选的实施方式中,参考图2所示,步骤S130可以具体包括以下步骤S210和S220:
步骤S210,根据损失函数和正样本对应的预测数据,更新学生网络中的参数;
步骤S220,根据损失函数和正样本对应的软标签数据,更新教师网络中的参数。
其中,可以利用损失函数对预测数据求导,得到损失函数的第一梯度,基于第一梯度的下降,更新学生网络中的参数,以及利用损失函数对软标签数据求导,得到损失函数的第二梯度,基于第二梯度的下降,更新教师网络中的参数。其中,第一梯度为针对于学生网络的梯度,利用梯度下降法可以训练学生网络;第二梯度为针对于教师网络的梯度,利用梯度下降法可以训练教师网络。以上述损失函数(4)为例,分别对预测数据pA和软标签数据pT求导,可以得到如下第一梯度Grad1和第二梯度Grad2:
Figure BDA0002218172530000111
Figure BDA0002218172530000112
由此,第一梯度和第二梯度分别给出了学生网络和教师网络的训练调整方向,进而可以分别调整两网络中的参数。
在一种可选的实施方式中,正样本对应的预测数据可以包括对正样本的学生预测值和该学生预测值对应的概率,学生预测值即学生网络对正样本进行处理后输出的结果。步骤S210可以包括:根据损失函数对学生预测值的梯度,更新学生网络中的参数,使学生预测值对应的概率趋近于1。其中,该梯度可以如上述第一梯度,是针对于学生网络的梯度,利用梯度下降更新学生网络中的参数,可以实现损失函数的优化,使学生预测值对应的概率趋近于1,即在正样本上不断拟合真实标签,且预测准确度不断提高。
此外,正样本对应的软标签数据可以包括对正样本的教师预测值和该教师预测值对应的概率,教师预测值即教师网络对正样本进行处理后输出的结果。在上述训练学生网络的过程中,也可以根据损失函数对学生预测值的梯度,以及学生预测值和教师预测值之间的误差,更新学生网络中的参数,使学生预测值对应的概率趋近于1和教师预测值对应的概率。这样学生网络可以同时向真实标签和教师网络拟合。
在一种可选的实施方式中,步骤S220可以包括:根据损失函数对教师预测值的梯度,更新教师网络中的参数,使教师预测值对应的概率趋近于1。其中,该梯度可以如上述第二梯度,是针对于教师网络的梯度,利用梯度下降更新教师网络中的参数,可以实现损失函数的优化,使教师预测值对应的概率趋近于1,即在正样本上不断拟合真实标签,且预测准确度不断提高。作为补充的,在通过梯度下降训练教师网络时,也可以结合学生预测值和教师预测值之间的误差,使教师网络同时向学生网络拟合,可以进一步加速训练。
以在二分类问题为例,硬标签数据包括两个类别,为0或1。从样本数据中获取第一类正样本和第二类正样本,第一类正样本和第二类正样本满足以下关系:
y(x1)=1; (7)
pA(x1)>1-pA(x1),pT(x1)>1-pT(x1);(8)
y(x2)=0; (9)
pA(x2)<1-pA(x2),pT(x2)<1-pT(x2);(10)
其中,x1为第一类正样本,x2为第二类正样本,y(x1)为x1对应的硬标签数据,y(x2)为x2对应的硬标签数据,pA(x1)为x1对应的预测数据,pA(x2)为x2对应的预测数据。换而言之,x1实际上是学生网络正确分类的正标签(1)样本数据,x2是学生网络正确分类的负标签(0)样本数据。基于此,可以根据pA(x1)和pT(x1)分别计算第一梯度和第二梯度的下降,以分别更新学生网络和教师网络中的参数,使pA(x1)趋近于pT(x1)和y(x1),且pT(x1)趋近于pA(x1)和y(x1);并根据pA(x2)和pT(x2)分别计算第一梯度和第二梯度的下降,以分别更新学生网络和教师网络中的参数,使pA(x2)趋近于pT(x2)和y(x2),且pT(x2)趋近于pA(x2)和y(x2)。
其中,pT(x1)为x1对应的软标签数据,pT(x2)为x2对应的软标签数据。需要说明的是,上述两个步骤分别是在x1和x2上进行的调整,可以是相互独立的步骤。下面通过x1上的调整过程做进一步说明,x2上的调整过程基本相同,因而不再赘述。
在第一梯度下降的过程中,学生网络主要做两点调整:逐渐增大pA(x1),使其趋近于y(x1)(即1);使pA(x1)逐渐趋近于pT(x1),即希望学生网络逐渐拟合教师网络。相当于学生网络同时受到硬标签数据和教师网络的监督。
在第二梯度下降的过程中,教师网络主要是在x1上逐渐增大pT(x1),使pT(x1)逐渐趋近于y(x1)(即1),同时也逐渐拟合pA(x1)。
教师网络所采用的第二梯度为对数函数,相比于学生网络,教师网络在梯度下降的过程中较为稳定,其调整量小于学生网络,这符合知识蒸馏本身的原理。随着训练过程的进行,第一类正样本逐渐接近正例样本,即两个数据集重合度越来越高,教师网络接受到的监督信息越来越精确,反过来也会给学生网络提供越来越准确的监督信息。
在一种可选的实施方式中,可以将交叉熵形式的误差进一步优化,以体现出在各个类别样本上的信息。基于此,损失函数可以表示如下:
Figure BDA0002218172530000131
其中,各项参数的含义与损失函数(1)中基本相同,需要注意的是,此处的i表示硬标签数据的类别,yi为第i类硬标签数据,
Figure BDA0002218172530000133
为第i类硬标签数据对应的预测数据,
Figure BDA0002218172530000134
为第i类硬标签数据对应的软标签数据。特别的,∈为经验参数,在硬标签数据的最大值与最小值之间,即满足min(yi)<∈<max(yi),例如当硬标签数据为0或1时,∈满足0<∈<1,且通常是值较小的正数,防止误差值在负例样本处溢出,例如可以为10-6;三个权重参数中,b不为0,即损失函数包含第二子损失,且a和c中至少一个不为0,即损失函数包含第一和第三子损失的至少一个。
通过上述损失函数(11),在训练网络的过程中,损失函数每次都针对于拟合程度较差的样本类别(如正例或负例)进行优化,而不仅仅是优化正例样本,最终在各个类别上都能够很好地实现拟合,进一步提高教师网络和学生网络的质量。
在一种可选的实施方式中,也可以采用损失函数(11)的变化形式,参考图3所示,在通过教师网络(Teacher Network)和学生网络(Apprentice Network)处理样本数据(Input Data)x时,将x分别输入教师网络和学生网络,在Softmax层之前得到的中间数据(通常是全连接层的输出)分别为zT和zA,可以和温度参数T一起进行Softmax计算,即知识蒸馏的处理过程,得到软标签数据pT,教师网络实际的输出为zT进行常规Softmax计算后得到的数据qT;学生网络的中间数据zA和温度参数T一起进行Softmax计算得到预测数据pA,zA进行常规Softmax计算后输出数据qA。因此,可以基于pT、qT、pA、qA以及硬标签数据(HardLabel)y计算损失函数Loss,如下所示:
Figure BDA0002218172530000132
换而言之,第一和第三子损失中的用于拟合的输出数据可以采用常规Softmax输出的数据,更有利于教师网络和学生网络向硬标签数据的拟合。
在上述性别识别的场景中,采用损失函数(11)结合第一梯度(5)和第二梯度(6)的训练策略,可以将20M的学生网络ResNet的FPR(False Positive Rate,假阳率)从0.63%降低到0.47%,优于现有技术的水平。
本公开的示例性实施方式还提供了一种用于知识蒸馏的网络训练装置,如图4所示,该网络训练装置400可以包括:处理模块410,用于将样本数据输入教师网络,获得样本数据对应的软标签数据,将样本数据输入学生网络,获得样本数据对应的预测数据;构建模块420,用于基于预测数据、软标签数据和硬标签数据之间的误差,构建损失函数;训练模块430,用于根据损失函数更新教师网络中的参数和学生网络中的参数。
在一种可选的实施方式中,软标签数据可以包括通过教师网络对样本数据进行分类得到的第一概率数据,预测数据可以包括通过学生网络对样本数据进行分类得到的第二概率数据。
在一种可选的实施方式中,构建模块420包括:第一子损失单元,用于根据预测数据和硬标签数据,构建第一子损失;第二子损失单元,用于根据预测数据和软标签数据,构建第二子损失;损失函数确定单元,用于根据第一子损失和第二子损失,确定损失函数。
在一种可选的实施方式中,样本数据可以包括正样本,第二子损失单元,还可以用于根据正样本对应的预测数据和正样本对应的软标签数据,构建第二子损失。
在一种可选的实施方式中,训练模块430可以包括:学生网络训练单元,用于根据损失函数和正样本对应的预测数据,更新学生网络中的参数;教师网络训练单元,用于根据损失函数和正样本对应的软标签数据,更新教师网络中的参数。
在一种可选的实施方式中,正样本对应的预测数据可以包括对正样本的学生预测值和学生预测值对应的概率;学生网络训练单元,还可以用于根据损失函数对学生预测值的梯度,更新学生网络中的参数,使学生预测值对应的概率趋近于1。
在一种可选的实施方式中,正样本对应的软标签数据可以包括对正样本的教师预测值和教师预测值对应的概率;学生网络训练单元,还可以用于根据损失函数对学生预测值的梯度,以及学生预测值和教师预测值之间的误差,更新学生网络中的参数,使学生预测值对应的概率趋近于1和教师预测值对应的概率。
在一种可选的实施方式中,教师网络训练单元,还可以用于根据损失函数对教师预测值的梯度,更新教师网络中的参数,使教师预测值对应的概率趋近于1。
在一种可选的实施方式中,损失函数可以是:
Figure BDA0002218172530000151
Figure BDA0002218172530000152
其中,L为损失函数,i表示硬标签数据的类别,yi为第i类硬标签数据,
Figure BDA0002218172530000153
为第i类硬标签数据对应的预测数据,
Figure BDA0002218172530000154
为第i类硬标签数据对应的软标签数据;∈为经验参数,min(yi)<∈<max(yi);a、b、c均为非负的权重参数,b不为0,且a和c中至少一个不为0。
上述装置中各模块的具体细节在方法部分实施方式中已经详细说明,未披露的方案细节可以参见方法部分的实施方式内容,因而不再赘述。
所属技术领域的技术人员能够理解,本公开的各个方面可以实现为***、方法或程序产品。因此,本公开的各个方面可以具体实现为以下形式,即:完全的硬件实施方式、完全的软件实施方式(包括固件、微代码等),或硬件和软件方面结合的实施方式,这里可以统称为“电路”、“模块”或“***”。
本公开的示例性实施方式还提供了一种计算机可读存储介质,其上存储有能够实现本说明书上述方法的程序产品。在一些可能的实施方式中,本公开的各个方面还可以实现为一种程序产品的形式,其包括程序代码,当程序产品在终端设备上运行时,程序代码用于使终端设备执行本说明书上述“示例性方法”部分中描述的根据本公开各种示例性实施方式的步骤。
参考图5所示,描述了根据本公开的示例性实施方式的用于实现上述方法的程序产品500,其可以采用便携式紧凑盘只读存储器(CD-ROM)并包括程序代码,并可以在终端设备,例如个人电脑上运行。然而,本公开的程序产品不限于此,在本文件中,可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行***、装置或者器件使用或者与其结合使用。
程序产品可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以为但不限于电、磁、光、电磁、红外线、或半导体的***、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
计算机可读信号介质可以包括在基带中或者作为载波一部分传播的数据信号,其中承载了可读程序代码。这种传播的数据信号可以采用多种形式,包括但不限于电磁信号、光信号或上述的任意合适的组合。可读信号介质还可以是可读存储介质以外的任何可读介质,该可读介质可以发送、传播或者传输用于由指令执行***、装置或者器件使用或者与其结合使用的程序。
可读介质上包含的程序代码可以用任何适当的介质传输,包括但不限于无线、有线、光缆、RF等等,或者上述的任意合适的组合。
可以以一种或多种程序设计语言的任意组合来编写用于执行本公开操作的程序代码,程序设计语言包括面向对象的程序设计语言—诸如Java、C++等,还包括常规的过程式程序设计语言—诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。在涉及远程计算设备的情形中,远程计算设备可以通过任意种类的网络,包括局域网(LAN)或广域网(WAN),连接到用户计算设备,或者,可以连接到外部计算设备(例如利用因特网服务提供商来通过因特网连接)。
本公开的示例性实施方式还提供了一种能够实现上述方法的电子设备。下面参照图6来描述根据本公开的这种示例性实施方式的电子设备600。图6显示的电子设备600仅仅是一个示例,不应对本公开实施方式的功能和使用范围带来任何限制。
如图6所示,电子设备600可以以通用计算设备的形式表现。电子设备600的组件可以包括但不限于:至少一个处理单元610、至少一个存储单元620、连接不同***组件(包括存储单元620和处理单元610)的总线630和显示单元640。
存储单元620存储有程序代码,程序代码可以被处理单元610执行,使得处理单元610执行本说明书上述“示例性方法”部分中描述的根据本公开各种示例性实施方式的步骤。例如,处理单元610可以执行图1或图2所示的方法步骤等。
存储单元620可以包括易失性存储单元形式的可读介质,例如随机存取存储单元(RAM)621和/或高速缓存存储单元622,还可以进一步包括只读存储单元(ROM)623。
存储单元620还可以包括具有一组(至少一个)程序模块625的程序/实用工具624,这样的程序模块625包括但不限于:操作***、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。
总线630可以为表示几类总线结构中的一种或多种,包括存储单元总线或者存储单元控制器、***总线、图形加速端口、处理单元或者使用多种总线结构中的任意总线结构的局域总线。
电子设备600也可以与一个或多个外部设备700(例如键盘、指向设备、蓝牙设备等)通信,还可与一个或者多个使得用户能与该电子设备600交互的设备通信,和/或与使得该电子设备600能与一个或多个其它计算设备进行通信的任何设备(例如路由器、调制解调器等等)通信。这种通信可以通过输入/输出(I/O)接口650进行。并且,电子设备600还可以通过网络适配器660与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图所示,网络适配器660通过总线630与电子设备600的其它模块通信。应当明白,尽管图中未示出,可以结合电子设备600使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、RAID***、磁带驱动器以及数据备份存储***等。
通过以上的实施方式的描述,本领域的技术人员易于理解,这里描述的示例性实施方式可以通过软件实现,也可以通过软件结合必要的硬件的方式来实现。因此,根据本公开实施方式的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中或网络上,包括若干指令以使得一台计算设备(可以是个人计算机、服务器、终端装置、或者网络设备等)执行根据本公开示例性实施方式的方法。
此外,上述附图仅是根据本公开示例性实施方式的方法所包括的处理的示意性说明,而不是限制目的。易于理解,上述附图所示的处理并不表明或限制这些处理的时间顺序。另外,也易于理解,这些处理可以是例如在多个模块中同步或异步执行的。
应当注意,尽管在上文详细描述中提及了用于动作执行的设备的若干模块或者单元,但是这种划分并非强制性的。实际上,根据本公开的示例性实施方式,上文描述的两个或更多模块或者单元的特征和功能可以在一个模块或者单元中具体化。反之,上文描述的一个模块或者单元的特征和功能可以进一步划分为由多个模块或者单元来具体化。
本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本公开的其他实施方式。本申请旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施方式仅被视为示例性的,本公开的真正范围和精神由权利要求指出。
应当理解的是,本公开并不局限于上面已经描述并在附图中示出的精确结构,并且可以在不脱离其范围进行各种修改和改变。本公开的范围仅由所附的权利要求来限。

Claims (10)

1.一种用于知识蒸馏的网络训练方法,其特征在于,包括:
将样本数据输入教师网络,获得所述样本数据对应的软标签数据,将所述样本数据输入学生网络,获得所述样本数据对应的预测数据;
基于所述预测数据、所述软标签数据和所述样本数据对应的硬标签数据,构建损失函数;
根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数;
其中,所述教师网络和所述学生网络用于图像分类,所述样本数据包括样本图片,所述硬标签数据包括所述样本图片的分类标签,所述软标签数据包括通过所述教师网络识别所述样本图片中存在目标对象的概率数据,所述预测数据包括通过所述学生网络识别所述样本图片中存在目标对象的概率数据;
所述损失函数为:
Figure FDA0003786395410000011
其中,L为所述损失函数,i表示所述硬标签数据的类别,yi为第i类硬标签数据,
Figure FDA0003786395410000012
为第i类硬标签数据对应的预测数据,
Figure FDA0003786395410000013
为第i类硬标签数据对应的软标签数据;∈为经验参数,min(yi)<∈<max(yi);a、b、c均为非负的权重参数,b不为0,且a和c中至少一个不为0。
2.根据权利要求1所述的方法,其特征在于,所述基于所述预测数据、所述软标签数据和所述样本数据对应的硬标签数据,构建损失函数,包括:
根据所述预测数据和所述硬标签数据,构建第一子损失;
根据所述预测数据和所述软标签数据,构建第二子损失;
根据所述第一子损失和所述第二子损失,确定所述损失函数。
3.根据权利要求2所述的方法,其特征在于,所述样本数据包括正样本;所述根据所述预测数据和所述软标签数据,构建第二子损失,包括:
根据所述正样本对应的预测数据和所述正样本对应的软标签数据,构建所述第二子损失。
4.根据权利要求3所述的方法,其特征在于,所述根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数,包括:
根据所述损失函数和所述正样本对应的预测数据,更新所述学生网络中的参数;
根据所述损失函数和所述正样本对应的软标签数据,更新所述教师网络中的参数。
5.根据权利要求4所述的方法,其特征在于,所述正样本对应的预测数据包括对所述正样本的学生预测值和所述学生预测值对应的概率;所述根据所述损失函数和所述正样本对应的预测数据,更新所述学生网络中的参数,包括:
根据所述损失函数对所述学生预测值的梯度,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1。
6.根据权利要求5所述的方法,其特征在于,所述正样本对应的软标签数据包括对所述正样本的教师预测值和所述教师预测值对应的概率,所述根据所述损失函数对所述学生预测值的梯度,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1,包括:
根据所述损失函数对所述学生预测值的梯度,以及所述学生预测值和所述教师预测值之间的误差,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1和所述教师预测值对应的概率。
7.根据权利要求5所述的方法,其特征在于,所述正样本对应的软标签数据包括对所述正样本的教师预测值和所述教师预测值对应的概率,所述根据所述损失函数和所述正样本对应的软标签数据,更新所述教师网络中的参数,包括:
根据所述损失函数对所述教师预测值的梯度,更新所述教师网络中的参数,使所述教师预测值对应的概率趋近于1。
8.一种用于知识蒸馏的网络训练装置,其特征在于,包括:
处理模块,用于将样本数据输入教师网络,获得所述样本数据对应的软标签数据,将所述样本数据输入学生网络,获得所述样本数据对应的预测数据;
构建模块,用于基于所述预测数据、所述软标签数据和所述样本数据对应的硬标签数据之间的误差,构建损失函数;
训练模块,用于根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数;
其中,所述教师网络和所述学生网络用于图像分类,所述样本数据包括样本图片,所述硬标签数据包括所述样本图片的分类标签,所述软标签数据包括通过所述教师网络识别所述样本图片中存在目标对象的概率数据,所述预测数据包括通过所述学生网络识别所述样本图片中存在目标对象的概率数据;
所述损失函数为:
Figure FDA0003786395410000031
其中,L为所述损失函数,i表示所述硬标签数据的类别,yi为第i类硬标签数据,
Figure FDA0003786395410000032
为第i类硬标签数据对应的预测数据,
Figure FDA0003786395410000033
为第i类硬标签数据对应的软标签数据;∈为经验参数,min(yi)<∈<max(yi);a、b、c均为非负的权重参数,b不为0,且a和c中至少一个不为0。
9.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-7任一项所述的方法。
10.一种电子设备,其特征在于,包括:
处理器;以及
存储器,用于存储所述处理器的可执行指令;
其中,所述处理器配置为经由执行所述可执行指令来执行权利要求1-7任一项所述的方法。
CN201910923038.9A 2019-09-27 2019-09-27 用于知识蒸馏的网络训练方法、装置、介质与电子设备 Active CN110674880B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910923038.9A CN110674880B (zh) 2019-09-27 2019-09-27 用于知识蒸馏的网络训练方法、装置、介质与电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910923038.9A CN110674880B (zh) 2019-09-27 2019-09-27 用于知识蒸馏的网络训练方法、装置、介质与电子设备

Publications (2)

Publication Number Publication Date
CN110674880A CN110674880A (zh) 2020-01-10
CN110674880B true CN110674880B (zh) 2022-11-11

Family

ID=69079897

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910923038.9A Active CN110674880B (zh) 2019-09-27 2019-09-27 用于知识蒸馏的网络训练方法、装置、介质与电子设备

Country Status (1)

Country Link
CN (1) CN110674880B (zh)

Families Citing this family (49)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108664893B (zh) * 2018-04-03 2022-04-29 福建海景科技开发有限公司 一种人脸检测方法及存储介质
CN111260056B (zh) * 2020-01-17 2024-03-12 北京爱笔科技有限公司 一种网络模型蒸馏方法及装置
CN111291886B (zh) * 2020-02-28 2022-02-18 支付宝(杭州)信息技术有限公司 神经网络模型的融合训练方法及装置
CN111460991A (zh) * 2020-03-31 2020-07-28 科大讯飞股份有限公司 异常检测方法、相关设备及可读存储介质
CN111639744B (zh) * 2020-04-15 2023-09-22 北京迈格威科技有限公司 学生模型的训练方法、装置及电子设备
CN111598216B (zh) * 2020-04-16 2021-07-06 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN111582101B (zh) * 2020-04-28 2021-10-01 中国科学院空天信息创新研究院 一种基于轻量化蒸馏网络的遥感图像目标检测方法及***
CN111753878A (zh) * 2020-05-20 2020-10-09 济南浪潮高新科技投资发展有限公司 一种网络模型部署方法、设备及介质
CN111369576B (zh) * 2020-05-28 2020-09-18 腾讯科技(深圳)有限公司 图像分割模型的训练方法、图像分割方法、装置及设备
CN111639710B (zh) * 2020-05-29 2023-08-08 北京百度网讯科技有限公司 图像识别模型训练方法、装置、设备以及存储介质
CN111667728B (zh) * 2020-06-18 2021-11-30 思必驰科技股份有限公司 语音后处理模块训练方法和装置
CN111753092B (zh) * 2020-06-30 2024-01-26 青岛创新奇智科技集团股份有限公司 一种数据处理方法、模型训练方法、装置及电子设备
CN111783898B (zh) * 2020-07-09 2021-09-14 腾讯科技(深圳)有限公司 图像识别模型的训练、图像识别方法、装置及设备
CN111898735A (zh) * 2020-07-14 2020-11-06 上海眼控科技股份有限公司 蒸馏学习方法、装置、计算机设备和存储介质
CN112183718B (zh) * 2020-08-31 2023-10-10 华为技术有限公司 一种用于计算设备的深度学习训练方法和装置
CN112184508B (zh) * 2020-10-13 2021-04-27 上海依图网络科技有限公司 一种用于图像处理的学生模型的训练方法及装置
CN113392864B (zh) * 2020-10-13 2024-06-28 腾讯科技(深圳)有限公司 模型生成方法及视频筛选方法、相关装置、存储介质
CN112348167B (zh) * 2020-10-20 2022-10-11 华东交通大学 一种基于知识蒸馏的矿石分选方法和计算机可读存储介质
CN112308237B (zh) * 2020-10-30 2023-09-26 平安科技(深圳)有限公司 一种问答数据增强方法、装置、计算机设备及存储介质
CN112381209B (zh) * 2020-11-13 2023-12-22 平安科技(深圳)有限公司 一种模型压缩方法、***、终端及存储介质
CN112560693B (zh) * 2020-12-17 2022-06-17 华中科技大学 基于深度学习目标检测的高速公路异物识别方法和***
CN112734046A (zh) * 2021-01-07 2021-04-30 支付宝(杭州)信息技术有限公司 模型训练及数据检测方法、装置、设备及介质
CN112801298B (zh) * 2021-01-20 2023-09-01 北京百度网讯科技有限公司 异常样本检测方法、装置、设备和存储介质
CN112861936B (zh) * 2021-01-26 2023-06-02 北京邮电大学 一种基于图神经网络知识蒸馏的图节点分类方法及装置
CN112967088A (zh) * 2021-03-03 2021-06-15 上海数鸣人工智能科技有限公司 基于知识蒸馏的营销活动预测模型结构和预测方法
CN113158902B (zh) * 2021-04-23 2023-08-11 深圳龙岗智能视听研究院 一种基于知识蒸馏的自动化训练识别模型的方法
CN113159073B (zh) * 2021-04-23 2022-11-18 上海芯翌智能科技有限公司 知识蒸馏方法及装置、存储介质、终端
CN113239985B (zh) * 2021-04-25 2022-12-13 北京航空航天大学 一种面向分布式小规模医疗数据集的分类检测方法
CN113222139B (zh) * 2021-04-27 2024-06-14 商汤集团有限公司 神经网络训练方法和装置、设备,及计算机存储介质
CN113052144B (zh) * 2021-04-30 2023-02-28 平安科技(深圳)有限公司 活体人脸检测模型的训练方法、装置、设备及存储介质
CN113344213A (zh) * 2021-05-25 2021-09-03 北京百度网讯科技有限公司 知识蒸馏方法、装置、电子设备及计算机可读存储介质
CN113283386A (zh) * 2021-05-25 2021-08-20 中国矿业大学(北京) 一种基于知识蒸馏的煤矿井下采煤机的设备故障诊断方法
CN113344205A (zh) * 2021-06-16 2021-09-03 广东电网有限责任公司 一种基于蒸馏关系的抽取加速方法及装置
CN113326940A (zh) * 2021-06-25 2021-08-31 江苏大学 基于多重知识迁移的知识蒸馏方法、装置、设备及介质
CN113343898B (zh) * 2021-06-25 2022-02-11 江苏大学 基于知识蒸馏网络的口罩遮挡人脸识别方法、装置及设备
CN113660038B (zh) * 2021-06-28 2022-08-02 华南师范大学 基于深度强化学习和知识蒸馏的光网络路由方法
CN113361710B (zh) * 2021-06-29 2023-11-24 北京百度网讯科技有限公司 学生模型训练方法、图片处理方法、装置及电子设备
CN113469977B (zh) * 2021-07-06 2024-01-12 浙江霖研精密科技有限公司 一种基于蒸馏学习机制的瑕疵检测装置、方法、存储介质
CN113554716A (zh) * 2021-07-28 2021-10-26 广东工业大学 基于知识蒸馏的瓷砖色差检测方法及装置
CN113609965B (zh) * 2021-08-03 2024-02-13 同盾科技有限公司 文字识别模型的训练方法及装置、存储介质、电子设备
CN113762737A (zh) * 2021-08-19 2021-12-07 北京邮电大学 网络服务质量预测的方法及***
CN113486185B (zh) * 2021-09-07 2021-11-23 中建电子商务有限责任公司 一种基于联合训练的知识蒸馏方法、处理器及存储介质
CN113487614B (zh) * 2021-09-08 2021-11-30 四川大学 胎儿超声标准切面图像识别网络模型的训练方法和装置
CN113505797B (zh) * 2021-09-09 2021-12-14 深圳思谋信息科技有限公司 模型训练方法、装置、计算机设备和存储介质
CN114241282B (zh) * 2021-11-04 2024-01-26 河南工业大学 一种基于知识蒸馏的边缘设备场景识别方法及装置
CN113869464B (zh) * 2021-12-02 2022-03-18 深圳佑驾创新科技有限公司 图像分类模型的训练方法及图像分类方法
CN115687914B (zh) * 2022-09-07 2024-01-30 中国电信股份有限公司 模型蒸馏方法、装置、电子设备及计算机可读介质
CN115544277A (zh) * 2022-12-02 2022-12-30 东南大学 一种基于迭代蒸馏的快速知识图谱嵌入模型压缩方法
CN116030323B (zh) * 2023-03-27 2023-08-29 阿里巴巴(中国)有限公司 图像处理方法以及装置

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109637546A (zh) * 2018-12-29 2019-04-16 苏州思必驰信息科技有限公司 知识蒸馏方法和装置
CN110147456A (zh) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 一种图像分类方法、装置、可读存储介质及终端设备
CN110162018A (zh) * 2019-05-31 2019-08-23 天津开发区精诺瀚海数据科技有限公司 基于知识蒸馏与隐含层共享的增量式设备故障诊断方法
CN110223281A (zh) * 2019-06-06 2019-09-10 东北大学 一种数据集中含有不确定数据时的肺结节图像分类方法

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US11410029B2 (en) * 2018-01-02 2022-08-09 International Business Machines Corporation Soft label generation for knowledge distillation

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109637546A (zh) * 2018-12-29 2019-04-16 苏州思必驰信息科技有限公司 知识蒸馏方法和装置
CN110147456A (zh) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 一种图像分类方法、装置、可读存储介质及终端设备
CN110162018A (zh) * 2019-05-31 2019-08-23 天津开发区精诺瀚海数据科技有限公司 基于知识蒸馏与隐含层共享的增量式设备故障诊断方法
CN110223281A (zh) * 2019-06-06 2019-09-10 东北大学 一种数据集中含有不确定数据时的肺结节图像分类方法

Also Published As

Publication number Publication date
CN110674880A (zh) 2020-01-10

Similar Documents

Publication Publication Date Title
CN110674880B (zh) 用于知识蒸馏的网络训练方法、装置、介质与电子设备
US9990558B2 (en) Generating image features based on robust feature-learning
US20210142181A1 (en) Adversarial training of machine learning models
CN111523640B (zh) 神经网络模型的训练方法和装置
US20200265301A1 (en) Incremental training of machine learning tools
WO2020048389A1 (zh) 神经网络模型压缩方法、装置和计算机设备
CN111602148A (zh) 正则化神经网络架构搜索
US10929448B2 (en) Determining a category of a request by word vector representation of a natural language text string with a similarity value
CN113344206A (zh) 融合通道与关系特征学习的知识蒸馏方法、装置及设备
CN110929802A (zh) 基于信息熵的细分类识别模型训练、图像识别方法及装置
US20220351634A1 (en) Question answering systems
US20220309292A1 (en) Growing labels from semi-supervised learning
US20230186668A1 (en) Polar relative distance transformer
EP3832485A1 (en) Question answering systems
US20210081800A1 (en) Method, device and medium for diagnosing and optimizing data analysis system
CN117611932B (zh) 基于双重伪标签细化和样本重加权的图像分类方法及***
WO2023231753A1 (zh) 一种神经网络的训练方法、数据的处理方法以及设备
CN113434683A (zh) 文本分类方法、装置、介质及电子设备
CN114386409A (zh) 基于注意力机制的自蒸馏中文分词方法、终端及存储介质
CN111950647A (zh) 分类模型训练方法和设备
CN116684330A (zh) 基于人工智能的流量预测方法、装置、设备及存储介质
CN109272165B (zh) 注册概率预估方法、装置、存储介质及电子设备
CN111161238A (zh) 图像质量评价方法及装置、电子设备、存储介质
WO2024114659A1 (zh) 一种摘要生成方法及其相关设备
CN113239883A (zh) 分类模型的训练方法、装置、电子设备以及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant