CN117807235B - 一种基于模型内部特征蒸馏的文本分类方法 - Google Patents

一种基于模型内部特征蒸馏的文本分类方法 Download PDF

Info

Publication number
CN117807235B
CN117807235B CN202410064744.3A CN202410064744A CN117807235B CN 117807235 B CN117807235 B CN 117807235B CN 202410064744 A CN202410064744 A CN 202410064744A CN 117807235 B CN117807235 B CN 117807235B
Authority
CN
China
Prior art keywords
model
student
output
loss
teacher
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202410064744.3A
Other languages
English (en)
Other versions
CN117807235A (zh
Inventor
王绍强
靳晓娇
申向峰
戴银飞
王艳柏
刘玉宝
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Changchun University
Original Assignee
Changchun University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Changchun University filed Critical Changchun University
Priority to CN202410064744.3A priority Critical patent/CN117807235B/zh
Publication of CN117807235A publication Critical patent/CN117807235A/zh
Application granted granted Critical
Publication of CN117807235B publication Critical patent/CN117807235B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • G06F16/353Clustering; Classification into predefined classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • G06N3/0442Recurrent networks, e.g. Hopfield networks characterised by memory or gating, e.g. long short-term memory [LSTM] or gated recurrent units [GRU]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • G06N3/0455Auto-encoder networks; Encoder-decoder networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/096Transfer learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Biophysics (AREA)
  • Evolutionary Computation (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Databases & Information Systems (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明提出一种基于模型内部特征蒸馏的文本分类方法,属于自然语言处理领域;包括:首先对微博文本数据集进行预处理,使用tokenizer将文本转换为模型所需的特征;其次将特征分别传入学生模型和教师模型提取预测结果;然后将教师模型中内部的特征蒸馏出来作为软标签,与学生模型中内部的特征计算损失loss1,同时利用学生模型和教师模型的预测结果计算损失loss3,利用学生模型预测结果与学生模型的标签计算损失loss2;最后将得到的三个损失求和,在网路中进行反向传播,优化网络。本发明在文本分类任务中,压缩了模型的内存,提升了模型的性能,更好的权衡了模型大小和模型性能。

Description

一种基于模型内部特征蒸馏的文本分类方法
技术领域
本发明属于自然语言处理领域,尤其涉及一种基于模型内部特征蒸馏的文本分类方法。
背景技术
随着人工智能的迅猛发展,深度学习在各个领域取得了巨大的成功。为了解决更为复杂的问题和提升模型的训练效果,研究者们逐渐将模型的网络结构设计得更加深度和复杂。在自然语言处理(NLP)领域,从最初的循环神经网络(RNN)、长短时记忆网络(LSTM)、预训练词向量(ELMo)到如今备受瞩目的BERT,模型结构逐渐演变为更为深奥和复杂的形态。然而,这些复杂的模型设计需要大量的计算资源和时间,无法满足当前移动计算发展对低资源、低功耗的迫切需求,因此获得轻量高效的语言模型显得尤为紧迫。
知识蒸馏作为一种新兴而通用的模型压缩和迁移学习架构,在最近几年表现出了蓬勃的活力。本质上,知识蒸馏属于迁移学习的范畴。其主要思路是将一个已经训练完善的模型作为教师模型,通过调整输出的“温度”来“蒸馏”出模型的“知识”,然后将这些知识用于训练轻量级的学生模型。希望通过这样的过程,简化模型结构,提高模型性能。
在2019年,知识蒸馏与BERT结合被提出,通过将BERT-large蒸馏到单层的Bi-LSTM中,其效果接近于ELMo模型,提高模型的运行速率,同时减少模型参数量,但由于学生模型的单一化,使得模型在蒸馏最后一层时出现拟合过快现象。
随着研究方法的丰富,蒸馏方式逐渐丰富。提出BERT-PKD,通过蒸馏教师模型中间层的知识,有效的防止了模型过拟合现象。DistillBERT在预训练阶段进行蒸馏,进一步压缩模型,减少模型参数。
尽管知识蒸馏与预训练模型结合的研究已经取得了显著进展,但仍然存在一些问题,目前教师模型设计的非常复杂,导致学生模型很容易过度拟合,模型的泛化能力也比较差。
发明内容
鉴于上述问题,本发明的目的是提供一种基于模型内部特征蒸馏的文本分类方法,以解决现有模型结构设计复杂深奥、计算资源和时间损耗大等的技术问题。
为实现上述目的,本发明采用如下技术方案:
一种基于模型内部特征蒸馏的文本分类方法,具体包括以下步骤:
步骤一、获取微博文本数据集,将微博文本数据集进行划分和预处理,其中,将获取的微博文本数据集划分出训练集和测试集,并使用tokenizer工具对训练集和测试集进行数据预处理,将文本转换为模型输入的包含特征向量的可识别字典;
步骤二、基于预训练模型,构建教师模型和学生模型;
构建教师模型,教师模型基于BERT预训练模型,主干网络利用GRU模型结合双向Bi-LSTM模型和Transformer_Attention模型搭建成带有三条路径的教师模型,三条路径产生的特征进行拼接和全连接后,得到最终的预测结果;其中教师模型是基于BERT预训练模型并将GRU、Bi-LSTM和Transformer_Attention机制相结合,可以更好的获取文本的局部特征和上下文特征;
搭建学生模型,学生模型基于RoBERTa预训练模型,主干网络仅使用GRU模型搭建成带有一条路径的学生模型;使用GRU模型搭建的学生模型简化了模型结构,在实验过程中证明所提出的教师模型的有效性;最后对教师模型和学生模型进行内部特征蒸馏,构建新的损失函数进行训练,压缩模型内存,提升模型性能。
步骤三、利用步骤二中的教师模型和学生模型进行知识蒸馏;
在训练过程中选择相应的损失函数计算模型损失,并保存最优模型;其中,相应的损失函数为torch.nn.functional模型库中的交叉熵CE函数、KL散度函数和均方误差MSE函数;
将教师模型中内部的特征蒸馏出来作为软标签,与学生模型中内部的特征计算损失,得到损失值loss1;利用学生模型预测结果与学生模型的标签计算损失,得到损失值loss2;利用学生模型和教师模型的预测结果计算损失,得到损失值loss3;最后将三个损失相加得到模型的最终损失值Loss;通过反向传播优化参数,训练该模型,通过比对损失值的大小决定是否保存训练得到的模型及参数,训练过程中只保存损失值最小的模型结构和模型参数。
作为本发明的优选,在步骤一中,所述微博文本数据集划分的训练集和测试集比例为8:2,所述tokenizer处理后的字典包括input_ids、token_type_ids和attention_mask。
作为本发明的优选,在步骤二中搭建教师模型的第一条路径,首先将数据送入BERT预训练模型得到输出,取得该输出的最后一个隐藏层状态作为tokens,将tokens作为Bi-LSTM模型的输入,训练过程中将Bi-LSTM模型输出的结果取最后一个时间步的隐藏状态作为输出,输出的维度是[20,640],目的是为了更好的捕获前后文的内在关系。
作为本发明的优选,在步骤二中搭建教师模型的第二条路径,将BERT预训练模型产生的tokens作为Transformer_Attention模型自注意力机制的输入,通过三个线性层得到张量K、张量Q和张量V,同时利用预训练模型权重中的config.hidden_size平方分之一作为自注意力的权重,训练过程中利用Transformer_Attention模型自注意力机制产生注意力特征输出,该输出的维度是[20,300],目的是为了通过不同权重更好地捕捉输入序列中不同部分之间的语义关系。
作为本发明的优选,在步骤二中搭建教师模型的第三条路径,将BERT预模型输出的tokens作为GRU模型的输入,通过运行GRU模型,获得每个序列在序列的最后一个时间步的隐藏状态作为输出,该输出的特征维度是[20,320],目的是为了捕获输入序列中的长期依赖关系。
作为本发明的优选,先将Bi-LSTM模型的输出与Transformer_Attention模型自注意力机制的输出在维度1上进行拼接,然后将其结果与GRU模型的输出在维度1上进行拼接,最后通过一个包含多层的全连接网络得到教师模型的输出:教师_输出[20,2]。
作为本发明的优选,将RoBERTa预训练模型和GRU模型进行结合构建成学生模型,将RoBERTa预训练模型的输出作为GRU模型的输入,运行GRU模型,然后通过一个全连接网络得到学生模型的输出:学生_输出[20,2]。
作为本发明的优选,步骤三选用三个不同的损失函数,在教师模型中内部蒸馏出的软标签与学生模型中内部的特征计算损失时,使用交叉熵损失函数;利用学生模型和教师模型的预测结果计算损失时,使用KL散度损失函数;学生模型预测结果与学生模型的标签计算损失时,使用MSE损失函数;最后将三个损失相加得到模型的最终损失值Loss,具体计算公式如下:
………(1)
……… (2)
………(3)
………(4)
上述公式中、、/>均表示样本数量,/>表示第/>个样本量;公式(1)中,/>表示教师模型中的GRU通过全连接层的输出值gru_out,/>表示学生模型的输出值student_out;公式(2)中,/>表示学生模型的输出student_out的值,/>表示学生标签的真实值;公式(3)中,/>表示教师模型的输出值teacher_out,/>表示学生模型的输出值student_out。
本发明的优点及积极效果是:
(1)针对现有无法平衡模型复杂性与模型性能等问题,本发明提出借助BERT预训练模型参数,将GRU、Bi-LSTM和Transformer_Attention机制结合的方法,通过Bi-LSTM优化GRU有效捕捉长距离上下文依赖关系,通过Transformer_Attention捕获带有权重的语义关系,有效提高模型文本分类精度。
(2)针对文本分类模型结构复杂,导致计算资源和时间损耗大等问题,本发明提出基于模型内部特征蒸馏方法,将基于BERT的预训练模型与GRU、Bi-LSTM、Transformer_Attention机制结合构建出教师模型,将基于RoBERTa的预训练模型与GRU结合构建学生模型,在教师模型内部特征中蒸馏出来作为软标签,与学生模型进行损失计算,有效防止最后一层蒸馏出现过快拟合现象,有效的提高模型运行速率。
(3)本发明的模型中,通过教师模型和学生模型的内部特征蒸馏,实现压缩模型内存的同时提高模型运行速率,此外通过蒸馏技术,使得GRU模型的性能接近于教师模型的效果。
附图说明
图1为本发明一种基于模型内部特征蒸馏的文本分类方法的未特征蒸馏的准确率和损失曲线图;
图2为本发明一种基于模型内部特征蒸馏的文本分类方法的内部特征蒸馏后的准确率和损失曲线图;
图3为本发明一种基于模型内部特征蒸馏的文本分类方法的网络架构图。
具体实施方式
为了更好地了解本发明的目的、结构及功能,下面结合附图,对本发明提出的一种基于模型内部特征蒸馏的文本分类方法做进一步详细的描述。
如图1-图3所示,本发明通过模型内部特征蒸馏的文本分类方法,在保证模型获取充分特征外,还压缩了模型的内存,减小了模型的复杂度,同时也加强了模型的学习能力。
实施例1
如图3所示,是本发明实施例提供的一种基于模型内部特征蒸馏的文本分类方法架构图,需要说明的是,本架构图仅示出了本实施例所述方法的逻辑顺序,在互不冲突的前提下,在本发明其它可能的实施例中,可以以不同于图1-图3所示的顺序完成所示出或描述的步骤。
参见图3,所述基于模型内部特征蒸馏的文本分类方法具体包含如下步骤:
步骤一:获取数据集,将数据集进行划分和预处理。
步骤一中的数据集使用微博文本数据集,该数据集是专门为文本分类任务设计的,数据集中包含了微博中用户评论的文本,每条文本都对分类做了类别的标签。训练集和测试集中的文本通过使用tokenizer对文本进行处理,得到模型输入时可识别的字典。
步骤二:搭建学生模型和教师模型,其中学生模型使用Robert作为预训练模型,主干网络采用GRU模型搭建。教师模型使用Bert作为预训练模型,主干网络由三条路径组成,第一条路径为BILSTM路径,第二条路径为Transformer Attention路径,第三条路径为Gru路径,三条路径产生的特征进行拼接和全连接后,得到最终的预测结果。
步骤二中的学生模型,仅使用RoBert预训练模型结合Gru模型构建,模型中Gru使用torch.nn库中的函数,参数设置为input_size=768,hidden_size=320,num_layers=1,batch_first=True。学生模型仅用于验证教师模型的性能和蒸馏的有效性。
步骤二中的教师模型,经过Bert预训练模型产生tokens经过三条路径,第一条路径为BILSTM路径,BILSTM的构建时调用torch.nn库搭建的,参数设置为input_size=self.input_size,hidden_size=320,num_layers=1,batch_first=True,bidirectional=True。第二条路径为Transformer Attention路径,通过获取Bert模型的config.hidden_size作为线性层的参数,传入Bert模型的tokens作为线性层的输入,分别得到K,Q,V三个张量,张量的维度都是([20, 134, 768])。使用torch中的permute(0,2,1)调整K的维度,同时通过计算Bert的config.hidden_size平方分之一作为自注意力权重,将K、Q和自注意力权重相乘后在最后一维度做激活,得到attention。将得到的attention与V相乘,得到该路径的输出。第三条路径为GRU模型,与学生模型的主干网络相同。最后将三个路径的特征进行1维度上的拼接后,再经过一次全连接层得到教师模型的输出。
步骤三:选择合适的损失函数,通过反向传播优化参数,训练轮数设定为50轮,训练过程中只保存最优的模型结构和模型参数。
步骤三中的损失函数选择了三种,首先是教师模型内部特征蒸馏的软标签与学生模型的损失计算,使用交叉熵CE损失函数,将教师模型中的gru_out与student_out输入到交叉熵CE损失函数中进行计算;然后是学生模型和学生标签的损失计算,使用均方误差MSE损失函数,将学生模型的student_out与学生模型真实标签y输入到均方误差MSE损失函数中进行计算;最后是学生模型与教师模型蒸馏的损失计算,使用KL散度损失函数,将教师模型的teacher_out与学生模型的student_out输入到KL散度损失函数中进行计算。模型整体损失值是将三者损失相加得到的。
实施例2
本实施例提供一种基于模型内部特征蒸馏的文本分类方法的具体技术方案如下:
步骤1:获取微博文本数据集,并对微博文本数据集进行处理。获取对应模型的输入数据集,并将其划分出训练集和测试集,其中训练集为20%,测试集为80%。
步骤2:加载BERT模型并进行推理,获取模型的输出。下载BERT模型数据包,使用BertTokenizer对文本进行标记化,将文本转换成模型可以理解的格式;使用BertModel加载预训练的BERT模型,其中包含预训练的权重;使用tokenizer将文本转换成模型可以接受的数据格式;运行模型得到输出值Tokens[20,133,768]。
步骤3:定义Bi-LSTM模型,将BERT模型的输出值Tokens作为输入,用于捕捉上下文关系。其中LSTM单元中隐藏状态的维度大小为320,层数设置为1,bidirectional设置为True。
步骤4:运行Bi-LSTM模型,并输出每个序列在最后一个时间步的隐藏状态。输出形状为[20,640],其中20表示批次大小,640表示每个序列在最后一个时间步的隐藏状态的维度。
步骤5:基于自注意力机制(Self-Attention)的操作,通过卷积神经网络(CNN)处理得到注意力输出。使用key_layer、query_layer、value_layer将Tokens值分别输入给键K、查询Q、值V,通过批量矩阵相乘(torch.bmm)计算查询和键的点积,再乘以1/√(d_k ),其中就是k_dim,目的是缩放注意力权重,使其更稳定;然后应用Softmax函数在最后一个维度上进行归一化;最后乘以一个标准化因子a,得到注意力矩阵。具体计算公式如(1):
……… (1)
步骤6:使用注意力矩阵对值进行加权求和,得到注意力加权的输出。具体计算公式如(2):
……… (2)
步骤7:将注意力输出增加一个维度,使其适应后续卷积操作;然后对增加维度后的注意力输出进行卷积操作,并通过池化操作得到一系列卷积特征;最后将这些特征拼接在一起,形成最终的输出注意力机制输出,其输出形状为[20,300]。
步骤8:将Bi-LSTM的输出和注意力机制输出在维度1(列方向)上进行拼接,得到输出[20,940],目的使将不同处理方式得到的特征融合在一起,以供后续的任务使用。
步骤9:定义GRU模型,将BERT模型的输出值Tokens作为输入,用于捕捉输入序列中的长期依赖关系。其中GRU单元中隐藏状态的维度大小为320,层数设置为1。
步骤10:运行GRU模型,并得到GRU的输出,其输出形状为[20,133,320]。
步骤11:在GRU输出中选择每个序列在序列的最后一个时间步的隐藏状态,其输出形状为[20,320]。
步骤12:将Bi-LSTM的输出和注意力机制输出拼接后的输出与GRU的输出在维度1上进行拼接,得到输出[20,1260]。
步骤13:定义一个全连接网络,其中Dropout层设值为0.5;第一个线性层为(1260,128);第二个线性层为(128,16);第三个线性层为(16,2);softmax层对输出在第一个维度上进行归一化。
步骤14:将步骤12的输出送入步骤13定义的全连接网络中,得到新的输出[20,2]。
从步骤2到步骤14建立完整的教师模型,并得到教师模型的输出。
步骤15:加载RoBERTa模型并进行推理,获取模型的输出值Tokens[20,133,768]。
步骤16:定义GRU模型,将RoBERTa模型的输出值Tokens作为输入,用于捕捉输入序列中的长期依赖关系。其中GRU单元中隐藏状态的维度大小为320,层数设置为1。
步骤17:运行GRU模型,并得到GRU的输出,其输出形状为[20,133,320]。
步骤18:在GRU输出中选择每个序列在序列的最后一个时间步的隐藏状态,其输出形状为[20,320]。
步骤19:定义包含多个层的全连接神经网络,并使用了一个softmax激活函数。其中Dropout层设值为0.6,防止过拟合;第一个线性层为(320,128);第二个线性层为(128,64);第三个线性层为(64,16);第四个线性层为(16,2);softmax层对输出在第一个维度上进行归一化,将其转化为表示类别概率的形式。
步骤20:将步骤18的输出送入步骤19定义的全连接神经网络进行前向传播,得到模型的输出。
从步骤15到步骤20建立完整的学生模型,并得到学生模型的输出。
步骤21:将步骤11的输出送入步骤19定义的全连接神经网络进行前向传播,得到模型的输出,即得到样本相应类别的概率。
步骤22:将步骤20和步骤21的输出进行损失计算,得到,具体的计算公式如(3):
… (3)
步骤23:将步骤20的输出和样本真实标签进行损失计算,得到具体计算公式如(4):
……… (4)
步骤24:将步骤14和步骤20的输出进行损失计算,得到,具体计算公式如(5):
………(5)
步骤25:计算模型的整体损失值和准确率,具体计算公式如(6):
……… (6)。
实施例3
获取公开的微博数据集weibo_senti_100k,并对数据集进行划分,分为训练集train_set和测试集test_set,其中训练集占80%。将数据集传入蒸馏前后的模型中进行训练,实验结果如下表所示:
Model GRU1 GRU2 GRU3 GRU-AVR GRU+Bi-LSTM+Attention
Un-Distilling 94.6 94.1 96.9 95.2 -
Distilling 97.9 98.5 98.1 98.3 98.5
如图1和2所示,通过未特征蒸馏的准确率和损失曲线图,与使用内部特征蒸馏后的准确率和损失曲线图对比。实验结果表明,本发明提出的教师模型准确率达到98.5,通过内部特征蒸馏技术提高了学生模型的准确率,使得学生模型准确率接近教师模型准确率,从而压缩模型的复杂度。
以上,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (2)

1.一种基于模型内部特征蒸馏的文本分类方法,其特征在于,包括以下步骤:
步骤一、获取微博文本数据集,将微博文本数据集进行划分和预处理,其中,将获取的微博文本数据集划分出训练集和测试集,并使用tokenizer对训练集和测试集进行数据预处理,将文本转换为模型输入的包含特征向量的可识别字典;
步骤二、基于预训练模型,构建教师模型和学生模型;
构建教师模型,教师模型基于BERT预训练模型,主干网络利用GRU模型结合双向Bi-LSTM模型和Transformer_Attention模型搭建成带有三条路径的教师模型,三条路径产生的特征进行拼接和全连接后,得到最终的预测结果;
搭建教师模型的第一条路径,首先将数据送入BERT预训练模型得到输出,取得该输出的最后一个隐藏层状态作为tokens,将tokens作为Bi-LSTM模型的输入,训练过程中将Bi-LSTM模型输出的结果取最后一个时间步的隐藏状态作为输出,输出的维度是[20,640];
搭建教师模型的第二条路径,将BERT预训练模型产生的tokens作为Transformer_Attention模型自注意力机制的输入,通过三个线性层得到张量K、张量Q和张量V,同时利用预训练模型权重中的config.hidden_size平方分之一作为自注意力的权重,训练过程中利用Transformer_Attention模型自注意力机制产生注意力特征输出,该输出的维度是[20,300];
搭建教师模型的第三条路径,将BERT预模型输出的tokens作为GRU模型的输入,通过运行GRU模型,获得每个序列在序列的最后一个时间步的隐藏状态作为输出,该输出的特征维度是[20,320];
先将Bi-LSTM模型的输出与Transformer_Attention模型自注意力机制的输出在维度1上进行拼接,然后将其结果与GRU模型的输出在维度1上进行拼接,最后通过一个包含多层的全连接网络得到教师模型的输出:教师_输出[20,2];
搭建学生模型,学生模型基于RoBERTa预训练模型,主干网络仅使用GRU模型搭建成带有一条路径的学生模型;
将RoBERTa预训练模型和GRU模型进行结合构建成学生模型,将RoBERTa预训练模型的输出作为GRU模型的输入,运行GRU模型,然后通过一个全连接网络得到学生模型的输出:学生输出[20,2]
步骤三、利用步骤二中的教师模型和学生模型进行知识蒸馏;
在训练过程中选择相应的损失函数计算模型损失,并保存最优模型;其中,相应的损失函数为torch.nn.functional模型库中的交叉熵CE函数、KL散度函数和均方误差MSE函数;
将教师模型中内部的特征蒸馏出来作为软标签,与学生模型中内部的特征计算损失,得到损失值loss1;利用学生模型预测结果与学生模型的标签计算损失,得到损失值loss2;利用学生模型和教师模型的预测结果计算损失,得到损失值loss3;最后将三个损失相加得到模型的最终损失值Loss;通过反向传播优化参数,训练该模型,通过比对损失值的大小决定是否保存训练得到的模型及参数,训练过程中只保存损失值最小的模型结构和模型参数;
步骤三选用三个不同的损失函数,在教师模型中内部蒸馏出的软标签与学生模型中内部的特征计算损失时,使用交叉熵损失函数;利用学生模型和教师模型的预测结果计算损失时,使用KL散度损失函数;学生模型预测结果与学生模型的标签计算损失时,使用MSE损失函数;最后将三个损失相加得到模型的最终损失值Loss,具体计算公式如下:
Loss=loss1+loss2+loss3………(4)
上述公式中x、n均表示样本数量,i表示第i个样本量;公式(1)中,g表示教师模型中的GRU通过全连接层的输出值gru_out,q表示学生模型的输出值student_out;公式(2)中,f(x)表示学生模型的输出student_out的值,y表示学生标签的真实值;公式(3)中,p表示教师模型的输出值teacher_out,q表示学生模型的输出值student_out。
2.根据权利要求1所述的一种基于模型内部特征蒸馏的文本分类方法,其特征在于,在步骤一中,所述微博文本数据集划分的训练集和测试集比例为8:2,所述tokenizer处理后的字典包括input_ids、token_type_ids和attention_mask。
CN202410064744.3A 2024-01-17 2024-01-17 一种基于模型内部特征蒸馏的文本分类方法 Active CN117807235B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202410064744.3A CN117807235B (zh) 2024-01-17 2024-01-17 一种基于模型内部特征蒸馏的文本分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202410064744.3A CN117807235B (zh) 2024-01-17 2024-01-17 一种基于模型内部特征蒸馏的文本分类方法

Publications (2)

Publication Number Publication Date
CN117807235A CN117807235A (zh) 2024-04-02
CN117807235B true CN117807235B (zh) 2024-05-10

Family

ID=90431980

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202410064744.3A Active CN117807235B (zh) 2024-01-17 2024-01-17 一种基于模型内部特征蒸馏的文本分类方法

Country Status (1)

Country Link
CN (1) CN117807235B (zh)

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113610126A (zh) * 2021-07-23 2021-11-05 武汉工程大学 基于多目标检测模型无标签的知识蒸馏方法及存储介质
CN114170655A (zh) * 2021-11-29 2022-03-11 西安电子科技大学 一种基于知识蒸馏的人脸伪造线索迁移方法
WO2022126683A1 (zh) * 2020-12-15 2022-06-23 之江实验室 面向多任务的预训练语言模型自动压缩方法及平台
CN114818902A (zh) * 2022-04-21 2022-07-29 浪潮云信息技术股份公司 基于知识蒸馏的文本分类方法及***
WO2023024427A1 (zh) * 2021-08-24 2023-03-02 平安科技(深圳)有限公司 适用于bert模型的蒸馏方法、装置、设备及存储介质
CN116595167A (zh) * 2023-03-29 2023-08-15 光控特斯联(重庆)信息技术有限公司 一种基于集成知识蒸馏网络的意图识别方法
CN117217223A (zh) * 2023-07-24 2023-12-12 湖南中医药大学 基于多特征嵌入的中文命名实体识别方法及***

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2022126683A1 (zh) * 2020-12-15 2022-06-23 之江实验室 面向多任务的预训练语言模型自动压缩方法及平台
CN113610126A (zh) * 2021-07-23 2021-11-05 武汉工程大学 基于多目标检测模型无标签的知识蒸馏方法及存储介质
WO2023024427A1 (zh) * 2021-08-24 2023-03-02 平安科技(深圳)有限公司 适用于bert模型的蒸馏方法、装置、设备及存储介质
CN114170655A (zh) * 2021-11-29 2022-03-11 西安电子科技大学 一种基于知识蒸馏的人脸伪造线索迁移方法
CN114818902A (zh) * 2022-04-21 2022-07-29 浪潮云信息技术股份公司 基于知识蒸馏的文本分类方法及***
CN116595167A (zh) * 2023-03-29 2023-08-15 光控特斯联(重庆)信息技术有限公司 一种基于集成知识蒸馏网络的意图识别方法
CN117217223A (zh) * 2023-07-24 2023-12-12 湖南中医药大学 基于多特征嵌入的中文命名实体识别方法及***

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
基于BERT模型的旅游文本情感分类研究与实现;王诗艺;《万方学位论文》;20231002;16-30 *
基于CenterNet的多教师联合知识蒸馏;刘绍华;《***工程与电子技术》;20230525;1174-1184 *

Also Published As

Publication number Publication date
CN117807235A (zh) 2024-04-02

Similar Documents

Publication Publication Date Title
CN112667818B (zh) 融合gcn与多粒度注意力的用户评论情感分析方法及***
Wang et al. Research on Web text classification algorithm based on improved CNN and SVM
CN109885756B (zh) 基于cnn和rnn的序列化推荐方法
CN111274375B (zh) 一种基于双向gru网络的多轮对话方法及***
US20220351043A1 (en) Adaptive high-precision compression method and system based on convolutional neural network model
CN110457661B (zh) 自然语言生成方法、装置、设备及存储介质
CN114398976A (zh) 基于bert与门控类注意力增强网络的机器阅读理解方法
CN112347756A (zh) 一种基于序列化证据抽取的推理阅读理解方法及***
CN112000770A (zh) 面向智能问答的基于语义特征图的句子对语义匹配方法
CN112199503A (zh) 一种基于特征增强的非平衡Bi-LSTM的中文文本分类方法
CN115687609A (zh) 一种基于Prompt多模板融合的零样本关系抽取方法
CN111522926A (zh) 文本匹配方法、装置、服务器和存储介质
Wang et al. A survey of extractive question answering
US20220138425A1 (en) Acronym definition network
CN117932066A (zh) 一种基于预训练的“提取-生成”式答案生成模型及方法
CN116543289B (zh) 一种基于编码器-解码器及Bi-LSTM注意力模型的图像描述方法
CN117807235B (zh) 一种基于模型内部特征蒸馏的文本分类方法
CN115599918B (zh) 一种基于图增强的互学习文本分类方法及***
CN116910190A (zh) 多任务感知模型获取方法、装置、设备及可读存储介质
CN116403231A (zh) 基于双视图对比学习与图剪枝的多跳阅读理解方法及***
Jun et al. Hierarchical multiples self-attention mechanism for multi-modal analysis
CN114648005A (zh) 一种多任务联合学习的多片段机器阅读理解方法及装置
CN114969279A (zh) 一种基于层次图神经网络的表格文本问答方法
Rajapaksha et al. Explainable Attention Pruning: A Meta-learning-based Approach
CN113609839A (zh) 一种文本处理方法、模型及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant