CN110852426B - 基于知识蒸馏的预训练模型集成加速方法及装置 - Google Patents
基于知识蒸馏的预训练模型集成加速方法及装置 Download PDFInfo
- Publication number
- CN110852426B CN110852426B CN201911134079.6A CN201911134079A CN110852426B CN 110852426 B CN110852426 B CN 110852426B CN 201911134079 A CN201911134079 A CN 201911134079A CN 110852426 B CN110852426 B CN 110852426B
- Authority
- CN
- China
- Prior art keywords
- model
- likelihood estimation
- teacher
- student
- estimation probability
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
- G06N5/022—Knowledge engineering; Knowledge acquisition
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Computing Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了基于知识蒸馏的预训练模型集成加速方法及装置,该装置应用该方法,该方法包括定义教师模型集团和学生模型;将已标注分类标签的训练数据输入到教师模型集团和学生模型训练,输出每个教师模型对应的似然估计概率值和学生模型似然估计概率值;对教师模型集团输出的似然估计概率值进行池化,输出池化后的似然估计概率值;衡量教师模型集团经过池化后的似然估计概率值与学生模型似然估计概率值间的差异值;对学生模型的参数进行更新,最终得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型;将得到的学生模型的特征提取器和特征编码器作为学生预训练模型预测待训练的数据,编码成数据特征向量。
Description
技术领域
本发明属于神经网络数据处理技术领域,具体地说,涉及基于知识蒸馏的预训练模型集成加速方法及装置。
背景技术
近年来,卷积神经网络在人脸检测,图片分类,自然语言处理等计算机视觉领域的相关任务中取得了巨大成就。例如,纽约大学的Yann LeCun等人提出将多层卷积神经网络应用于手写数字识别中,Hinton团队使用深度神经网络在ImageNet图像分类比赛中获得压倒性胜利。
随着卷积神经网络的发展,其层次结构的设计越来越复杂,网络参数数量也越来越多,相应的,训练一个优秀的卷积神经网络所需的训练数据集也更加庞大。这样使得运算过程的时间和空间复杂度以及存储代价都大大增加,导致现有的大型卷积神经网络依赖于运算能力极强的高性能处理器和集群服务器。巨大的运算量、时耗和能源消耗使得卷积神经网络很难再计算资源和能源存储有限的移动设备上进行部署,例如手机,智能穿戴设备等。所以,压缩大型神经网络的参数量及降低运算复杂度是一个重要的研究方向。
发明内容
针对现有技术中上述的不足,本发明提供基于知识蒸馏的预训练模型集成加速方法及装置,该方法将运算量庞大的模型作为教师模型,通过对教师模型集团中各个教师模型的似然估计概率值进行池化操作,对不同教师模型的估计结果做一个归纳,使得对数据的分类概率更准确,以便进一步提高对数据的理解能力;通过对教师模型集团池化后的似然估计概率及学生模型的似然估计概率对比得出二者间的差异值,根据该差异值对学生模型进行更新,得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型,将得到的学生模型的特征提取器和特征编码器作为学生预训练模型,通过对学生模型进行更新的过程将大量教师模型已经学会的知识以及对知识的理解模式迁移到学生模型中,以便既保证复杂教师模型的效果,又保证在真实场景进行训练数据识别时的速度;得到的学生预训练模型将待训练的数据编码成数据特征向量,可以应用在不同的处理任务中,一次处理可以重复应用,减少运算复杂度。
为了达到上述目的,本发明采用的解决方案是:基于知识蒸馏的预训练模型集成加速方法,包括:
定义教师模型集团,所述的教师模型集团包括多个教师模型,每个所述的教师模型均包括第一特征提取器、第一特征编码器和第一分类器,所述的第一特征提取器包括卷积网络特征提取器和长短期记忆网络特征提取器与卷积网络特征提取器的结合;定义学生模型,所述的学生模型包括第二特征提取器、第二特征编码器和第二分类器;教师模型集团包括大量的已经经过训练,具备优秀识别能力的教师模型,教师模型的第一特征提取器和第一编码器都是经过已标注标签的训练数据训练迭代过的,而学生模型则是原始的,未经过训练的第二特征提取器和第二特征编码器。
将已标注分类标签的训练数据分别输入到教师模型集团和学生模型训练,学生模型输出似然估计概率值;教师模型集团输出每个教师模型对应的似然估计概率值;教师模型集团输出的是每个教师模型的输出结果,各个结果可以相互补全映证对分类的判断结果,减少判断失误的情况,提高预测的准确率。
对教师模型集团输出的似然估计概率值进行池化操作,输出池化后的似然估计概率值,池化操作包括求平均操作和加权求平均操作;所述的求平均操作包括:对教师模型集团输出的每个教师模型对应的似然估计概率值求平均;所述的加权求平均操作包括:对教师模型集团输出的每个教师模型对应的似然估计概率值进行加权后求平均,通过计算所有教师模型输出概率的平均值或加权平均值,平滑单一教师模型造成的误差,提高预测的准确率;
衡量教师模型集团经过池化后的似然估计概率值与学生模型似然估计概率值间的差异值;
采用梯度下降算法计算对学生模型的参数进行更新,使得学生模型似然估计概率值向教师模型集团经过池化后的似然估计概率值迭代,最终得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型;
将得到的学生模型的特征提取器和特征编码器作为学生预训练模型;
学生预训练模型预测待训练的数据,编码成数据特征向量,提供给下游应用,例如分类、聚类、匹配。
所述的衡量教师模型集团经过池化后的似然估计概率值与学生模型似然估计概率值间的差异采用交叉熵cross entropy损失函数或KL散度(衡量两个概率分布的距离)。
应用基于知识蒸馏的预训练模型集成加速方法的装置,包括教师模型集团、似然估计池化器、学生模型、知识蒸馏装置和学生预训练模型;
所述的教师模型集团包括多个教师模型,用于对已标注分类标签的训练数据进行训练,得到各个教师模型对应的似然估计概率值,多个教师模型可以相同也可以不相同,不相同的教师模型组成的教师模型集团池化后的似然估计概率值有利于学生模型迁移学到更多的知识以及对知识的理解模式;
所述的学生模型用于对已标注分类标签的训练数据进行训练,得到学生模型对应的似然估计概率值;
所述的似然估计池化器用于对教师模型集团输出的似然估计概率值进行池化操作,输出池化后的似然估计概率值,似然估计池化器收集了教师模型集团中若干教师模型的预测结果,通过计算所有教师模型输出概率的平均值的方式,平滑单一模型造成的误差,提高预测的准确率;
所述的知识蒸馏装置用于衡量教师模型集团经过池化后的似然估计概率值与学生模型似然估计概率值间的差异值,并对学生模型进行参数更新,得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型;
所述的学生预训练模型用于包括得到的学生模型的特征提取器和特征编码器,用于将待训练的数据编码成数据特征向量。学生预训练模型处理出来的数据特征向量可以应用在不同的处理任务中,一次处理可以重复应用,减少运算复杂度。例如应用在聚类装置中,进行聚类;应用在分类器中进行分类;应用在匹配器中,进行匹配。
所述的教师模型包括第一特征提取器、第一特征编码器和第一分类器。
所述的学生模型包括第二特征提取器、第二特征编码器和第二分类器。
本发明的有益效果是:
(1)该方法将运算量庞大的模型作为教师模型,通过对教师模型集团中各个教师模型的似然估计概率值进行池化操作,对不同教师模型的估计结果做一个归纳,使得对数据的分类概率更准确,以便进一步提高对数据的理解能力;通过对教师模型集团池化后的似然估计概率及学生模型的似然估计概率对比得出二者间的差异值,根据该差异值对学生模型进行更新,得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型,将得到的学生模型的特征提取器和特征编码器作为学生预训练模型,通过对学生模型进行更新的过程将大量教师模型已经学会的知识以及对知识的理解模式迁移到学生模型中,以便既保证复杂教师模型的效果,又保证在真实场景进行训练数据识别时的速度;得到的学生预训练模型将待训练的数据编码成数据特征向量,可以应用在不同的处理任务中,一次处理可以重复应用,减少运算复杂度。
附图说明
图1为本发明预训练模型集成加速方法流程图;
图2为本发明预训练模型集成加速装置图;
图3为本发明实施例一智能客服预训练模型集成加速装置图;
图中,100-教师模型集团,110-教师模型,111-第一特征提取器,111A-长短期记忆网络特征提取器,111B-卷积网络特征提取器,112-第一特征编码器,1112B-线性特征编码器,113-第一分类器,200-学生模型,210-第二特征提取器,211-卷积网络特征提取器,220-第二特征编码器,221-线性特征编码器,230-第二分类器,300-似然估计池化器,400-知识蒸馏装置,500-学生预训练模型,510-特征提取器,511-卷积网络特征提取器,520-特征编码器,521-线性特征编码器。
具体实施方式
以下结合附图对本发明作进一步描述:
如图1所示,基于知识蒸馏的预训练模型集成加速方法,包括:
定义教师模型集团100,所述的教师模型集团100包括多个教师模型110,每个所述的教师模型110均包括第一特征提取器111、第一特征编码器112和第一分类器113,所述的第一特征提取器111包括卷积网络特征提取器和长短期记忆网络特征提取器与卷积网络特征提取器的结合;定义学生模型200,所述的学生模型200包括第二特征提取器210、第二特征编码器220和第二分类器230;教师模型集团100包括大量的已经经过训练,具备优秀识别能力的教师模型110,教师模型110的第一特征提取器111和第一编码器112都是经过已标注标签的训练数据训练迭代过的,而学生模型200则使用的是原始的,未经过训练的第二特征提取器210和第二特征编码器220。
将已标注分类标签的训练数据分别输入到教师模型集团100和学生模型200训练,学生模型200不进行迭代直接根据第二特征提取器210和第二特征编码器220的原始参数计算输出似然估计概率值;教师模型集团100输出每个教师模型110对应的似然估计概率值,教师模型110也不需要进行迭代,直接计算训练数据在该教师模型中的似然估计概率值;教师模型集团100输出的是每个教师模型110的输出结果,各个结果可以相互补全映证对分类的判断结果,减少判断失误的情况,提高预测的准确率。
对教师模型集团100输出的似然估计概率值进行池化操作,输出池化后的似然估计概率值,池化操作包括求平均操作和加权求平均操作;所述的求平均操作包括:对教师模型集团100输出的每个教师模型110对应的似然估计概率值求平均;所述的加权求平均操作包括:对教师模型集团100输出的每个教师模型110对应的似然估计概率值进行加权后求平均,通过计算所有教师模型110输出概率的平均值或加权平均值,平滑单一教师模型110造成的误差,提高预测的准确率;
衡量教师模型集团100经过池化后的似然估计概率值与学生模型200似然估计概率值间的差异值;
采用梯度下降算法计算对学生模型200的参数进行更新,使得学生模型200似然估计概率值向教师模型集团100经过池化后的似然估计概率值迭代,最终得到似然估计概率值最接近教师模型集团100池化后的似然估计概率值的学生模型200;
将得到的学生模型200的特征提取器210和特征编码器220作为学生预训练模型500;
学生预训练模型500预测待训练的数据,编码成数据特征向量,提供给下游应用,例如分类、聚类、匹配。
所述的衡量教师模型集团100经过池化后的似然估计概率值与学生模型200似然估计概率值间的差异采用交叉熵cross entropy损失函数或KL散度(衡量两个概率分布的距离)。
如图2所示,应用基于知识蒸馏的预训练模型集成加速方法的装置,包括教师模型集团100、似然估计池化器300、学生模型200、知识蒸馏装置400和学生预训练模型500;
所述的教师模型集团100包括多个教师模型110,用于对已标注分类标签的训练数据进行训练,得到各个教师模型110对应的似然估计概率值,多个教师模型110可以相同也可以不相同,不相同的教师模型110组成的教师模型集团100池化后的似然估计概率值有利于学生模型200迁移学到更多的知识以及对知识的理解模式;
所述的学生模型200用于对已标注分类标签的训练数据进行训练,得到学生模型200对应的似然估计概率值;
所述的似然估计池化器300用于对教师模型集团100输出的似然估计概率值进行池化操作,输出池化后的似然估计概率值,似然估计池化器300收集了教师模型集团100中若干教师模型110的预测结果,通过计算所有教师模型110输出概率的平均值的方式,平滑单一模型造成的误差,提高预测的准确率;
所述的知识蒸馏装置400用于衡量教师模型集团100经过池化后的似然估计概率值与学生模型200似然估计概率值间的差异值,并对学生模型200进行参数更新,得到似然估计概率值最接近教师模型集团100池化后的似然估计概率值的学生模型200;
所述的学生预训练模型500包括得到的学生模型200的特征提取器510和特征编码器520,用于将待训练的数据编码成数据特征向量。学生预训练模型处理出来的数据特征向量可以应用在不同的处理任务中,一次处理可以重复应用,减少运算复杂度。例如应用在聚类装置中,进行聚类;应用在分类器中进行分类;应用在匹配器中,进行匹配。
所述的教师模型110包括第一特征提取器111、第一特征编码器112和第一分类器113。
所述的学生模型200包括第二特征提取器210、第二特征编码器220和第二分类器230。
实施例一
随着电商行业的急速发展,网上购物成为了大多数人必不可少的日常,在网络购物过程中,消费者往往会对产品的性能、商家的服务、产品的合适尺码等等问题咨询商家,因此,在各大电商平台的商家需要招募大量的客服人员对买家进行答疑解惑,与日俱增的咨询量使得商家对客服机器人的需求逐渐提高。在智能客服领域,意图识别是一个重要任务,旨在理解客服场景中买家发来的问题。之后,针对识别到的买家意图,进行后续相关的操作或回复。
现有的常用自然语言理解模型为长短期记忆网络(LSTM),为时序型计算,并行程度低,且难以通过剪枝等策略去减少计算量。最近涌现的自然语言理解的预训练模型(BERT、XLNet等)都具有数亿参数,计算复杂度高,对计算装置的要求高,且响应时间长。简单模型(CNN、Transformer等)在预训练后的效果表现较差。
同时由于咨询量的不断增加,意图识别模块对计算装置的计算需求不断扩大,请求的响应时间成为了瓶颈。因此需要降低计算需求、提高响应速度,降低模型的复杂度、提高计算的并行能力。
如图3所示为应用本申请的预训练模型集成加速方法的智能客服预训练模型集成加速装置的示意图。智能客服预训练模型集成加速装置包括教师模型集团100,学生模型200、似然估计池化器300、知识蒸馏装置400和学生预训练模型500。
教师模型集团100包括多个教师模型110,教师模型110包括长短期记忆网络特征提取器111A、卷积网络特征提取器111B、线性特征编码器112B和第一分类器113,长短期记忆网络特征提取器111A(Long-short Term Memory Network)是一种时序逻辑的计算网络,可以提取文本的上下文关系特征。而卷积网络特征提取器111B可以提取更多文本的局部特征,获取字和字组成词语的关系特征。线性特征编码器112B对上游的特征向量进行进一步压缩编码和特征空间的转换,以便后面更快更好的运算。第一分类器113包含softmax函数,根据上游特征编码给出不同分类的得分,并计算该输入的客户问题对应每个分类的概率。教师模型110都是参数量大、计算复杂度高的模型,且经过了大量多个行业(服装、鞋包、电器、日用品等)的客户问题与对应的意图标签的训练,具备优秀的意图识别能力。多行业客户问题原本为中文文本,转换为对应的字词向量后在神经网络中参与计算,多行业客户问题的字词向量在教师模型110中首先经过计算复杂且无法并行计算的长短期记忆网络特征提取器111A获得足够的上下文特征信息,然后再经过卷积提取更多文本的局部特征、字和字组成词语的关系特征,这些特征经线性特征编码器112A编码后由第一分类器113进行判别,该输入的客户问题对应每个分类的概率。
似然估计池化器300获取教师模型集团100中每一个教师模型110输出的概率进行池化操作,通过计算所有教师模型110输出概率的平均值的方式,平滑单一模型造成的误差,提高预测的准确率。
学生模型200包括卷积网络特征提取器211、线性特征编码器221和第二分类器230,学生模型200对多行业客户问题的字词向量进行卷积提取文本的局部特征、字和字组成词语的关系特征,这些特征经线性特征编码器221编码后由第二分类器230进行判别,该输入的客户问题对应每个分类的概率。
知识蒸馏装置400用于衡量教师模型集团100经过似然估计池化器300后的似然估计概率与学生模型200似然估计概率差异,知识蒸馏装置400会对两者的差异给出数值,经过常用的梯度下降算法针对学生模型200进行参数更新,直到两者的距离不再变小。
将得到的学生模型200的卷积网络特征提取器211和线性特征编码器221B作为学生预训练模型500的卷积网络特征提取器511和线性特征编码器521,这样学生预训练模型500是一个轻量级的文本编码器,可以接收客户问题字/词向量,编码成客户问题特征向量提供给具体任务应用。学生预训练模型500可以作为任意自然语言理解任务的预训练模型使用。
学生预训练模型500编码成的客户问题特征向量可以应用在应用层,应用层可以包括聚类装置、分类器以及匹配器。根据任务需求,聚类装置可以将表示到高维向量空间的客户问题进行聚类操作;匹配器可以将两个客户问题特征向量映射到一个相似度数值上,根据具体的阈值设定来判定两个客户问题是否相似;对于分类任务,如意图识别任务,学生预训练模型500后面可以对接一个分类器,并经过智能客服相应商品类目的数据进行训练学习,获得一个优秀的意图识别***。
教师模型集团100(包含有各种复杂结构的模型)输出的各种概率经过似然估计池化器300后,可以对不同教师模型110的估计结果做一个归纳,使得对客户问题的分类概率更准确,以便进一步提高对真实客户问题的理解能力,并将之传授给学生模型200。根据实验结果,F1分数至少提高1%。
在使用上述智能客服预训练模型集成加速装置时,首先定义出教师模型集团100和学生模型200,然后将已经标注了分类标签的客户问题数据分别输入到教师模型集团100和学生模型200训练,学生模型200输出似然估计概率值;教师模型集团100输出每个教师模型110对应的似然估计概率值;似然估计池化器300对教师模型集团100输出的似然估计概率值进行池化操作,输出池化后的似然估计概率值;知识蒸馏装置400衡量教师模型集团100经过池化后的似然估计概率值与学生模型200似然估计概率值间的差异值;对学生模型200的参数进行更新,使得学生模型200似然估计概率值向教师模型集团100经过池化后的似然估计概率值迭代,最终得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型100;将得到的学生模型100的卷积网络特征提取器211和线性特征编码器221B作为学生预训练模型500;学生预训练模型500预测待客户问题,编码成数据特征向量提供给应用层应用。
采用以上实施例的装置将运算量庞大的模型作为教师模型110,通过对教师模型集团100中各个教师模型110的似然估计概率值进行池化操作,对不同教师模型110的估计结果做一个归纳,使得对客户问题的分类概率更准确,以便进一步提高对真实客户问题的理解能力;通过对教师模型集团100池化后的似然估计概率及学生模型200的似然估计概率对比得出二者间的差异值,根据该差异值对学生模型200进行更新,得到似然估计概率值最接近教师模型集团100池化后的似然估计概率值的学生模型200,将得到的学生模型200的特征提取器510和特征编码器520作为学生预训练模型500,通过对学生模型200进行更新的过程将大量教师模型110已经学会的知识以及对知识的理解模式迁移到学生模型200中,以便既保证复杂教师模型110的效果,又保证在真实场景进行训练数据识别时的速度;得到的学生预训练模型500将待训练的数据编码成数据特征向量,可以应用在不同的处理任务中,一次处理可以重复应用,减少运算复杂度。
以上所述实施例仅表达了本发明的具体实施方式,其描述较为具体和详细,但并不能因此而理解为对本发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干变形和改进,这些都属于本发明的保护范围。
Claims (8)
1.基于知识蒸馏的预训练模型集成加速方法,其特征在于,包括:
定义教师模型集团,教师模型集团定义学生模型;
将已标注分类标签的客户咨询文本训练数据分别输入到教师模型集团和学生模型训练,学生模型输出似然估计概率值;教师模型集团输出每个教师模型对应的似然估计概率值;
所述教师模型集团包括多个教师模型,教师模型包括长短期记忆网络特征提取器、第一卷积网络特征提取器、第一线性特征编码器和第一分类器,长短期记忆网络特征提取器为时序逻辑的计算网络,用于提取文本的上下文关系特征;所述教师模型获取分类概率的方法如下:将客户问题的字词向量在教师模型中首先经过长短期记忆网络特征提取器获得上下文特征信息,再经过卷积提取文本的局部特征、字和字组成词语的关系特征,所述局部特征与关系特征经第一线性特征编码器编码后由第一分类器进行判别,得到输入的客户问题对应每个分类的概率;
所述学生模型包括第二卷积网络特征提取器、第二线性特征编码器和第二分类器,所述学生模型获取分类概率的方法如下:对客户问题的字词向量进行卷积提取文本的局部特征、字和字组成词语的关系特征,所述局部特征与关系特征经第二线性特征编码器编码后由第二分类器进行判别,得到输入的客户问题对应每个分类的概率;
对教师模型集团输出的似然估计概率值进行池化操作,输出池化后的似然估计概率值;
衡量教师模型集团经过池化后的似然估计概率值与学生模型似然估计概率值间的差异值;
对学生模型的参数进行更新,使得学生模型似然估计概率值向教师模型集团经过池化后的似然估计概率值迭代,最终得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型;
将得到的学生模型的特征提取器和特征编码器作为学生预训练模型;
学生预训练模型预测待训练的客户咨询文本数据,编码成客户咨询文本数据特征向量。
2.根据权利要求1所述的基于知识蒸馏的预训练模型集成加速方法,其特征在于:所述的教师模型集团包括多个教师模型,每个所述的教师模型均包括第一特征提取器、第一特征编码器和第一分类器。
3.根据权利要求2所述的基于知识蒸馏的预训练模型集成加速方法,其特征在于:所述的第一特征提取器包括卷积网络特征提取器和长短期记忆网络特征提取器与卷积网络特征提取器的结合。
4.根据权利要求1所述的基于知识蒸馏的预训练模型集成加速方法,其特征在于:所述的池化操作包括求平均操作和加权求平均操作;所述的求平均操作包括:对教师模型集团输出的每个教师模型对应的似然估计概率值求平均;所述的加权求平均操作包括:对教师模型集团输出的每个教师模型对应的似然估计概率值进行加权后求平均。
5.根据权利要求1所述的基于知识蒸馏的预训练模型集成加速方法,其特征在于:所述的学生模型包括第二特征提取器、第二特征编码器和第二分类器。
6.根据权利要求1所述的基于知识蒸馏的预训练模型集成加速方法,其特征在于:所述的衡量教师模型集团经过池化后的似然估计概率值与学生模型似然估计概率值间的差异采用交叉熵cross entropy损失函数或KL散度。
7.根据权利要求1所述的基于知识蒸馏的预训练模型集成加速方法,其特征在于:所述的对学生模型的参数进行更新采用梯度下降算法计算。
8.应用权利要求1-7中任意一项所述的基于知识蒸馏的预训练模型集成加速方法的装置,其特征在于:包括教师模型集团、似然估计池化器、学生模型、知识蒸馏装置和学生预训练模型;
所述的教师模型集团包括多个教师模型,用于对已标注分类标签的客户咨询文本训练数据进行训练,得到各个教师模型对应的似然估计概率值;
所述的学生模型用于对已标注分类标签的客户咨询文本训练数据进行训练,得到学生模型对应的似然估计概率值;
所述的似然估计池化器用于对教师模型集团输出的似然估计概率值进行池化操作,输出池化后的似然估计概率值;
所述的知识蒸馏装置用于衡量教师模型集团经过池化后的似然估计概率值与学生模型似然估计概率值间的差异值,并对学生模型进行参数更新,得到似然估计概率值最接近教师模型集团池化后的似然估计概率值的学生模型;
所述的学生预训练模型包括得到的学生模型的特征提取器和特征编码器,用于将待训练的客户咨询文本数据编码成客户咨询文本数据特征向量。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911134079.6A CN110852426B (zh) | 2019-11-19 | 2019-11-19 | 基于知识蒸馏的预训练模型集成加速方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911134079.6A CN110852426B (zh) | 2019-11-19 | 2019-11-19 | 基于知识蒸馏的预训练模型集成加速方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110852426A CN110852426A (zh) | 2020-02-28 |
CN110852426B true CN110852426B (zh) | 2023-03-24 |
Family
ID=69602619
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911134079.6A Active CN110852426B (zh) | 2019-11-19 | 2019-11-19 | 基于知识蒸馏的预训练模型集成加速方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110852426B (zh) |
Families Citing this family (17)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111523324B (zh) * | 2020-03-18 | 2024-01-26 | 大箴(杭州)科技有限公司 | 命名实体识别模型的训练方法及装置 |
CN111506702A (zh) * | 2020-03-25 | 2020-08-07 | 北京万里红科技股份有限公司 | 基于知识蒸馏的语言模型训练方法、文本分类方法及装置 |
CN111611377B (zh) * | 2020-04-22 | 2021-10-29 | 淮阴工学院 | 基于知识蒸馏的多层神经网络语言模型训练方法与装置 |
CN111967224A (zh) * | 2020-08-18 | 2020-11-20 | 深圳市欢太科技有限公司 | 对话文本的处理方法、装置、电子设备及存储介质 |
CN112184508B (zh) * | 2020-10-13 | 2021-04-27 | 上海依图网络科技有限公司 | 一种用于图像处理的学生模型的训练方法及装置 |
CN112465138A (zh) * | 2020-11-20 | 2021-03-09 | 平安科技(深圳)有限公司 | 模型蒸馏方法、装置、存储介质及设备 |
US20220188622A1 (en) * | 2020-12-10 | 2022-06-16 | International Business Machines Corporation | Alternative soft label generation |
CN112836762A (zh) * | 2021-02-26 | 2021-05-25 | 平安科技(深圳)有限公司 | 模型蒸馏方法、装置、设备及存储介质 |
CN112949786B (zh) * | 2021-05-17 | 2021-08-06 | 腾讯科技(深圳)有限公司 | 数据分类识别方法、装置、设备及可读存储介质 |
CN113469977B (zh) * | 2021-07-06 | 2024-01-12 | 浙江霖研精密科技有限公司 | 一种基于蒸馏学习机制的瑕疵检测装置、方法、存储介质 |
CN113836903B (zh) * | 2021-08-17 | 2023-07-18 | 淮阴工学院 | 一种基于情境嵌入和知识蒸馏的企业画像标签抽取方法及装置 |
CN113673254B (zh) * | 2021-08-23 | 2022-06-07 | 东北林业大学 | 基于相似度保持的知识蒸馏的立场检测方法 |
CN113837308B (zh) * | 2021-09-29 | 2022-08-05 | 北京百度网讯科技有限公司 | 基于知识蒸馏的模型训练方法、装置、电子设备 |
CN114241282B (zh) * | 2021-11-04 | 2024-01-26 | 河南工业大学 | 一种基于知识蒸馏的边缘设备场景识别方法及装置 |
WO2023212997A1 (zh) * | 2022-05-05 | 2023-11-09 | 五邑大学 | 基于知识蒸馏的神经网络训练方法、设备及存储介质 |
CN115064155A (zh) * | 2022-06-09 | 2022-09-16 | 福州大学 | 一种基于知识蒸馏的端到端语音识别增量学习方法及*** |
CN114841173B (zh) * | 2022-07-04 | 2022-11-18 | 北京邮电大学 | 基于预训练模型的学术文本语义特征提取方法、***和存储介质 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108921294A (zh) * | 2018-07-11 | 2018-11-30 | 浙江大学 | 一种用于神经网络加速的渐进式块知识蒸馏方法 |
CN109616105A (zh) * | 2018-11-30 | 2019-04-12 | 江苏网进科技股份有限公司 | 一种基于迁移学习的带噪语音识别方法 |
CN109637546A (zh) * | 2018-12-29 | 2019-04-16 | 苏州思必驰信息科技有限公司 | 知识蒸馏方法和装置 |
CN109829038A (zh) * | 2018-12-11 | 2019-05-31 | 平安科技(深圳)有限公司 | 基于深度学习的问答反馈方法、装置、设备及存储介质 |
CN109871851A (zh) * | 2019-03-06 | 2019-06-11 | 长春理工大学 | 一种基于卷积神经网络算法的汉字书写规范性判定方法 |
CN110097178A (zh) * | 2019-05-15 | 2019-08-06 | 电科瑞达(成都)科技有限公司 | 一种基于熵注意的神经网络模型压缩与加速方法 |
CN110135574A (zh) * | 2018-02-09 | 2019-08-16 | 北京世纪好未来教育科技有限公司 | 神经网络训练方法、图像生成方法及计算机存储介质 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11410029B2 (en) * | 2018-01-02 | 2022-08-09 | International Business Machines Corporation | Soft label generation for knowledge distillation |
-
2019
- 2019-11-19 CN CN201911134079.6A patent/CN110852426B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110135574A (zh) * | 2018-02-09 | 2019-08-16 | 北京世纪好未来教育科技有限公司 | 神经网络训练方法、图像生成方法及计算机存储介质 |
CN108921294A (zh) * | 2018-07-11 | 2018-11-30 | 浙江大学 | 一种用于神经网络加速的渐进式块知识蒸馏方法 |
CN109616105A (zh) * | 2018-11-30 | 2019-04-12 | 江苏网进科技股份有限公司 | 一种基于迁移学习的带噪语音识别方法 |
CN109829038A (zh) * | 2018-12-11 | 2019-05-31 | 平安科技(深圳)有限公司 | 基于深度学习的问答反馈方法、装置、设备及存储介质 |
CN109637546A (zh) * | 2018-12-29 | 2019-04-16 | 苏州思必驰信息科技有限公司 | 知识蒸馏方法和装置 |
CN109871851A (zh) * | 2019-03-06 | 2019-06-11 | 长春理工大学 | 一种基于卷积神经网络算法的汉字书写规范性判定方法 |
CN110097178A (zh) * | 2019-05-15 | 2019-08-06 | 电科瑞达(成都)科技有限公司 | 一种基于熵注意的神经网络模型压缩与加速方法 |
Non-Patent Citations (5)
Title |
---|
Improving the interpretability of deep neural networks with knowledge distillation;X Liu等;《2018 IEEE International Conference on Data Mining Workshops (ICDMW)》;20181228;1-8 * |
一种基于模拟退火算法改进的卷积神经网络;满凤环等;《微电子学与计算机》;20171102;第34卷(第9期);58-62 * |
基于深度特征蒸馏的人脸识别;葛仕明等;《北京交通大学学报》;20171215(第06期);32-38+46 * |
基于用户隐性反馈行为的下一个购物篮推荐;李裕礞等;《中文信息学报》;20170915(第05期);220-227 * |
深度神经网络压缩与加速综述;纪荣嵘等;《计算机研究与发展》;20180915;第55卷(第9期);1871-1888 * |
Also Published As
Publication number | Publication date |
---|---|
CN110852426A (zh) | 2020-02-28 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110852426B (zh) | 基于知识蒸馏的预训练模型集成加速方法及装置 | |
Dong et al. | Automatic age estimation based on deep learning algorithm | |
CN110609891A (zh) | 一种基于上下文感知图神经网络的视觉对话生成方法 | |
CN110969020A (zh) | 基于cnn和注意力机制的中文命名实体识别方法、***及介质 | |
CN113407660B (zh) | 非结构化文本事件抽取方法 | |
CN110781686B (zh) | 一种语句相似度计算方法、装置及计算机设备 | |
CN112784778A (zh) | 生成模型并识别年龄和性别的方法、装置、设备和介质 | |
CN113822776B (zh) | 课程推荐方法、装置、设备及存储介质 | |
Keren et al. | Convolutional neural networks with data augmentation for classifying speakers' native language | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
CN113837308A (zh) | 基于知识蒸馏的模型训练方法、装置、电子设备 | |
CN110110724A (zh) | 基于指数型挤压函数驱动胶囊神经网络的文本验证码识别方法 | |
CN113255366A (zh) | 一种基于异构图神经网络的方面级文本情感分析方法 | |
CN108805280B (zh) | 一种图像检索的方法和装置 | |
CN114036298B (zh) | 一种基于图卷积神经网络与词向量的节点分类方法 | |
CN114117039A (zh) | 一种小样本文本分类方法及模型 | |
CN116883746A (zh) | 一种基于分区池化超图神经网络的图节点分类方法 | |
CN113705197B (zh) | 一种基于位置增强的细粒度情感分析方法 | |
CN116089605A (zh) | 基于迁移学习和改进词袋模型的文本情感分析方法 | |
CN115687620A (zh) | 一种基于三模态表征学习的用户属性检测方法 | |
JP2023017759A (ja) | セマンティック増強に基づく画像識別モデルのトレーニング方法およびトレーニング装置 | |
CN114036947A (zh) | 一种半监督学习的小样本文本分类方法和*** | |
CN114357166A (zh) | 一种基于深度学习的文本分类方法 | |
CN113689301A (zh) | 赔付策略的构建方法、装置、设备及存储介质 | |
CN113255824A (zh) | 训练分类模型和数据分类的方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |