CN113822434A - 用于知识蒸馏的模型选择学习 - Google Patents

用于知识蒸馏的模型选择学习 Download PDF

Info

Publication number
CN113822434A
CN113822434A CN202010561319.7A CN202010561319A CN113822434A CN 113822434 A CN113822434 A CN 113822434A CN 202010561319 A CN202010561319 A CN 202010561319A CN 113822434 A CN113822434 A CN 113822434A
Authority
CN
China
Prior art keywords
model
training
target
reference models
prediction
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010561319.7A
Other languages
English (en)
Inventor
寿林钧
林武桃
公明
姜大昕
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Microsoft Technology Licensing LLC
Original Assignee
Microsoft Technology Licensing LLC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Microsoft Technology Licensing LLC filed Critical Microsoft Technology Licensing LLC
Priority to CN202010561319.7A priority Critical patent/CN113822434A/zh
Priority to PCT/US2021/026288 priority patent/WO2021257160A1/en
Publication of CN113822434A publication Critical patent/CN113822434A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/004Artificial life, i.e. computing arrangements simulating life
    • G06N3/006Artificial life, i.e. computing arrangements simulating life based on simulated virtual individual or collective life forms, e.g. social simulations or particle swarm optimisation [PSO]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Feedback Control In General (AREA)

Abstract

本公开提供了用于基于知识蒸馏来获得目标模型的方法和装置。可以获得数据集合和一组候选参考模型。可以针对所述数据集合中的每个训练样本,确定从所述一组候选参考模型中选择出的一组选定参考模型。可以获取所述一组选定参考模型针对所述训练样本输出的一组目标概率分布。可以利用所述一组目标概率分布来训练所述目标模型。

Description

用于知识蒸馏的模型选择学习
背景技术
随着深度学习技术的发展,各种各样的深度预训练模型得以不断地开发,并在诸如自然语言处理和计算机视觉等领域有出色表现。例如,在自然语言处理领域,诸如来自转换器的双向编码器表示(Bidirectional Encoder Resentations from Transformers,BERT)模型、生成式预训练转换器(Generative Pre-trained Transformer,GPT)模型之类的深度预训练模型被证明具有良好的效果。这类深度预训练模型往往是依赖于具有巨量参数的深度网络的复杂模型,例如BERT模型可能包含24个转换器层共3.4亿参数,GPT模型可能包含48个转换器层共15亿参数。训练这样的复杂模型和使用这样的复杂模型进行推断都是十分耗时的,从而难以将其应用于实际的商业场景。通常采用模型压缩方法来获得具有比复杂模型更少参数的、能够部署的简单模型。
发明内容
提供本发明内容以便介绍一组构思,这组构思将在以下的具体实施方式中做进一步描述。本发明内容并非旨在标识所保护主题的关键特征或必要特征,也不旨在用于限制所保护主题的范围。
本公开的实施例提供了用于基于知识蒸馏来获得目标模型的方法和装置。可以获得数据集合和一组候选参考模型。可以针对所述数据集合中的每个训练样本,确定从所述一组候选参考模型中选择出的一组选定参考模型。可以获取所述一组选定参考模型针对所述训练样本输出的一组目标概率分布。可以利用所述一组目标概率分布来训练所述目标模型。
应当注意,以上一个或多个方面包括以下详细描述以及权利要求中具体指出的特征。下面的说明书及附图详细提出了所述一个或多个方面的某些说明性特征。这些特征仅仅指示可以实施各个方面的原理的多种方式,并且本公开旨在包括所有这些方面和其等同变换。
附图说明
以下将结合附图描述所公开的多个方面,这些附图被提供用以说明而非限制所公开的多个方面。
图1示出了根据本公开实施例的用于获得目标模型的示例性过程。
图2示出了根据本公开实施例的用于选择参考模型的示例性过程。
图3示出了根据本公开实施例的用于训练目标模型的示例性过程
图4示出了根据本公开实施例的用于训练目标模型的具体示例。
图5示出了根据本公开实施例的用于更新策略参数的示例性过程。
图6示出了根据本公开实施例的用于初始化策略参数的示例性过程。
图7是根据本公开实施例的用于基于知识蒸馏来获得目标模型的示例性方法的流程图。
图8示出了根据本公开实施例的用于基于知识蒸馏来获得目标模型的示例性装置。
图9示出了根据本公开实施例的用于基于知识蒸馏来获得目标模型的示例性装置。
具体实施方式
现在将参考若干示例性实施方式来讨论本公开。应当理解,这些实施方式的讨论仅仅用于使得本领域技术人员能够更好地理解并从而实施本公开的实施例,而并非教导对本公开的范围的任何限制。
一种常用的模型压缩方法可以基于知识蒸馏(Knowledge Distillation)。该方法通常是通过简单模型学习复杂模型的输出分布来将知识从复杂模型迁移给简单模型。知识可以被认为是复杂模型的参数以及复杂模型实现的输入到输出的映射。该方法基于老师-学生架构,其中,提供知识的模型可以被认为是老师模型,并且学习知识的模型可以被认为是学生模型。具体地,在训练学生模型时,向学生模型提供不仅具有真实标注,例如人为提供的标注,还具有老师模型输出的概率分布的训练数据。因此,学生模型可以通过学习老师模型输出的概率分布来优化其模型参数,以试图达到老师模型的效果。基于老师模型和学生模型的数量,可以将目前的知识蒸馏方法分为一对一方法、多对多方法和多对一方法。在本文中,一对一方法指一个老师模型向一个学生模型提供知识,多对多方法指多个老师模型向多个学生模型提供知识并在应用时将这多个学生模型组合为学生模型集合,而多对一方法则指多个老师模型向一个学生模型提供知识。近来的研究和实验表明,利用多对一方法来训练学生模型可以更为有效地提升学生模型的性能。
目前的多对一方法通常向各个老师模型指派相同权重,或者向各个老师模型指派不同权重,但这些权重在整个知识蒸馏期间是固定不变。然而,即使采用针对相同任务的一组训练数据来训练学生模型,各个老师模型对于该组训练数据中的不同训练样本的表现也是不同的。以用于预测两个句子之间的语义对等的两个训练样本为例,针对第一训练样本,老师模型A的表现可能优于老师模型B;而针对第二训练样本,老师模型B的表现可能优于老师模型A。此外,在知识蒸馏的各个阶段,学生模型的性能是在逐渐提升的,用于训练该学生模型的老师模型也应当相应发生变化。例如,在知识蒸馏初期,学生模型的性能较弱,其从复杂的老师模型学习的效果可能不佳,这是由于复杂的老师模型捕获了训练数据中的较细粒度模式,这可能会导致学生模型对训练数据中的一些部分过度拟合。而随着训练过程的推进,学生模型已具有较强的性能,若通过与其性能差距不大的老师模型来进行训练,可能难以取得明显效果。因此,在知识蒸馏期间,向各个老师模型指派相同或固定权重可能会限制学生模型的性能提升。
本公开的实施例提出了通过改进的训练过程来提升目标模型的性能。例如,可以通过知识蒸馏,使用一组参考模型来训练目标模型。在本文中,目标模型指期望被训练的结构简单且能够部署的模型,其也可以被称为学生模型,而参考模型指能够用于协助训练目标模型的、具有比目标模型更高的复杂性的模型,其也可以被称为老师模型。
在一个方面,本公开的实施例提出通过强化学习(Reinforcement Learning)来从一组候选参考模型中选择用于训练目标模型的参考模型。例如,可以针对用于训练目标模型的数据集合中的每个训练样本,动态地向各个参考模型指派不同的权重。在本文中,向特定参考模型指派的权重可以被实现为与该参考模型相对应的采样概率,其可以用于确定针对该训练样本是否选择该参考模型来训练目标模型。
在另一个方面,本公开的实施例提出通过策略函数来确定针对当前训练样本的、各个参考模型的采样概率。针对各个参考模型的策略函数例如包括策略参数、以及与当前训练样本以及该参考模型针对当前训练样本的表现有关的信息等。利用这样的策略函数,可以有助于选择出针对当前训练样本表现良好的参考模型。
在另一个方面,本公开的实施例提出基于目标模型的性能来更新策略函数中的策略参数。例如,用于训练目标模型的数据集合可以被划分成多个数据子集。可以在利用一数据子集对目标模型进行训练之后,基于经训练的目标模型的性能来更新策略函数中的策略参数,从而影响针对下一数据子集的、各个参考模型的采样概率。通过基于目标模型的性能来更新策略函数中的策略参数,可以有助于选择出与当前目标模型的性能相匹配的参考模型。
在另一个方面,本公开的实施例提出在通过策略函数来确定各个参考模型的采样概率之前,对策略函数中的策略参数进行初始化。例如,可以通过策略函数来从一组候选参考模型中选择一组参考模型,并根据所选择的一组参考模型的平均性能来初始化该策略函数中的策略参数。
在另一个方面,本公开的实施例提出在使用一组参考模型训练目标模型之前,对目标模型进行预训练。例如,可以使用该组参考模型中的所有参考模型对一数据集合进行评分,并利用经评分的数据集合来预训练目标模型。
图1示出了根据本公开实施例的用于获得目标模型的示例性过程100。目标模型例如为图1中的目标模型160,其可以是具有3层或6层转换器的BERT模型。
可以首先获得用于训练目标模型160的数据集合
Figure BDA0002546193780000041
数据集合
Figure BDA0002546193780000042
可以被划分成多个数据子集
Figure BDA0002546193780000043
其中M表示数据子集的数量。各个数据子集可以包括多个训练样本。在本文中,将用于训练目标模型160的样本称为训练样本。以数据子集
Figure BDA0002546193780000044
为例,其可以包括m个训练样本
Figure BDA0002546193780000045
例如训练样本i 102(xi,yi),其中xi是第i个输入,并且yi是针对xi的真实标注,例如人为提供的标注。
可以获得用于训练目标模型160的一组候选参考模型110。候选参考模型可以是具有比目标模型160更高的复杂性的模型,例如具有12层转换器的BERT模型。可以通过利用针对特定任务的训练数据来对预训练的模型(pre-trained model)进行优化,例如微调(fine-tune),来获得候选参考模型。
还可以获得表示模型120,其可以是能够有效地表示xi的内容的任何预训练的模型。
训练样本i 102可以作为输入提供给一组候选参考模型110中的每个参考模型以及表示模型120,以获得一组状态信息。该组状态信息可以至少包括与训练样本i 102以及各个参考模型针对训练样本i 102输出的目标概率分布有关的信息。后面将结合图2来解释获得状态信息的具体过程。
随后,在130处,可以针对一组候选参考模型110中的每个候选参考模型,通过强化学习来确定是否选择该候选参考模型用于训练目标模型160。例如,可以利用策略函数πθ132来确定是否选择该候选参考模型。可以将选定的参考模型组合成一组选定参考模型140。后面将结合图2来解释选择参考模型的具体过程。
接着,可以获取该组选定参考模型140中的每个参考模型针对训练样本i 102输出的目标概率分布,以获取一组目标概率分布150。用于确定一组选定参考模型140的一组状态信息可以包括与各个参考模型针对训练样本i102输出的目标概率分布有关的信息。可以从这组状态信息中提取出与该组选定参考模型140中的各个参考模型相对应的一组目标概率分布150。
可以利用训练样本i 102和该组目标概率分布150来训练目标模型160。在一种实施方式中,可以在利用单个训练样本执行完上述过程之后,就对目标模型160的参数进行优化,从而获得经训练的目标模型170。在另一种实施方式中,可以在利用数据子集,例如数据子集
Figure BDA0002546193780000051
中的所有训练样本执行完上述过程之后,再对目标模型160的参数进行优化,以获得经训练的目标模型170。在这种情况下,针对相同的数据子集内的所有训练样本,目标模型160的参数保持不变。后面将结合图3和图4来解释训练目标模型160的具体过程。
随后,可以对经训练的目标模型170性能进行评估。可以利用验证样本180来评估经训练的目标模型170的性能。在本文中,将用于评估目标模型的性能的样本称为验证样本,其可以与训练样本相同或不同。所评估的性能可以被转换为奖励190。奖励190随后可以用于更新策略函数πθ132中的策略参数θ。后面将结合图5来解释更新策略参数的具体过程。
在利用单个训练样本来获得经训练的目标模型170的情况下,基于该经训练的目标模型170的性能来更新策略参数可以影响针对下一训练样本的、各个参考模型的采样概率。在利用数据子集
Figure BDA0002546193780000061
中的所有训练样本来获得经训练的目标模型170的情况下,基于该经训练的目标模型170的性能来更新策略参数可以影响针对下一数据子集的、各个参考模型的采样概率。在这种情况下,针对相同的数据子集内的所有训练样本,策略函数πθ的策略参数θ保持不变。
过程100主要包括训练目标模型的过程和更新策略参数的过程,这两个过程可以迭代地执行,直到目标模型的性能收敛。在训练目标模型的过程期间,可以优化目标模型的参数而固定策略参数;并且在更新策略参数的过程期间,可以优化策略参数而固定目标模型的参数。以利用数据子集
Figure BDA0002546193780000062
中的所有训练样本来获得经训练的目标模型并基于该经训练的目标模型的性能来更新策略参数的实施方式为例。可以将目标模型的当前参数设定为
Figure BDA0002546193780000063
并且将当前策略参数设定为θb。可以先将策略参数固定为θb并在利用数据子集
Figure BDA0002546193780000064
训练了目标模型之后,将目标模型的参数更新为
Figure BDA0002546193780000065
接着,可以将目标模型的参数固定为
Figure BDA0002546193780000066
并基于参数为
Figure BDA0002546193780000067
的目标模型的性能来更新策略参数,以将策略参数更新为θb+1;等等。
图2示出了根据本公开实施例的用于选择参考模型的示例性过程200。过程200可以对应于图1中的步骤130。可以首先获得训练样本i 202(xi,yi),其可以对应于图1中的训练样本i 102。还可以获得一组候选参考模型210和表示模型220。该组候选参考模型210可以包括K个参考模型,例如参考模型210-1、参考模型210-2、……、参考模型210-K。该组候选参考模型210可以对应于图1中的一组候选参考模型110,并且表示模型220可以对应于图1中的表示模型120。
过程200可以将针对训练样本i 202的、对应于各个参考模型的状态编码为状态信息,并基于状态信息来确定是否选择该参考模型。状态信息可以包括与训练样本i 202以及该参考模型针对训练样本i 202的表现有关的信息。以参考模型210-k(1≤k≤K)为例,可以将针对参考模型210-k的状态表示为sjk,并将针对状态sjk的状态信息表示为F(sjk)。状态信息F(sjk)可以被实现为实值向量,其例如包括三项特征的级联。
第一项特征可以是训练样本i 202(xi,yi)中的xi的表示。可以例如通过表示模型220来获得xi的向量表示
Figure BDA0002546193780000071
其中d是隐藏大小(hidden size)。
第二项特征可以是参考模型210-k针对训练样本i 202输出的概率分布。以训练样本i 202是针对分类任务的样本为例,参考模型210-k输出的概率分布可以被表示为
Figure BDA0002546193780000072
其中
Figure BDA0002546193780000073
是参考模型210-k输出的xi属于类别c的概率,c是1到C之间的整数,C是类别的数量,并且
Figure BDA0002546193780000074
是参考模型210-k的参数。
第三项特征可以是与该概率分布相对应的预测损失。在一种实施方式中,可以通过交叉熵函数来计算该预测损失。例如,可以通过如下公式来计算与参考模型210-k输出的针对训练样本i 202的概率分布
Figure BDA0002546193780000075
Figure BDA0002546193780000076
预测损失
Figure BDA0002546193780000077
Figure BDA0002546193780000078
其中,
Figure BDA0002546193780000079
是来自真实标注yi的独热向量(one-hot vector)。
参考模型210-k针对训练样本i 202输出的概率分布以及与该概率分布相对应的预测损失可以被视为是参考模型210-k针对训练样本i 202的表现。
可以对表示模型220输出的向量表示
Figure BDA00025461937800000710
参考模型210-k输出的概率分布
Figure BDA00025461937800000711
以及与概率分布
Figure BDA00025461937800000712
相对应的预测损失
Figure BDA00025461937800000713
进行级联,以获得针对参考模型210-k的状态信息230-k F(sjk)。
可以通过上述过程来获得针对该组候选参考模型210中的各个参考模型的一组状态信息230,其例如包括状态信息230-1、状态信息230-2、……、状态信息230-K。
策略函数πθ240可以基于参考模型210-k的状态信息230-k来确定参考模型210-k的采样概率250-k。可以例如采用逻辑函数作为策略函数,如以下公式所示:
Figure BDA0002546193780000081
其中,
Figure BDA0002546193780000082
是状态信息,σ(·)是具有可训练参数
Figure BDA0002546193780000083
Figure BDA0002546193780000084
的sigmoid函数,并且Pθ(ajk|sjk)是采样概率,其表示在状态sjk下选择动作值ajk的概率,ajk∈{0,1}。可以利用Pθ(ajk|sjk)来对动作值ajk进行采样。当ajk被采样为“0”值时,指示不选择参考模型210-k;而当ajk被采样为“1”值时,指示选择参考模型210-k。
可以通过上述过程来获得针对该组候选参考模型210中的各个参考模型的一组采样概率250,其例如包括采样概率250-1、采样概率250-2、……、采样概率250-K,以及一组动作值260,其例如包括动作值260-1、动作值260-2、……、动作值260-K。
在确定了一组动作值260之后,在270处,可以基于一组参考模型210和一组动作值260来确定一组选定参考模型280,其例如包括参考模型280-1、参考模型280-2、……、参考模型280-K’(0≤K’≤K)。该组选定参考模型280中的每个参考模型例如是其动作值被采样为“1”的参考模型。
应当理解,尽管在前述讨论和以下讨论中可能涉及选择了至少一个参考模型来训练目标模型,但是没有任何一个参考模型被选中也是可能的。例如,针对某些训练样本,所有参考模型对其表现都不佳,因此所有参考模型的采样概率都较低,进一步地,根据这些采样概率采样出的动作值可能都为“0”,这将导致没有参考模型被选中。
在确定了一组选定参考模型280之后,可以获取该组选定参考模型280针对训练样本i 202输出的一组目标概率分布。上面确定的一组状态信息230中包括一组候选参考模型中的各个参考模型针对训练样本i 202输出的目标概率分布。可以从这些目标概率分布中提取出与该组选定参考模型280中的各个参考模型相对应的一组目标概率分布。可以利用这组目标概率分布来训练目标模型。
图3示出了根据本公开实施例的用于训练目标模型的示例性过程300。过程300可以利用训练样本i和一组选定参考模型针对训练样本i输出的一组目标概率分布来训练目标模型。训练样本i可以包括真实标注。
在310处,目标模型可以对训练样本i进行评分,以获得该训练样本i的预测概率分布。
在320处,可以基于预测概率分布和一组目标概率分布中的每个目标概率分布来分别计算与该目标概率分布相对应的子预测损失,以获得一组子预测损失。在一种实施方式中,可以通过交叉熵函数来计算子预测损失。
在330处,可以基于该组选定参考模型的数量和该组子预测损失来计算与训练样本i相对应的第一预测损失。在一种实施方式中,可以通过先对该组子预测损失进行求和以获得中间预测损失,然后将中间预测损失除以该组选定参考模型的数量,来计算第一预测损失。
在340处,可以基于预测概率分布和训练样本i中的真实标注来计算与训练样本i相对应的第二预测损失。在一种实施方式中,可以通过交叉熵函数来计算第二预测损失。
在350处,可以基于第一预测损失和第二预测损失来计算与训练样本i相对应的综合预测损失。在一种实施方式中,可以通过对第一预测损失和第二预测损失进行加权求和来计算综合预测损失。
在360处,可以通过使综合预测损失最小化来优化目标模型。
图4示出了根据本公开实施例的用于训练目标模型的具体示例400。在示例400中,用于训练目标模型的训练样本可以例如是训练样本410(xi,yi),其中xi是输入,yi是真实标注。可以利用一组选定参考模型420针对训练样本410输出的一组目标概率分布
Figure BDA0002546193780000091
来训练目标模型430,其中
Figure BDA0002546193780000092
该组选定参考模型420例如包括编号为参考模型420-1、参考模型420-2、……、参考模型420-K’的K’个参考模型。
目标模型430可以对训练样本410进行评分,以获得训练样本410的预测概率分布
Figure BDA0002546193780000093
其中,Ps(yi=c|xi;Θs)表示目标模型430输出的xi属于类别c的概率,c是1到C之间的整数,C是类别的数量,并且Θs是目标模型430的参数。
随后可以基于预测概率分布
Figure BDA0002546193780000094
和一组目标概率分布
Figure BDA0002546193780000095
中的每个目标概率分布
Figure BDA0002546193780000096
来分别计算与该目标概率分布
Figure BDA0002546193780000097
相对应的子预测损失
Figure BDA0002546193780000098
以获得一组子预测损失
Figure BDA0002546193780000099
在一种实施方式中,可以通过交叉熵函数来计算子预测损失
Figure BDA00025461937800000910
如以下公式所示:
Figure BDA0002546193780000101
接着,可以基于一组选定参考模型420中的参考模型的数量K’和这组子预测损失来计算针对训练样本i的第一预测损失
Figure BDA0002546193780000102
在一种实施方式中,可以通过先对该组子预测损失进行求和以获得中间预测损失,然后将中间预测损失除以该组选定参考模型的数量K’,来计算第一预测损失,如以下公式所示:
Figure BDA0002546193780000103
然后,可以基于目标模型输出的预测概率分布
Figure BDA0002546193780000104
Figure BDA0002546193780000105
和训练样本i中的真实标注yi来计算与训练样本i相对应的第二预测损失
Figure BDA0002546193780000106
在一种实施方式中,可以通过交叉熵函数来计算第二预测损失,如以下公式所示:
Figure BDA0002546193780000107
在获得了与训练样本i相对应的第一预测损失
Figure BDA0002546193780000108
和第二预测损失
Figure BDA0002546193780000109
之后,可以计算与训练样本i相对应的综合预测损失
Figure BDA00025461937800001010
在一种实施方式中,可以通过对第一预测损失和第二预测损失进行加权求和来计算综合预测损失,如以下公式所示:
Figure BDA00025461937800001011
其中,α是用于平衡第一预测损失和第二预测损失的超参数。
可以通过使综合预测损失
Figure BDA00025461937800001012
最小化来优化目标模型。
上文结合图3和图4说明的过程是通过使对应于单个训练样本的综合预测损失最小化来训练目标模型的。替代地,为了提高训练效率,可以对数据子集,例如数据子集
Figure BDA00025461937800001013
中的所有训练样本执行上述过程,并获得与数据子集
Figure BDA00025461937800001014
相对应的综合预测损失
Figure BDA00025461937800001015
可以通过使该综合预测损失
Figure BDA00025461937800001016
最小化来优化目标模型。在这种情况下,针对相同的数据子集内的所有训练样本,目标模型的参数保持不变。与数据子集
Figure BDA00025461937800001017
相对应的综合预测损失
Figure BDA00025461937800001018
可以通过如下公式获得:
Figure BDA0002546193780000111
根据本公开的实施例,在利用一训练样本或一数据子集对目标模型进行训练之后,可以基于经训练的目标模型的性能来更新策略函数πθ中的策略参数θ,从而影响针对下一训练样本或下一数据子集的、各个参考模型的采样概率。图5示出了根据本公开实施例的用于更新策略参数的示例性过程500。过程500可以利用验证样本(x′,y′)来对经训练的目标模型的性能进行评估。验证样本可以与用于训练目标模型的训练样本(xi,yi)相同或不同。所评估的性能可以被转换为奖励。奖励随后可以用于更新策略函数πθ中的策略参数θ。
在510处,可以通过目标模型对验证样本(x′,y′)进行评分,以获得验证样本(x′,y′)的预测概率分布
Figure BDA0002546193780000112
其中Θs为目标模型的当前参数。
在520处,可以基于验证样本(x′,y′)中的真实标注y′和预测概率分布
Figure BDA0002546193780000113
来计算与验证样本相对应的预测损失
Figure BDA0002546193780000114
在一种实施方式中,可以通过交叉熵函数来计算第二预测损失,如以下公式所示:
Figure BDA0002546193780000115
在530处,可以基于预测损失
Figure BDA0002546193780000116
来计算与验证样本相对应的奖励vj。在一种实施方式中,可以将奖励vj计算为预测损失
Figure BDA0002546193780000117
的相反数,如以下公式所示:
Figure BDA0002546193780000118
在540处,可以基于奖励vj来更新策略参数θ。在一种实施方式中,可以通过标准策略梯度方法,例如基于蒙特卡罗(Monte-Carlo)的策略梯度方法,来更新策略参数θ,如以下公式所示:
Figure BDA0002546193780000119
其中,β是学习速率,并且πθ(sjk,ajk)是针对第k个参考模型的策略函数。
根据本公开的实施例,在通过策略函数来确定各个参考模型的采样概率之前,可以对策略函数中的策略参数进行初始化。例如,可以通过策略函数来从一组候选参考模型中选择至少一个参考模型,并根据所选择的参考模型的平均性能来初始化该策略函数中的策略参数。
图6示出了根据本公开实施例的用于初始化策略参数的示例性过程600。过程600可以利用初始化样本来对策略参数进行初始化。在文本中,将用于对策略参数进行初始化的样本称为初始化样本,其可以与用于训练目标模型的训练样本相同或不同。初始化样本可以包括输入以及与该输入相对应的真实标注。
在610处,可以针对初始化样本,通过策略函数来确定从一组候选参考模型中选择出的一组选定参考模型。该策略函数可以具有一原始策略参数。
在620处,可以通过该组选定参考模型对初始化样本分别进行评分,以获得初始化样本的一组概率分布。
在630处,可以基于初始化样本中的真实标注和该组概率分布来计算与所述初始化样本相对应的一组预测损失。例如,可以先基于真实标注和该组概率分布中的各个概率分布来计算针对各个概率分布的子预测损失,以获得一组子预测损失。在一种实施方式中,可以通过交叉熵函数来计算针对各个概率分布的预测损失。
在640处,可以基于一组候选参考模型的数量和该组子预测损失来计算与初始化样本相对应的预测损失。在一种实施方式中,可以通过先对该组子预测损失进行求和以获得中间预测损失,然后将中间预测损失除以该组选定参考模型的数量,来计算预测损失。
在650处,可以基于该预测损失来计算与所述初始化样本相对应的奖励。在一种实施方式中,可以将奖励计算为预测损失的相反数。
在660处,可以基于奖励来初始化策略参数。在一种实施方式中,可以通过利用标准策略梯度方法对原始策略参数进行更新来初始化策略参数,如上面的公式(10)所示。
根据本公开的实施例,在使用一组候选参考模型训练目标模型之前,可以对目标模型进行预训练。在一种实施方式中,可以使用该组候选参考模型中的所有参考模型对预训练数据集合进行评分,并利用经评分的预训练数据集合来预训练目标模型。在本文中,将用于预训练目标模型的数据集合称为预训练数据集合。用于对目标模型进行预训练的过程可以与结合图3和图4解释的用于训练目标模型的过程相类似,不同之处在于所涉及的一组目标概率分布为该组候选参考模型中的所有参考模型而不是该组候选参考模型中的选定参考模型针对预训练数据集合中的预训练样本输出的概率分布。
图7是根据本公开实施例的用于基于知识蒸馏来获得目标模型的示例性方法700的流程图。
在步骤710处,可以获得数据集合和一组候选参考模型。
在步骤720处,可以针对所述数据集合中的每个训练样本,确定从所述一组候选参考模型中选择出的一组选定参考模型。
在步骤730处,可以获取所述一组选定参考模型针对所述训练样本输出的一组目标概率分布。
在步骤740处,可以利用所述一组目标概率分布来训练所述目标模型。
在一种实施方式中,所述确定一组选定参考模型可以包括:针对所述一组候选参考模型中的每个候选参考模型,通过强化学习来确定是否选择所述候选参考模型。
所述确定是否选择所述候选参考模型可以包括:基于策略函数来确定所述候选参考模型的采样概率;基于所述采样概率,对所述候选参考模型的动作值进行采样;以及基于所采样的动作值,选择所述候选参考模型。
所述策略函数可以具有策略参数。所述确定是否选择所述候选参考模型还可以包括:基于所述目标模型的性能来更新所述策略参数。
所述更新所述策略参数可以包括:通过所述目标模型对验证样本进行评分,以获得所述验证样本的预测概率分布;基于所述验证样本中的真实标注和所述预测概率分布来计算与所述验证样本相对应的预测损失;基于所述预测损失来计算与所述验证样本相对应的奖励;以及基于所述奖励来更新所述策略参数。
所述数据集合可以包括多个数据子集。针对相同的数据子集内的所有训练样本,所述策略函数的策略参数可以保持不变。
所述确定采样概率可以针对状态信息来执行。所述状态信息可以至少包括:所述训练样本的表示、所述候选参考模型针对所述训练样本输出的概率分布、以及与所述概率分布相对应的预测损失。
所述策略函数可以具有策略参数。所述策略参数可以通过以下操作来被初始化:针对初始化样本,通过所述策略函数来确定从所述一组候选参考模型中选择出的一组选定参考模型;通过所述一组选定参考模型对所述初始化样本分别进行评分,以获得所述初始化样本的一组概率分布;基于所述初始化样本中的真实标注和所述一组概率分布来计算与所述初始化样本相对应的预测损失;基于所述预测损失来计算与所述初始化样本相对应的奖励;以及基于所述奖励来初始化所述策略参数。
在一种实施方式中,所述训练样本可以包括真实标注。所述训练所述目标模型可以包括:通过所述目标模型对所述训练样本进行评分,以获得所述训练样本的预测概率分布;基于所述预测概率分布和所述一组目标概率分布来计算与所述训练样本相对应的第一预测损失;基于所述预测概率分布和所述真实标注来计算与所述训练样本相对应的第二预测损失;基于所述第一预测损失和所述第二预测损失来计算与所述训练样本相对应的综合预测损失;以及通过使所述综合预测损失最小化来优化所述目标模型。
所述计算第一预测损失可以包括:基于所述预测概率分布和所述一组目标概率分布中的每个目标概率分布来分别计算与所述目标概率分布相对应的子预测损失,以获得一组子预测损失;以及基于所述一组选定参考模型中的参考模型的数量和所述一组子预测损失来计算所述第一预测损失。
所述数据集合可以包括多个数据子集。针对相同的数据子集内的所有训练样本,所述目标模型的参数可以保持不变。
在一种实施方式中,方法700还可以包括:通过所述一组候选参考模型对预训练数据集合进行评分;以及利用经评分的预训练数据集合来预训练所述目标模型。
在一种实施方式中,所述一组候选参考模型可以是具有比所述目标模型更高的复杂性的模型。
应当理解,方法700还可以包括根据上述本公开的实施例的用于基于知识蒸馏来获得目标模型的任何步骤/处理。
图8示出了根据本公开实施例的用于基于知识蒸馏来获得目标模型的示例性装置800。装置800可以包括:获得模块810,用于获得数据集合和一组候选参考模型;参考模型确定模块820,用于针对所述数据集合中的每个训练样本,确定从所述一组候选参考模型中选择出的一组选定参考模型;概率分布获取模块830,用于获取所述一组选定参考模型针对所述训练样本输出的一组目标概率分布;以及目标模型训练模块840,用于利用所述一组目标概率分布来训练所述目标模型。
在一种实施方式中,所述参考模型确定模块820还可以被配置为:针对所述一组候选参考模型中的每个候选参考模型,通过强化学习来确定是否选择所述候选参考模型。
所述确定是否选择所述候选参考模型可以包括:基于策略函数来确定所述候选参考模型的采样概率;基于所述采样概率,对所述候选参考模型的动作值进行采样;以及基于所采样的动作值,选择所述候选参考模型。
所述策略函数可以具有策略参数。所述确定是否选择所述候选参考模型还可以包括:基于所述目标模型的性能来更新所述策略参数。
所述数据集合可以包括多个数据子集。针对相同的数据子集内的所有训练样本,所述策略函数的策略参数可以保持不变。
所述确定采样概率可以针对状态信息来执行。所述状态信息可以至少包括:所述训练样本的表示、所述候选参考模型针对所述训练样本输出的概率分布、以及与所述概率分布相对应的预测损失。
应当理解,装置800还可以包括根据上述本公开的实施例的被配置用于基于知识蒸馏来获得目标模型的任何其他模块。
图9示出了根据本公开实施例的用于基于知识蒸馏来获得目标模型的示例性装置900。
装置900可以包括至少一个处理器910。装置900还可以包括与处理器910连接的存储器920。存储器920可以存储计算机可执行指令,当所述计算机可执行指令被执行时,使得处理器1910执行根据上述本公开的实施例的用于基于知识蒸馏来获得目标模型的方法的任何操作。
本公开的实施例可以体现在非暂时性计算机可读介质中。所述非暂时性计算机可读介质可以包括指令,所述指令当被执行时,使得一个或多个处理器执行根据如上所述的本公开的实施例的用于基于知识蒸馏来获得目标模型的方法的任何操作。
应当领会,以上描述的方法中的所有操作都仅仅是示例性的,本公开并不限制于方法中的任何操作或这些操作的顺序,而是应当涵盖在相同或相似构思下的所有其他等同变换。
还应当领会,以上描述的装置中的所有模块都可以通过各种方式来实施。这些模块可以被实施为硬件、软件、或其组合。此外,这些模块中的任何模块可以在功能上被进一步划分成子模块或组合在一起。
已经结合各种装置和方法描述了处理器。这些处理器可以使用电子硬件、计算机软件或其任意组合来实施。这些处理器是实施为硬件还是软件将取决于具体的应用以及施加在***上的总体设计约束。作为示例,本公开中给出的处理器、处理器的任意部分、或者处理器的任意组合可以利用微处理器、微控制器、数字信号处理器(DSP)、现场可编程门阵列(FPGA)、可编程逻辑器件(PLD)、状态机、门控逻辑单元、分立硬件电路、以及配置用于执行在本公开中描述的各种功能的其他适合的处理组件来实现。本公开给出的处理器、处理器的任意部分、或者处理器的任意组合的功能可以利用由微处理器、微控制器、DSP或其他适合的平台所执行的软件来实现。
软件应当被广泛地视为意指指令、指令集、代码、代码段、程序代码、程序、子程序、软件模块、应用、软件应用、软件包、例程、子例程、对象、运行线程、过程、函数等。软件可以驻留在计算机可读介质中。计算机可读介质可以包括例如存储器,存储器可以例如为磁性存储设备(例如,硬盘、软盘、磁条)、光盘、智能卡、闪存设备、随机存取存储器(RAM)、只读存储器(ROM)、可编程ROM(PROM)、可擦除PROM(EPROM)、电可擦除PROM(EEPROM)、寄存器或者可移动盘。尽管在本公开给出的多个方面中将存储器示出为是与处理器分离的,但是存储器也可以位于处理器内部,例如高速缓存器或寄存器。
以上描述被提供用于使得本领域任何技术人员能够实践本文所描述的各个方面。对这些方面的各种修改对于本领域技术人员将是显而易见的,并且本文限定的一般性原理可以应用于其他方面。因此,权利要求并非旨在被局限于本文示出的方面。关于本领域普通技术人员已知或即将获知的、对本公开所描述各个方面的元素的所有结构和功能上的等同变换都被明确并入本文并且由权利要求所覆盖。

Claims (20)

1.一种用于基于知识蒸馏来获得目标模型的方法,包括:
获得数据集合和一组候选参考模型;
针对所述数据集合中的每个训练样本,确定从所述一组候选参考模型中选择出的一组选定参考模型;
获取所述一组选定参考模型针对所述训练样本输出的一组目标概率分布;以及
利用所述一组目标概率分布来训练所述目标模型。
2.根据权利要求1所述的方法,其中,所述确定一组选定参考模型包括:
针对所述一组候选参考模型中的每个候选参考模型,通过强化学习来确定是否选择所述候选参考模型。
3.根据权利要求2所述的方法,其中,所述确定是否选择所述候选参考模型包括:
基于策略函数来确定所述候选参考模型的采样概率;
基于所述采样概率,对所述候选参考模型的动作值进行采样;以及
基于所采样的动作值,选择所述候选参考模型。
4.根据权利要求3所述的方法,其中,所述策略函数具有策略参数,并且所述确定是否选择所述候选参考模型还包括:
基于所述目标模型的性能来更新所述策略参数。
5.根据权利要求4所述的方法,其中,所述更新所述策略参数包括:
通过所述目标模型对验证样本进行评分,以获得所述验证样本的预测概率分布;
基于所述验证样本中的真实标注和所述预测概率分布来计算与所述验证样本相对应的预测损失;
基于所述预测损失来计算与所述验证样本相对应的奖励;以及
基于所述奖励来更新所述策略参数。
6.根据权利要求3所述的方法,其中,所述数据集合包括多个数据子集,并且针对相同的数据子集内的所有训练样本,所述策略函数的策略参数保持不变。
7.根据权利要求3所述的方法,其中,所述确定采样概率是针对状态信息来执行的,所述状态信息至少包括:所述训练样本的表示、所述候选参考模型针对所述训练样本输出的概率分布、以及与所述概率分布相对应的预测损失。
8.根据权利要求3所述的方法,其中,所述策略函数具有策略参数,并且所述策略参数是通过以下操作来被初始化的:
针对初始化样本,通过所述策略函数来确定从所述一组候选参考模型中选择出的一组选定参考模型;
通过所述一组选定参考模型对所述初始化样本分别进行评分,以获得所述初始化样本的一组概率分布;
基于所述初始化样本中的真实标注和所述一组概率分布来计算与所述初始化样本相对应的预测损失;
基于所述预测损失来计算与所述初始化样本相对应的奖励;以及
基于所述奖励来初始化所述策略参数。
9.根据权利要求1所述的方法,其中,所述训练样本包括真实标注,并且所述训练所述目标模型包括:
通过所述目标模型对所述训练样本进行评分,以获得所述训练样本的预测概率分布;
基于所述预测概率分布和所述一组目标概率分布来计算与所述训练样本相对应的第一预测损失;
基于所述预测概率分布和所述真实标注来计算与所述训练样本相对应的第二预测损失;
基于所述第一预测损失和所述第二预测损失来计算与所述训练样本相对应的综合预测损失;以及
通过使所述综合预测损失最小化来优化所述目标模型。
10.根据权利要求9所述的方法,其中,所述计算第一预测损失包括:
基于所述预测概率分布和所述一组目标概率分布中的每个目标概率分布来分别计算与所述目标概率分布相对应的子预测损失,以获得一组子预测损失;以及
基于所述一组选定参考模型中的参考模型的数量和所述一组子预测损失来计算所述第一预测损失。
11.根据权利要求9所述的方法,其中,所述数据集合包括多个数据子集,并且针对相同的数据子集内的所有训练样本,所述目标模型的参数保持不变。
12.根据权利要求1所述的方法,还包括:
通过所述一组候选参考模型对预训练数据集合进行评分;以及
利用经评分的预训练数据集合来预训练所述目标模型。
13.根据权利要求1所述的方法,其中,所述一组候选参考模型是具有比所述目标模型更高的复杂性的模型。
14.一种用于基于知识蒸馏来获得目标模型的装置,包括:
获得模块,用于获得数据集合和一组候选参考模型;
参考模型确定模块,用于针对所述数据集合中的每个训练样本,确定从所述一组候选参考模型中选择出的一组选定参考模型;
概率分布获取模块,用于获取所述一组选定参考模型针对所述训练样本输出的一组目标概率分布;以及
目标模型训练模块,用于利用所述一组目标概率分布来训练所述目标模型。
15.根据权利要求14所述的装置,其中,所述参考模型确定模块还被配置为:
针对所述一组候选参考模型中的每个候选参考模型,通过强化学习来确定是否选择所述候选参考模型。
16.根据权利要求15所述的装置,其中,所述确定是否选择所述候选参考模型包括:
基于策略函数来确定所述候选参考模型的采样概率;
基于所述采样概率,对所述候选参考模型的动作值进行采样;以及
基于所采样的动作值,选择所述候选参考模型。
17.根据权利要求16所述的装置,其中,所述策略函数具有策略参数,并且所述确定是否选择所述候选参考模型还包括:
基于所述目标模型的性能来更新所述策略参数。
18.根据权利要求16所述的装置,其中,所述数据集合包括多个数据子集,并且针对相同的数据子集内的所有训练样本,所述策略函数的策略参数保持不变。
19.根据权利要求16所述的装置,其中,所述确定采样概率是针对状态信息来执行的,所述状态信息至少包括:所述训练样本的表示、所述候选参考模型针对所述训练样本输出的概率分布、以及与所述概率分布相对应的预测损失。
20.一种用于基于知识蒸馏来获得目标模型的装置,包括:
至少一个处理器;以及
存储计算机可执行指令的存储器,所述计算机可执行指令在被执行时使得所述至少一个处理器:
获得数据集合和一组候选参考模型,
针对所述数据集合中的每个训练样本,确定从所述一组候选参考模型中选择出的一组选定参考模型,
获取所述一组选定参考模型针对所述训练样本输出的一组目标概率分布,以及
利用所述一组目标概率分布来训练所述目标模型。
CN202010561319.7A 2020-06-18 2020-06-18 用于知识蒸馏的模型选择学习 Pending CN113822434A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202010561319.7A CN113822434A (zh) 2020-06-18 2020-06-18 用于知识蒸馏的模型选择学习
PCT/US2021/026288 WO2021257160A1 (en) 2020-06-18 2021-04-08 Model selection learning for knowledge distillation

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010561319.7A CN113822434A (zh) 2020-06-18 2020-06-18 用于知识蒸馏的模型选择学习

Publications (1)

Publication Number Publication Date
CN113822434A true CN113822434A (zh) 2021-12-21

Family

ID=75690703

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010561319.7A Pending CN113822434A (zh) 2020-06-18 2020-06-18 用于知识蒸馏的模型选择学习

Country Status (2)

Country Link
CN (1) CN113822434A (zh)
WO (1) WO2021257160A1 (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115129975B (zh) * 2022-05-13 2024-01-23 腾讯科技(深圳)有限公司 推荐模型训练方法、推荐方法、装置、设备及存储介质
CN115082920B (zh) * 2022-08-16 2022-11-04 北京百度网讯科技有限公司 深度学习模型的训练方法、图像处理方法和装置
CN117806172B (zh) * 2024-02-28 2024-05-24 华中科技大学 一种基于云边协同与自适应知识传递的故障诊断方法

Also Published As

Publication number Publication date
WO2021257160A1 (en) 2021-12-23

Similar Documents

Publication Publication Date Title
Joulin et al. Efficient softmax approximation for GPUs
CN113822434A (zh) 用于知识蒸馏的模型选择学习
CN112257449B (zh) 命名实体识别方法、装置、计算机设备和存储介质
US11334791B2 (en) Learning to search deep network architectures
CN112905795A (zh) 文本意图分类的方法、装置和可读介质
CN110796199A (zh) 一种图像处理方法、装置以及电子医疗设备
CN113987187B (zh) 基于多标签嵌入的舆情文本分类方法、***、终端及介质
CN112215696A (zh) 基于时序归因分析的个人信用评估与解释方法、装置、设备及存储介质
CN112287656B (zh) 文本比对方法、装置、设备和存储介质
CN117611932B (zh) 基于双重伪标签细化和样本重加权的图像分类方法及***
CN110866113A (zh) 基于稀疏自注意力机制微调伯特模型的文本分类方法
CN115222950A (zh) 一种面向嵌入式平台的轻量化目标检测方法
CN112257860A (zh) 基于模型压缩的模型生成
CN113988267A (zh) 用户意图识别模型的生成方法、用户意图识别方法和设备
CN113656563A (zh) 一种神经网络搜索方法及相关设备
CN112712068A (zh) 一种关键点检测方法、装置、电子设备及存储介质
Hu et al. Saliency-based YOLO for single target detection
CN114997287A (zh) 模型训练和数据处理方法、装置、设备及存储介质
Yang et al. Structured pruning via feature channels similarity and mutual learning for convolutional neural network compression
CN111783688A (zh) 一种基于卷积神经网络的遥感图像场景分类方法
CN111259673A (zh) 一种基于反馈序列多任务学习的法律判决预测方法及***
CN116595189A (zh) 基于两阶段的零样本关系三元组抽取方法及***
CN114692615A (zh) 一种针对小语种的小样本意图识别方法
CN114358579A (zh) 评阅方法、评阅装置、电子设备以及计算机可读存储介质
CN113487453A (zh) 基于犯罪要素的法律判决预测方法及***

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination