CN113673242A - 一种基于k邻近结点算法和对比学习的文本分类方法 - Google Patents

一种基于k邻近结点算法和对比学习的文本分类方法 Download PDF

Info

Publication number
CN113673242A
CN113673242A CN202110960433.1A CN202110960433A CN113673242A CN 113673242 A CN113673242 A CN 113673242A CN 202110960433 A CN202110960433 A CN 202110960433A CN 113673242 A CN113673242 A CN 113673242A
Authority
CN
China
Prior art keywords
encoder
training
classification
representing
samples
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110960433.1A
Other languages
English (en)
Inventor
邱锡鹏
宋德敏
李林阳
傅家庆
杨非
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Fudan University
Zhejiang Lab
Original Assignee
Fudan University
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Fudan University, Zhejiang Lab filed Critical Fudan University
Priority to CN202110960433.1A priority Critical patent/CN113673242A/zh
Publication of CN113673242A publication Critical patent/CN113673242A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/279Recognition of textual entities
    • G06F40/289Phrasal analysis, e.g. finite state techniques or chunking
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/205Parsing
    • G06F40/211Syntactic parsing, e.g. based on context-free grammar [CFG] or unification grammars
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种基于K邻近结点算法和对比学习的文本分类方法,该方法在训练阶段使用对比学习拉进类内距离,拉远类间距离,并且结合交叉熵损失,辅助对比学习进行联合训练,在推理过程中,通过联合训练好的模型,结合最邻近结点算法,进行联合预测,计算待推断文本的分类;本发明不仅能够在文本分类的准确率上取得比目前业内使用的文本分类方式更高的结果,而且在模型的鲁棒性上也取得了极大的提升。

Description

一种基于K邻近结点算法和对比学习的文本分类方法
技术领域
本发明涉及深度学习和自然语言处理,尤其是涉及一种基于K邻近结点算法和对比学习的文本分类方法。
背景技术
文本分类任务是自然语言处理中的一类基础任务,目前主流的文本分类方法是在大规模预训练模型(如BERT)的基础上,使用一个线性分类器进行分类。但是线性分类器往往不具备很好的鲁棒性,容易被TextFooler或BertAttack这类对抗攻击的方式所愚弄。
发明内容
为解决现有技术的不足,实现提高鲁棒性的同时,提升模型分类准确率的目的,本发明采用如下的技术方案:
一种基于K邻近结点算法和对比学习的文本分类方法,包括如下步骤:
S1,训练过程中,通过构建句子向量表示k的正负样本,进行对比学习,拉近类内间距,拉远类间间距,对比学习的损失函数如下:
Figure BDA0003221873920000011
其中,M表示正样本的数量,N表示负样本的数量,q表示预训练编码器encoder_q输出的句子的向量表示,k表示预训练编码器encoder_k输出的句子向量表示,encoder_q与encoder_k相同,kj表示第j个正样本k+,ki表示遍历负样本k-和kj的集合,exp(·)表示指数函数,τ为超参数;
结合交叉熵损失函数,进行联合训练,联合损失函数如下:
L=λLec+(1-λ)Lsc
Figure BDA0003221873920000012
其中,λ表示调节交叉熵损失函数Lec和所述对比学习的损失函数Lsc之间的权重参数,yc表示q的类别,C表示文本分类的分类数,F(·)表示线性分类器;
反向传播损失函数,更新encoder_q和线性分类器的参数;
联合损失函数为交叉熵损失函数和有监督对比学习损失函数的加权和,通过对比学习的损失函数Lsc来辅助交叉熵损失函数训练模型,使用对比学习训练模型,使得模型在训练过程中,能够自动对样本的embedding表示进行聚类,从而能够达到更好的分类效果;
S2,通过训练好的encoder_q和线性分类器,对文本进行分类。
进一步地,所述S2中,通过训练好的encoder_q获得待预测文本的句子向量表示q,使用联合预测函数预测文本分类,联合预测函数如下:
Figure BDA0003221873920000021
其中,S表示最终分类的概率值,
Figure BDA0003221873920000022
表示超参数,Softmax(·)表示激活函数,F(q)表示训练好的线性分类器,KNN(q)表示从队列Q中选取在样本空间中离q最近的K个训练样本,根据训练样本的分类标签,用投票的方式给出KNN模型的概率值,通过概率值得到分类结果,在推断样本类别时,使用KNN和线性分类器联合预测待预测样本的分类,通过K邻近结点算法,显著提高了模型的鲁棒性。
进一步地,所述选取离q最近的K个训练样本,由于K个训练样本的分类标签已知,K=s1+s2+……+sc,si表示样本的分类标签属于第i个类别的数量,c表示训练样本的类别数量,通过KNN模型,给出的q属于分类yi的概率值为
Figure BDA0003221873920000023
进一步地,所述q与训练样本的相似度,通过cos函数来计算。
进一步地,所述S1中,通过超参数m更新encoder_k的动量参数:
θk←mθk+(1-m)θq
其中θk表示encoder_k的动量参数,θq表示encoder_q的动量参数,在每个batch迭代过程中,将经过encoder_k编码获得的k存放在队列Q中,为了让队列中的样本表示,在每次迭代过程中,通过动量参数更新的方式,更新encoder_k,使其获得的k与直接通过encoder_q获得的q接近。
进一步地,所述队列Q,按先后顺序替换其中的元素k。
进一步地,从所述队列Q中获取与样本的分类标签相同的M个元素k作为正样本k+,与样本的分类标签不同的N个元素k作为负样本k-
本发明的优势和有益效果在于:
本发明不仅在模型的鲁棒性上取得了极大的改进,同时模型的准确率也有相应的提升。此外,为了使用K邻近算法预测样本的所属分类,我们在训练过程中添加了对比学习已期能够拉近同类样本的距离。同时在使用对比学习的过程中我们引入了MOCO的训练方式,极大的增加了正负样本的规模。
附图说明
图1是本发明的方法流程图。
图2是本发明的λ取值在不同数据集上对于模型准确率的影响折线图。
图3是本发明的
Figure BDA0003221873920000024
取值在不同数据集上对于模型准确率的影响折线图。
图4a是普通线性分类器样本空间分布图。
图4b是本发明的KNN-BERT样本空间分布图。
图5是本发明的模型分类准确率试验结果比较图。
图6是本发明的模型鲁棒性试验结果比较图。
具体实施方式
以下结合附图对本发明的具体实施方式进行详细说明。应当理解的是,此处所描述的具体实施方式仅用于说明和解释本发明,并不用于限制本发明。
一种基于K邻近结点算法和对比学习的文本分类方法,如图1所示,包括如下步骤:
第一部分:模型训练过程,具体地,分为以下步骤:
步骤1.1:使用预训练模型BERT作为样本编码器encoder_q,使用相同的预训练模型BERT作为样本编码器encoder_k。
步骤1.2:使用超参m=0.999来更新encoder_k的参数,具体地,动量参数更新的公式为:
θk←mθk+(1-m)θq
其中θk表示样本编码器encoder_k的动量参数,θq表示样本编码器encoder_q的动量参数。传统的对比学习采用batch内部选取正负样本,这样训练过程中使用到的正样本和负样本的数量过少,MoCo采用动量更新的方式,在每个batch迭代过程中,将经过编码器编码的样本存放在一个队列中,为了让队列中的样本表示,在每次迭代过程中,和直接通过编码器获得的样本表示接近,使用动量参数更新的方式更新encoder_k;
对于两个编码器encoder_q和encoder_k的更新都采用迭代更新的方式,基于训练数据,对每个batch进行更新。
步骤1.3:使用编码器encoder_q获得句子的向量表示[CLS]_q(即样本的句子表示q),使用编码器encoder_k获得句子的向量表示[CLS]_k(即样本的句子表示k)。
例如:对训练语句“北京是中国的首都”,BERT在编码时会给句首加入一个TokenCLS,在句尾加入一个Token SEP。一般情况下使用CLS的embedding向量作为整个句子的表示。
步骤1.4:将[CLS]_k存储在大小为32000的队列Q中,并按先后顺序替换Q中的元素;
步骤1.5:从Q中获取与样本标签相同的M个样本作为正样本k+,与样本标签不同的N个样本作为负样本k-
步骤1.6:使用正负样本计算对比学习的损失函数,拉近同类样本的距离,具体地,对比学习损失函数如下:
Figure BDA0003221873920000031
其中,q是encoder_q输出的句子表示,kj表示第j个正样本k+,ki表示遍历k-和kj的集合,exp(·)表示指数函数,τ为超参数,具体地,τ=0.07。
步骤1.7:使用交叉熵损失函数来辅助模型训练,取λ=0.01,具体地,模型训练的损失函数如下:
L=λLec+(1-λ)Lsc
Figure BDA0003221873920000041
其中,λ表示调节Lec和Lsc之间的权重参数,如图2所示的λ取值,是在RTE和MRPC两个数据集上的试验结果,yc表示q或者输入样本x(q是x经过encoder_q后得到的句子表示)的类别,C是文本分类的分类数,F(·)是线性分类器。
通过对比学习的损失函数Lsc来辅助交叉熵损失函数训练模型。其有益效果是使用对比学习训练模型,使得模型在训练过程中,能够自动对样本的embedding表示进行聚类。从而能够达到更好的分类效果。
步骤1.8:反向传播模型损失函数,更新encoder_q和线性分类器的参数。
第二部分:使用KNN和线性分类器联合预测待预测样本的分类,具体地,分为以下步骤:
步骤2.1:使用编码器encoder_q获得待预测样本的句子表示q;
步骤2.2:使用联合预测函数预测样本分类,取
Figure BDA0003221873920000042
具体地,联合预测函数如下:
Figure BDA0003221873920000043
其中,S是最终模型分类的概率值,
Figure BDA0003221873920000044
是超参数,KNN(q)是从Q中取在样本空间中离q最近的K个样本,然后根据这些样本的label用投票的方式给出KNN模型的概率值。
具体地,使用cos函数来计算两个样本的相似度,选取相似度最大的K个训练样本,因为这K个训练样本的分类信息label是已知的,假设训练样本一共有c个分类,s1+s2+……+sc=K,其中si表示样本Label属于第i个分类的数量,所以KNN模型给出的待预测样本x属于分类yi的概率值为
Figure BDA0003221873920000045
如图3所示的
Figure BDA0003221873920000046
值,是在RTE、MRPC和MNLI数据集上对于模型准确率的影响。
如图4a、4b所示,红点和蓝点表示两种不同类别的数据点,从图中可以看出KNN-BERT的样本分布的聚类效果要优于普通的线性分类器。
如图5所示,RTE、MRPC、QNLI、MNLI、SST-2、IMDB、AG’News均为目前常用文本分类数据集,BERT是目前比较通用的分类模型,SCL-Train是传统的对比学习+BERT的分类模型,MoCo是使用了动量参数更新方法后扩充了正负样本的分类模型,KNN-BERT是本发明所提出的分类模型。从图中可以看出,本发明所提出的方法,在各个数据集上相较于现有的方法,分类准确率均有提升,其中在RTE和MRPC这两个数据量较少的数据集上,提升效果更好。
如图6所示,IMDB、AG’s News是两种常用的文本分类数据集,Origin表示原始的准确率,Textfooler和BERT-Attack表示在这两种对抗攻击方式攻击下的分类准确率,BERT表示传统的分类方法,KNN表示采用本发明的方法,实验结果表明,当
Figure BDA0003221873920000051
时,即模型的结果只由KNN分类器给出时,模型的鲁棒性最好。
以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的范围。

Claims (7)

1.一种基于K邻近结点算法和对比学习的文本分类方法,其特征在于包括如下步骤:
S1,训练过程中,通过构建句子向量表示k的正负样本,进行对比学习,对比学习的损失函数如下:
Figure FDA0003221873910000011
其中,M表示正样本的数量,N表示负样本的数量,q表示预训练编码器encoder_q输出的句子的向量表示,k表示预训练编码器encoder_k输出的句子向量表示,encoder_q与encoder_k相同,kj表示第j个正样本k+,ki表示遍历负样本k-和kj的集合,exp(·)表示指数函数,τ为超参数;
结合交叉熵损失函数,进行联合训练,联合损失函数如下:
L=λLec+(1-λ)Lsc
Figure FDA0003221873910000012
其中,λ表示调节交叉熵损失函数Lec和所述对比学习的损失函数Lsc之间的权重参数,yc表示q的类别,C表示文本分类的分类数,F(·)表示线性分类器;
反向传播损失函数,更新encoder_q和线性分类器的参数;
S2,通过训练好的encoder_q和线性分类器,对文本进行分类。
2.根据权利要求1所述的一种基于K邻近结点算法和对比学习的文本分类方法,其特征在于所述S2中,通过训练好的encoder_q获得待预测文本的句子向量表示q,使用联合预测函数预测文本分类,联合预测函数如下:
Figure FDA0003221873910000013
其中,S表示最终分类的概率值,
Figure FDA0003221873910000014
表示超参数,Softmax(·)表示激活函数,F(q)表示训练好的线性分类器,KNN(q)表示从队列Q中选取离q最近的K个训练样本,根据训练样本的分类标签,用投票的方式给出KNN模型的概率值,通过概率值得到分类结果。
3.根据权利要求2所述的一种基于K邻近结点算法和对比学习的文本分类方法,其特征在于所述选取离q最近的K个训练样本,由于K个训练样本的分类标签已知,K=s1+s2+……+sc,si表示样本的分类标签属于第i个类别的数量,c表示训练样本的类别数量,通过KNN模型,给出的q属于分类yi的概率值为
Figure FDA0003221873910000015
4.根据权利要求2所述的一种基于K邻近结点算法和对比学习的文本分类方法,其特征在于所述q与训练样本的相似度,通过cos函数来计算。
5.根据权利要求1所述的一种基于K邻近结点算法和对比学习的文本分类方法,其特征在于所述S1中,通过超参数m更新encoder_k的动量参数:
θk←mθk+(1-m)θq
其中θk表示encoder_k的动量参数,θq表示encoder_q的动量参数,将经过encoder_k编码获得的k存放在队列Q中,在每次迭代过程中,通过动量参数更新的方式,更新encoder_k,使其获得的k与直接通过encoder_q获得的q接近。
6.根据权利要求5所述的一种基于K邻近结点算法和对比学习的文本分类方法,其特征在于所述队列Q,按先后顺序替换其中的元素k。
7.根据权利要求5所述的一种基于K邻近结点算法和对比学习的文本分类方法,其特征在于从所述队列Q中获取与样本的分类标签相同的M个元素k作为正样本k+,与样本的分类标签不同的N个元素k作为负样本k-
CN202110960433.1A 2021-08-20 2021-08-20 一种基于k邻近结点算法和对比学习的文本分类方法 Pending CN113673242A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110960433.1A CN113673242A (zh) 2021-08-20 2021-08-20 一种基于k邻近结点算法和对比学习的文本分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110960433.1A CN113673242A (zh) 2021-08-20 2021-08-20 一种基于k邻近结点算法和对比学习的文本分类方法

Publications (1)

Publication Number Publication Date
CN113673242A true CN113673242A (zh) 2021-11-19

Family

ID=78544489

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110960433.1A Pending CN113673242A (zh) 2021-08-20 2021-08-20 一种基于k邻近结点算法和对比学习的文本分类方法

Country Status (1)

Country Link
CN (1) CN113673242A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110533104A (zh) * 2019-08-30 2019-12-03 中山大学 一种基于不同类别的联合距离均值的分类方法
CN114090780A (zh) * 2022-01-20 2022-02-25 宏龙科技(杭州)有限公司 一种基于提示学习的快速图片分类方法
CN114299304A (zh) * 2021-12-15 2022-04-08 腾讯科技(深圳)有限公司 一种图像处理方法及相关设备
CN115346084A (zh) * 2022-08-15 2022-11-15 腾讯科技(深圳)有限公司 样本处理方法、装置、电子设备、存储介质及程序产品
CN117574309A (zh) * 2023-11-28 2024-02-20 东华理工大学南昌校区 融合多标签对比学习和knn的层次文本分类方法

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110533104A (zh) * 2019-08-30 2019-12-03 中山大学 一种基于不同类别的联合距离均值的分类方法
CN114299304A (zh) * 2021-12-15 2022-04-08 腾讯科技(深圳)有限公司 一种图像处理方法及相关设备
CN114299304B (zh) * 2021-12-15 2024-04-12 腾讯科技(深圳)有限公司 一种图像处理方法及相关设备
CN114090780A (zh) * 2022-01-20 2022-02-25 宏龙科技(杭州)有限公司 一种基于提示学习的快速图片分类方法
CN115346084A (zh) * 2022-08-15 2022-11-15 腾讯科技(深圳)有限公司 样本处理方法、装置、电子设备、存储介质及程序产品
CN117574309A (zh) * 2023-11-28 2024-02-20 东华理工大学南昌校区 融合多标签对比学习和knn的层次文本分类方法

Similar Documents

Publication Publication Date Title
CN113673242A (zh) 一种基于k邻近结点算法和对比学习的文本分类方法
CN109376242B (zh) 基于循环神经网络变体和卷积神经网络的文本分类方法
Peng et al. Accelerating minibatch stochastic gradient descent using typicality sampling
CN112560432A (zh) 基于图注意力网络的文本情感分析方法
CN112069310A (zh) 基于主动学习策略的文本分类方法及***
CN116644755B (zh) 基于多任务学习的少样本命名实体识别方法、装置及介质
WO2022241932A1 (zh) 一种基于非侵入式注意力预处理过程与BiLSTM模型的预测方法
CN115510245B (zh) 一种面向非结构化数据的领域知识抽取方法
CN113705238A (zh) 基于bert和方面特征定位模型的方面级情感分析方法及模型
CN111460097B (zh) 一种基于tpn的小样本文本分类方法
CN114547299A (zh) 一种基于复合网络模型的短文本情感分类方法及装置
CN114692605A (zh) 一种融合句法结构信息的关键词生成方法及装置
CN114722835A (zh) 基于lda和bert融合改进模型的文本情感识别方法
CN113255366A (zh) 一种基于异构图神经网络的方面级文本情感分析方法
CN113553245B (zh) 结合双向切片gru与门控注意力机制日志异常检测方法
Fonseca et al. Model-agnostic approaches to handling noisy labels when training sound event classifiers
CN112491891B (zh) 物联网环境下基于混合深度学习的网络攻击检测方法
CN118013038A (zh) 一种基于原型聚类的文本增量关系抽取方法
CN116050419B (zh) 一种面向科学文献知识实体的无监督识别方法及***
CN112861626A (zh) 基于小样本学习的细粒度表情分类方法
CN114357166B (zh) 一种基于深度学习的文本分类方法
CN115809346A (zh) 一种基于多视图语义增强的小样本知识图谱补全方法
Li et al. A position weighted information based word embedding model for machine translation
CN114707483A (zh) 基于对比学习和数据增强的零样本事件抽取***及方法
CN113342982A (zh) 融合RoBERTa和外部知识库的企业行业分类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination