CN112418291A - 一种应用于bert模型的蒸馏方法、装置、设备及存储介质 - Google Patents

一种应用于bert模型的蒸馏方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN112418291A
CN112418291A CN202011288877.7A CN202011288877A CN112418291A CN 112418291 A CN112418291 A CN 112418291A CN 202011288877 A CN202011288877 A CN 202011288877A CN 112418291 A CN112418291 A CN 112418291A
Authority
CN
China
Prior art keywords
model
original
distillation
layer
target
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011288877.7A
Other languages
English (en)
Inventor
朱桂良
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202011288877.7A priority Critical patent/CN112418291A/zh
Publication of CN112418291A publication Critical patent/CN112418291A/zh
Priority to PCT/CN2021/090524 priority patent/WO2022105121A1/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/16Matrix or vector computation, e.g. matrix-matrix or matrix-vector multiplication, matrix factorization
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Computational Mathematics (AREA)
  • Mathematical Optimization (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Pure & Applied Mathematics (AREA)
  • Computational Linguistics (AREA)
  • Molecular Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Algebra (AREA)
  • Databases & Information Systems (AREA)
  • Feedback Control In General (AREA)

Abstract

本申请实施例属于深度学***衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。

Description

一种应用于BERT模型的蒸馏方法、装置、设备及存储介质
技术领域
本申请涉及深度学习技术领域,尤其涉及一种应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质。
背景技术
近年在计算机视觉、语音识别等诸多领域,在利用深度网络解决问题的时候人们常常倾向于设计更为复杂的网络收集更多的数据以期获得更好的结果。但是,随之而来的是模型的复杂度急剧提升,直观的表现是模参数越来越多、规模越来越大,需要的硬件资源(内存、GPU)越来越高。不利于模型的部署和应用向移动端的推广。
现有一种深度模型蒸馏方法,采用蒸馏模型的优势在进行模型蒸馏时匹配各个中间层之间的数据,已实现压缩模型的目的。
然而,传统的深度模型蒸馏方法普遍不智能,在蒸馏的过程中匹配中间层输出时,往往需要平衡较多损失(loss)参数,例如:下游任务loss、中间层输出loss、相关矩阵loss、注意力矩阵(Attention)loss、等等,从而导致传统的深度模型蒸馏方法存在平衡loss参数较为困难的问题。
发明内容
本申请实施例的目的在于提出一种应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质,以解决传统的深度模型蒸馏方法存在平衡loss参数较为困难的问题。
为了解决上述技术问题,本申请实施例提供一种应用于BERT模型的蒸馏方法,采用了如下所述的技术方案:
接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;
读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;
构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;
基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;
在所述本地数据库中获取所述中间精简模型的训练数据;
基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。
为了解决上述技术问题,本申请实施例还提供一种应用于BERT模型的蒸馏装置,采用了如下所述的技术方案:
请求接收模块,用于接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;
原始模型获取模块,用于读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;
默认模型构建模块,用于构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;
蒸馏操作模块,用于基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;
训练数据获取模块,用于在所述本地数据库中获取所述中间精简模型的训练数据;
模型训练模块,用于基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。
为了解决上述技术问题,本申请实施例还提供一种计算机设备,采用了如下所述的技术方案:
包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所述计算机可读指令时实现如上所述的应用于BERT模型的蒸馏方法的步骤。
为了解决上述技术问题,本申请实施例还提供一种计算机可读存储介质,采用了如下所述的技术方案:
所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如上所述的应用于BERT模型的蒸馏方法的步骤。
与现有技术相比,本申请实施例提供的应用于BERT模型的蒸馏方法、装置、计算机设备及存储介质主要有以下有益效果:
本申请实施例提供了一种应用于BERT模型的蒸馏方法,接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;在所述本地数据库中获取所述中间精简模型的训练数据;基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。
附图说明
为了更清楚地说明本申请中的方案,下面将对本申请实施例描述中所需要使用的附图作一个简单介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例一提供的应用于BERT模型的蒸馏方法的实现流程图;
图2是图1中步骤S104的实现流程图;
图3是图1中步骤S105的实现流程图;
图4是本申请实施例一提供的参数优化操作的实现流程图;
图5是图4中步骤S403的实现流程图;
图6是本申请实施例二提供的应用于BERT模型的蒸馏装置的结构示意图;
图7是根据本申请的计算机设备的一个实施例的结构示意图。
具体实施方式
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同;本文中在申请的说明书中所使用的术语只是为了描述具体的实施例的目的,不是旨在于限制本申请;本申请的说明书和权利要求书及上述附图说明中的术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含。本申请的说明书和权利要求书或上述附图中的术语“第一”、“第二”等是用于区别不同对象,而不是用于描述特定顺序。
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
为了使本技术领域的人员更好地理解本申请方案,下面将结合附图,对本申请实施例中的技术方案进行清楚、完整地描述。
实施例一
如图1所示,示出了根据本申请实施例一提供的应用于BERT模型的蒸馏方法的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在步骤S101中,接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数。
在本申请实施例中,用户终端指的是用于执行本申请提供的预防证件滥用的图像处理方法的终端设备,该当前终端可以是诸如移动电话、智能电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、导航装置等等的移动终端以及诸如数字TV、台式计算机等等的固定终端,应当理解,此处对用户终端的举例仅为方便理解,不用于限定本申请。
在本申请实施例中,蒸馏对象标识主要用于唯一标识需要蒸馏的模型对象,该蒸馏对象标识可以是基于模型名称命名,作为示例,例如:视觉识别模型、语音识别模型等等;该蒸馏对象标识可以是基于名称简称进行命名,作为示例,例如:sjsbmx、yysbmx等等;该蒸馏对象标识还可以是序号进行命名,作为示例,例如:001、002等等,应当理解,此处对蒸馏对象标识的举例仅为方便理解,不用于限定本申请。
在本申请实施例中,蒸馏系数主要用于确认将原始BERT模型的层数缩小的倍数,作为示例,例如:需要将BERT模型从12层蒸馏至4层,那么该蒸馏系数则为3,应当理解,此处对蒸馏系数的举例仅为方便理解,不用于限定本申请。
在步骤S102中,读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵。
在本申请实施例中,本地数据库是指驻留于运行客户应用程序的机器的数据库。本地数据库提供最快的响应时间。因为在客户(应用程序)和服务器之间没有网络转输。该本地数据库预先存储有各式各样的训练好的原始BERT模型,以解决在计算机视觉、语音识别等诸多领域存在的问题。
在本申请实施例中,Bert模型可以分为向量(embedding)层、转换器(transformer)层和预测(prediction)层,每种层是知识的不同表示形式。该原始BERT模型由12层transformer(一种基于“encoder-decoder”结构的模型)组成,该原始BERT模型选用的是交叉熵作为损失函数。该交叉熵主要用于度量两个概率分布间的差异性信息。语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,或者从压缩的角度来看,每个词平均要用几个位来编码。复杂度的意义是用该模型表示这一文本平均的分支数,其倒数可视为每个词的平均概率。平滑是指对没观察到的N元组合赋予一个概率值,以保证词序列总能通过语言模型得到一个概率值。
在步骤S103中,构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵。
在本申请实施例中,构建出来的默认精简模型保留了与BERT相同的模型结构,不同之处在于transformer层的数量。
在步骤S104中,基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型。
在本申请实施例中,蒸馏操作具体包括蒸馏transformer层以及参数初始化。
在本申请实施例中,蒸馏transformer层指的是倘若蒸馏系数为3,那么训练好的原始BERT模型的第一至第三层将替换至默认精简模型的第一层;训练好的原始BERT模型的第四至第六层将替换至默认精简模型的第二层;训练好的原始BERT模型的第七至第九层将替换至默认精简模型的第三层;训练好的原始BERT模型的第十至第十二层将替换至默认精简模型的第四层。
在本申请实施例中,在进行蒸馏替换的过程中,可采用伯努利分布概率确定每一层被替换的概率。
在本申请实施例中,参数初始化指的是embedding、pooler、全连接层参数依据训练好的原始BERT模型中各层级的参数,替换至默认精简模型对应的参数位置。
在步骤S105中,在本地数据库中获取中间精简模型的训练数据。
在本申请实施例中,精简模型训练数据可以采用训练上述原始BERT模型得到的有标签数据,也可以是额外的无标签数据。
在本审请实施例中,可获取原始BERT模型训练后的原始训练数据;调高原始BERT模型softmax层的温度参数,得到调高BERT模型,将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签;基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签;基于放大训练数据以及筛选训练数据选取精简模型训练数据。
在步骤S106中,基于训练数据对中间精简模型进行模型训练操作,得到目标精简模型。
在本申请实施例中,提供了一种应用于BERT模型的蒸馏方法,接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵;构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵;基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型;在本地数据库中获取中间精简模型的训练数据;基于训练数据对中间精简模型进行模型训练操作,得到目标精简模型。由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。
继续参阅图2,示出了图1中步骤S104的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在本申请实施例一的一些可选的实现方式中,上述步骤S104具体包括:步骤S201、步骤S202以及步骤S203。
在步骤S201中,基于蒸馏系数对原始BERT模型的transformer层进行分组操作,得到分组transformer层。
在本申请实施例中,分组操作指的是transformer层数按照蒸馏系数进行分组,作为示例,例如:transformer层数为12,蒸馏系数为3,分组操作则将12个transformer层划分成4组。
在步骤S202中,基于伯努利分布分别在分组transformer层中进行提取操作,得到待替换transformer层。
在本申请实施例中,伯努利分布指的是对于随机变量X有,参数为p(0<p<1),如果它分别以概率p和1-p取1和0为值。EX=p,DX=p(1-p)。伯努利试验成功的次数服从伯努利分布,参数p是试验成功的概率。伯努利分布是一个离散型机率分布,是N=1时二项分布的特殊情况。
在步骤S203中,将待替换transformer层分别替换至默认精简模型,得到中间精简模型。
在本申请实施例中,基于层替换的蒸馏方式,保留了与BERT相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,由于蒸馏时,小模型的部分层基于伯努利采样,随机初始化成训练好的大模型映射层的权重,使模型收敛更快,减少训练轮数。
继续参阅图3,示出了图1中步骤S105的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在本申请实施例一的一些可选的实现方式中,上述步骤S105具体包括:步骤S301、步骤S302、步骤S303、步骤S304以及步骤S305。
在步骤S301中,获取原始BERT模型训练后的原始训练数据。
在本申请实施例中,原始训练数据指的是在获得训练后的原始BERT模型之前,将训练数据输入至未训练的原始BERT模型的训练数据。
在步骤S302中,调高原始BERT模型softmax层的温度参数,得到调高BERT模型。
在本申请实施例中,可将温度参数T调高至一个较大值,作为示例,例如:T=20,应当理解,此处对调高温度参数的举例仅为方便理解,不用于限定本申请。
在步骤S303中,将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签。
在本申请实施例中,每一个原始训练数据在每一个原始BERT模型可以得到其最终的分类概率向量,选取其中概率至最大即为该模型对于当前原始训练数据的判定结果。对于t个原始BERT模型就可以输出t概率向量,然后对t个概率向量求取均值作为当前原始训练数据最后的概率输出向量,当所有原始训练数据完成预测操作之后,得到该原始训练数据对应的均值结果标签。
在步骤S304中,基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签。
在本申请实施例中,由于在训练原始BERT模型时,会对部分样本数据附上标签数据,为获得有映射关系的训练数据,需要根据是否携带标签数据为条件对原始训练数据进行筛选操作,以得到有映射关系的训练数据,作为该筛选结果标签。
在步骤S305中,基于放大训练数据以及筛选训练数据选取精简模型训练数据。
在本申请实施例中,选取到的精简模型训练数据可表示为:
Target=a*hard_target+b*soft_target(a+b=1)
其中,Target表示最终作为中间精简模型训练数据的标签;hard_target表示筛选结果标签;soft_target表示均值结果标签;a、b表示控制标签融合的权重。
继续参阅图4,示出了本申请实施例一提供的参数优化操作的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在本申请实施例一的一些可选的实现方式中,在上述步骤S106之后,上述方法还包括:步骤S401、步骤S402、步骤S403以及步骤S404。
在步骤S401中,在本地数据库中获取优化训练数据。
在本申请实施例中,优化训练数据主要用于优化目标精简模型的参数,该优化训练数据分别输入至训练好的原始BERT模型和目标精简模型,在保证输入数据一致的前提下,可获知原始BERT模型和目标精简模型各个transformer层输出的差异。
在步骤S402中,将优化训练数据分别输入至训练好的原始BERT模型以及目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据。
在步骤S403中,基于搬土距离计算原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据。
在本申请实施例中,搬土距离(EMD)是在一个区域D上两个概率分布之间的距离的度量。可分别获取原始transformer层和目标transformer层分别输出的attention(注意力)矩阵数据,并计算二者attention(注意力)矩阵数据的注意力EMD距离;再获取原始transformer层和目标transformer层分别输出的FFN(全连接前馈神经网络)隐层矩阵数据,并计算二者FFN隐层矩阵数据的FFN隐层EMD距离,以得到该蒸馏损失数据。
在步骤S404中,根据蒸馏损失数据对目标精简模型进行参数优化操作,得到优化精简模型。
在本申请实施例中,在获知蒸馏损失数据(即原始transformer层输出数据以及目标transformer层输出数据的距离度量)后,对目标精简模型的中的参数进行优化,直至蒸馏损失数据小于预设值,或者训练的次数满足预设次数,从而获得该优化精简模型。
在本申请实施例中,由于目标精简模型的transformer层是基于伯努利分布概率进行选取的,从而导致该目标精简模型的参数存在一定的误差,由于Bert模型中的transformer层对模型的贡献最大,包含的信息最丰富,精简模型在该层的学习能力也最为重要,因此通过采用“搬土距离EMD”计算原始BERT模型transformer层的输出以及目标精简模型transformer层的输出之间的损失数据,并基于该损失数据对该目标精简模型的参数进行优化,以提高该目标精简模型的的准确率,能够保证目标模型学习到更多的原始模型的知识。
继续参阅图5,示出了图4中步骤S403的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在本申请实施例一的一些可选的实现方式中,上述步骤S403具体包括:步骤S501、步骤S502、步骤503、步骤S504以及步骤S505。
在步骤S501中,获取原始transformer层输出的原始注意力矩阵以及目标transformer层输出的目标注意力矩阵。
在步骤S502中,根据原始注意力矩阵以及目标注意力矩阵计算注意力EMD距离。
在本申请实施例中,注意力EMD距离表示为:
Figure BDA0002783260590000111
其中,Lattn表示注意力EMD距离;AT表示原始注意力矩阵;AS表示目标注意力矩阵;
Figure BDA0002783260590000112
表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
Figure BDA0002783260590000113
Figure BDA0002783260590000114
表示第i层原始transformer层的原始注意力矩阵;
Figure BDA0002783260590000115
表示第j层目标transformer层的目标注意力矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
在步骤S503中,获取原始transformer层输出的原始FFN隐层矩阵以及目标transformer层输出的目标FFN隐层矩阵。
在步骤S504中,根据原始FFN隐层矩阵以及目标FFN隐层矩阵计算FFN隐层EMD距离。
在本申请实施例中,FFN隐层EMD距离表示为:
Figure BDA0002783260590000121
其中,Lffn表示FFN隐层EMD距离;HT表示原始transformer层的原始FFN隐层矩阵;HS表示目标transformer层的目标FFN隐层矩阵;
Figure BDA0002783260590000122
表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
Figure BDA0002783260590000123
Figure BDA0002783260590000124
表示第j层目标transformer层的目标FFN隐层矩阵;Wh表示转换矩阵;
Figure BDA0002783260590000125
表示第i层原始transformer层的原始FFN隐层矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
在步骤S505中,基于注意力EMD距离以及FFN隐层EMD距离获得蒸馏损失数据。
在本申请实施例中,transformer层是Bert模型中的重要组成部分,通过自注意力机制可以捕获长距离依赖关系,一个标准的transformer主要包括两部分:多头注意力机制(Multi-Head Attention,MHA)和全连接前馈神经网络(FFN)。EMD是使用线性规划计算两个分布之间最优距离的方法,可以使知识的蒸馏更加合理。
在本申请实施例一的一些可选的实现方式中,注意力EMD距离表示为:
Figure BDA0002783260590000126
其中,Lattn表示注意力EMD距离;AT表示原始注意力矩阵;AS表示目标注意力矩阵;
Figure BDA0002783260590000127
表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
Figure BDA0002783260590000128
Figure BDA0002783260590000129
表示第i层原始transformer层的原始注意力矩阵;
Figure BDA00027832605900001210
表示第j层目标transformer层的目标注意力矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
在本申请实施例一的一些可选的实现方式中,FFN隐层EMD距离表示为:
Figure BDA0002783260590000131
其中,Lffn表示FFN隐层EMD距离;HT表示原始transformer层的原始FFN隐层矩阵;HS表示目标transformer层的目标FFN隐层矩阵;
Figure BDA0002783260590000132
表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
Figure BDA0002783260590000133
Figure BDA0002783260590000134
表示第j层目标transformer层的目标FFN隐层矩阵;Wh表示转换矩阵;
Figure BDA0002783260590000135
表示第i层原始transformer层的原始FFN隐层矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
综上,本申请实施例一提供了一种应用于BERT模型的蒸馏方法,接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵;构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵;基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型;在本地数据库中获取中间精简模型的训练数据;基于训练数据对中间精简模型进行模型训练操作,得到目标精简模型。由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。另外,基于层替换的蒸馏方式,保留了与BERT相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,由于蒸馏时,小模型的部分层基于伯努利采样,随机初始化成训练好的大模型映射层的权重,使模型收敛更快,减少训练轮数。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机可读指令来指令相关的硬件来完成,该计算机可读指令可存储于一计算机可读取存储介质中,该计算机可读指令在执行时,可包括如上述各方法的实施例的流程。其中,前述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)等非易失性存储介质,或随机存储记忆体(Random Access Memory,RAM)等。
应该理解的是,虽然附图的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,其可以以其他的顺序执行。而且,附图的流程图中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,其执行顺序也不必然是依次进行,而是可以与其他步骤或者其他步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
实施例二
进一步参考图6,作为对上述图1所示方法的实现,本申请提供了一种应用于BERT模型的蒸馏装置的一个实施例,该装置实施例与图1所示的方法实施例相对应,该装置具体可以应用于各种电子设备中。
如图6所示,本实施例的应用于BERT模型的蒸馏装置100包括:请求接收模块110、原始模型获取模块120、默认模型构建模块130、蒸馏操作模块140、训练数据获取模块150以及模型训练模块160。其中:
请求接收模块110,用于接收用户终端发送的模型蒸馏请求,模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;
原始模型获取模块120,用于读取本地数据库,在本地数据库中获取与蒸馏对象标识相对应的训练好的原始BERT模型,原始BERT模型的损失函数为交叉熵;
默认模型构建模块130,用于构建与训练好的原始BERT模型结构一致的待训练的默认精简模型,默认精简模型的损失函数为交叉熵;
蒸馏操作模块140,用于基于蒸馏系数对默认精简模型进行蒸馏操作,得到中间精简模型;
训练数据获取模块150,用于在本地数据库中获取中间精简模型的训练数据;
模型训练模块160,用于基于训练数据对中间精简模型进行模型训练操作,得到目标精简模型。
在本申请实施例中,用户终端指的是用于执行本申请提供的预防证件滥用的图像处理方法的终端设备,该当前终端可以是诸如移动电话、智能电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、导航装置等等的移动终端以及诸如数字TV、台式计算机等等的固定终端,应当理解,此处对用户终端的举例仅为方便理解,不用于限定本申请。
在本申请实施例中,蒸馏对象标识主要用于唯一标识需要蒸馏的模型对象,该蒸馏对象标识可以是基于模型名称命名,作为示例,例如:视觉识别模型、语音识别模型等等;该蒸馏对象标识可以是基于名称简称进行命名,作为示例,例如:sjsbmx、yysbmx等等;该蒸馏对象标识还可以是序号进行命名,作为示例,例如:001、002等等,应当理解,此处对蒸馏对象标识的举例仅为方便理解,不用于限定本申请。
在本申请实施例中,蒸馏系数主要用于确认将原始BERT模型的层数缩小的倍数,作为示例,例如:需要将BERT模型从12层蒸馏至4层,那么该蒸馏系数则为3,应当理解,此处对蒸馏系数的举例仅为方便理解,不用于限定本申请。
在本申请实施例中,本地数据库是指驻留于运行客户应用程序的机器的数据库。本地数据库提供最快的响应时间。因为在客户(应用程序)和服务器之间没有网络转输。该本地数据库预先存储有各式各样的训练好的原始BERT模型,以解决在计算机视觉、语音识别等诸多领域存在的问题。
在本申请实施例中,Bert模型可以分为向量(embedding)层、转换器(transformer)层和预测(prediction)层,每种层是知识的不同表示形式。该原始BERT模型由12层transformer(一种基于“encoder-decoder”结构的模型)组成,该原始BERT模型选用的是交叉熵作为损失函数。该交叉熵主要用于度量两个概率分布间的差异性信息。语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,或者从压缩的角度来看,每个词平均要用几个位来编码。复杂度的意义是用该模型表示这一文本平均的分支数,其倒数可视为每个词的平均概率。平滑是指对没观察到的N元组合赋予一个概率值,以保证词序列总能通过语言模型得到一个概率值。
在本申请实施例中,构建出来的默认精简模型保留了与BERT相同的模型结构,不同之处在于transformer层的数量。
在本申请实施例中,蒸馏操作具体包括蒸馏transformer层以及参数初始化。
在本申请实施例中,蒸馏transformer层指的是倘若蒸馏系数为3,那么训练好的原始BERT模型的第一至第三层将替换至默认精简模型的第一层;训练好的原始BERT模型的第四至第六层将替换至默认精简模型的第二层;训练好的原始BERT模型的第七至第九层将替换至默认精简模型的第三层;训练好的原始BERT模型的第十至第十二层将替换至默认精简模型的第四层。
在本申请实施例中,在进行蒸馏替换的过程中,可采用伯努利分布概率确定每一层被替换的概率。
在本申请实施例中,参数初始化指的是embedding、pooler、全连接层参数依据训练好的原始BERT模型中各层级的参数,替换至默认精简模型对应的参数位置。
在本申请实施例中,精简模型训练数据可以采用训练上述原始BERT模型得到的有标签数据,也可以是额外的无标签数据。
在本审请实施例中,可获取原始BERT模型训练后的原始训练数据;调高原始BERT模型softmax层的温度参数,得到调高BERT模型,将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签;基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签;基于放大训练数据以及筛选训练数据选取精简模型训练数据。
在本申请实施例中,提供了一种应用于BERT模型的蒸馏装置,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。
在本申请实施例二的一些可选的实现方式中,上述蒸馏操作模块140具体包括:分组操作子模块、提取操作子模块以及替换操作子模块。其中:
分组操作子模块,用于基于蒸馏系数对原始BERT模型的transformer层进行分组操作,得到分组transformer层;
提取操作子模块,用于基于伯努利分布分别在分组transformer层中进行提取操作,得到待替换transformer层;
替换操作子模块,用于将待替换transformer层分别替换至默认精简模型,得到中间精简模型。
在本申请实施例二的一些可选的实现方式中,上述训练数据获取模块150具体包括:原始训练数据获取子模块、参数子调高模型、预测操作子模块、筛选操作子模块以及训练数据获取子模块。其中:
原始训练数据获取子模块,用于获取原始BERT模型训练后的原始训练数据;
参数子调高模型,用于调高原始BERT模型softmax层的温度参数,得到调高BERT模型;
预测操作子模块,用于将原始训练数据输入至调高BERT模型进行预测操作,得到均值结果标签;
筛选操作子模块,用于基于标签信息在原始训练数据进行筛选操作,得到带标签的筛选结果标签;
训练数据获取子模块,用于基于放大训练数据以及筛选训练数据选取精简模型训练数据。
在本申请实施例二的一些可选的实现方式中,上述应用于BERT模型的蒸馏装置100还包括:优化训练数据获取模块、蒸馏损失数据计算模块以及参数优化模块。其中:
优化训练数据获取模块,用于在本地数据库中获取优化训练数据;
优化训练数据输入模块,用于将优化训练数据分别输入至训练好的原始BERT模型以及目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据;
蒸馏损失数据计算模块,用于基于搬土距离计算原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据;
参数优化模块,用于根据蒸馏损失数据对目标精简模型进行参数优化操作,得到优化精简模型。
在本申请实施例二的一些可选的实现方式中,上述蒸馏损失数据计算模块具体包括:目标注意力矩阵获取子模块、注意力EMD距离计算子模块、目标FFN隐层矩阵获取子模块、FFN隐层EMD距离计算子模块以及蒸馏损失数据获取子模块。其中:
目标注意力矩阵获取子模块,用于获取原始transformer层输出的原始注意力矩阵以及目标transformer层输出的目标注意力矩阵;
注意力EMD距离计算子模块,用于根据原始注意力矩阵以及目标注意力矩阵计算注意力EMD距离;
目标FFN隐层矩阵获取子模块,用于获取原始transformer层输出的原始FFN隐层矩阵以及目标transformer层输出的目标FFN隐层矩阵;
FFN隐层EMD距离计算子模块,用于根据原始FFN隐层矩阵以及目标FFN隐层矩阵计算FFN隐层EMD距离;
蒸馏损失数据获取子模块,用于基于注意力EMD距离以及FFN隐层EMD距离获得蒸馏损失数据。
在本申请实施例二的一些可选的实现方式中,注意力EMD距离表示为:
Figure BDA0002783260590000191
其中,Lattn表示注意力EMD距离;AT表示原始注意力矩阵;AS表示目标注意力矩阵;
Figure BDA0002783260590000192
表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
Figure BDA0002783260590000193
Figure BDA0002783260590000194
表示第i层原始transformer层的原始注意力矩阵;
Figure BDA0002783260590000195
表示第j层目标transformer层的目标注意力矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
在本申请实施例二的一些可选的实现方式中,FFN隐层EMD距离表示为:
Figure BDA0002783260590000196
其中,Lffn表示FFN隐层EMD距离;HT表示原始transformer层的原始FFN隐层矩阵;HS表示目标transformer层的目标FFN隐层矩阵;
Figure BDA0002783260590000197
表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
Figure BDA0002783260590000198
Figure BDA0002783260590000199
表示第j层目标transformer层的目标FFN隐层矩阵;Wh表示转换矩阵;
Figure BDA00027832605900001910
表示第i层原始transformer层的原始FFN隐层矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
综上,本申请实施例二提供了一种应用于BERT模型的蒸馏装置,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。另外,基于层替换的蒸馏方式,保留了与BERT相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,由于蒸馏时,小模型的部分层基于伯努利采样,随机初始化成训练好的大模型映射层的权重,使模型收敛更快,减少训练轮数。
为解决上述技术问题,本申请实施例还提供计算机设备。具体请参阅图7,图7为本实施例计算机设备基本结构框图。
所述计算机设备200包括通过***总线相互通信连接存储器210、处理器220、网络接口230。需要指出的是,图中仅示出了具有组件210-230的计算机设备200,但是应理解的是,并不要求实施所有示出的组件,可以替代的实施更多或者更少的组件。其中,本技术领域技术人员可以理解,这里的计算机设备是一种能够按照事先设定或存储的指令,自动进行数值计算和/或信息处理的设备,其硬件包括但不限于微处理器、专用集成电路(Application Specific Integrated Circuit,ASIC)、可编程门阵列(Field-Programmable Gate Array,FPGA)、数字处理器(Digital Signal Processor,DSP)、嵌入式设备等。
所述计算机设备可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机设备可以与用户通过键盘、鼠标、遥控器、触摸板或声控设备等方式进行人机交互。
所述存储器210至少包括一种类型的可读存储介质,所述可读存储介质包括闪存、硬盘、多媒体卡、卡型存储器(例如,SD或DX存储器等)、随机访问存储器(RAM)、静态随机访问存储器(SRAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、可编程只读存储器(PROM)、磁性存储器、磁盘、光盘等。在一些实施例中,所述存储器210可以是所述计算机设备200的内部存储单元,例如该计算机设备200的硬盘或内存。在另一些实施例中,所述存储器210也可以是所述计算机设备200的外部存储设备,例如该计算机设备200上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。当然,所述存储器210还可以既包括所述计算机设备200的内部存储单元也包括其外部存储设备。本实施例中,所述存储器210通常用于存储安装于所述计算机设备200的操作***和各类应用软件,例如应用于BERT模型的蒸馏方法的计算机可读指令等。此外,所述存储器210还可以用于暂时地存储已经输出或者将要输出的各类数据。
所述处理器220在一些实施例中可以是中央处理器(Central Processing Unit,CPU)、控制器、微控制器、微处理器、或其他数据处理芯片。该处理器220通常用于控制所述计算机设备200的总体操作。本实施例中,所述处理器220用于运行所述存储器210中存储的计算机可读指令或者处理数据,例如运行所述应用于BERT模型的蒸馏方法的计算机可读指令。
所述网络接口230可包括无线网络接口或有线网络接口,该网络接口230通常用于在所述计算机设备200与其他电子设备之间建立通信连接。
本申请提供的应用于BERT模型的蒸馏方法,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。
本申请还提供了另一种实施方式,即提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可读指令,所述计算机可读指令可被至少一个处理器执行,以使所述至少一个处理器执行如上述的应用于BERT模型的蒸馏方法的步骤。
本申请提供的应用于BERT模型的蒸馏方法,由于精简BERT模型保留了与原始BERT模型相同的模型结构,差异是层数的不同,使得代码改动量较小,而且大模型与小模型的预测代码是一致的,可以复用原代码,使得模型在蒸馏的过程中,无需平衡各个loss参数的权重,进而降低深度模型蒸馏方法的困难程度,同时,训练精简BERT模型各个阶段的任务均保持一致性,使得精简BERT模型收敛得更加稳定。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本申请各个实施例所述的方法。
显然,以上所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例,附图中给出了本申请的较佳实施例,但并不限制本申请的专利范围。本申请可以以许多不同的形式来实现,相反地,提供这些实施例的目的是使对本申请的公开内容的理解更加透彻全面。尽管参照前述实施例对本申请进行了详细的说明,对于本领域的技术人员来而言,其依然可以对前述各具体实施方式所记载的技术方案进行修改,或者对其中部分技术特征进行等效替换。凡是利用本申请说明书及附图内容所做的等效结构,直接或间接运用在其他相关的技术领域,均同理在本申请专利保护范围之内。

Claims (10)

1.一种应用于BERT模型的蒸馏方法,其特征在于,包括下述步骤:
接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;
读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;
构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;
基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;
在所述本地数据库中获取所述中间精简模型的训练数据;
基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。
2.根据权利要求1所述的应用于BERT模型的蒸馏方法,其特征在于,所述基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型的步骤,具体包括:
基于所述蒸馏系数对所述原始BERT模型的transformer层进行分组操作,得到分组transformer层;
基于伯努利分布分别在所述分组transformer层中进行提取操作,得到待替换transformer层;
将所述待替换transformer层分别替换至所述默认精简模型,得到所述中间精简模型。
3.根据权利要求1所述的应用于BERT模型的蒸馏方法,其特征在于,所述在所述本地数据库中获取所述中间精简模型的训练数据的步骤,具体包括:
获取所述原始BERT模型训练后的原始训练数据;
调高所述原始BERT模型softmax层的温度参数,得到调高BERT模型;
将所述原始训练数据输入至所述调高BERT模型进行预测操作,得到均值结果标签;
基于标签信息在所述原始训练数据进行筛选操作,得到带标签的筛选结果标签;
基于所述放大训练数据以及所述筛选训练数据选取所述精简模型训练数据。
4.根据权利要求1所述的应用于BERT模型的蒸馏方法,其特征在于,在所述基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型的步骤之后还包括:
在所述本地数据库中获取优化训练数据;
将所述优化训练数据分别输入至所述训练好的原始BERT模型以及所述目标精简模型中,分别得到原始transformer层输出数据以及目标transformer层输出数据;
基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据;
根据所述蒸馏损失数据对所述目标精简模型进行参数优化操作,得到优化精简模型。
5.根据权利要求4所述的应用于BERT模型的蒸馏方法,其特征在于,所述基于搬土距离计算所述原始transformer层输出数据以及目标transformer层输出数据的蒸馏损失数据的步骤,具体包括:
获取所述原始transformer层输出的原始注意力矩阵以及所述目标transformer层输出的目标注意力矩阵;
根据所述原始注意力矩阵以及所述目标注意力矩阵计算注意力EMD距离;
获取所述原始transformer层输出的原始FFN隐层矩阵以及所述目标transformer层输出的目标FFN隐层矩阵;
根据所述原始FFN隐层矩阵以及所述目标FFN隐层矩阵计算FFN隐层EMD距离;
基于所述注意力EMD距离以及所述FFN隐层EMD距离获得所述蒸馏损失数据。
6.根据权利要求5所述的应用于BERT模型的蒸馏方法,其特征在于,所述注意力EMD距离表示为:
Figure FDA0002783260580000031
其中,Lattn表示注意力EMD距离;AT表示原始注意力矩阵;AS表示目标注意力矩阵;
Figure FDA0002783260580000032
表示原始注意力矩阵与标注意力矩阵之间的均方误差,且
Figure FDA0002783260580000033
Figure FDA0002783260580000034
表示第i层原始transformer层的原始注意力矩阵;
Figure FDA0002783260580000035
表示第j层目标transformer层的目标注意力矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
7.根据权利要求5所述的应用于BERT模型的蒸馏方法,其特征在于,所述FFN隐层EMD距离表示为:
Figure FDA0002783260580000036
其中,Lffn表示FFN隐层EMD距离;HT表示原始transformer层的原始FFN隐层矩阵;HS表示目标transformer层的目标FFN隐层矩阵;
Figure FDA0002783260580000037
表示原始FFN隐层矩阵与目标FFN隐层矩阵之间的均方误差,且
Figure FDA0002783260580000038
Figure FDA0002783260580000039
表示第j层目标transformer层的目标FFN隐层矩阵;Wh表示转换矩阵;
Figure FDA00027832605800000310
表示第i层原始transformer层的原始FFN隐层矩阵;fij表示从第i层原始transformer层迁移到第j层目标transformer层的知识量;M表示原始transformer层的层数;N表示目标transformer层的层数。
8.一种应用于BERT模型的蒸馏装置,其特征在于,包括:
请求接收模块,用于接收用户终端发送的模型蒸馏请求,所述模型蒸馏请求至少携带有蒸馏对象标识以及蒸馏系数;
原始模型获取模块,用于读取本地数据库,在所述本地数据库中获取与所述蒸馏对象标识相对应的训练好的原始BERT模型,所述原始BERT模型的损失函数为交叉熵;
默认模型构建模块,用于构建与所述训练好的原始BERT模型结构一致的待训练的默认精简模型,所述默认精简模型的损失函数为交叉熵;
蒸馏操作模块,用于基于所述蒸馏系数对所述默认精简模型进行蒸馏操作,得到中间精简模型;
训练数据获取模块,用于在所述本地数据库中获取所述中间精简模型的训练数据;
模型训练模块,用于基于所述训练数据对所述中间精简模型进行模型训练操作,得到目标精简模型。
9.一种计算机设备,包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所述计算机可读指令时实现如权利要求1至7中任一项所述的应用于BERT模型的蒸馏方法的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如权利要求1至7中任一项所述的应用于BERT模型的蒸馏方法的步骤。
CN202011288877.7A 2020-11-17 2020-11-17 一种应用于bert模型的蒸馏方法、装置、设备及存储介质 Pending CN112418291A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202011288877.7A CN112418291A (zh) 2020-11-17 2020-11-17 一种应用于bert模型的蒸馏方法、装置、设备及存储介质
PCT/CN2021/090524 WO2022105121A1 (zh) 2020-11-17 2021-04-28 一种应用于bert模型的蒸馏方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011288877.7A CN112418291A (zh) 2020-11-17 2020-11-17 一种应用于bert模型的蒸馏方法、装置、设备及存储介质

Publications (1)

Publication Number Publication Date
CN112418291A true CN112418291A (zh) 2021-02-26

Family

ID=74832129

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011288877.7A Pending CN112418291A (zh) 2020-11-17 2020-11-17 一种应用于bert模型的蒸馏方法、装置、设备及存储介质

Country Status (2)

Country Link
CN (1) CN112418291A (zh)
WO (1) WO2022105121A1 (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2022105121A1 (zh) * 2020-11-17 2022-05-27 平安科技(深圳)有限公司 一种应用于bert模型的蒸馏方法、装置、设备及存储介质
US11526774B2 (en) * 2020-12-15 2022-12-13 Zhejiang Lab Method for automatically compressing multitask-oriented pre-trained language model and platform thereof

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116402811B (zh) * 2023-06-05 2023-08-18 长沙海信智能***研究院有限公司 一种打架斗殴行为识别方法及电子设备

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110188360A (zh) * 2019-06-06 2019-08-30 北京百度网讯科技有限公司 模型训练方法和装置
CN111553479A (zh) * 2020-05-13 2020-08-18 鼎富智能科技有限公司 一种模型蒸馏方法、文本检索方法及装置
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10607598B1 (en) * 2019-04-05 2020-03-31 Capital One Services, Llc Determining input data for speech processing
CN112418291A (zh) * 2020-11-17 2021-02-26 平安科技(深圳)有限公司 一种应用于bert模型的蒸馏方法、装置、设备及存储介质

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110188360A (zh) * 2019-06-06 2019-08-30 北京百度网讯科技有限公司 模型训练方法和装置
CN111553479A (zh) * 2020-05-13 2020-08-18 鼎富智能科技有限公司 一种模型蒸馏方法、文本检索方法及装置
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2022105121A1 (zh) * 2020-11-17 2022-05-27 平安科技(深圳)有限公司 一种应用于bert模型的蒸馏方法、装置、设备及存储介质
US11526774B2 (en) * 2020-12-15 2022-12-13 Zhejiang Lab Method for automatically compressing multitask-oriented pre-trained language model and platform thereof

Also Published As

Publication number Publication date
WO2022105121A1 (zh) 2022-05-27

Similar Documents

Publication Publication Date Title
CN108536679B (zh) 命名实体识别方法、装置、设备及计算机可读存储介质
CN111753060B (zh) 信息检索方法、装置、设备及计算机可读存储介质
CN111581229B (zh) Sql语句的生成方法、装置、计算机设备及存储介质
CN109190120B (zh) 神经网络训练方法和装置及命名实体识别方法和装置
CN112418291A (zh) 一种应用于bert模型的蒸馏方法、装置、设备及存储介质
CN111259625A (zh) 意图识别方法、装置、设备及计算机可读存储介质
CN114780727A (zh) 基于强化学习的文本分类方法、装置、计算机设备及介质
CN113837308B (zh) 基于知识蒸馏的模型训练方法、装置、电子设备
CN112861012B (zh) 基于上下文和用户长短期偏好自适应学习的推荐方法及装置
CN111078847A (zh) 电力用户意图识别方法、装置、计算机设备和存储介质
WO2020215683A1 (zh) 基于卷积神经网络的语义识别方法及装置、非易失性可读存储介质、计算机设备
CN112084752B (zh) 基于自然语言的语句标注方法、装置、设备及存储介质
CN111429204A (zh) 酒店推荐方法、***、电子设备和存储介质
CN115115914B (zh) 信息识别方法、装置以及计算机可读存储介质
CN114781611A (zh) 自然语言处理方法、语言模型训练方法及其相关设备
CN112632227A (zh) 简历匹配方法、装置、电子设备、存储介质和程序产品
CN113239157A (zh) 对话模型的训练方法、装置、设备和存储介质
CN116152833A (zh) 基于图像的表格还原模型的训练方法及表格还原方法
CN114861758A (zh) 多模态数据处理方法、装置、电子设备及可读存储介质
CN114238656A (zh) 基于强化学习的事理图谱补全方法及其相关设备
CN114187486A (zh) 模型训练方法及相关设备
CN112559877A (zh) 基于跨平台异构数据及行为上下文的ctr预估方法及***
CN115545035B (zh) 一种文本实体识别模型及其构建方法、装置及应用
CN116684903A (zh) 小区参数处理方法、装置、设备及存储介质
CN115618043A (zh) 文本操作图互检方法及模型训练方法、装置、设备、介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination