CN114722805A - 基于大小导师知识蒸馏的少样本情感分类方法 - Google Patents

基于大小导师知识蒸馏的少样本情感分类方法 Download PDF

Info

Publication number
CN114722805A
CN114722805A CN202210653730.6A CN202210653730A CN114722805A CN 114722805 A CN114722805 A CN 114722805A CN 202210653730 A CN202210653730 A CN 202210653730A CN 114722805 A CN114722805 A CN 114722805A
Authority
CN
China
Prior art keywords
model
sample
instructor
training
tutor
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210653730.6A
Other languages
English (en)
Other versions
CN114722805B (zh
Inventor
李寿山
常晓琴
周国栋
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Suzhou University
Original Assignee
Suzhou University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Suzhou University filed Critical Suzhou University
Priority to CN202210653730.6A priority Critical patent/CN114722805B/zh
Publication of CN114722805A publication Critical patent/CN114722805A/zh
Application granted granted Critical
Publication of CN114722805B publication Critical patent/CN114722805B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/279Recognition of textual entities
    • G06F40/284Lexical analysis, e.g. tokenisation or collocates
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2155Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • Biophysics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Biomedical Technology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Other Investigation Or Analysis Of Materials By Electrical Means (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明涉及一种基于大小导师知识蒸馏的少样本情感分类方法,包括收集大量情感分类任务上的未标注样本和有标注样本,使用有标注样本训练大导师模型和小导师模型;全部未标注样本经过小导师模型得到每个样本概率的不确定性,然后根据阈值筛选出样本概率高度不确定的样本再次经过大导师模型;结合大导师模型和小导师模型的概率输出形成软标签来蒸馏学生模型,使用蒸馏后的学生模型进行分类预测。本发明减少了访问大导师模型的频率,减少了训练学生模型过程中的蒸馏时间,减少资源消耗的同时提升了分类识别的正确率。

Description

基于大小导师知识蒸馏的少样本情感分类方法
技术领域
本发明涉及自然语言处理技术领域,尤其是指一种基于大小导师知识蒸馏的少样本情感分类方法。
背景技术
情感分类任务旨在对文本表达的情感极性(如:消极和积极)进行自动判断。该任务是自然语言处理研究领域中的研究热点,并在意见挖掘、信息检索和问答***等众多应用***中广泛应用,是这些应用***的基础环节。情感分类只中的少样本情感分类是指在训练分类器时仅有少量的标注样本可以使用。
在进行少样本情感分类时,人工智能领域通常使用机器学***衡的少量标注样本的语料;(2)基于提示的大规模预训练语言模型(比如GPT-3)利用少量的标注样本训练模型,获得分类模型;(3)使用分类模型对某个未知标签的文本进行测试,获得该文本段的极性标签。测试过程中,每次输入分类模型的是单个文本。其中第(2)步的基于提示的大规模预训练语言模型的网络结构如图1所示,图中[CLS] x [SEP]是输入语句,[CLS]标志句首,[SEP]标志句子与句子的分隔,x是原始预训练模型预测句子的分类。图1中“MLM head”是掩码语言模型在基于提示的大规模预训练语言模型中的固定用法。通过“MLM head”得到积极标签“好”,从而得到输入语句“[CLS]我会把他们推荐给每一个人!它 [MASK]。[SEP]”的反馈输出为“我会把他们推荐给每一个人!它好。”。
少样本情感分类由于训练样本很少,常见的浅层神经网络(例如CNN、LSTM等)和深度预训练语言模型(例如BERT、RoBERTa等)很难对某些文本的语义做出正确判断,分类的识别率不够高。现有技术GPT-3大型模型的参数量达1750亿,通过增加一些输入和相应输出的实例作为上下文,能够在少样本学习任务上表现优异。但是由于参数量过于庞大,调用模型需要耗费昂贵的计算资源,推理速度也很慢,给实际应用带来了阻碍。
发明内容
为此,本发明所要解决的技术问题在于克服现有技术中的不足,提供一种基于大小导师知识蒸馏的少样本情感分类方法,可以有效减少访问大导师模型的频率和训练学生模型过程中的蒸馏时间,并在减少资源消耗的同时提升分类识别的正确率。
为解决上述技术问题,本发明提供了一种基于大小导师知识蒸馏的少样本情感分类方法,包括以下步骤:
S1:将样本分为有标注样本x u 和未标注样本x u ′,收集大量情感分类任务上的未标 注样本x u ′,建立有标注样本的集合
Figure 300169DEST_PATH_IMAGE001
和未标注样本的集合D u ={x u ′};
S2:构建大导师模型和小导师模型,使用有标注样本集合D l 训练大导师模型得到训练完成的大导师模型M L ,使用有标注样本集合D l 训练小导师模型得到训练完成的小导师模型M B
S3:使用训练完成的小导师模型M B 预测全部未标注样本x u ′得到样本概率
Figure 113405DEST_PATH_IMAGE002
,计算 每个样本概率的不确定性
Figure 565377DEST_PATH_IMAGE003
S4:将不确定性
Figure 602603DEST_PATH_IMAGE004
与预设阈值threshold比较,筛选出样本概率高度 不确定的样本x u ″;
S5:将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标签P,将样本x u ″输入训练完成的大导师模型M L 得到大导师模型的软标签P′,结合小导师模型的软标签P 和大导师模型的软标签P′得到最终的软标签
Figure 22083DEST_PATH_IMAGE005
S6:构建学生模型,使用所述未标注样本集合D u 和所述软标签
Figure 373430DEST_PATH_IMAGE005
蒸馏学生模型,得 到蒸馏完成的学生模型;
S7:使用蒸馏完成的学生模型对测试集进行分类预测。
作为优选的,所述大导师模型和所述小导师模型均为由基于提示的预训练语言模型M组成的教师模型,所述大导师模型的参数量大于所述小导师模型的参数量。
作为优选的,所述使用有标注样本集合D l 训练大导师模型得到训练完成的大导师模型M L ,具体为:
S21:训练集D l ={x u }={x,y}中,x表示输入样例,y表示真实标签;对输入样例x添加提示模板转化成完形填空任务形式:
P(x)=[CLS] x It is [MASK].[SEP],其中[MASK]为填充词,P(x)是语言模型的输入,It is [MASK].是输入文本添加的提示模板;
S22:将L作为分类任务的标签集合,V作为分类任务的标签词集合,构造标签映射 函数:
Figure 194755DEST_PATH_IMAGE006
P(x)作为语言模型的输入,通过基于提示的预训练语言模型M得到[MASK]对应位 置在不同标签
Figure 589833DEST_PATH_IMAGE007
上的得分
Figure 496610DEST_PATH_IMAGE008
其中
Figure 448385DEST_PATH_IMAGE009
Figure 124217DEST_PATH_IMAGE010
表示标签l对应的标签词,k为标签词的长度;
S23:通过softmax层建立预测[MASK]在不同标签l上的类别概率,通过类别概率得 到输入样例x的情感类别
Figure 706508DEST_PATH_IMAGE011
S24:建立大导师模型输出层的损失函数;
S25:重复S22~S24,直到大导师模型收敛,结束训练,得到训练完成的大导师模型M L
所述使用有标注样本集合D l 训练小导师模型得到训练完成的小导师模型M B ,具体为:
S26:训练集D l ={x u }={x,y}中,x表示输入样例,y表示真实标签;对输入样例x添加提示模板转化成完形填空任务形式:
P(x)=[CLS] x It is [MASK].[SEP],其中[MASK]为填充词;
S27:将L作为分类任务的标签集合,V作为分类任务的标签词集合,构造标签映射 函数:
Figure 585733DEST_PATH_IMAGE012
通过基于提示的预训练语言模型M得到[MASK]对应位置在不同标签
Figure 278883DEST_PATH_IMAGE013
上的 得分
Figure 74801DEST_PATH_IMAGE014
其中
Figure 359151DEST_PATH_IMAGE015
Figure 240520DEST_PATH_IMAGE016
表示标签l对应的标签词,k为标签词的长度;
S28:通过softmax层建立预测[MASK]在不同标签l上的类别概率,通过类别概率得 到输入样例x的情感类别
Figure 986628DEST_PATH_IMAGE017
S29:建立小导师模型的输出层的损失函数;
S210:重复S27~S29,直到小导师模型收敛,结束训练,得到训练完成的小导师模型M B
作为优选的,所述使用训练完成的小导师模型M B 预测全部未标注样本x u ′得到样本 概率
Figure 637052DEST_PATH_IMAGE018
,计算每个样本概率的不确定性
Figure 561146DEST_PATH_IMAGE019
,具体为:
S31:将全部未标注样本x u ′输入训练完成的小导师模型M B ,预测得到的概率分布为
Figure 664231DEST_PATH_IMAGE020
S32:计算每个样本概率的不确定性
Figure 449915DEST_PATH_IMAGE021
,计算公式为:
Figure 220425DEST_PATH_IMAGE022
其中|L|为分类任务中标签的类别个数。
作为优选的,所述预设阈值threshold的取值范围为
Figure 49841DEST_PATH_IMAGE023
作为优选的,所述将不确定性
Figure 640222DEST_PATH_IMAGE024
与预设阈值threshold比较,筛选出样 本概率高度不确定的样本x u ″,具体为:
若样本概率的不确定性
Figure 728133DEST_PATH_IMAGE025
大于threshold,则将此样本作为样本概率高 度不确定的样本x u ″。
作为优选的,所述将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标 签P,将样本x u ″输入训练完成的大导师模型M L 得到大导师模型的软标签P′,结合小导师模型 的软标签P和大导师模型的软标签P′得到最终的软标签
Figure 353149DEST_PATH_IMAGE005
,具体为:
S51:将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标签
Figure 619046DEST_PATH_IMAGE026
S52:将样本x u ″输入训练完成的大导师模型M L 得到大导师模型的软标签
Figure 696723DEST_PATH_IMAGE027
S53:
Figure 824210DEST_PATH_IMAGE028
的表达式为:
Figure 303733DEST_PATH_IMAGE030
作为优选的,所述使用所述未标注样本集合D u 和所述软标签
Figure 740531DEST_PATH_IMAGE031
蒸馏学生模型,得 到蒸馏完成的学生模型,具体过程为:
S61:将未标注样本集合D u 作为蒸馏学生模型的训练集,经过学生模型的向量表示 为
Figure 305504DEST_PATH_IMAGE032
,其中g( )表示学生模型的网络函数,A u 为未标注样本集合D u 对应的词向 量矩阵,上标s表示学生模型,
Figure 282687DEST_PATH_IMAGE033
表示学生模型的可学习参数;
S62:建立学生模型输出层的损失函数
Figure 865984DEST_PATH_IMAGE034
,其中n表示批大小,
Figure 208104DEST_PATH_IMAGE035
表示经过学生模型的第i个样本的预测概率,
Figure 260374DEST_PATH_IMAGE036
表示最终的样本概率
Figure 510089DEST_PATH_IMAGE037
中第i个样本的 预测概率,T是蒸馏模型的温度参数,DKL表示KL散度损失函数;
S63:
Figure 183778DEST_PATH_IMAGE038
依次经过线性层和softmax激活层,得到未标注样本集合D u 的概率输出
Figure 962379DEST_PATH_IMAGE039
W s 表示学生模型的线性层上待学习的权重矩阵;
S64:使用损失函数LKD更新学生模型的可学习参数;
S65:重复S61~S64直到损失函数LKD收敛,得到蒸馏完成的学生模型。
作为优选的,所述词向量矩阵A u 中,每一行是输入样本x u ′中每个字符的字向量表示,每个字符的字向量通过word2vec或Glove模型训练获得。
作为优选的,所述KL散度损失函数的表达式为
Figure 501944DEST_PATH_IMAGE040
,其 中|L|为分类任务中标签的类别个数。
本发明的上述技术方案相比现有技术具有以下优点:
本发明通过建立大导师模型和小导师模型对学生模型进行蒸馏,使得样本经过小导师模型筛选后再经过大导师模型,可以有效减少对学生模型的蒸馏时间,从而减少资源消耗;同时,在大导师模型和小导师模型减少资源消耗的情况下,可以收集情感分类任务中的大量未标注样本,从而提高分类识别的正确率。
附图说明
为了使本发明的内容更容易被清楚的理解,下面根据本发明的具体实施例并结合附图,对本发明作进一步详细的说明,其中
图1是基于提示的大规模预训练语言模型的网络结构;
图2是传统单一教师和单一学生知识蒸馏方法的结构示意图;
图3是本发明中基于大小导师机制的知识蒸馏方法的结构示意图;
图4是本发明实施例中YELP和IMDB数据集在BERT模型上的实验结果图;
图5是本发明实施例中YELP和IMDB数据集在RoBERTa模型上的实验结果图。
具体实施方式
下面结合附图和具体实施例对本发明作进一步说明,以使本领域的技术人员可以更好地理解本发明并能予以实施,但所举实施例不作为对本发明的限定。
在模型的优化过程中,大模型往往是单个复杂网络或者是若干网络的集合,拥有良好的性能和泛化能力;而小模型因为网络规模较小,表达能力有限。因此,可以利用大模型(教师模型)学习到的知识去指导小模型(学生模型)训练,使得小模型具有与大模型相当的性能,但是参数数量大幅降低,从而实现模型压缩与加速,这个过程称为蒸馏。
与图2所示的传统单一教师和单一学生知识蒸馏方法相比,图3所示的本发明方法 在传统方法的基础上使用了大量未标注样本,并引入基于提示的大导师模型和小导师模型 两个教师模型,图中
Figure 820930DEST_PATH_IMAGE041
为学生模型的输出概率。
本发明一种基于大小导师知识蒸馏的少样本情感分类方法,包括以下步骤:
S1:将样本分为有标注样本x u 和未标注样本x u ′,收集大量情感分类任务上的未标 注样本x u ′,建立有标注样本的集合
Figure 129552DEST_PATH_IMAGE001
和未标注样本的集合
Figure 62742DEST_PATH_IMAGE042
,有标注样本x u 为 含有标签的样本,未标注样本x u ′为无标签的样本,样本中有少量的有标注样本x u 和大量的 未标注样本x u ′。
S2:构建大导师模型和小导师模型,大导师模型和小导师模型均为由基于提示的预训练语言模型M(即prompt方法)组成的教师模型,所述大导师模型的参数量大于所述小导师模型的参数量,本实施例中所述大导师模型的参数量远远大于所述小导师模型的参数量。使用有标注样本集合D l 训练大导师模型得到训练完成的大导师模型M L ,使用有标注样本集合D l 训练小导师模型得到训练完成的小导师模型M B
使用有标注样本集合D l 分别训练大导师模型和小导师模型,大导师模型和小导师模型的训练过程类似,具体过程为:
S21:训练集D l ={x u }={x,y}中,x表示输入样例,y表示真实标签;对输入样例x添加提示模板转化成完形填空任务形式:
P(x)=[CLS] x It is [MASK].[SEP],其中[MASK]为填充词,目的是让基于提示的预训练语言模型M来决定[MASK]处的填充词,将分类任务转化成完形填空任务。输入文本添加提示模板“It is [MASK].”,[MASK]对应分类任务的不同标签,将新的输入通过语言模型,让语言模型决定[MASK]处的填充词,从而实现对文本的分类。
S22:将L作为分类任务的标签集合,V作为分类任务的标签词集合,构造标签映射 函数:
Figure 89603DEST_PATH_IMAGE006
;用于将任务标签映射到基于提示的预训练语言模型M的词表中的某个词或多 个词。例如:情感二分类任务中用0类别对应词表中的单词“terrible”,1类别对应词表中的 单词“great”。P(x)作为语言模型的输入,通过基于提示的预训练语言模型M得到[MASK]对 应位置在不同标签
Figure 149963DEST_PATH_IMAGE013
上的得分
Figure 109829DEST_PATH_IMAGE043
其中
Figure 495811DEST_PATH_IMAGE044
Figure 495122DEST_PATH_IMAGE045
表示标签l对应的标签词,k为标签词的长度。
S23:通过softmax层建立预测[MASK]在不同标签l上的类别概率,通过类别概率得 到输入样例x的情感类别
Figure 93594DEST_PATH_IMAGE046
S24:建立大导师模型的输出层的损失函数;本实施例中损失函数为交叉熵函数, 用来衡量训练样本的真实标签y和预测概率
Figure 376808DEST_PATH_IMAGE047
之间的差异。
S25:重复S22~S24,直到大导师模型收敛结束训练,得到训练完成的大导师模型
Figure 668112DEST_PATH_IMAGE048
S26:训练集D l ={x u }={x,y}中,x表示输入样例,y表示真实标签;对输入样例x添加提示模板转化成完形填空任务形式:
P(x)=[CLS] x It is [MASK].[SEP],其中[MASK]为填充词。
S27:将L作为分类任务的标签集合,V作为分类任务的标签词集合,构造标签映射 函数:
Figure 200724DEST_PATH_IMAGE049
通过基于提示的预训练语言模型M得到[MASK]对应位置在不同标签
Figure 117733DEST_PATH_IMAGE013
上的 得分
Figure 255454DEST_PATH_IMAGE050
其中
Figure 717659DEST_PATH_IMAGE015
Figure 940830DEST_PATH_IMAGE016
表示标签l对应的标签词,k为标签词的长度。
S28:通过softmax层建立预测[MASK]在不同标签l上的类别概率,通过类别概率得 到输入样例x的情感类别
Figure 897416DEST_PATH_IMAGE051
S29:建立小导师模型的输出层的损失函数。
S210:重复S27~S29,直到小导师模型收敛,结束训练,得到训练完成的小导师模型M B
S3:使用训练完成的小导师模型M B 预测全部未标注样本x u ′得到样本概率
Figure 624063DEST_PATH_IMAGE052
,计算 每个样本概率的不确定性
Figure 522749DEST_PATH_IMAGE053
S31:将全部未标注样本x u ′输入训练完成的小导师模型M B ,预测得到的概率分布为
Figure 233216DEST_PATH_IMAGE054
S32:计算每个样本概率的不确定性
Figure 226449DEST_PATH_IMAGE055
,计算公式为:
Figure 338762DEST_PATH_IMAGE056
其中|L|为分类任务中标签的类别个数,通过不确定性
Figure 939507DEST_PATH_IMAGE057
可以衡量样本 预测概率的质量。
S4:将不确定性
Figure 340533DEST_PATH_IMAGE058
与预设阈值threshold比较,筛选出样本概率高度不 确定的样本,预设阈值threshold的取值范围为
Figure 684926DEST_PATH_IMAGE023
若样本概率的不确定性
Figure 390759DEST_PATH_IMAGE059
大于threshold,则将此样本作为样本概率高 度不确定的样本。样本概率的不确定性
Figure 631247DEST_PATH_IMAGE060
大于threshold,说明小导师对样本x u ′ 的分类概率结果置信度不够,需要再次经过大导师模型得到新的概率分布。
S5:将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标签P,将样本x u ″输入训练完成的大导师模型M L 得到大导师模型的软标签P′,结合小导师模型的软标签P 和大导师模型的软标签P′得到最终的软标签
Figure 50727DEST_PATH_IMAGE005
S51:将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标签
Figure 402074DEST_PATH_IMAGE061
S52:将样本x u ″输入训练完成的大导师模型M L 得到大导师模型的软标签
Figure 207088DEST_PATH_IMAGE062
S53:
Figure 618478DEST_PATH_IMAGE063
的表达式为:
Figure 525254DEST_PATH_IMAGE064
S6:构建学生模型,本实施例中的学生模型由小型的浅层神经网络模型组成。使用 未标注样本集合D u 和软标签
Figure 414713DEST_PATH_IMAGE065
蒸馏学生模型,得到蒸馏完成的学生模型。
S61:将未标注样本集合D u 作为蒸馏学生模型的训练集,经过学生模型的向量表示 为
Figure 106856DEST_PATH_IMAGE066
,其中g( )表示学生模型的网络函数,A u 为未标注样本集合D u 对应的词向量 矩阵,未标注样本x u ′的长度为k,字向量的维度为d,则
Figure 689147DEST_PATH_IMAGE067
;上标s表示学生模型,
Figure 83220DEST_PATH_IMAGE068
表示学生模型的可学习参数。
词向量矩阵A u 中,每一行是输入样本x u ′中每个字符的字向量表示,每个字符的字向量通过word2vec或Glove模型训练获得。
S62:建立学生模型输出层的损失函数,即教师模型蒸馏学生模型时使用的损失函 数
Figure 776369DEST_PATH_IMAGE069
,其中n表示批大小,
Figure 103445DEST_PATH_IMAGE070
表示经过学生模型的第i个样本的预测概 率,
Figure 105905DEST_PATH_IMAGE071
表示最终的样本概率
Figure 987274DEST_PATH_IMAGE072
中第i个样本的预测概率,T是蒸馏模型的温度参数,T是蒸馏 模型自带的参数,T越大,softmax的概率分布就越趋于平滑,分布的熵也就越大,携带的信 息越多,DKL表示KL散度损失函数。
KL散度损失函数的表达式为
Figure 484114DEST_PATH_IMAGE073
,其中|L|为分类任务 中标签的类别个数。
S63:
Figure 134538DEST_PATH_IMAGE074
依次经过线性层和softmax激活层,得到未标注样本集合D u 的概率输出
Figure 12627DEST_PATH_IMAGE075
Figure 381291DEST_PATH_IMAGE076
表示学生模型的线性层上待学习的权重矩阵;
S64:使用损失函数LKD更新学生模型的可学习参数;
S65:重复S61~S64直到损失函数LKD收敛,得到蒸馏完成的学生模型。
S7:使用蒸馏完成的学生模型对测试集进行分类预测。
本发明的有益效果:
本发明通过建立大导师模型和小导师模型对学生模型进行蒸馏,使得样本经过小导师模型筛选后再经过大导师模型,可以有效减少对学生模型的蒸馏时间,从而减少资源消耗;同时,在大导师模型和小导师模型减少资源消耗的情况下,可以收集情感分类任务中的大量未标注样本,从而提高分类识别的正确率。
为了进一步说明本发明的有益效果,本实施例中将测试集输入到训练完成的学生模型中得到预测概率。从(1)学生模型对测试集进行预测得到的分类结果的正确率、(2)蒸馏模型中教师模型对所有未标注样本的预测时间和(3)对大导师模型访问率减少的比例这三方面来分析本发明的效果。
本实施例中使用了句子级YELP数据集(详见文献“Zhang X, Zhao J, LeCun Y.Character-level convolutional networks for text classification[J]. Advancesin neural information processing systems, 2015, 28: 649-657.”)和篇章级IMDB数据集(详见文献“Maas A L , Daly R E , Pham P T , et al. Learning Word Vectorsfor Sentiment Analysis[C]// Meeting of the ACL: Human Language Technologies.ACL, 2011.”)作为测试集,分别进行仿真实验,YELP数据集包含用户对餐馆、购物中心、酒店、旅游等领域的商户的评价以及正负情感倾向,IMDB数据集包含来自互联网电影数据库的两级分化的评论。实验过程中,每个数据集分别选取正负平衡的8个样本作为训练集和验证集,正负500个样本作为测试集。此外,YELP数据集的未标注样本数为10万,IMDB数据集的未标注样本数为9.8万。
为了模拟大小导师机制的知识蒸馏过程,本发明在BERT、RoBERTa模型下分别设置大导师模型和小导师模型,使用BERT-large(BERT下的大导师模型)、BERT-base(BERT下的小导师模型)和RoBERTa-large(RoBERTa下的大导师模型)、RoBERTa-base(RoBERTa下的小导师模型)表示。训练教师模型时,标签词分别为“terrible”和“great;批大小设置为4或8;优化器使用AdamW,其中,学习率选择{1e-5, 2e-5, 5e-5}中的一个,权重衰减设置为1e-3,批大小和学习率根据网格搜索超参数的方式确定。学生模型为CNN模型,使用3种不同尺寸的卷积核,分别为(3, 50)、(4, 50)和(5, 50)。每种卷积核的数量为100;每个CNN 使用Glove.6B.50d作为词向量;批大小设置为128;优化器使用Adam,其中,学习率设置为1e-3,权重衰减为1e-5。为防止神经网络模型训练过程出现过拟合现象,设置Dropout参数为0.5。
YELP和IMDB数据集在BERT模型上的实验结果如图4所示,其中YELP和IMDB数据集上的不确定性阈值均设为0.85;YELP和IMDB数据集在RoBERTa模型上的实验结果如图5所示,其中YELP数据集上的不确定性阈值设为0.6,IMDB数据集上的不确定性阈值设为0.9。图4和图5中Fine-tuning表示使用标准微调预训练语言模型,LM-BFF表示使用基于提示微调预训练语言模型,LM-BFF蒸馏CNN表示使用基于提示的预训练语言模型蒸馏CNN模型。由于少样本学***均值(5次结果的方差)”的形式表示。
从图4的分类结果的正确率可以看出,本发明方法与BERT-large模型的蒸馏性能相比,在YELP数据集下提高了91.13%-90.64%=0.49%、在IMDB数据集下提高了84.14%-84.08%=0.06%。并且,本发明方法与BERT-base模型的蒸馏性能相比,在YELP数据集下91.13% > 87.18%、在IMDB数据集下84.14%>84.08%,本发明方法的结果远优于BERT-base模型的蒸馏性能。从图4的预测时间可以看出,本发明方法蒸馏花费的时间与BERT-large教师模型相比,在YELP数据集下提高了91.93s/163.27s=56.31%、在IMDB数据集下提高了962.37s/1598.34s=60.21%。同时,仿真程序统计出,本发明方法对大导师模型访问率减少的比例(对大导师模型访问率减少的比例,即大小导师机制下未标注样本经过大导师模型的次数相比全部经过大导师模型减少的次数占比)与BERT-large相比,在YELP数据集下减少了74.40%、在IMDB数据集下减少了72.42%。
从图5的分类结果的正确率可以看出,本发明方法与RoBERTa-large模型的蒸馏性能相比,在YELP数据集下提高了93.16%-92.80%=0.36%、在IMDB数据集下提高了87.84%-87.64%=0.2%。并且,本发明方法与RoBERTa-base模型的蒸馏性能相比,在YELP数据集下93.16% > 91.82%、在IMDB数据集下87.84%>87.64%,本发明方法的结果优于RoBERTa-base模型的蒸馏性能。从图5的预测时间可以看出,本发明方法蒸馏花费的时间与RoBERTa-large教师模型相比,在YELP数据集下提高了75.59s/163.32s=46.28%、在IMDB数据集下提高了912.65s/1594.93s=57.22%。同时,仿真程序统计出,本发明方法对大导师模型访问率减少的比例(对大导师模型访问率减少的比例,即大小导师机制下未标注样本经过大导师模型的次数相比全部经过大导师模型减少的次数占比)与RoBERTa-large相比,在YELP数据集下减少了84.65%、在IMDB数据集下减少了75.56%。
访问频率和占用的资源成正比的。全部未标注样本都需要经过小导师模型,阈值筛选出的少部分样本再经过大导师模型。相比全部未标注样本都经过大导师模型可以减少大量资源消耗,小导师模型参数量相对较小、占用的计算资源也少,因此仿真中只对大导师模型访问率减少的比例进行分析。
仿真结果进一步说明了本发明可以有效减少访问大导师模型的频率和训练学生模型过程中的蒸馏时间,并在减少资源消耗的同时提升分类识别的正确率。
本领域内的技术人员应明白,本申请的实施例可提供为方法、***、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请实施例的方法、设备(***)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
显然,上述实施例仅仅是为清楚地说明所作的举例,并非对实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式变化或变动。这里无需也无法对所有的实施方式予以穷举。而由此所引申出的显而易见的变化或变动仍处于本发明创造的保护范围之中。

Claims (10)

1.一种基于大小导师知识蒸馏的少样本情感分类方法,其特征在于,包括以下步骤:
S1:将样本分为有标注样本x u 和未标注样本x u ′,收集大量情感分类任务上的未标注样本x u ′,建立有标注样本的集合D l ={x u }和未标注样本的集合D u ={x u ′};
S2:构建大导师模型和小导师模型,使用有标注样本集合D l 训练大导师模型得到训练完成的大导师模型M L ,使用有标注样本集合D l 训练小导师模型得到训练完成的小导师模型M B
S3:使用训练完成的小导师模型M B 预测全部未标注样本x u ′得到样本概率
Figure 56166DEST_PATH_IMAGE001
,计算每个 样本概率的不确定性
Figure 826676DEST_PATH_IMAGE002
S4:将不确定性
Figure 231043DEST_PATH_IMAGE003
与预设阈值threshold比较,筛选出样本概率高度不确 定的样本x u ″;
S5:将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标签P,将样本x u ″输 入训练完成的大导师模型M L 得到大导师模型的软标签P′,结合小导师模型的软标签P和大 导师模型的软标签P′得到最终的软标签
Figure 555845DEST_PATH_IMAGE004
S6:构建学生模型,使用所述未标注样本集合D u 和所述软标签
Figure 394488DEST_PATH_IMAGE005
蒸馏学生模型,得到蒸 馏完成的学生模型;
S7:使用蒸馏完成的学生模型对测试集进行分类预测。
2.根据权利要求1所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在于:所述大导师模型和所述小导师模型均为由基于提示的预训练语言模型M组成的教师模型,所述大导师模型的参数量大于所述小导师模型的参数量。
3.根据权利要求2所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在于:所述使用有标注样本集合D l 训练大导师模型得到训练完成的大导师模型M L ,具体为:
S21:训练集D l ={x u }={x,y}中,x表示输入样例,y表示真实标签;对输入样例x添加提示模板转化成完形填空任务形式:
P(x)=[CLS] x It is [MASK].[SEP],其中[MASK]为填充词,P(x)是语言模型的输入,It is [MASK].是输入文本添加的提示模板;
S22:将L作为分类任务的标签集合,V作为分类任务的标签词集合,构造标签映射函数:
Figure 3193DEST_PATH_IMAGE006
通过基于提示的预训练语言模型M得到[MASK]对应位置在不同标签
Figure 269089DEST_PATH_IMAGE007
上的得分
Figure 346767DEST_PATH_IMAGE008
其中
Figure 723521DEST_PATH_IMAGE009
Figure 953777DEST_PATH_IMAGE010
表示标签l对应的标签词,k为标签词的长度;
S23:通过softmax层建立预测[MASK]在不同标签l上的类别概率,通过类别概率得到输入样例x的情感类别
Figure 593836DEST_PATH_IMAGE011
S24:建立大导师模型输出层的损失函数;
S25:重复S22~S24,直到大导师模型收敛,结束训练,得到训练完成的大导师模型M L
所述使用有标注样本集合D l 训练小导师模型得到训练完成的小导师模型M B ,具体为:
S26:训练集D l ={x u }={x,y}中,x表示输入样例,y表示真实标签;对输入样例x添加提示模板转化成完形填空任务形式:
P(x)=[CLS] x It is [MASK].[SEP],其中[MASK]为填充词;
S27:将L作为分类任务的标签集合,V作为分类任务的标签词集合,构造标签映射函数:
Figure 408078DEST_PATH_IMAGE012
;通过基于提示的预训练语言模型M得到[MASK]对应位置在不同标签
Figure 588523DEST_PATH_IMAGE013
上的 得分
Figure 922553DEST_PATH_IMAGE014
其中
Figure 264672DEST_PATH_IMAGE015
Figure 67674DEST_PATH_IMAGE016
表示标签l对应的标签词,k为标签词的长度;
S28:通过softmax层建立预测[MASK]在不同标签l上的类别概率,通过类别概率得到输入样例x的情感类别
Figure 520652DEST_PATH_IMAGE017
S29:建立小导师模型的输出层的损失函数;
S210:重复S27~S29,直到小导师模型收敛,结束训练,得到训练完成的小导师模型M B
4.根据权利要求3所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在于: 所述使用训练完成的小导师模型M B 预测全部未标注样本x u ′得到样本概率
Figure 709188DEST_PATH_IMAGE018
,计算每个样本 概率的不确定性
Figure 2635DEST_PATH_IMAGE019
,具体为:
S31:将全部未标注样本x u ′输入训练完成的小导师模型M B ,预测得到的概率分布为
Figure 276622DEST_PATH_IMAGE020
S32:计算每个样本概率的不确定性
Figure 798870DEST_PATH_IMAGE021
,计算公式为:
Figure 373071DEST_PATH_IMAGE022
其中|L|为分类任务中标签的类别个数。
5.根据权利要求1所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在于: 所述预设阈值threshold的取值范围为
Figure 10988DEST_PATH_IMAGE023
6.根据权利要求1所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在于: 所述将不确定性
Figure 506691DEST_PATH_IMAGE024
与预设阈值threshold比较,筛选出样本概率高度不确定 的样本x u ″,具体为:
若样本概率的不确定性
Figure 81898DEST_PATH_IMAGE025
大于threshold,则将此样本作为样本概率高度 不确定的样本x u ″。
7.根据权利要求3所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在于: 所述将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标签P,将样本x u ″输入训 练完成的大导师模型M L 得到大导师模型的软标签P′,结合小导师模型的软标签P和大导师 模型的软标签P′得到最终的软标签
Figure 245026DEST_PATH_IMAGE026
,具体为:
S51:将样本x u ′输入训练完成的小导师模型M B 得到小导师模型的软标签
Figure 365429DEST_PATH_IMAGE027
S52:将样本x u ″输入训练完成的大导师模型M L 得到大导师模型的软标签
Figure 879587DEST_PATH_IMAGE028
S53:
Figure 494370DEST_PATH_IMAGE029
的表达式为:
Figure 246425DEST_PATH_IMAGE030
8.根据权利要求1-7任一项所述的基于大小导师知识蒸馏的少样本情感分类方法,其 特征在于:所述使用所述未标注样本集合D u 和所述软标签
Figure 803308DEST_PATH_IMAGE005
蒸馏学生模型,得到蒸馏完成 的学生模型,具体过程为:
S61:将未标注样本集合D u 作为蒸馏学生模型的训练集,经过学生模型的向量表示为
Figure 804762DEST_PATH_IMAGE031
,其中g( )表示学生模型的网络函数,A u 为未标注样本集合D u 对应的词向 量矩阵,上标s表示学生模型,
Figure 721772DEST_PATH_IMAGE032
表示学生模型的可学习参数;
S62:建立学生模型输出层的损失函数
Figure 531596DEST_PATH_IMAGE033
,其中n表示批大小,
Figure 10113DEST_PATH_IMAGE034
表示经过学生模型的第i个样本的预测概率,
Figure 498863DEST_PATH_IMAGE035
表示最终的样本概率
Figure 439137DEST_PATH_IMAGE005
中第i个样本的 预测概率,T是蒸馏模型的温度参数,DKL表示KL散度损失函数;
S63:
Figure 431364DEST_PATH_IMAGE036
依次经过线性层和softmax激活层,得到未标注样本集合D u 的概率输出
Figure 313738DEST_PATH_IMAGE037
W s 表示学生模型的线性层上待学习的权重矩阵;
S64:使用损失函数LKD更新学生模型的可学习参数;
S65:重复S61~S64直到损失函数LKD收敛,得到蒸馏完成的学生模型。
9.根据权利要求8所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在于:所述词向量矩阵A u 中,每一行是输入样本x u ′中每个字符的字向量表示,每个字符的字向量通过word2vec或Glove模型训练获得。
10.根据权利要求8所述的基于大小导师知识蒸馏的少样本情感分类方法,其特征在 于:所述KL散度损失函数的表达式为
Figure 758626DEST_PATH_IMAGE038
,其中|L| 为分类任务中标签的类别个数。
CN202210653730.6A 2022-06-10 2022-06-10 基于大小导师知识蒸馏的少样本情感分类方法 Active CN114722805B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210653730.6A CN114722805B (zh) 2022-06-10 2022-06-10 基于大小导师知识蒸馏的少样本情感分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210653730.6A CN114722805B (zh) 2022-06-10 2022-06-10 基于大小导师知识蒸馏的少样本情感分类方法

Publications (2)

Publication Number Publication Date
CN114722805A true CN114722805A (zh) 2022-07-08
CN114722805B CN114722805B (zh) 2022-08-30

Family

ID=82232411

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210653730.6A Active CN114722805B (zh) 2022-06-10 2022-06-10 基于大小导师知识蒸馏的少样本情感分类方法

Country Status (1)

Country Link
CN (1) CN114722805B (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115186083A (zh) * 2022-07-26 2022-10-14 腾讯科技(深圳)有限公司 一种数据处理方法、装置、服务器、存储介质及产品
CN116186200A (zh) * 2023-01-19 2023-05-30 北京百度网讯科技有限公司 模型训练方法、装置、电子设备和存储介质
CN116861302A (zh) * 2023-09-05 2023-10-10 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113762144A (zh) * 2021-09-05 2021-12-07 东南大学 一种基于深度学习的黑烟车检测方法
CN113886562A (zh) * 2021-10-02 2022-01-04 智联(无锡)信息技术有限公司 一种ai简历筛选方法、***、设备和存储介质
CN114168844A (zh) * 2021-11-11 2022-03-11 北京快乐茄信息技术有限公司 在线预测方法、装置、设备及存储介质
CN114283402A (zh) * 2021-11-24 2022-04-05 西北工业大学 基于知识蒸馏训练与时空联合注意力的车牌检测方法

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113762144A (zh) * 2021-09-05 2021-12-07 东南大学 一种基于深度学习的黑烟车检测方法
CN113886562A (zh) * 2021-10-02 2022-01-04 智联(无锡)信息技术有限公司 一种ai简历筛选方法、***、设备和存储介质
CN114168844A (zh) * 2021-11-11 2022-03-11 北京快乐茄信息技术有限公司 在线预测方法、装置、设备及存储介质
CN114283402A (zh) * 2021-11-24 2022-04-05 西北工业大学 基于知识蒸馏训练与时空联合注意力的车牌检测方法

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115186083A (zh) * 2022-07-26 2022-10-14 腾讯科技(深圳)有限公司 一种数据处理方法、装置、服务器、存储介质及产品
CN115186083B (zh) * 2022-07-26 2024-05-24 腾讯科技(深圳)有限公司 一种数据处理方法、装置、服务器、存储介质及产品
CN116186200A (zh) * 2023-01-19 2023-05-30 北京百度网讯科技有限公司 模型训练方法、装置、电子设备和存储介质
CN116186200B (zh) * 2023-01-19 2024-02-09 北京百度网讯科技有限公司 模型训练方法、装置、电子设备和存储介质
CN116861302A (zh) * 2023-09-05 2023-10-10 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法
CN116861302B (zh) * 2023-09-05 2024-01-23 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法

Also Published As

Publication number Publication date
CN114722805B (zh) 2022-08-30

Similar Documents

Publication Publication Date Title
CN114722805B (zh) 基于大小导师知识蒸馏的少样本情感分类方法
CN111177374B (zh) 一种基于主动学习的问答语料情感分类方法及***
CN110188358B (zh) 自然语言处理模型的训练方法及装置
CN110188272B (zh) 一种基于用户背景的社区问答网站标签推荐方法
CN110619044B (zh) 一种情感分析方法、***、存储介质及设备
CN111914885A (zh) 基于深度学习的多任务人格预测方法和***
CN109062958B (zh) 一种基于TextRank和卷积神经网络的小学作文自动分类方法
CN115270752A (zh) 一种基于多层次对比学习的模板句评估方法
CN113988079A (zh) 一种面向低数据的动态增强多跳文本阅读识别处理方法
Jishan et al. Natural language description of images using hybrid recurrent neural network
Cai Automatic essay scoring with recurrent neural network
CN112364743A (zh) 一种基于半监督学习和弹幕分析的视频分类方法
Nassiri et al. Arabic L2 readability assessment: Dimensionality reduction study
CN115391520A (zh) 一种文本情感分类方法、***、装置及计算机介质
WO2020240572A1 (en) Method for training a discriminator
Arifin et al. Automatic essay scoring for Indonesian short answers using siamese Manhattan long short-term memory
US20220253694A1 (en) Training neural networks with reinitialization
Ma et al. Enhanced hierarchical structure features for automated essay scoring
CN114997175A (zh) 一种基于领域对抗训练的情感分析方法
CN113821571A (zh) 基于bert和改进pcnn的食品安全关系抽取方法
Rawat et al. A Systematic Review of Question Classification Techniques Based on Bloom's Taxonomy
CN112200268A (zh) 一种基于编码器-解码器框架的图像描述方法
LU504829B1 (en) Text classification method, computer readable storage medium and system
Guan et al. Understanding lexical features for Chinese essay grading
Chen et al. An effective relation-first detection model for relational triple extraction

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant