CN111640425B - 一种模型训练和意图识别方法、装置、设备及存储介质 - Google Patents

一种模型训练和意图识别方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN111640425B
CN111640425B CN202010444204.XA CN202010444204A CN111640425B CN 111640425 B CN111640425 B CN 111640425B CN 202010444204 A CN202010444204 A CN 202010444204A CN 111640425 B CN111640425 B CN 111640425B
Authority
CN
China
Prior art keywords
training
model
network
target
distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010444204.XA
Other languages
English (en)
Other versions
CN111640425A (zh
Inventor
王晶
彭程
罗雪峰
王健飞
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202010444204.XA priority Critical patent/CN111640425B/zh
Publication of CN111640425A publication Critical patent/CN111640425A/zh
Application granted granted Critical
Publication of CN111640425B publication Critical patent/CN111640425B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G10MUSICAL INSTRUMENTS; ACOUSTICS
    • G10LSPEECH ANALYSIS TECHNIQUES OR SPEECH SYNTHESIS; SPEECH RECOGNITION; SPEECH OR VOICE PROCESSING TECHNIQUES; SPEECH OR AUDIO CODING OR DECODING
    • G10L15/00Speech recognition
    • G10L15/06Creation of reference templates; Training of speech recognition systems, e.g. adaptation to the characteristics of the speaker's voice
    • G10L15/063Training
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F3/00Input arrangements for transferring data to be processed into a form capable of being handled by the computer; Output arrangements for transferring data from processing unit to output unit, e.g. interface arrangements
    • G06F3/01Input arrangements or combined input and output arrangements for interaction between user and computer
    • G06F3/011Arrangements for interaction with the human body, e.g. for user immersion in virtual reality
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G10MUSICAL INSTRUMENTS; ACOUSTICS
    • G10LSPEECH ANALYSIS TECHNIQUES OR SPEECH SYNTHESIS; SPEECH RECOGNITION; SPEECH OR VOICE PROCESSING TECHNIQUES; SPEECH OR AUDIO CODING OR DECODING
    • G10L15/00Speech recognition
    • G10L15/08Speech classification or search
    • G10L15/16Speech classification or search using artificial neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A90/00Technologies having an indirect contribution to adaptation to climate change
    • Y02A90/10Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Engineering & Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Computational Linguistics (AREA)
  • Human Computer Interaction (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Acoustics & Sound (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请公开了一种模型训练和意图识别方法、装置、设备及存储介质,涉及人工智能技术领域。其中,模型训练方法为:根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到预训练模型的强化模型;其中,各次沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络;将强化模型中的至少两个网络作为目标网络,并根据目标网络构建蒸馏模型,其中,目标网络包含特征识别网络和所述预测层网络;特征识别网络至少包括底层网络;通过强化模型的目标网络,抽取训练任务数据集的目标知识;根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型,以提高目标学习模型预测的效率和准确性。

Description

一种模型训练和意图识别方法、装置、设备及存储介质
技术领域
本申请实施例涉及计算机技术领域,具体涉及人工智能技术。
背景技术
随着人工智能技术的发展,深度学习模型在人机交互领域的应用越来越广泛。预训练模型作为深度学习模型的一种,其结构复杂,模型参数庞大,所以预训练模型可能在运行阶段耗时长,速度慢。为了提高预训练模型的响应速度,现有技术通常需要研发人员从预训练模型中,人工选出权重值较小的网络层,并将其从预训练模型中裁剪掉,以实现对预训练模型的压缩,降低预训练模型的结构复杂度。但是采用该方式裁剪后的预训练模型,受人为因素影响较大,准确性较低,严重影响人机交互效果,亟需改进。
发明内容
提供了一种模型训练和意图识别方法、装置、设备及存储介质。
根据第一方面,提供了一种基于知识蒸馏的模型训练方法,该方法包括:
根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到所述预训练模型的强化模型;其中,各次所述沉淀训练的训练对象至少包括所述底层网络、预测层网络和逐次递减的中高层网络,所述预训练模型自底向上包括所述底层网络、至少一个所述中高层网络和所述预测层网络;
将所述强化模型中的至少两个网络作为目标网络,并根据所述目标网络构建蒸馏模型,其中,所述目标网络包含特征识别网络和所述预测层网络;所述特征识别网络至少包括所述底层网络;
通过所述强化模型的目标网络,抽取所述训练任务数据集的目标知识;
根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,得到目标学习模型。
根据第二方面,提供了一种意图识别方法,该方法包括:
获取人机交互设备采集的用户语音数据;
将所述用户语音数据输入目标学习模型,以获取所述目标学习模型输出的用户意图识别结果;其中,所述目标学习模型基于本申请任一实施例所述的基于知识蒸馏的模型训练方法训练而确定;
根据所述用户意图识别结果确定所述人机交互设备的响应结果。
根据第三方面,提供了一种基于知识蒸馏的模型训练装置,该装置包括:
沉淀训练模块,用于根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到所述预训练模型的强化模型;其中,各次所述沉淀训练的训练对象至少包括所述底层网络、预测层网络,和逐次递减的中高层网络,所述预训练模型自底向上包括所述底层网络、至少一个所述中高层网络和所述预测层网络;
蒸馏模型构建模块,用于将所述强化模型中的至少两个网络作为目标网络,并根据所述目标网络构建蒸馏模型,其中,所述目标网络包含特征识别网络和所述预测层网络;所述特征识别网络至少包括所述底层网络;
目标知识抽取模块,用于通过所述强化模型的目标网络,抽取所述训练任务数据集的目标知识;
蒸馏模型训练模块,用于根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,得到目标学习模型。
根据第四方面,提供了一种意图识别装置,该装置包括:
语音数据获取模块,用于获取人机交互设备采集的用户语音数据;
意图识别模块,用于将所述用户语音数据输入目标学习模型,以获取所述目标学习模型输出的用户意图识别结果;其中,所述目标学习模型基于本申请任一实施例所述的基于知识蒸馏的模型训练方法训练而确定;
响应结果确定模块,用于根据所述用户意图识别结果确定所述人机交互设备的响应结果。
根据第五方面,提供了一种电子设备,该电子设备包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本申请任一实施例所述的基于知识蒸馏的模型训练方法或意图识别方法。
根据第六方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质。所述计算机指令用于使所述计算机执行本申请任一实施例所述的基于知识蒸馏的模型训练方法或意图识别方法。
根据本申请实施例的技术解决了现有技术人工压缩预训练模型,准确性低的问题,能够通过低成本自动压缩训练出高精度的目标学习模型,以提高人机交互效果。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本申请的限定。其中:
图1A是根据本申请实施例提供的一种基于知识蒸馏的模型训练方法的流程图;
图1B是根据本申请实施例提供的一种预训练模型的网络结构示意图;
图2是根据本申请实施例提供的另一种基于知识蒸馏的模型训练方法的流程图;
图3是根据本申请实施例提供的另一种基于知识蒸馏的模型训练方法的流程图;
图4-图5是根据本申请实施例提供的两种基于知识蒸馏的模型训练方法的流程图;
图6A是根据本申请实施例提供的另一种基于知识蒸馏的模型训练方法的流程图;
图6B是根据本申请实施例提供的对蒸馏模型进行训练的原理结构示意图;
图7是根据本申请实施例提供的一种基于知识蒸馏的模型训练方法的流程图;
图8是根据本申请实施例提供的一种意图识别方法的流程图;
图9是根据本申请实施例提供的一种视频处理装置的结构示意图;
图10是根据本申请实施例提供的一种意图识别装置的结构示意图;
图11是用来实现本申请实施例的基于知识蒸馏的模型训练方法或意图识别方法的电子设备的框图。
具体实施方式
以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
图1A是根据本申请实施例提供的一种基于知识蒸馏的模型训练方法的流程图;图1B是根据本申请实施例提供的一种预训练模型的网络结构示意图。本实施例适用于基于知识蒸馏技术将网络结构复杂的预训练模型压缩训练成一个网络结构简单的目标学习模型的情况。该实施例可以由电子设备中配置的基于知识蒸馏的模型训练装置来执行,该装置可以采用软件和/或硬件来实现。如图1A-1B所示,该方法包括:
S101,根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到预训练模型的强化模型。
其中,本申请实施例的任务训练数据集可以是根据预训练模型需要执行的预测任务,获取与该预测任务相关的样本数据作为训练任务数据集。例如,若预训练模型需要执行的预测任务是对购物平台A中的用户语音数据进行意图识别,则此时可以获取购物平台A中所有的历史用户语音数据,进行相关处理(如打标签、删除无效数据等处理)后得到该预测任务对应的训练任务数据集。
本申请实施例的预训练模型可以是基于深度学习架构搭建的,且已径采用海量数据训练好的,可以执行某一学习任务的高精度模型,该预训练模型通常具有网络层数较深、每一层网络的维度较宽,且模型参数较多等特点。该与训练模型可以是用户自己采用大量样本数据训练的,也可以是直接从预训练模型数据库中获取的,对此本实施例不进行限定。可选的,该预训练模型自底向上可以包括底层网络、至少一个中高层网络和预测层网络。所述底层网络和所述中高层网络用于进行特征识别;所述预测层网络用于根据识别的特征进行任务预测。其中,底层网络通常用于识别简单的特征,中高层网络通常用于从简单特征中抽象出复杂特征。例如,若预训练模型为进行意图识别的bert模型,则该bert模型的底层网络通常用于识别较简单的语法特征;中高层网络通常用于从语法特征中抽象出复杂特征。预测层网络用于根据底层网络和中高层网络识别出的特征进行任务预测。可选的,本申请实施例的预训练模型可以是bert模型。
示例性的,图1B所示的预训练模型1由12个网络层构成,其中第1网络层至第3网络层为底层网络10,第4网络层至第11网络层为中高层网络11,第12网络层为预测层网络12,其中,中高层网络11又包括中层网络110(即第4网络层至第7网络层)和高层网络111(即第8网络层至第11网络层)。
可选的,通常情况下,预训练模型的中高层网络抽象的复杂特征,与预测任务本身的相关性较低,要精准完成预测任务,主要依靠的还是底层网络。所以本操作可以对预训练模型的底层网络进行多次沉淀训练,在多次沉淀训练的过程中,不断调整训练对象(即预训练模型中需要训练的网络层)。其中,各次沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络。也就是说,本实施例虽然是着重对预训练模型的底层网络进行沉淀训练,但是为了保证训练结果的准确性,每次训练的对象中至少要包括底层网络和预测层网络,对于中高层网络,其随着沉淀训练次数的增加,训练对象中包含的中高层网络的层数呈递减趋势。例如,假设对图1B所示的预训练模型1的底层网络10进行五次沉淀训练,则五次沉淀训练的训练对象中都包含底层网络10和预测层网络12,对于中间层网络,第一次沉淀训练的训练对象中可能包含所有网络层;第二次沉淀训练的训练对象可能递减为包含第4网络层至第9网络层;第三次沉淀训练的训练对象可能再次递减为包含第4网络层至第7网络层,依次递减,到第五次沉淀训练时,训练对象可能已经递减为不包含中高层网络11。
具体的,本步骤在根据训练任务数据集对预训练模型的底层网络进行至少两次深度训练时,可以是每次向预训练模型输入一部分训练任务数据集,并用这部分训练任务数据集对本次训练对象进行一次沉淀训练,然后将经过多次沉淀训练后的预训练模型作为强化模型。由于本步骤训练对象中的中高层网络随着训练次数的增加,逐次递减,因此本操作可以随着沉淀训练次数的增加,对底层网络进行更为精准的训练更新,使得底层网络的参数越来越精确。也就是说本实施例的强化模型与预训练模型相比,网络结构并没有发生变化,如若预训练模型为bert模型,则沉淀训练后的强化模型也是bert模型,只是底层网络的网络参数更为精准。
S102,将强化模型中的至少两个网络作为目标网络,并根据目标网络构建蒸馏模型。
其中,目标网络可以是从强化模型所包含的网络中筛选出的完成当前预测任务所需的网络。该目标网络包含特征识别网络和预测层网络;特征识别网络是用于进行特征识别的网络,本申请实施例中特征识别网络至少包括底层网络。可选的,除了底层网络之外,特征识别网络也可以包括部分或全部的中高层网络,对此本实施例不进行限定。
可选的,本实施例可以是从强化模型中选择至少两个网络作为目标网络,其中,若选择两个网络时,这两个网络为强化模型的底层网络和预测层网络,此时目标网络中的特征识别网络仅包括底层网络;若选择三个或三个以上网络时,可以是在选择底层网络和预测层网络的基础上,从中高层网络中选择剩余的网络,此时目标网络中的特征网络除了包括底层网络外,还包括至少一个中高层网络。需要说明的是,是否将强化模型的中高层网络作为目标网络的特征识别网络,可以综合实际预测任务、后续要抽取的目标知识的类型等因素而定。对此本实施例不进行限定。
可选的,本实施例根据目标网络构建蒸馏模型时,可以是根据目标网络,构建一个同样包含目标网络的蒸馏模型。需要说明的是,本步骤构建的蒸馏模型的目标网络的网络类型要与强化模型的目标网络的类型相同。具体的,蒸馏模型中也要包括预测层网络和特征识别网络,关于特征识别网络,若从强化模型中选择的特征识别网络中只包括底层网络,则构建的蒸馏模型的特征识别网络中也只包括底层网络;若从强化模型中选择的特征识别网络中不但包括底层网络,还包括中高层网络中的中层网络,则构建的蒸馏模型的特征识别网络中也同样包括底层网络和中层网络。
可选的,本操作根据目标网络构建蒸馏模型时,可以是结合强化模型的目标网络的网络层结构,构建一个与强化模型具有相同结构的蒸馏模型,即强化模型的同构模型。例如,若强化模型为bert模型,构建的蒸馏模型为只包含强化模型中的目标网络结构的bert模型。构建的蒸馏模型也可以与强化模型的结构不同,但同样要包含强化模型的目标网络类型,即强化模型的异构模型。例如,强化模型为bert模型,构建的蒸馏模型为CNN模型,但是该CNN模型中同样包括与强化模型相同类型的目标网络。具体如何构建同构或异构蒸馏模型的方法,将在后续实施例进行详细介绍。
需要说明的是,本操作构建的蒸馏模型可以是机器学习模型,也可以是基于神经网络的小模型,该蒸馏模型的特点是参数少,推理速度块,移植性好。
S103,通过强化模型的目标网络,抽取训练任务数据集的目标知识。
其中,目标知识可以是强化模型中的目标网络对训练任务数据集处理后得到的结果,该目标知识用于后续注入到蒸馏模型中,作为蒸馏模型训练时的监督信号。
可选的,本步骤在抽取训练任务数据集的目标知识时,可以是将训练任务数据集作为强化模型的输入,获取强化模型的特征识别网络输出的第一数据特征表示,和强化模型的预测层网络输出的第一预测概率表示;并将获取的第一数据特征表示和第一预测概率表示作为训练任务数据集的目标知识。具体的,可以是将训练任务数据集按照预设尺寸,如batch_size大小,划分为多份。然后将划分后的每份训练任务数据输入到强化模型中,运行强化模型,获取强化模型的特征识别网络输出的特征表示作为第一数据特征表示。其中,若特征识别网络只有底层网络,则第一数据特征表示只是底层网络输出的特征表示;若特征识别网络包括底层网络和一部分中高层网络,则第一数据特征表示不但包括底层网络输出的特征表示,还包括该部分中高层网络输出的特征表示。获取强化模型的预测层网络输出的特征表示,如预测概率值,作为第一预测概率表示,进而将获取的第一数据特征表示和第一预测概率表示作为本次输入的训练任务数据对应的目标知识。
S104,根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。
可选的,本操作将S103获取的目标知识作为训练蒸馏模型的监督信号,基于训练任务数据集,诱导蒸馏模型进行训练,从而实现在训练的过程中,将目标知识迁移到蒸馏模型中,以使蒸馏模型学会强化模型的预测任务。具体的,本步骤可以是根据目标知识中的数据特征表示和预测概率表示,和蒸馏模型处理训练任务数据得到的数据特征表示和预测概率表示计算软监督标签,根据蒸馏模型处理训练任务数据的处理结果,计算硬监督标签,进而将软件监督标签与硬监督标签结合,通过更少的训练任务数据,更高效的学习效率对蒸馏模型进行蒸馏训练。具体如何计算硬监督标签和软件度标签,以及具体如何根据这两种监督标签进行蒸馏训练的过程将在后续实施例进行详细介绍。
可选的,本步骤对蒸馏模型进行训练,训练后的蒸馏模型即为目标学习模型,由于该目标学习模型是对预训练模型进行知识蒸馏得到的,所以该目标学习模型可以精准执行预训练模型的预测任务,且目标学习模型相对于预训练模型结构简单,所以执行预测任务时,耗时短,响应速度快。
可选的,本实施例在训练得到目标学习模型后,可以将该目标学习模型部署到实际人机交互领域中,以执行线上任务的预测。优选的,若所述预训练模型和目标学习模型是用于进行意图识别的模型,相应的,在根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型之后,本申请实施例还可以:将目标学习模型部署到人机交互设备中,以对人机交互设备实时获取的用户语音数据进行意图识别。具体的,目标学习模型在部署到人机交互设备中之后,人机交互设备在获取到用户语音数据后,会将该用户语音数据传输给该目标学习模型,目标学习模型会对输入的用户语音数据进行意图识别,并将意图识别结果反馈给人机交互设备,人机交互设备会根据目标学习模型的意图识别结果生成用户语音数据对应的响应结果,反馈给用户。本申请实施例的方案是通过知识蒸馏的方式训练得到目标学习模型的,其网络结构相比于预训练模型更为简单,且预测效果可逼近与复杂的预训练模型,可以实现快速且准确的进行意图识别,以满足人机交互设备实时响应的需求。
本实施例的技术方案,根据训练任务数据集,以底层网络、预测层网络和逐次递减的中高层网络为训练对象,对预训练模型进行至少两次沉淀训练,得到强化模型;根据从强化模型中确定的目标网络,络构建蒸馏模型。再通过强化模型的目标网络,抽取训练任务数据集的目标知识;基于抽取的目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。本实施例采用逐次递减中高层网络的方式对预训练模型的底层网络进行多次沉淀训练,可以使得预训练模型的底层网络的参数更为精准。后续至少根据沉淀后精准的底层网络和预测层网络构建蒸馏模型,并基于提取出的目标知识对蒸馏模型进行蒸馏训练,使得从预训练模型中蒸馏出的目标学习模型在精简了网络结构的同时,保留了预训练模型的预测精准性,还实现了提高模型的泛化能力。且整个蒸馏过程不受人为因素的影响,将该目标学习模型部署到人机交互设备中,可以实现快速准确的执行任务,以满足人机交互设备实时响应的需求。
可选的,本申请实施例中的预训练模型是已经训练好的可执行某种预测任务的模型,当该预测任务覆盖领域比较广时,预训练模型可能在多领域都能进行任务预测,但是对于其中某一领域而言,预测效果可能就不是很好了。例如,若预训练模型是用于意图识别的模型,其可以对购物、业务办理以及智能家具控制等众多领域的用户语音进行意图识别,但是对于其中具体的某些领域而言,可能预测效果并不是很准确。针对该情况,本实施例可以是在根据训练任务数据集,对预训练模型进行至少两次沉淀训练之前,执行根据训练领域数据集,对预训练模型进行领域训练,更新所述预训练模型。
具体的,训练领域数据集可以是基于预训练模型待部署的工作领域,专门获取与该领域相关的样本数据作为训练领域数据集,例如,若预训练模型需要执行的是购物领域用户语音的意图识别,则此时可以是将各个购物平台的所有语音数据进行相关处理(如打标签、删除无效数据等处理)后得到该领域对应的训练领域数据集。将训练领域数据集输入到预训练模型中,以对预训练模型针对该领域进行更新训练,微调预训练模型的参数,使得更新后的预训练模型能够更为精准的执行该领域的预测任务。本实施例在根据训练领域数据集,对预训练模型进行领域训练,更新预训练模型后,对更新后的预训练模型执行S101的沉淀训练操作,这样设置的好处是极大的提高了预训练模型在其预测任务所属领域的预测精度。为后续基于该预训练模型蒸馏出精准的目标学习模型提供了保障。
图2是根据本申请实施例提供的另一种基于知识蒸馏的模型训练方法的流程图;本实施例在上述实施例的基础上,进行了进一步的优化,给出了根据训练任务数据集,对预训练模型进行至少两次沉淀训练的具体情况介绍。如图2所示,该方法包括:
S201,将训练任务数据集进行划分,以确定多份训练数据子集。
可选的,本操作可以是依据预设的沉淀策略,如每次抽取知识的网络层数,将训练任务数据集划分为多份训练数据子集。例如,若预训练模型为图1B所示的模型,且沉淀策略是每次抽取一层网络的知识,则此时可以是将训练任务数据集划分为12份。其中,训练数据子集的划分份数小于等于预训练模型的总层数。例如,当预训练模型的总层数N时,本步骤划分的训练数据子集的份数K可以等于总层数N的一半。可选的,划分后的各份训练数据子集中的训练数据数量可以相同,也可以不同,对此本实施例不进行限定。
S202,根据设定沉淀训练次数,确定每份训练数据子集各自对应的训练对象。
其中,训练对象可以是每次进行沉淀训练时,预训练模型中需要进行训练的网络层。本申请实施例每份训练数据子集对应的训练对象不同。具体的,各份训练数据子集对应的训练对象包括预训练模型的底层网络、中高层网络和预测层网络,且包括的中高层网络的层数与沉淀训练的顺序呈反比。且各训练对象包括的中高层网络是与底层网络相邻且向上连续的网络层。也就是说,各份训练数据子集对应的训练对象中,底层网络和预测层网络保持不变,中高层网络的层数随着训练数据子集对应的沉淀训练顺序的后移,中高层网络的层数自上而下层逐次减少。可选的,且基于沉淀训练次数的增加,训练对象中包括的中高层网络的层数递减为零。从而实现随着沉淀训练次数的增加,最后只对底层网络进行更新训练。
可选的,本实施例在确定每份训练数据子集各自对应的训练对象时,底层网络和预测层网络是不变的,可以根据预训练模型的总层数和每份训练数据子集对应的沉淀训练顺序,确定每份训练数据子集对应的训练对象中的中高层网络的层数。例如,若预训练模型的总层数为N,某一份训练数据子集对应的沉淀训练顺序为第k次,则该份训练数据子集对应的训练对象中包含的中高层网络的最高层数为S=N-2*k,即,S层以下的中高层网络都是该份训练数据集对应的训练对象。
S203,根据每份训练数据子集,对预训练模型中,该份训练数据子集对应的训练对象进行一次沉淀训练,得到预训练模型的强化模型。
可选的,确定出划分后的每份训练数据子集对应的训练对象后,可以是按照各份训练数据子集对应的沉淀训练顺序,依次将各份训练数据子集输入到预训练模型中,利用输入的训练数据子集,对预训练模型中本次训练对象对应的各网络层进行训练,更新训练对象对应的各网络层的参数。由于本实施例各份训练数据子集对应的训练对象中的中高层网络的层数随着沉淀训练次数的增加,逐次递减,所以多次沉淀训练的过程中,更新的中高层网络的参数越来越少,实现训练过程逐渐向底层网络进行集中,经过多次沉淀训练,会使预训练模型的底层网络越来越精确,此时可以将经过多次沉淀训练后的预训练模型作为强化模型。
S204,将强化模型中的至少两个网络作为目标网络,并根据目标网络构建蒸馏模型。
其中,目标网络包含特征识别网络和所述预测层网络;特征识别网络至少包括底层网络。
S205,通过强化模型的目标网络,抽取训练任务数据集的目标知识。
S206,根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。
本实施例的技术方案,将训练任务数据集划分为多份训练数据子集,并基于训练对象中的中高层网络的层数与沉淀训练的顺序呈反比的原则,确定每份训练数据子集的训练对象,根据划分后下每份训练数据子集,对其训练对象进行一次沉淀训练,得到强化模型。根据该强化模型,络构建蒸馏模型,以及抽取目标知识;进而基于抽取的目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。本实施例基于训练对象中的中高层网络的层数与沉淀训练的顺序呈反比的原则,确定每次沉淀训练的训练对象,使得多次沉淀训练后的预训练模型的底层网络的参数更为精准。为知识蒸馏过程的知识沉淀操作提供了一种新思路。为知识蒸馏训练目标学习模型的后续操作提供了保障。
图3是根据本申请实施例提供的另一种基于知识蒸馏的模型训练方法的流程图,本实施例在上述实施例的基础上,进行了进一步的优化,给出了对预训练模型进行多次沉淀训练的过程中,何时得到预训练模型的强化模型的具体情况介绍。如图3所示,该方法包括:
S301,根据训练任务数据集,对预训练模型逐次进行沉淀训练。
需要说明的是,本步骤对预训练模型逐次进行沉淀训练的具体实现方式在上述实施例中已经进行了详细介绍,在此不进行赘述。
S302,根据测试任务数据集,对沉淀训练后的预训练模型进行测试。
其中,测试任务数据集可以是用于测试沉淀训练后的预训练模型是否可以精准完成预测任务的测试数据。可选的,可以根据预训练模型需要执行的预测任务,获取与该预测任务相关的样本数据,然后将其分为两份,一份作为本申请实施例的训练任务数据集,另一份作为本申请实施例的测试任务数据集。
可选的,本实施例可以是将测试任务数据集输入到S301经过多次沉淀训练后的预训练模型中,得到沉淀训练后的预训练模型基于测试任务数据输出的预测结果,最后根据测试任务数据中的真实标签对预测结果进行分析,计算表征多次沉淀训练后的预训练模型输出结果是否准确的评价指标值,并将该评价指标值作为测试结果。可选的,该评价指标值可以是根据预测任务而定,如可以是多次沉淀训练后的预训练模型输出结果的准确率、精确率和召回率等。
可选的,为了保证测试结果的准确性,本实施例可以是采用多组测试任务数据集对沉淀训练后的预训练模型进行多次测试,根据多次测试结果来确定最终的测试结果。
S303,若测试结果满足沉淀结束条件,则将沉淀训练后的预训练模型作为强化模型。
其中,沉淀结束条件可以是判断经过多次沉淀训练后的预训练模型是否满足作为强化模型的判断条件。具体的,可以是测试结果中的评价指标值对应的指标阈值。
可选的,本实施例可以是将S302对沉淀训练后的预训练模型进行测试,得到的测试结果(即评价指标值)与沉淀结束条件中的指标阈值进行比较,若评价指标值满足指标阈值,则说明测试结果满足沉淀结束条件,此时可以将沉淀训练后的预训练模型作为强化模型;若测试结果不满足沉淀结束条件,则需要返回S301继续根据训练任务数据集,对预训练模型逐次进行沉淀训练,直到测试结果满足沉淀结束条件。
S304,将强化模型中的至少两个网络作为目标网络,并根据目标网络构建蒸馏模型。
其中,目标网络包含特征识别网络和所述预测层网络;特征识别网络至少包括底层网络。
S305,通过强化模型的目标网络,抽取训练任务数据集的目标知识。
S306,根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。
本实施例的技术方案,根据训练任务数据集,对预训练模型的底层网络进行多次沉淀训练后,根据测试任务数据集对沉淀训练后的预训练模型进行测试,如果测试通过,方可将沉淀训练后的预训练模型作为强化模型。进而根据该强化模型,络构建蒸馏模型,以及抽取目标知识;并基于抽取的目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。本实施例通过对知识沉淀后的预训练模型进行测试,来确定知识沉淀是否达到沉淀训练的预期效果,只有达到预期效果,才可以作为强化模型,保证了得到的强化模型的底层网络参数的精准性。为知识蒸馏训练目标学习模型的后续操作提供了保障。
可选的,上述实施例介绍了在对预训练模型的底层网络进行多次沉淀训练的过程中,何时得到预训练模型的强化模型的确定过程,同理,本实施例在根据目标知识和训练任务数据集,对蒸馏模型进行训练的过程中,也可以采用类似的方法,来判断蒸馏模型是否训练完成,可得到目标学习模型。具体的:本申请实施例可以是在执行根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型时,具体执行:根据目标知识和训练任务数据集,对蒸馏模型进行训练;根据测试任务数据集,对训练后的蒸馏模型进行测试;若测试结果满足训练结束条件,则将训练后的蒸馏模型作为目标学习模型。需要说明的是,根据训练任务数据集对训练后的蒸馏模型进行测试的过程,与上述实施例介绍的根据训练任务数据集对沉淀训练后的预训练模型进行测试的过程相似,如可以是将测试任务数据集输入到训练后的蒸馏模型中,根据训练后的蒸馏模型的输出的预测结果与测试任务数据集的真实标签计算评价指标值,若评价指标值满足训练结束条件中的指标阈值,则说明对训练后的蒸馏模型的测试结果满足训练结束条件,可以将本次训练后的蒸馏模型作为目标学习模型。这样设置的好处是通过对训练后的蒸馏模型进行测试,来确定训练后的蒸馏模型的任务预测精度是否达到预期效果,只有达到预期效果,才可以将其作为最终的目标学习模型,提高了基于知识蒸馏技术蒸馏到的目标学习模型的准确性。
图4-图5是根据本申请实施例提供的两种基于知识蒸馏的模型训练方法的流程图,本实施例在上述实施例的基础上,进行了进一步的优化,给出了根据目标网络构建蒸馏模型的两种具体实施方式的介绍。
可选的,图4示出的是根据目标网络构建与强化模型同结构的蒸馏模型的可实施方式,具体的:
S401,根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到预训练模型的强化模型。
其中,各次沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络,预训练模型自底向上包括底层网络、至少一个所述中高层网络和预测层网络。
S402,将强化模型中的至少两个网络作为目标网络,并获取目标网络的网络结构块。
其中,目标网络包含强化模型的特征识别网络和预测层网络;所述特征识别网络至少包括强化模型的底层网络。可选的,还可以包括强化模型的部分或全部的中高层网络。由于本实施例构建的蒸馏模型的网络结构要比强化模型简单,所以通常情况下,本申请实施例的目标网络的特征识别网络中通常不包含或者仅包含少量的中高层网络。网络结构块可以是对强化模型中的一个或多个网络层的网络结构进行封装后得到的,例如,假设本实施例的强化模型是对图1B示出的预训练模型进行沉淀训练后得到的,那强化模型的网络结构应该也如图1B所示,此时可以将图1B中的第1网络层至第3网络层的网络结构封装为底层网络10的网络结构块;将第4网络层至第7网络层的网络结构封装为中层网络110的网络结构块;将第8网络层至第11网络层的网络结构封装为高层网络111的网络结构块;将第12网络层的网络结构封装为预测层网络12的网络结构块。
可选的,如果要构建与强化模型具有相同结构的蒸馏模型时,可以是在从强化模型中选择出目标网络后,获取目标网络在强化模型中对应的网络结构块。例如,将强化模型中的底层网络和预测层网络作为目标网络,则可以是底层网络的网络结构块,和预测层网络的网络结构块,作为目标网络的网络结构块。
S403,根据获取的网络结构块,构建与强化模型同结构的蒸馏模型。
可选的,由于目标网络对应的是强化模型中的至少两个网络,所以获取的网络结构块也是至少两个网络的网络结构块,本步骤可以是对至少两个网络结构块按照其在强化模型中自下而上的顺序,进行排列,并将位于下方的网络结构块的输出作为与其相邻的上方网络结构块的输入,从而形成一个由目标网络构成的新模型,该新模型即为构建的蒸馏模型。
例如,假设S402获取的是图1B中的底层网络10和预测层网络12的网络结构块,由于底层网络10位于预测层网络12的下方,则此时可以是将底层网络10的网络结构块置于预测层网络12的网络结构块下方,并将底层网络10的网络结构块的输出与预测层网络12的网络结构块的输入连接,从而生成一个由底层网络10的网络结构块和预测层网络12的网络结构块构成的蒸馏模型。同理,如果S402获取的是底层网络10、中层网络110和预测层网络12的网络结构块,则此时可以是将底层网络10的网络结构块位于最下方,将中层网络110的网络结构块位于中间,将预测层网络12的网络结构块位于最上方,将底层网络10的网络结构块的输出连接中层网络110的网络结构块的输入,将中层网络110的网络结构块的输出连接预测层网络12的网络结构块的输入,从而生成一个由底层网络10的网络结构块、中层网络110的网络结构块和预测层网络12的网络结构块构成的蒸馏模型。
S404,通过强化模型的目标网络,抽取训练任务数据集的目标知识。
S405,根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。
可选的,图5示出的是根据目标网络构建与强化模型结构不同的蒸馏模型的可实施方式,具体的:
S501,根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到预训练模型的强化模型。
其中,各次沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络,预训练模型自底向上包括底层网络、至少一个中高层网络和预测层网络。
S502,将强化模型中的至少两个网络作为目标网络。
可选的,从强化模型中选择目标网络的过程在上述实施例已经进行了介绍,在此本实施例不进行赘述。
S503,根据目标网络,选择与强化模型结构不同的神经网络模型作为蒸馏模型。
其中,神经网络模型的输出层网络与目标网络中预测层网络的类型一致,神经网络模型的非输出层网络与目标网络中特征识别网络的类型一致。所谓预测层网络的类型是指网络的类型属于预测型,即进行任务预测类型。所谓特征识别网络的类型包括:底层网络、中层网络,以及高层网络等。
可选的,由于本可实施方式构建的蒸馏模型与强化模型的结构不同,所以,此时可以根据需求选择一个结构简单,且可用于实现预测任务的神经网络模型作为蒸馏模型。其中,可以选作蒸馏模型的神经网络模型的结构通常比较简单,层数较少,但是需要神经网络模型的输出层与目标网络中预测层网络的类型一致,非输出层网络与目标网络中特征识别网络的类型一致。即需要神经网络模型的输出层为可进行任务预测的网络,而其非输出层需要为与目标网络的特征识别网络的类型一致,例如,若目标网络的特征识别网络的类型为底层网络,则该神经网络模型的非输出层的类型也应该为底层网络;若目标网络的特征识别网络的类型为底层网络和中层网络,则神经网络模型的非输出层的类型也应该为底层网络和中层网络。
本步骤构建的该蒸馏模型由于结构单元,层数交少,所以通常情况下与结构复杂的强化模型是异构模型的关系。例如,假设强化模型为bert模型,此时可以是选择CNN模型作为蒸馏模型。
S504,通过强化模型的目标网络,抽取训练任务数据集的目标知识。
S505,根据目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。
本申请实施例的技术方案,给出了基于知识蒸馏技术,训练预训练模型的目标学习模型过程中,根据沉淀训练后的强化模型的目标网络,构建与强化模型结构相同或不同的两种蒸馏模型的具体执行方式。若构建与强化模型同结构的蒸馏模型,由于蒸馏模型保留了强化模型的网络结构块,所以同构的蒸馏模型更容易蒸馏训练到强化模型的预测效果;若构建与强化模型同结构的蒸馏模型时,异构的蒸馏模型可以学习到与强化模型不同的特征,提高模型的泛化能力。本申请实施例可以根据实际需求进行选择,灵活性强。
图6A是根据本申请实施例提供的另一种基于知识蒸馏的模型训练方法的流程图;图6B是根据本申请实施例提供的对蒸馏模型进行训练的原理结构示意图。本实施例在上述实施例的基础上,进行了进一步的优化,给出了根据目标知识和训练任务数据集,对蒸馏模型进行训练的具体情况介绍。如图6A-6B所示,该方法包括:
S601,根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到预训练模型的强化模型。
其中,各次沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络,预训练模型自底向上包括底层网络、至少一个中高层网络和预测层网络。
S602,将强化模型中的至少两个网络作为目标网络,并根据目标网络构建蒸馏模型。
其中,目标网络包含特征识别网络和预测层网络;特征识别网络至少包括底层网络。
示例性的,假设从图6B示出的强化模型中,选择出的目标网络为底层网络,中层网络和预测层网络,并根据这三种网络构建了图6B中所示的蒸馏模型。
S603,通过强化模型的目标网络,抽取训练任务数据集的目标知识。
示例性的,本操作可以是将训练任务数据集中预设大小,如batch_size大小的训练数据输入到图6B所示的强化模型中,获取强化模型的底层网络输出的特征表示(knowledge_seql)和中层网络输出的特征表示(knowledge_seqm),作为第一数据特征表示(knowledge_seq);获取强化模型的预测层网络输出的特征表示(knowledge_predict)作为第一预测概率表示。本步骤获取的第一数据特征表示和第一预测概率表示即为抽取到的目标知识。
S604,将训练任务数据集输入蒸馏模型中,并根据蒸馏模型对训练任务数据集的处理结果和目标知识,确定软监督标签和硬监督标签。
其中,软监督标签和硬监督标签是对蒸馏模型进行训练的过程中的两种监督信号。其中软监督标签是基于抽取的目标知识计算出的,硬监督标签是基于训练任务数据集中的实际标签计算出的。
可选的,本实施例可以是将训练任务数据集输入到蒸馏模型中,蒸馏模型对输入的训练数据集进行处理,得到蒸馏模型的各网络层的输出结果,该输出结果一方面用于结合目标知识确定软监督标签。另一方面用于结合训练任务数据集的相关信息计算硬监督标签。具体的确定过程包括以下三个子步骤:
S6041,将训练任务数据集输入蒸馏模型,得到蒸馏模型的特征识别网络输出的第二数据特征表示,和蒸馏模型的预测层网络输出的第二预测概率表示。
具体的,将训练任务数据中预设大小,如batch_size大小的训练数据输入到蒸馏模型之后,获取蒸馏模型中的预测层网络输出预测结果(即特征表示)作为第二预测概率表示;如果蒸馏模型中的特征识别网络中只有底层网络,则获取底层网络输出的特征表示作为第二数据特征表示;如果蒸馏模型中的特征识别网络中除了底层网络外,还包括部分中高层网络,则此时获取底层网络和该部分中高层网络输出的特征表示一并作为第二数据特征表示。示例性的,如图6B所示,将训练任务数据集输入蒸馏模型,由于图6B中的目标网络的特征识别网络中包括底层网络和中层网络,所以,需要将蒸馏模型处理训练任务数据集后,底层网络输出的特征表示(samll_seql)和中层网络输出的特征表示(samll_seqm)作为第二数据特征表示(samll_seq);将蒸馏模型的预测层网络输出的特征表示(small_predict)作为第二预测概率表示。
S6042,根据目标知识、第二数据特征表示和第二预测概率表示,确定软监督标签。
可选的,由于目标知识是由第一数据特征表示和第一预测概率表示构成的,本实施例可以是按照预设的算法,对第一数据特征表示、第一预测概率表示、第二数据特征表示和第二预测概率表示进行计算,得到软监督标签。具体的计算算法本实施例不进行限定。如可以是将目标知识中的第一数据特征表示和第二数据特征表示的均值方差作为数据特征标签;将目标知识中的第一预测概率表示和第二预测概率表示的均值方差作为概率预测标签;然后根据强化模型的特征识别网络的权重值,对所述数据特征标签和所述概率预测标签进行标签融合,得到软监督标签。本实施例根据强化模型和蒸馏模型基于相同的训练任务数据集,输出的特征表示来确定软监督标签,使得确定出的软监督标签更为准确,进而提高后续训练出的目标学习模型的准确性。
具体的,可以是按照下述公式(1)计算数据特征标签,按照下述公式(2)计算概率预测标签;最后按照下述公式(3)计算软监督标签。
loss_i=MSE(knowledge_seq,small_seq) (1)
loss_p=MSE(knowledge_predict,small_predict) (2)
loss_soft=Wi*loss_i+loss_p (3)
其中,loss_i为数据特征标签;MSE()为均值方差函数;knowledge_seq为第一数据特征表示;small_seq为第二数据特征表示;loss_p为概率预测标签;knowledge_predict为第一预测概率表示;small_predict为第二预测概率表示;loss_soft为软监督标签;Wi为特征识别网络的权重值。
可选的,当特征识别网络包括多个网络(如底层网络和中层网络)时,第一数据特征表示和第而数据特征表示都是由多个网络层输出的特征表示构成,此时可以是针对每个网络层输出的特征表示,都按照公式(1)计算出一个数据特征标签。例如,如图6B所示,第一数据特征表示包括:knowledge_seql和knowledge_seqm,第二数据特征表示包括:samll_seql和samll_seqm。此时可以是根据knowledge_seql和samll_seql,计算底层网络的数据特征标签loss_il,根据knowledge_seqm和samll_seqm计算中层网络的数据特征标签loss_im。相应的,此时在计算软监督标签时,可以是将各网络的权重值与其数据特征标签的乘积,以及概率预测标签进行求和,得到最终的软监督标签。例如针对图6B所示的场景,软监督标签的计算公式可以是loss_soft=Wl*loss_il+Wm*loss_im+loss_p。
S6043,根据第二预测概率表示和训练任务数据集信息,确定硬监督标签。
其中,训练任务数据集信息包括:训练任务数据集中训练样本数量、训练标签数量和实际标签值。
可选的,本子步骤可以是根据下述公式(4)计算硬监督标签。
其中,loss_hart为硬监督标签;N为训练任务数据集中的训练样本数量;M为训练标签数量,i为第i个训练样本;c为第c为训练标签;yic为第i个样本属于第c个训练标签的实际标签值;small_predictic为蒸馏模型的预测网络层输出的第i个训练样本属于第c个训练标签的概率值。可选的,yic的取值可以为0或1。
本实施例S6041-S6043根据强化模型和蒸馏模型基于相同的训练任务数据集,输出的特征表示来确定软监督标签,根据训练任务数据的实际标签值和蒸馏模型的预测概率来确定硬监督标签,为软监督标签和硬监督标签的确定提供了一种新思路,提高了软硬监督标签的准确性。
S605,根据软监督标签和硬监督标签,确定目标标签。
其中,目标标签是结合了软监督标签和硬监督标签的特性后,确定出的最终用于监督蒸馏模型训练的标签值。可选的,本步骤根据下述公式(5)确定目标标签:
loss=alpha*loss_soft+(1-alpha)*loss_hart (5)
其中,loss为目标标签,alpha为参数变量;loss_soft为软监督标签;loss_hart为硬监督标签。
可选的上述公式(5)中的参数变量可以是基于预设规则设置的常量,也可以是随蒸馏模型一起进行训练的变量。对此本实施例不进行限定。
S606,根据目标标签,对蒸馏模型的参数进行迭代更新,得到目标学习模型。
可选的,本实施例可以是根据S605确定出的目标标签,按照预设规则,如反向传播算法(BP算法),对蒸馏模型的参数进行更新调整,从而完成对蒸馏模型参数的一次迭代更新。然后再从训练任务数据集中,获取下一组预设大小,如batch_size大小的训练数据输入到蒸馏模型中,返回执行S603-S606的操作,对蒸馏模型的参数进行下一次的迭代更新,从而完成对蒸馏模型训练。对蒸馏模型训练多次后,可以通过测试任务数据集对训练的蒸馏模型进行测试,如果满足训练结束条件,则说明蒸馏模型已经训练好,可将训练后的蒸馏模型作为目标学习模型。
本实施例的技术方案,根据对预训练模型的底层网络进行沉淀训练得到强化模型,构建蒸馏模型以及抽取目标知识;根据蒸馏模型对任务训练数据的处理结果和抽取的目标知识,确定软监督标签和硬监督标签,进而基于软硬监督标签确定出目标标签来对蒸馏模型的参数进行迭代更新,得到目标学习模型。本实施例将软监督标签和硬监督标签结合来训练蒸馏模型,使得训练的蒸馏模型在逼近预训练模型预测效果的同时,还提高了蒸馏模型的泛化能力。从而更好的满足人机交互设备实时响应的需求。
图7是根据本申请实施例提供的一种基于知识蒸馏的模型训练方法的流程图。本实施例在上述各实施例的基础上,提供了一种优选实例,具体的,如图7所示,该方法包括:
S701,获取预训练模型。
可选的,本步骤获取的预训练模型是已经基于海量训练样本训练好的模型,该预训练模型能够较好的完成线上预测任务。
S702,根据训练领域数据集,对预训练模型进行领域训练,更新预训练模型。
S703,根据训练任务数据集,对预训练模型进行逐次沉淀训练。
其中,各次沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络,预训练模型自底向上包括底层网络、至少一个中高层网络和预测层网络。
S704,根据测试任务数据集,对沉淀训练后的预训练模型进行测试。
S705,判断测试结果是否满足沉淀结束条件,若是,则执行S706,若否,则返回执行S702。
可选的,如果测试结果满足沉淀结束条件,则说明沉淀训练已经达到预期效果,可以执行S706将其作为强化模型,否则,说明沉淀训练不充分,需要返回S702基于训练领域数据集,对预训练模型的参数进行更新调整。
S706,若测试结果满足沉淀结束条件,则将沉淀训练后的预训练模型作为强化模型。
S707,将强化模型中的至少两个网络作为目标网络,并根据目标网络构建蒸馏模型。
S708,通过强化模型的目标网络,抽取训练任务数据集的目标知识。
S709,根据目标知识和训练任务数据集,对蒸馏模型进行训练。
S710,根据测试任务数据集,对训练后的蒸馏模型进行测试。
S711,判断测试结果是否满足训练结束条件,若是,则执行S712,若否,则返回执行S709。
S712,若测试结果满足训练结束条件,则将训练后的蒸馏模型作为目标学习模型。
本申请实施例的技术方案,给出了基于知识蒸馏技术,从预训练模型中国蒸馏出目标学习模型的具体实现方案,该方案蒸馏出的目标学习模型在保留预训练模型的精准预测能力的同时,精简了网络结构分支,提高了模型的泛化能力。将该目标学习模型部署到人机交互设备中,可以实现快速准确的执行任务,以满足人机交互设备实时响应的需求。
图8是根据本申请实施例提供的一种意图识别方法的流程图。本实施例适用于基于上述各实施例训练的目标学习模型,进行意图识别的情况。该实施例可以由电子设备中配置的意图识别装置来执行,该装置可以采用软件和/或硬件来实现。可选的,该电子设备可以是人机交互设备或与人机交互设备通信交互的服务端。该人机交互设备可以是智能机器人、智能音箱、智能手机等。如图8所示,该方法包括:
S801,获取人机交互设备采集的用户语音数据。
可选的,本申请实施例的人机交互设备可以通过其内部配置的语音采集装置(如麦克风),实时采集环境中的用户语音数据。若本实施例的执行主体为人机交互设备,则该人机交互设备采集了用户语音数据后可直接进行下述S802的操作。若本实施例的执行主体为与人机交互设备通信交互的服务端,则人机交互设备在采集到用户语音数据后,会将该用户语音数据传输至其通信交互的服务端,由服务端获取用户语音数据后执行下述S802的操作。
S802,将用户语音数据输入目标学习模型,以获取目标学习模型输出的用户意图识别结果。
其中,本实施例中的目标学习模型是基于上述任一实施例所述的基于知识蒸馏的模型训练方法训练而确定。且本实施例的目标学习模型是用于执行意图识别的模型。
可选的,人机交互设备或与其通信交互的服务端在获取用户语音数据后,会将获取的用户语音数据输入到目标学习模型中,此时目标学习模型会基于输入的用户语音数据,采用训练时的算法对该用户语音数据进行线上分析预测,输出用户意图识别结果,此时人机交互设备或与其通信交互的服务端获取目标学习模型输出的用户意图识别结果。
S803,根据用户意图识别结果确定人机交互设备的响应结果。
可选的,人机交互设备或与其通信交互的服务端会基于获取的用户意图识别结果,确定该用户意图识别结果所对应的目标人机交互响应规则,并基于该目标人机交互响应规则,确定本次响应结果,并将响应结果反馈给用户,以实现基于用户语音数据进行人机交互。
本申请实施例的技术方案,将基于上述任意实施例所述的基于知识蒸馏的模型训练方法训练的用于意图识别的目标学习模型,部署到人机交互设备或与人机交互设备通信交互的服务端中,人机交互设备或与其通信交互的服务端可以获取用户语音数据输入到目标学习模型中,并基于目标学习模型输出的用户意图识别结果,确定本次响应结果。本申请实施例部署到人机交互设备或与其通信交互的服务端中的目标学习模型是通过知识蒸馏的方式训练得到的,其网络结构相比于预训练模型更为简单,且预测效果可逼近与复杂的预训练模型,可以实现快速且准确的进行意图识别,以满足人机交互设备实时响应的需求。
图9是根据本申请实施例提供的一种视频处理装置的结构示意图,本实施例适用于基于知识蒸馏技术将网络结构复杂的预训练模型压缩训练成一个网络结构简单的目标学习模型的情况。该装置可实现本申请任意实施例所述的基于知识蒸馏的模型训练方法,该装置900具体包括如下:
沉淀训练模块901,用于根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到所述预训练模型的强化模型;其中,各次所述沉淀训练的训练对象至少包括所述底层网络、预测层网络和逐次递减的中高层网络,所述预训练模型自底向上包括所述底层网络、至少一个所述中高层网络和所述预测层网络;
蒸馏模型构建模块902,用于将所述强化模型中的至少两个网络作为目标网络,并根据所述目标网络构建蒸馏模型,其中,所述目标网络包含特征识别网络和所述预测层网络;所述特征识别网络至少包括所述底层网络;
目标知识抽取模块903,用于通过所述强化模型的目标网络,抽取所述训练任务数据集的目标知识;
蒸馏模型训练模块904,用于根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,得到目标学习模型。
进一步的,所述底层网络和所述中高层网络用于进行特征识别;所述预测层网络用于根据识别的特征进行任务预测。
进一步的,所述沉淀训练模块901包括:
数据子集划分单元,用于将所述训练任务数据集进行划分,以确定多份训练数据子集;
训练对象确定单元,用于根据设定沉淀训练次数,确定每份训练数据子集各自对应的训练对象;其中,各份训练数据子集对应的训练对象包括所述预训练模型的底层网络、中高层网络和预测层网络,且包括的所述中高层网络的层数与沉淀训练的顺序呈反比;
沉淀训练单元,用于根据所述每份训练数据子集,对所述预训练模型中,该份训练数据子集对应的训练对象进行一次沉淀训练;
其中,训练数据子集的划分份数小于等于所述预训练模型的总层数。
进一步的,各所述训练对象包括的中高层网络是与底层网络相邻且向上连续的网络层;且基于所述沉淀训练次数的增加,所述训练对象中包括的中高层网络的层数递减为零。
进一步的,所述沉淀训练模块901具体用于:
根据所述训练任务数据集,对所述预训练模型逐次进行沉淀训练;
根据测试任务数据集,对沉淀训练后的预训练模型进行测试;
若测试结果满足沉淀结束条件,则将所述沉淀训练后的预训练模型作为强化模型。
进一步的,所述装置还包括:
领域训练模型,用于在根据训练任务数据集,对预训练模型进行至少两次沉淀训练之前,根据训练领域数据集,对预训练模型进行领域训练,更新所述预训练模型。
进一步的,所述蒸馏模型构建模块902具体用于:
将所述强化模型中的至少两个网络作为目标网络,并获取所述目标网络的网络结构块;
根据获取的所述网络结构块,构建与所述强化模型同结构的蒸馏模型。
进一步的,所述蒸馏模型构建模块902还具体用于:
将所述强化模型中的至少两个网络作为目标网络;
根据所述目标网络,选择与所述强化模型结构不同的神经网络模型作为蒸馏模型,其中,所述神经网络模型的输出层网络与所述目标网络中预测层网络的类型一致,所述神经网络模型的非输出层网络与所述目标网络中特征识别网络的类型一致。
进一步的,所述目标知识抽取模块903具体用于:
将所述训练任务数据集作为所述强化模型的输入,获取所述强化模型的特征识别网络输出的第一数据特征表示,和所述强化模型的预测层网络输出的第一预测概率表示;
将获取的所述第一数据特征表示和所述第一预测概率表示作为所述训练任务数据集的目标知识。
进一步的,所述蒸馏模型训练模块904包括:
监督标签确定单元,用于将所述训练任务数据集输入所述蒸馏模型中,并根据所述蒸馏模型对所述训练任务数据集的处理结果和所述目标知识,确定软监督标签和硬监督标签;
目标标签确定单元,用于根据所述软监督标签和所述硬监督标签,确定目标标签;
模型参数更新单元,用于根据所述目标标签,对所述蒸馏模型的参数进行迭代更新。
进一步的,所述监督标签确定单元具体包括:
输出获取子单元,用于将所述训练任务数据集输入所述蒸馏模型,得到所述蒸馏模型的特征识别网络输出的第二数据特征表示,和所述蒸馏模型的预测层网络输出的第二预测概率表示;
软标签确定子单元,用于根据所述目标知识、所述第二数据特征表示和所述第二预测概率表示,确定软监督标签;
硬标签确定子单元,用于根据所述第二预测概率表示和所述训练任务数据集信息,确定硬监督标签。
进一步的,所述训练任务数据集信息包括:训练任务数据集中训练样本数量、训练标签数量和实际标签值。
进一步的,所述软标签确定子单元具体用于:
将所述目标知识中的第一数据特征表示和所述第二数据特征表示的均值方差作为数据特征标签;
将所述目标知识中的第一预测概率表示和所述第二预测概率表示的均值方差作为概率预测标签;
根据所述强化模型的特征识别网络的权重值,对所述数据特征标签和所述概率预测标签进行标签融合,得到软监督标签。
进一步的,所述蒸馏模型训练模块904具体用于:
根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练;
根据测试任务数据集,对训练后的蒸馏模型进行测试;
若测试结果满足训练结束条件,则将所述训练后的蒸馏模型作为目标学习模型。
进一步的,所述预训练模型为bert模型。
进一步的,所述预训练模型和目标学习模型是用于进行意图识别的模型;
相应的,所述装置还包括:
模型部署模块,用于将所述目标学习模型部署到人机交互设备中,以对所述人机交互设备实时获取的用户语音数据进行意图识别。
本实施例的技术方案,根据训练任务数据集,以底层网络、预测层网络和逐次递减的中高层网络为训练对象,对预训练模型进行至少两次沉淀训练,得到强化模型;根据从强化模型中确定的目标网络,络构建蒸馏模型。再通过强化模型的目标网络,抽取训练任务数据集的目标知识;基于抽取的目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。本实施例采用逐次递减中高层网络的方式对预训练模型的底层网络进行多次沉淀训练,可以使得预训练模型的底层网络的参数更为精准。后续至少根据沉淀后精准的底层网络和预测层网络构建蒸馏模型,并基于提取出的目标知识对蒸馏模型进行蒸馏训练,使得从预训练模型中蒸馏出的目标学习模型在精简了网络结构的同时,保留了预训练模型的预测精准性,还实现了提高模型的泛化能力。且整个蒸馏过程不受人为因素的影响,将该目标学习模型部署到人机交互设备中,可以实现快速准确的执行任务,以满足人机交互设备实时响应的需求。
图10是根据本申请实施例提供的一种意图识别装置的结构示意图,本实施例可适用于基于上述各实施例训练的目标学习模型,进行意图识别的情况。该装置可实现本申请任意实施例所述的意图识别方法,该装置1000具体包括如下:
语音数据获取模块1001,用于获取人机交互设备采集的用户语音数据;
意图识别模块1002,用于将所述用户语音数据输入目标学习模型,以获取所述目标学习模型输出的用户意图识别结果;其中,所述目标学习模型基于上述任一实施例所述的基于知识蒸馏的模型训练方法训练而确定;
响应结果确定模块1003,用于根据所述用户意图识别结果确定人机交互设备的响应结果。
进一步的,所述装置配置于所述人机交互设备中,或与所述人机交互设备通信交互的服务端。
本申请实施例的技术方案,将基于上述任意实施例所述的基于知识蒸馏的模型训练方法训练的用于意图识别的目标学习模型,部署到人机交互设备或与人机交互设备通信交互的服务端中,人机交互设备或与其通信交互的服务端可以获取用户语音数据输入到目标学习模型中,并基于目标学习模型输出的用户意图识别结果,确定本次响应结果。本申请实施例部署到人机交互设备或与其通信交互的服务端中的目标学习模型是通过知识蒸馏的方式训练得到的,其网络结构相比于预训练模型更为简单,且预测效果可逼近与复杂的预训练模型,可以实现快速且准确的进行意图识别,以满足人机交互设备实时响应的需求。
根据本申请的实施例,本申请还提供了一种电子设备和一种可读存储介质。
如图11所示,是根据本申请实施例的基于知识蒸馏的模型训练方法或意图识别方法的电子设备的框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本申请的实现。
如图11所示,该电子设备包括:一个或多个处理器1101、存储器1102,以及用于连接各部件的接口,包括高速接口和低速接口。各个部件利用不同的总线互相连接,并且可以被安装在公共主板上或者根据需要以其它方式安装。处理器可以对在电子设备内执行的指令进行处理,包括存储在存储器中或者存储器上以在外部输入/输出装置(诸如,耦合至接口的显示设备)上显示GUI的图形信息的指令。在其它实施方式中,若需要,可以将多个处理器和/或多条总线与多个存储器和多个存储器一起使用。同样,可以连接多个电子设备,各个设备提供部分必要的操作(例如,作为服务器阵列、一组刀片式服务器、或者多处理器***)。图11中以一个处理器1101为例。
存储器1102即为本申请所提供的非瞬时计算机可读存储介质。其中,所述存储器存储有可由至少一个处理器执行的指令,以使所述至少一个处理器执行本申请所提供的基于知识蒸馏的模型训练方法或意图识别方法。本申请的非瞬时计算机可读存储介质存储计算机指令,该计算机指令用于使计算机执行本申请所提供的基于知识蒸馏的模型训练方法或意图识别方法。
存储器1102作为一种非瞬时计算机可读存储介质,可用于存储非瞬时软件程序、非瞬时计算机可执行程序以及模块,如本申请实施例中的基于知识蒸馏的模型训练方法或意图识别方法对应的程序指令/模块(例如,附图9所示的沉淀训练模块901、蒸馏模型构建模块902、目标知识抽取模块903和蒸馏模型训练模块904;或附图10所示的语音数据获取模块1001、意图识别模块1002和响应结果确定模块1003)。处理器1101通过运行存储在存储器1102中的非瞬时软件程序、指令以及模块,从而执行服务器的各种功能应用以及数据处理,即实现上述方法实施例中的基于知识蒸馏的模型训练方法或意图识别方法。
存储器1102可以包括存储程序区和存储数据区,其中,存储程序区可存储操作***、至少一个功能所需要的应用程序;存储数据区可存储根据基于知识蒸馏的模型训练方法或意图识别方法的电子设备的使用所创建的数据等。此外,存储器1102可以包括高速随机存取存储器,还可以包括非瞬时存储器,例如至少一个磁盘存储器件、闪存器件、或其他非瞬时固态存储器件。在一些实施例中,存储器1102可选包括相对于处理器1101远程设置的存储器,这些远程存储器可以通过网络连接至基于知识蒸馏的模型训练方法或意图识别方法的电子设备。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
基于知识蒸馏的模型训练方法或意图识别方法的电子设备还可以包括:输入装置1103和输出装置1104。处理器1101、存储器1102、输入装置1103和输出装置1104可以通过总线或者其他方式连接,图11中以通过总线连接为例。
输入装置1103可接收输入的数字或字符信息,以及产生与基于知识蒸馏的模型训练方法或意图识别方法的电子设备的用户设置以及功能控制有关的键信号输入,例如触摸屏、小键盘、鼠标、轨迹板、触摸板、指示杆、一个或者多个鼠标按钮、轨迹球、操纵杆等输入装置。输出装置1104可以包括显示设备、辅助照明装置(例如,LED)和触觉反馈装置(例如,振动电机)等。该显示设备可以包括但不限于,液晶显示器(LCD)、发光二极管(LED)显示器和等离子体显示器。在一些实施方式中,显示设备可以是触摸屏。
此处描述的***和技术的各种实施方式可以在数字电子电路***、集成电路***、专用ASIC(专用集成电路)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程***上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储***、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储***、该至少一个输入装置、和该至少一个输出装置。
这些计算程序(也称作程序、软件、软件应用、或者代码)包括可编程处理器的机器指令,并且可以利用高级过程和/或面向对象的编程语言、和/或汇编/机器语言来实施这些计算程序。如本文使用的,术语“机器可读介质”和“计算机可读介质”指的是用于将机器指令和/或数据提供给可编程处理器的任何计算机程序产品、设备、和/或装置(例如,磁盘、光盘、存储器、可编程逻辑装置(PLD)),包括,接收作为机器可读信号的机器指令的机器可读介质。术语“机器可读信号”指的是用于将机器指令和/或数据提供给可编程处理器的任何信号。
为了提供与用户的交互,可以在计算机上实施此处描述的***和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的***和技术实施在包括后台部件的计算***(例如,作为数据服务器)、或者包括中间件部件的计算***(例如,应用服务器)、或者包括前端部件的计算***(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的***和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算***中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将***的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机***可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
本实施例的技术方案,根据训练任务数据集,以底层网络、预测层网络和逐次递减的中高层网络为训练对象,对预训练模型进行至少两次沉淀训练,得到强化模型;根据从强化模型中确定的目标网络,络构建蒸馏模型。再通过强化模型的目标网络,抽取训练任务数据集的目标知识;基于抽取的目标知识和训练任务数据集,对蒸馏模型进行训练,得到目标学习模型。本实施例采用逐次递减中高层网络的方式对预训练模型的底层网络进行多次沉淀训练,可以使得预训练模型的底层网络的参数更为精准。后续至少根据沉淀后精准的底层网络和预测层网络构建蒸馏模型,并基于提取出的目标知识对蒸馏模型进行蒸馏训练,使得从预训练模型中蒸馏出的目标学习模型在精简了网络结构的同时,保留了预训练模型的预测精准性,还实现了提高模型的泛化能力。且整个蒸馏过程不受人为因素的影响,将该目标学习模型部署到人机交互设备中,可以实现快速准确的执行任务,以满足人机交互设备实时响应的需求。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发申请中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本申请公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本申请保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本申请的精神和原则之内所作的修改、等同替换和改进等,均应包含在本申请保护范围之内。

Claims (38)

1.一种基于知识蒸馏的模型训练方法,所述方法包括:
根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到所述预训练模型的强化模型;其中,各次所述沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络,所述预训练模型自底向上包括所述底层网络、至少一个所述中高层网络和所述预测层网络;
将所述强化模型中的至少两个网络作为目标网络,并根据所述目标网络构建蒸馏模型,其中,所述目标网络包含特征识别网络和所述预测层网络;所述特征识别网络至少包括所述底层网络;
通过所述强化模型的目标网络,抽取所述训练任务数据集的目标知识;
根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,得到目标学习模型;所述目标学习模型用于对用户语音数据进行意图识别。
2.根据权利要求1所述的方法,其中,所述底层网络和所述中高层网络用于进行特征识别;所述预测层网络用于根据识别的特征进行任务预测。
3.根据权利要求1所述的方法,其中,根据训练任务数据集,对预训练模型进行至少两次沉淀训练,包括:
将所述训练任务数据集进行划分,以确定多份训练数据子集;
根据设定沉淀训练次数,确定每份训练数据子集各自对应的训练对象;其中,各份训练数据子集对应的训练对象包括所述预训练模型的底层网络、中高层网络和预测层网络,且包括的所述中高层网络的层数与沉淀训练的顺序呈反比;
根据所述每份训练数据子集,对所述预训练模型中,该份训练数据子集对应的训练对象进行一次沉淀训练;
其中,训练数据子集的划分份数小于等于所述预训练模型的总层数。
4.根据权利要求3所述的方法,其中,各所述训练对象包括的中高层网络是与底层网络相邻且向上连续的网络层;且基于所述沉淀训练次数的增加,所述训练对象中包括的中高层网络的层数递减为零。
5.根据权利要求1所述的方法,其中,根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到所述预训练模型的强化模型,包括:
根据所述训练任务数据集,对所述预训练模型逐次进行沉淀训练;
根据测试任务数据集,对沉淀训练后的预训练模型进行测试;
若测试结果满足沉淀结束条件,则将所述沉淀训练后的预训练模型作为强化模型。
6.根据权利要求1所述的方法,其中,根据训练任务数据集,对预训练模型进行至少两次沉淀训练之前,还包括:
根据训练领域数据集,对预训练模型进行领域训练,更新所述预训练模型。
7.根据权利要求1所述的方法,其中,将所述强化模型中的至少两个网络作为目标网络,并根据所述目标网络构建蒸馏模型,包括:
将所述强化模型中的至少两个网络作为目标网络,并获取所述目标网络的网络结构块;
根据获取的所述网络结构块,构建与所述强化模型同结构的蒸馏模型。
8.根据权利要求1所述的方法,其中,将所述强化模型中的至少两个网络作为目标网络,并根据所述目标网络构建蒸馏模型,包括:
将所述强化模型中的至少两个网络作为目标网络;
根据所述目标网络,选择与所述强化模型结构不同的神经网络模型作为蒸馏模型,其中,所述神经网络模型的输出层网络与所述目标网络中预测层网络的类型一致,所述神经网络模型的非输出层网络与所述目标网络中特征识别网络的类型一致。
9.根据权利要求1所述的方法,其中,通过所述强化模型的目标网络,抽取所述训练任务数据集的目标知识,包括:
将所述训练任务数据集作为所述强化模型的输入,获取所述强化模型的特征识别网络输出的第一数据特征表示,和所述强化模型的预测层网络输出的第一预测概率表示;
将获取的所述第一数据特征表示和所述第一预测概率表示作为所述训练任务数据集的目标知识。
10.根据权利要求1所述的方法,其中,根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,包括:
将所述训练任务数据集输入所述蒸馏模型中,并根据所述蒸馏模型对所述训练任务数据集的处理结果和所述目标知识,确定软监督标签和硬监督标签;
根据所述软监督标签和所述硬监督标签,确定目标标签;
根据所述目标标签,对所述蒸馏模型的参数进行迭代更新。
11.根据权利要求10所述的方法,其中,将所述训练任务数据集输入所述蒸馏模型中,并根据所述蒸馏模型对所述训练任务数据集的处理结果和所述目标知识,确定软监督标签和硬监督标签,包括:
将所述训练任务数据集输入所述蒸馏模型,得到所述蒸馏模型的特征识别网络输出的第二数据特征表示,和所述蒸馏模型的预测层网络输出的第二预测概率表示;
根据所述目标知识、所述第二数据特征表示和所述第二预测概率表示,确定软监督标签;
根据所述第二预测概率表示和所述训练任务数据集信息,确定硬监督标签。
12.根据权利要求11所述的方法,其中,所述训练任务数据集信息包括:训练任务数据集中训练样本数量、训练标签数量和实际标签值。
13.根据权利要求11所述的方法,其中,根据所述目标知识、所述第二数据特征表示和所述第二预测概率表示,确定软监督标签,包括:
将所述目标知识中的第一数据特征表示和所述第二数据特征表示的均值方差作为数据特征标签;
将所述目标知识中的第一预测概率表示和所述第二预测概率表示的均值方差作为概率预测标签;
根据所述强化模型的特征识别网络的权重值,对所述数据特征标签和所述概率预测标签进行标签融合,得到软监督标签。
14.根据权利要求1所述的方法,其中,根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,得到目标学习模型,包括:
根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练;
根据测试任务数据集,对训练后的蒸馏模型进行测试;
若测试结果满足训练结束条件,则将所述训练后的蒸馏模型作为目标学习模型。
15.根据权利要求1-14中任一项所述的方法,其中,所述预训练模型为bert模型。
16.根据权利要求1-14中任一项所述的方法,其中,所述预训练模型和目标学习模型是用于进行意图识别的模型;
相应的,在根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,得到目标学习模型之后,还包括:
将所述目标学习模型部署到人机交互设备中,以对所述人机交互设备实时获取的用户语音数据进行意图识别。
17.一种意图识别方法,所述方法包括:
获取人机交互设备采集的用户语音数据;
将所述用户语音数据输入目标学习模型,以获取所述目标学习模型输出的用户意图识别结果;其中,所述目标学习模型基于权利要求1-16任一所述的基于知识蒸馏的模型训练方法训练而确定;
根据所述用户意图识别结果确定所述人机交互设备的响应结果。
18.根据权利要求17所述的方法,其中,所述方法的执行主体为所述人机交互设备或与所述人机交互设备通信交互的服务端。
19.一种基于知识蒸馏的模型训练装置,所述装置包括:
沉淀训练模块,用于根据训练任务数据集,对预训练模型进行至少两次沉淀训练,得到所述预训练模型的强化模型;其中,各次所述沉淀训练的训练对象至少包括底层网络、预测层网络和逐次递减的中高层网络,所述预训练模型自底向上包括所述底层网络、至少一个所述中高层网络和所述预测层网络;
蒸馏模型构建模块,用于将所述强化模型中的至少两个网络作为目标网络,并根据所述目标网络构建蒸馏模型,其中,所述目标网络包含特征识别网络和所述预测层网络;所述特征识别网络至少包括所述底层网络;
目标知识抽取模块,用于通过所述强化模型的目标网络,抽取所述训练任务数据集的目标知识;
蒸馏模型训练模块,用于根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练,得到目标学习模型;所述目标学习模型用于对用户语音数据进行意图识别。
20.根据权利要求19所述的装置,其中,所述底层网络和所述中高层网络用于进行特征识别;所述预测层网络用于根据识别的特征进行任务预测。
21.根据权利要求19所述的装置,其中,所述沉淀训练模块包括:
数据子集划分单元,用于将所述训练任务数据集进行划分,以确定多份训练数据子集;
训练对象确定单元,用于根据设定沉淀训练次数,确定每份训练数据子集各自对应的训练对象;其中,各份训练数据子集对应的训练对象包括所述预训练模型的底层网络、中高层网络和预测层网络,且包括的所述中高层网络的层数与沉淀训练的顺序呈反比;
沉淀训练单元,用于根据所述每份训练数据子集,对所述预训练模型中,该份训练数据子集对应的训练对象进行一次沉淀训练;
其中,训练数据子集的划分份数小于等于所述预训练模型的总层数。
22.根据权利要求21所述的装置,其中,各所述训练对象包括的中高层网络是与底层网络相邻且向上连续的网络层;且基于所述沉淀训练次数的增加,所述训练对象中包括的中高层网络的层数递减为零。
23.根据权利要求19所述的装置,其中,所述沉淀训练模块具体用于:
根据所述训练任务数据集,对所述预训练模型逐次进行沉淀训练;
根据测试任务数据集,对沉淀训练后的预训练模型进行测试;
若测试结果满足沉淀结束条件,则将所述沉淀训练后的预训练模型作为强化模型。
24.根据权利要求19所述的装置,其中,还包括:
领域训练模型,用于在根据训练任务数据集,对预训练模型进行至少两次沉淀训练之前,根据训练领域数据集,对预训练模型进行领域训练,更新所述预训练模型。
25.根据权利要求19所述的装置,其中,所述蒸馏模型构建模块具体用于:
将所述强化模型中的至少两个网络作为目标网络,并获取所述目标网络的网络结构块;
根据获取的所述网络结构块,构建与所述强化模型同结构的蒸馏模型。
26.根据权利要求19所述的装置,其中,所述蒸馏模型构建模块还具体用于:
将所述强化模型中的至少两个网络作为目标网络;
根据所述目标网络,选择与所述强化模型结构不同的神经网络模型作为蒸馏模型,其中,所述神经网络模型的输出层网络与所述目标网络中预测层网络的类型一致,所述神经网络模型的非输出层网络与所述目标网络中特征识别网络的类型一致。
27.根据权利要求19所述的装置,其中,所述目标知识抽取模块具体用于:
将所述训练任务数据集作为所述强化模型的输入,获取所述强化模型的特征识别网络输出的第一数据特征表示,和所述强化模型的预测层网络输出的第一预测概率表示;
将获取的所述第一数据特征表示和所述第一预测概率表示作为所述训练任务数据集的目标知识。
28.根据权利要求19所述的装置,其中,所述蒸馏模型训练模块,包括:
监督标签确定单元,用于将所述训练任务数据集输入所述蒸馏模型中,并根据所述蒸馏模型对所述训练任务数据集的处理结果和所述目标知识,确定软监督标签和硬监督标签;
目标标签确定单元,用于根据所述软监督标签和所述硬监督标签,确定目标标签;
模型参数更新单元,用于根据所述目标标签,对所述蒸馏模型的参数进行迭代更新。
29.根据权利要求28所述的装置,其中,所述监督标签确定单元具体包括:
输出获取子单元,用于将所述训练任务数据集输入所述蒸馏模型,得到所述蒸馏模型的特征识别网络输出的第二数据特征表示,和所述蒸馏模型的预测层网络输出的第二预测概率表示;
软标签确定子单元,用于根据所述目标知识、所述第二数据特征表示和所述第二预测概率表示,确定软监督标签;
硬标签确定子单元,用于根据所述第二预测概率表示和所述训练任务数据集信息,确定硬监督标签。
30.根据权利要求29所述的装置,其中,所述训练任务数据集信息包括:训练任务数据集中训练样本数量、训练标签数量和实际标签值。
31.根据权利要求29所述的装置,其中,所述软标签确定子单元具体用于:
将所述目标知识中的第一数据特征表示和所述第二数据特征表示的均值方差作为数据特征标签;
将所述目标知识中的第一预测概率表示和所述第二预测概率表示的均值方差作为概率预测标签;
根据所述强化模型的特征识别网络的权重值,对所述数据特征标签和所述概率预测标签进行标签融合,得到软监督标签。
32.根据权利要求19所述的装置,其中,所述蒸馏模型训练模块还用于:
根据所述目标知识和所述训练任务数据集,对所述蒸馏模型进行训练;
根据测试任务数据集,对训练后的蒸馏模型进行测试;
若测试结果满足训练结束条件,则将所述训练后的蒸馏模型作为目标学习模型。
33.根据权利要求19-32中任一项所述的装置,其中,所述预训练模型为bert模型。
34.根据权利要求19-32中任一项所述的装置,其中,所述预训练模型和目标学习模型是用于进行意图识别的模型;
相应的,还包括:
模型部署模块,用于将所述目标学习模型部署到人机交互设备中,以对所述人机交互设备实时获取的用户语音数据进行意图识别。
35.一种意图识别装置,所述装置包括:
语音数据获取模块,用于获取人机交互设备采集的用户语音数据;
意图识别模块,用于将所述用户语音数据输入目标学习模型,以获取所述目标学习模型输出的用户意图识别结果;其中,所述目标学习模型基于权利要求1-16任一所述的基于知识蒸馏的模型训练方法训练而确定;
响应结果确定模块,用于根据所述用户意图识别结果确定所述人机交互设备的响应结果。
36.根据权利要求35所述的装置,其中,所述装置配置于所述人机交互设备中,或与所述人机交互设备通信交互的服务端。
37. 一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-16中任一项所述的基于知识蒸馏的模型训练方法,或执行权利要求17-18中任一项所述的意图识别方法。
38.一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行权利要求1-16中任一项所述的基于知识蒸馏的模型训练方法,或执行权利要求17-18中任一项所述的意图识别方法。
CN202010444204.XA 2020-05-22 2020-05-22 一种模型训练和意图识别方法、装置、设备及存储介质 Active CN111640425B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010444204.XA CN111640425B (zh) 2020-05-22 2020-05-22 一种模型训练和意图识别方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010444204.XA CN111640425B (zh) 2020-05-22 2020-05-22 一种模型训练和意图识别方法、装置、设备及存储介质

Publications (2)

Publication Number Publication Date
CN111640425A CN111640425A (zh) 2020-09-08
CN111640425B true CN111640425B (zh) 2023-08-15

Family

ID=72333280

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010444204.XA Active CN111640425B (zh) 2020-05-22 2020-05-22 一种模型训练和意图识别方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN111640425B (zh)

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20220188643A1 (en) * 2020-12-11 2022-06-16 International Business Machines Corporation Mixup data augmentation for knowledge distillation framework
US11960842B2 (en) * 2021-02-27 2024-04-16 Walmart Apollo, Llc Methods and apparatus for natural language understanding in conversational systems using machine learning processes
CN113160801B (zh) * 2021-03-10 2024-04-12 云从科技集团股份有限公司 语音识别方法、装置以及计算机可读存储介质
CN113157183B (zh) * 2021-04-15 2022-12-16 成都新希望金融信息有限公司 深度学习模型构建方法、装置、电子设备及存储介质
CN113204614B (zh) * 2021-04-29 2023-10-17 北京百度网讯科技有限公司 模型训练方法、优化训练数据集的方法及其装置
CN113239272B (zh) * 2021-05-12 2022-11-29 烽火通信科技股份有限公司 一种网络管控***的意图预测方法和意图预测装置

Citations (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107247989A (zh) * 2017-06-15 2017-10-13 北京图森未来科技有限公司 一种神经网络训练方法及装置
CN109543817A (zh) * 2018-10-19 2019-03-29 北京陌上花科技有限公司 用于卷积神经网络的模型蒸馏方法及装置
CN109637546A (zh) * 2018-12-29 2019-04-16 苏州思必驰信息科技有限公司 知识蒸馏方法和装置
WO2019143946A1 (en) * 2018-01-19 2019-07-25 Visa International Service Association System, method, and computer program product for compressing neural network models
CN110084368A (zh) * 2018-04-20 2019-08-02 谷歌有限责任公司 用于正则化神经网络的***和方法
CN110162018A (zh) * 2019-05-31 2019-08-23 天津开发区精诺瀚海数据科技有限公司 基于知识蒸馏与隐含层共享的增量式设备故障诊断方法
CN110807515A (zh) * 2019-10-30 2020-02-18 北京百度网讯科技有限公司 模型生成方法和装置
CN110832596A (zh) * 2017-10-16 2020-02-21 因美纳有限公司 基于深度学习的深度卷积神经网络训练方法
CN110837761A (zh) * 2018-08-17 2020-02-25 北京市商汤科技开发有限公司 多模型知识蒸馏方法及装置、电子设备和存储介质
CN110909775A (zh) * 2019-11-08 2020-03-24 支付宝(杭州)信息技术有限公司 一种数据处理方法、装置及电子设备
CN111062951A (zh) * 2019-12-11 2020-04-24 华中科技大学 一种基于语义分割类内特征差异性的知识蒸馏方法
CN111062495A (zh) * 2019-11-28 2020-04-24 深圳市华尊科技股份有限公司 机器学习方法及相关装置
CN111079938A (zh) * 2019-11-28 2020-04-28 百度在线网络技术(北京)有限公司 问答阅读理解模型获取方法、装置、电子设备及存储介质
EP3648014A1 (en) * 2018-10-29 2020-05-06 Fujitsu Limited Model training method, data identification method and data identification device

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20160350834A1 (en) * 2015-06-01 2016-12-01 Nara Logics, Inc. Systems and methods for constructing and applying synaptic networks

Patent Citations (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107247989A (zh) * 2017-06-15 2017-10-13 北京图森未来科技有限公司 一种神经网络训练方法及装置
CN110832596A (zh) * 2017-10-16 2020-02-21 因美纳有限公司 基于深度学习的深度卷积神经网络训练方法
WO2019143946A1 (en) * 2018-01-19 2019-07-25 Visa International Service Association System, method, and computer program product for compressing neural network models
CN110084368A (zh) * 2018-04-20 2019-08-02 谷歌有限责任公司 用于正则化神经网络的***和方法
CN110837761A (zh) * 2018-08-17 2020-02-25 北京市商汤科技开发有限公司 多模型知识蒸馏方法及装置、电子设备和存储介质
CN109543817A (zh) * 2018-10-19 2019-03-29 北京陌上花科技有限公司 用于卷积神经网络的模型蒸馏方法及装置
EP3648014A1 (en) * 2018-10-29 2020-05-06 Fujitsu Limited Model training method, data identification method and data identification device
CN109637546A (zh) * 2018-12-29 2019-04-16 苏州思必驰信息科技有限公司 知识蒸馏方法和装置
CN110162018A (zh) * 2019-05-31 2019-08-23 天津开发区精诺瀚海数据科技有限公司 基于知识蒸馏与隐含层共享的增量式设备故障诊断方法
CN110807515A (zh) * 2019-10-30 2020-02-18 北京百度网讯科技有限公司 模型生成方法和装置
CN110909775A (zh) * 2019-11-08 2020-03-24 支付宝(杭州)信息技术有限公司 一种数据处理方法、装置及电子设备
CN111062495A (zh) * 2019-11-28 2020-04-24 深圳市华尊科技股份有限公司 机器学习方法及相关装置
CN111079938A (zh) * 2019-11-28 2020-04-28 百度在线网络技术(北京)有限公司 问答阅读理解模型获取方法、装置、电子设备及存储介质
CN111062951A (zh) * 2019-12-11 2020-04-24 华中科技大学 一种基于语义分割类内特征差异性的知识蒸馏方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
Yuuki Tachioka.Knowledge Distillation Using Soft and Hard Labels and Annealing for Acoustic Model Training.《2019 IEEE 8th Global Conference on Consumer Electronics》.2019,全文. *

Also Published As

Publication number Publication date
CN111640425A (zh) 2020-09-08

Similar Documents

Publication Publication Date Title
CN111640425B (zh) 一种模型训练和意图识别方法、装置、设备及存储介质
CN111639710B (zh) 图像识别模型训练方法、装置、设备以及存储介质
CN112270379B (zh) 分类模型的训练方法、样本分类方法、装置和设备
CN111598216B (zh) 学生网络模型的生成方法、装置、设备及存储介质
CN111539227B (zh) 训练语义表示模型的方法、装置、设备和计算机存储介质
CN110175628A (zh) 一种基于自动搜索与知识蒸馏的神经网络剪枝的压缩算法
CN111582479B (zh) 神经网络模型的蒸馏方法和装置
CN111259671B (zh) 文本实体的语义描述处理方法、装置及设备
CN111737954B (zh) 文本相似度确定方法、装置、设备和介质
CN111667056B (zh) 用于搜索模型结构的方法和装置
CN111831813B (zh) 对话生成方法、装置、电子设备及介质
EP3961476A1 (en) Entity linking method and apparatus, electronic device and storage medium
CN112560985B (zh) 神经网络的搜索方法、装置及电子设备
CN111326251B (zh) 一种问诊问题输出方法、装置以及电子设备
CN111639753B (zh) 用于训练图像处理超网络的方法、装置、设备以及存储介质
CN110675954A (zh) 信息处理方法及装置、电子设备、存储介质
KR102293791B1 (ko) 반도체 소자의 시뮬레이션을 위한 전자 장치, 방법, 및 컴퓨터 판독가능 매체
CN113705628B (zh) 预训练模型的确定方法、装置、电子设备以及存储介质
JP2022078310A (ja) 画像分類モデル生成方法、装置、電子機器、記憶媒体、コンピュータプログラム、路側装置およびクラウド制御プラットフォーム
CN112329453B (zh) 样本章节的生成方法、装置、设备以及存储介质
CN115455171B (zh) 文本视频的互检索以及模型训练方法、装置、设备及介质
CN114715145B (zh) 一种轨迹预测方法、装置、设备及自动驾驶车辆
CN112288483A (zh) 用于训练模型的方法和装置、用于生成信息的方法和装置
CN112580723B (zh) 多模型融合方法、装置、电子设备和存储介质
CN112487239B (zh) 视频检索方法、模型训练方法、装置、设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant