CN114417975A - 基于深度pu学习与类别先验估计的数据分类方法及*** - Google Patents

基于深度pu学习与类别先验估计的数据分类方法及*** Download PDF

Info

Publication number
CN114417975A
CN114417975A CN202111591020.7A CN202111591020A CN114417975A CN 114417975 A CN114417975 A CN 114417975A CN 202111591020 A CN202111591020 A CN 202111591020A CN 114417975 A CN114417975 A CN 114417975A
Authority
CN
China
Prior art keywords
model
student
teacher
data
learning
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111591020.7A
Other languages
English (en)
Inventor
黄庆明
赵昀睿
姜阳邦彦
许倩倩
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Institute of Computing Technology of CAS
Original Assignee
Institute of Computing Technology of CAS
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Institute of Computing Technology of CAS filed Critical Institute of Computing Technology of CAS
Priority to CN202111591020.7A priority Critical patent/CN114417975A/zh
Publication of CN114417975A publication Critical patent/CN114417975A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2155Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • G06F16/353Clustering; Classification into predefined classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/217Validation; Performance evaluation; Active pattern learning techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Databases & Information Systems (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明提出一种基于基于深度PU学***均教师、温度锐化等技术,提高算法性能和稳定性。该框架能应用于包括计算机视觉、推荐***、生物医疗等在内各领域的PU问题,并且效果优异,兼具科学价值和实用价值。

Description

基于深度PU学习与类别先验估计的数据分类方法及***
技术领域
本发明涉及机器学习中正样本和无标签(Positive-Unlabeled,PU)学习技术领域,特别涉及基于代价敏感的基于深度PU学习。
背景技术
近年来,随着互联网与信息技术的发展,人类进入了大数据时代。基于海量数据的深度学习,受到广泛关注并取得了突破性进展。然而,深度学习算法取得的优异性能,依赖于大量数据,尤其离不开完整的类别标签信息的指导。而许多实际应用场景中,数据标注获取困难、代价高昂,在有限的人力物力条件下往往仅能获得一小部分数据的标签。因此,对数据标注依赖更少的学习方式成为当下热门,PU学习即为其中之一。例如罕见病的分类问题,确诊样本可以看作正类样本,而未确诊的其他样本,即无标签样本,仍存在患罕见病可能,即无标签样本包含了正类和负类样本。类似的情况还出现在恶意URL检测、虚假评论检测、冷冻电镜的粒子拾取等任务里。由此可见,仅利用正样本和无标签数据进行学习,又称为PU学习,具有重要价值。
PU学***滑假设或聚类假设,从无标签数据中提取得到可靠的负样本和正样本,从而将PU问题转换为一般性的半监督学习问题后,运用半监督或监督学习方法进行训练。另一类较为直观的方法是有偏PU学习,即将无标签数据处理成有噪声的负样本。此外,无偏PU学习在代价敏感(Cost Sensitive)学习的框架下,通过仅有的正样本估计正类的损失,同时基于将无标签数据全部当作负类而产生的损失和类别先验知识,间接构造负类的损失,以此实现对常用分类优化目标的无偏估计,取得了当前最先进的性能。
发明内容
发明人在进行无偏PU学***,面对基于海量数据的深度学习算法时仍然捉襟见肘。因此,如何在缺失类别先验知识的条件下,以尽可能小的计算代价,准确地估计正类先验,以便进行基于代价敏感的基于深度PU学习,成为问题的关键。
具体来说,为了克服上述技术问题,本发明提出了一种基于深度PU学习与类别先验估计的数据分类方法,其中包括:
步骤1、获取包括多个数据样本的训练集,且在该训练集中只有部分数据样本标有类别标签,将该训练集同时输入至两个网络结构相同,但参数不同的学生模型和教师模型中,分别得到学生模型和教师模型输出的各数据样本对应的学生预测分数和教师预测分数;
步骤2、将所有教师预测分数输入至高斯混合模型,得到正类先验;基于所有学生预测分数,构建温度锐化损失;基于所有学生预测分数和教师预测分数,构建一致性损失;基于该正类先验和所有学生预测分数,得到非负PU风险,合并该一致性损失、该非负PU风险和该温度锐化损失,得到目标损失,并基于该目标损失,使用梯度反向传播对该学生模型的参数进行更新,直至该目标收敛或达到预设迭代次数,保存当前学生模型或老师模型作为数据分类模型,例如保存当前教师模型和学生模型中性能更优的作为数据分类模型;
步骤3、将待分类数据输入至该数据分类模型,已得到该待分类数据的类别。
所述的基于深度PU学习与类别先验估计的数据分类方法,其中当用于恶意URL检测时,该训练集中数据为部分已标注恶意类别的URL和无标签的URL,且学生模型和教师模型均为循环神经网络;当用于虚假评论检测时,该训练集中数据为已标注虚假类别的评论和无标签的评论,且学生模型和教师模型均为循环神经网络;当用于冷冻电镜的粒子拾取时,该训练集中数据为已标注选中类别的粒子区域和无标签的粒子区域,且学生模型和教师模型均为卷积神经网络。
所述的基于深度PU学习与类别先验估计的数据分类方法,其中该步骤2包括:
该学生模型、该教师模型各自的预测分数分别为:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
其中,
Figure BDA0003429906560000031
X为训练集,S为学生模型输出的学生预测分数,S′为该教师模型输出的教师预测分数,Θt为学生模型在t时刻的参数,Θ′t为教师模型在t时刻的参数。
所述的基于深度PU学习与类别先验估计的数据分类方法,其中该学生模型和该教师模型输出的一致性损失
Figure BDA0003429906560000032
为:
Figure BDA0003429906560000033
其中,xi∈X表示训练集X中的第i个样本,ci是其置信度,N=|X|;
Figure BDA0003429906560000037
是指示函数,当满足条件(·)时函数值取1,反之取0;τ为置信度阈值,Θ为学生模型的参数,Θ′为教师模型的参数。
所述的基于深度PU学习与类别先验估计的数据分类方法,其中该温度锐化损失
Figure BDA0003429906560000034
为:
Figure BDA0003429906560000035
Figure BDA0003429906560000036
式中T为类别分布的温度,s为学生模型输出的学生预测分数。
本发明还提出了一种基于深度PU学习与类别先验估计的数据分类***,其中包括:
初始模块,用于获取包括多个数据样本的训练集,且在该训练集中只有部分数据样本标有类别标签,将该训练集同时输入至两个网络结构相同,但参数不同的学生模型和教师模型中,分别得到学生模型和教师模型输出的各数据样本对应的学生预测分数和教师预测分数;
训练模块,用于将所有教师预测分数输入至高斯混合模型,得到正类先验;基于所有学生预测分数,构建温度锐化损失;基于所有学生预测分数和教师预测分数,构建一致性损失;基于该正类先验和所有学生预测分数,得到非负PU风险,合并该一致性损失、该非负PU风险和该温度锐化损失,得到目标损失,并基于该目标损失,使用梯度反向传播对该学生模型的参数进行更新,直至该目标收敛或达到预设迭代次数,保存当前学生模型或老师模型作为数据分类模型;
分类模块,用于将待分类数据输入至该数据分类模型,已得到该待分类数据的类别。
所述的基于深度PU学习与类别先验估计的数据分类***,其中当用于恶意URL检测时,该训练集中数据为部分已标注恶意类别的URL和无标签的URL,且学生模型和教师模型均为循环神经网络;当用于虚假评论检测时,该训练集中数据为已标注虚假类别的评论和无标签的评论,且学生模型和教师模型均为循环神经网络;当用于冷冻电镜的粒子拾取时,该训练集中数据为已标注选中类别的粒子区域和无标签的粒子区域,且学生模型和教师模型均为卷积神经网络。
所述的基于深度PU学习与类别先验估计的数据分类***,其中该训练模块用于基于下式为该学生模型、该教师模型各自的预测分数:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
其中,
Figure BDA0003429906560000041
X为训练集,S为学生模型输出的学生预测分数,S′为该教师模型输出的教师预测分数,Θt为学生模型在t时刻的参数,Θ′t为教师模型在t时刻的参数。
所述的基于深度PU学习与类别先验估计的数据分类***,其中该学生模型和该教师模型输出的一致性损失
Figure BDA0003429906560000042
为:
Figure BDA0003429906560000043
其中,xi∈X表示训练集X中的第i个样本,ci是其置信度,N=|X|;
Figure BDA0003429906560000044
是指示函数,当满足条件(·)时函数值取1,反之取0;τ为置信度阈值,Θ为学生模型的参数,Θ′为教师模型的参数。
所述的基于深度PU学习与类别先验估计的数据分类***,其中该温度锐化损失
Figure BDA0003429906560000051
为:
Figure BDA0003429906560000052
Figure BDA0003429906560000053
式中T为类别分布的温度,s为学生模型输出的学生预测分数。
本发明还提出了一种存储介质,用于存储执行所述任意一种基于深度PU学习与类别先验估计的数据分类方法的程序。
本发明还提出了一种客户端,用于上述任意一种基于深度PU学习与类别先验估计的数据分类***。
由以上方案可知,本发明的优点在于:
本发明提出了一种迭代式的深度PU学***均教师、温度锐化等技术,提高算法性能和稳定性。该框架能应用于包括计算机视觉、推荐***、生物医疗等在内各领域的PU问题上,且效果优异,兼具科学价值和实用价值。
附图说明
图1为本发明流程框图。
具体实施方式
本发明的目标是解决在缺失类别先验知识的条件下如何进行无偏PU学习的问题。现有的无偏PU学习方法默认类别先验已知或易于估计,而现实PU问题中的类别先验往往未知,且难以估计。另外,已有的类别先验估计算法则主要针对传统的机器学习分类器进行设计,没有发挥深度学习在大规模数据集的优势。为了克服以上问题,本发明提出了一个基于无监督混合模型的迭代式深度PU学习框架。它利用了深度神经网络对不同的类别的样本(正样本和负样本)给出的预测分数具有不同的分布这一特性,使用高斯混合模型近似拟合预测分数的混合分布。再结合常用的半监督学习技术和针对PU问题的优化目标,实现了与基于真实正类先验的PU算法相媲美的分类性能。
本发明包括以下关键技术点:
关键点1,深度神经网络发生过拟合现象之前,正样本的预测分数与负样本的预测分数呈现出不同的分布,正样本的预测分数集中分布在分值较高的区间,而负样本的预测分数集中分布在分值较低的区间,两者分别构成了两条中间高、两端低的钟型曲线。基于上述观察,提出使用高斯混合模型(Gaussian Mixture Model,GMM)对预测分数进行无监督式地建模。GMM的参数量少,不受PU学习缺失负类标签的约束,其求解所需的时间复杂度、空间复杂度均为。因而,所提方法占有的计算资源少,可广泛应用于各种规模的数据集;
关键点2,同时考虑类别先验估计和PU学习,提出针对PU问题的迭代式解决方案,即模型的训练与GMM估计类别先验迭代进行;在理想情况下,随着时间的推移,深度神经网络的分类性能若越来越好,那么它给出的预测分数也应当趋于可靠;若预测分数愈发可信,那么GMM估计的正类先验将更加准确,并将进一步促进模型分类性能的提升。模型的训练和GMM估计类别先验互相存在着正向反馈,这一特性能为迭代式的框架所用;
关键点3,为框架引入半监督学***均教师(Mean Teacher)和温度锐化(Temperature Sharpening)。未引入时,类别先验估计不够稳定,先验估计值随着训练epoch数的增长始终震荡,方差较大。引入后,平均教师通过历史参数的平均保证了预测分数的平稳;而温度锐化鼓励预测分数向0或1靠近,使得预测分数中不同类别对应的钟形曲线更可区分,GMM的拟合效果因之得到加强。二者共同作用,平稳了类别先验的估计值,有效提升了算法分类性能。
为让本发明的上述特征和效果能阐述的更明确易懂,下文特举实施例,并配合说明书附图作详细说明如下。
本发明提出了一种新的迭代式深度PU学***均Exponential Moving Average运算得到);S=f(X,Θ),S′=f(·,Θ′)分别为学生模型、教师模型各自的预测分数;对教师模型的预测分数S′使用高斯混合建模GMM,求解以估计正类先验
Figure BDA0003429906560000071
由此,便可以借助
Figure BDA0003429906560000072
计算非负PU风险
Figure BDA0003429906560000073
结合平均教师的一致性损失
Figure BDA0003429906560000074
和温度锐化损失
Figure BDA0003429906560000075
进而计算优化目标
Figure BDA0003429906560000076
最后使用梯度反向传播算法对学生模型的参数Θ进行更新。上述步骤如此迭代反复,直至优化目标
Figure BDA0003429906560000077
收敛。接下来详细阐述各个步骤的计算过程。
其中,文本类的任务,例如恶意URL检测和虚假评论检测,它们的学生模型和教师模型可采用循环神经网络。图像类的任务,例如冷冻电镜的粒子拾取,可采用卷积神经网络。不同应用的模型、置信度阈值τ、温度T、超参数λ1、λ2均可视实际情况自由设置。预测分数可以看作模型给出的正类概率。它由样本经过深度神经网络和sigmoid变换得到。
(1)计算预测分数S,S′;
给定三元组(x,y,z),x是输入特征,y是其类别标签,z∈{1,0}表示其类别标签的有无。PU学习中,由于真实的y未知,因此PU训练集通常由若干个二元组(x,z)组成。此外,有标签的样本均为正类,即有Pr(y=1|z=1)=1。样本集合X=Xl∪Xu,Xl为有标签的正样本子集,Xu为无标签样本子集。t时刻学生模型和教师模型的参数分别用Θt,Θ′t表示,其对应的预测函数定义分别为f(·,Θt),f(·,Θ′t):
Figure BDA0003429906560000078
那么,学生模型、教师模型各自的预测分数分别为:
S=sigmoid(f(X,Θt)),
S′=sigmoid(f(X,Θ′t)).
其中,
Figure BDA0003429906560000079
使用GMM对S′建模,以获取正类先验πp的估计值
Figure BDA00034299065600000710
此公式的·是任意实数。由于
Figure BDA00034299065600000711
不存在闭式解,因此需要使用EM算法迭代逼近(见M步和参数更新方程)。此步骤使用GMM对S′建模以估计正类先验是本发明的发明点之一,其带来的技术进步是适用于深度学***稳器(平均教师和温度锐化),其带来的技术进步是正类先验估计过程更加准确平稳。以及正类先验估计与深度模型迭代进行,其带来的技术进步是二者互相促进,使得正类先验估计更准确,模型分类性能更好。
GMM是一种无监督建模方法,其对预测分数S′的建模如下:
Figure BDA0003429906560000081
Figure BDA0003429906560000082
Figure BDA0003429906560000083
其中,类别标签y是隐变量;πp是混合系数,同时也表示正类先验;πn=1-πp
Figure BDA0003429906560000084
分别表示正(负)样本预测分数所服从的高斯分布;高斯分布
Figure BDA0003429906560000085
μ,σ分别表示其均值和方差;因为正样本的预测分数整体大于负样本,所以μn<μp
通常使用最大期望(Expectation Maximization,EM)算法对GMM进行求解。设
Figure BDA0003429906560000086
Figure BDA0003429906560000087
表示t时刻GMM的参数值,
Figure BDA0003429906560000088
表示参数为Φ(t)的高斯分布,EM算法交替迭代地执行(1)期望(E)步,即计算s由
Figure BDA0003429906560000089
生成的条件概率:
Figure BDA00034299065600000810
和(2)最大化(M)步,即:
Figure BDA00034299065600000811
Figure BDA00034299065600000812
直至收敛。根据文献,Φ(t+1)的参数更新方程可展开为以下形式:
Figure BDA00034299065600000813
Figure BDA00034299065600000814
Figure BDA0003429906560000091
Figure BDA0003429906560000092
(2)计算优化目标
Figure BDA0003429906560000093
步骤(2)得到了正类先验的估计值
Figure BDA0003429906560000094
由此,借助
Figure BDA0003429906560000095
计算非负PU风险
Figure BDA0003429906560000096
Figure BDA0003429906560000097
Figure BDA0003429906560000098
Figure BDA0003429906560000099
Figure BDA00034299065600000910
将模型输出的置信度c定义为样本经过学生模型输出得到的类别概率最大值,即:
c=max(s,1-s).
利用基于置信度的掩码技术,将置信度阈值设为τ,那么学生模型和教师模型输出的一致性损失
Figure BDA00034299065600000911
为:
Figure BDA00034299065600000912
其中,xi∈X表示样本集合X中的第i个样本,ci是其置信度,N=|X|;
Figure BDA00034299065600000913
是指示函数,当满足条件(·)时函数值取1,反之取0。且该一致性损失
Figure BDA00034299065600000914
可平稳类别先验估计过程,最终提升深度模型的分类性能。
给定预测分数s,使用锐化函数来降低其关于类别分布的信息熵。温度锐化通过调整类别分布的温度(Temperature,T),实现上述目的,公式如下:
Figure BDA00034299065600000915
结合平均教师中所提及的基于置信度的掩码技术,只对可信的输出进行温度锐化,那么其损失
Figure BDA00034299065600000916
为:
Figure BDA00034299065600000917
综上所述,最终的优化目标
Figure BDA00034299065600000918
Figure BDA0003429906560000101
其中,λ1,λ2为超参数。
(3)更新模型参数Θ,Θ′
运用梯度反向传播算法,对学生模型的参数Θ进行更新:
Figure BDA0003429906560000102
其中,η表示学习率,表示
Figure BDA0003429906560000103
优化目标
Figure BDA0003429906560000104
关于模型参数Θ的导数。通过EMA运算更新教师模型的参数Θ′:
Θ′t+1=αΘ′t+(1-α)Θt+1.
其中,α是平滑系数,取值在[0,1]内。
迭代执行步骤(1)-(4),直至优化目标
Figure BDA0003429906560000105
收敛。算法收敛时的
Figure BDA0003429906560000106
便是最终正类先验的估计值,f(·,Θ),f(·,Θ′)为所得分类器。
以下为与上述方法实施例对应的***实施例,本实施方式可与上述实施方式互相配合实施。上述实施方式中提到的相关技术细节在本实施方式中依然有效,为了减少重复,这里不再赘述。相应地,本实施方式中提到的相关技术细节也可应用在上述实施方式中。
本发明还提出了一种基于深度PU学习与类别先验估计的数据分类***,其中包括:
初始模块,用于获取包括多个数据样本的训练集,且在该训练集中只有部分数据样本标有类别标签,将该训练集同时输入至两个网络结构相同,但参数不同的学生模型和教师模型中,分别得到学生模型和教师模型输出的各数据样本对应的学生预测分数和教师预测分数;
训练模块,用于将所有教师预测分数输入至高斯混合模型,得到正类先验;基于所有学生预测分数,构建温度锐化损失;基于所有学生预测分数和教师预测分数,构建一致性损失;基于该正类先验和所有学生预测分数,得到非负PU风险,合并该一致性损失、该非负PU风险和该温度锐化损失,得到目标损失,并基于该目标损失,使用梯度反向传播对该学生模型的参数进行更新,直至该目标收敛或达到预设迭代次数,保存当前学生模型或老师模型作为数据分类模型;
分类模块,用于将待分类数据输入至该数据分类模型,已得到该待分类数据的类别。
所述的基于深度PU学习与类别先验估计的数据分类***,其中当用于恶意URL检测时,该训练集中数据为部分已标注恶意类别的URL和无标签的URL,且学生模型和教师模型均为循环神经网络;当用于虚假评论检测时,该训练集中数据为已标注虚假类别的评论和无标签的评论,且学生模型和教师模型均为循环神经网络;当用于冷冻电镜的粒子拾取时,该训练集中数据为已标注选中类别的粒子区域和无标签的粒子区域,且学生模型和教师模型均为卷积神经网络。
所述的基于深度PU学习与类别先验估计的数据分类***,其中该训练模块用于基于下式为该学生模型、该教师模型各自的预测分数:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
其中,
Figure BDA0003429906560000111
X为训练集,S为学生模型输出的学生预测分数,S′为该教师模型输出的教师预测分数,Θt为学生模型在t时刻的参数,Θ′t为教师模型在t时刻的参数。
所述的基于深度PU学习与类别先验估计的数据分类***,其中该学生模型和该教师模型输出的一致性损失
Figure BDA0003429906560000112
为:
Figure BDA0003429906560000113
其中,xi∈X表示训练集X中的第i个样本,ci是其置信度,N=|X|;
Figure BDA0003429906560000114
是指示函数,当满足条件(·)时函数值取1,反之取0;τ为置信度阈值,Θ为学生模型的参数,Θ′为教师模型的参数。
所述的基于深度PU学习与类别先验估计的数据分类***,其中该温度锐化损失
Figure BDA0003429906560000115
为:
Figure BDA0003429906560000116
Figure BDA0003429906560000117
式中T为类别分布的温度,s为学生模型输出的学生预测分数。
本发明还提出了一种存储介质,用于存储执行所述任意一种基于深度PU学习与类别先验估计的数据分类方法的程序。
本发明还提出了一种客户端,用于上述任意一种基于深度PU学习与类别先验估计的数据分类***。

Claims (12)

1.一种基于基于深度PU学习与类别先验估计的数据分类方法,其特征在于,包括:
步骤1、获取包括多个数据样本的训练集,且在该训练集中只有部分数据样本标有类别标签,将该训练集同时输入至两个网络结构相同,但参数不同的学生模型和教师模型中,分别得到学生模型和教师模型输出的各数据样本对应的学生预测分数和教师预测分数;
步骤2、将所有教师预测分数输入至高斯混合模型,得到正类先验;基于所有学生预测分数,构建温度锐化损失;基于所有学生预测分数和教师预测分数,构建一致性损失;基于该正类先验和所有学生预测分数,得到非负PU风险,合并该一致性损失、该非负PU风险和该温度锐化损失,得到目标损失,并基于该目标损失,使用梯度反向传播对该学生模型的参数进行更新,直至该目标收敛或达到预设迭代次数,保存当前学生模型或老师模型作为数据分类模型;
步骤3、将待分类数据输入至该数据分类模型,已得到该待分类数据的类别。
2.如权利要求1所述的基于深度PU学习与类别先验估计的数据分类方法,其特征在于,当用于恶意URL检测时,该训练集中数据为部分已标注恶意类别的URL和无标签的URL,且学生模型和教师模型均为循环神经网络;当用于虚假评论检测时,该训练集中数据为已标注虚假类别的评论和无标签的评论,且学生模型和教师模型均为循环神经网络;当用于冷冻电镜的粒子拾取时,该训练集中数据为已标注选中类别的粒子区域和无标签的粒子区域,且学生模型和教师模型均为卷积神经网络。
3.如权利要求1或2所述的基于深度PU学习与类别先验估计的数据分类方法,其特征在于,该步骤2包括:
该学生模型、该教师模型各自的预测分数分别为:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
其中,
Figure FDA0003429906550000011
X为训练集,S为学生模型输出的学生预测分数,S′为该教师模型输出的教师预测分数,Θt为学生模型在t时刻的参数,Θ′t为教师模型在t时刻的参数。
4.如权利要求1或2所述的基于深度PU学习与类别先验估计的数据分类方法,其特征在于,该学生模型和该教师模型输出的一致性损失
Figure FDA0003429906550000021
为:
Figure FDA0003429906550000022
其中,xi∈X表示训练集X中的第i个样本,ci是其置信度,N=|X|;
Figure FDA0003429906550000023
是指示函数,当满足条件(·)时函数值取1,反之取0;τ为置信度阈值,Θ为学生模型的参数,Θ′为教师模型的参数。
5.如权利要求4所述的基于深度PU学习与类别先验估计的数据分类方法,其特征在于,该温度锐化损失
Figure FDA0003429906550000024
为:
Figure FDA0003429906550000025
Figure FDA0003429906550000026
式中T为类别分布的温度,s为学生模型输出的学生预测分数。
6.一种基于深度PU学习与类别先验估计的数据分类***,其特征在于,包括:
初始模块,用于获取包括多个数据样本的训练集,且在该训练集中只有部分数据样本标有类别标签,将该训练集同时输入至两个网络结构相同,但参数不同的学生模型和教师模型中,分别得到学生模型和教师模型输出的各数据样本对应的学生预测分数和教师预测分数;
训练模块,用于将所有教师预测分数输入至高斯混合模型,得到正类先验;基于所有学生预测分数,构建温度锐化损失;基于所有学生预测分数和教师预测分数,构建一致性损失;基于该正类先验和所有学生预测分数,得到非负PU风险,合并该一致性损失、该非负PU风险和该温度锐化损失,得到目标损失,并基于该目标损失,使用梯度反向传播对该学生模型的参数进行更新,直至该目标收敛或达到预设迭代次数,保存当前学生模型或老师模型作为数据分类模型;
分类模块,用于将待分类数据输入至该数据分类模型,已得到该待分类数据的类别。
7.如权利要求6所述的基于深度PU学习与类别先验估计的数据分类***,其特征在于,当用于恶意URL检测时,该训练集中数据为部分已标注恶意类别的URL和无标签的URL,且学生模型和教师模型均为循环神经网络;当用于虚假评论检测时,该训练集中数据为已标注虚假类别的评论和无标签的评论,且学生模型和教师模型均为循环神经网络;当用于冷冻电镜的粒子拾取时,该训练集中数据为已标注选中类别的粒子区域和无标签的粒子区域,且学生模型和教师模型均为卷积神经网络。
8.如权利要求6或7所述的基于深度PU学习与类别先验估计的数据分类***,其特征在于,该训练模块用于基于下式为该学生模型、该教师模型各自的预测分数:
S=sigmoid(f(X,Θt))
S′=sigmoid(f(X,Θ′t))
其中,
Figure FDA0003429906550000031
X为训练集,S为学生模型输出的学生预测分数,S′为该教师模型输出的教师预测分数,Θt为学生模型在t时刻的参数,Θ′t为教师模型在t时刻的参数。
9.如权利要求6或7所述的基于深度PU学习与类别先验估计的数据分类***,其特征在于,该学生模型和该教师模型输出的一致性损失
Figure FDA0003429906550000032
为:
Figure FDA0003429906550000033
其中,xi∈X表示训练集X中的第i个样本,ci是其置信度,N=|X|;
Figure FDA0003429906550000034
是指示函数,当满足条件(·)时函数值取1,反之取0;τ为置信度阈值,Θ为学生模型的参数,Θ′为教师模型的参数。
10.如权利要求9所述的基于深度PU学习与类别先验估计的数据分类***,其特征在于,该温度锐化损失
Figure FDA0003429906550000035
为:
Figure FDA0003429906550000036
Figure FDA0003429906550000037
式中T为类别分布的温度,s为学生模型输出的学生预测分数。
11.一种存储介质,用于存储执行如权利要求1到5所述任意一种基于深度PU学习与类别先验估计的数据分类方法的程序。
12.一种客户端,用于权利要求6至10中任意一种基于深度PU学习与类别先验估计的数据分类***。
CN202111591020.7A 2021-12-23 2021-12-23 基于深度pu学习与类别先验估计的数据分类方法及*** Pending CN114417975A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111591020.7A CN114417975A (zh) 2021-12-23 2021-12-23 基于深度pu学习与类别先验估计的数据分类方法及***

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111591020.7A CN114417975A (zh) 2021-12-23 2021-12-23 基于深度pu学习与类别先验估计的数据分类方法及***

Publications (1)

Publication Number Publication Date
CN114417975A true CN114417975A (zh) 2022-04-29

Family

ID=81266728

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111591020.7A Pending CN114417975A (zh) 2021-12-23 2021-12-23 基于深度pu学习与类别先验估计的数据分类方法及***

Country Status (1)

Country Link
CN (1) CN114417975A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115859106A (zh) * 2022-12-05 2023-03-28 中国地质大学(北京) 一种基于半监督学习的矿产勘探方法、装置和存储介质
CN117574258A (zh) * 2024-01-15 2024-02-20 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 一种基于文本噪声标签和协同训练策略的文本分类方法

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115859106A (zh) * 2022-12-05 2023-03-28 中国地质大学(北京) 一种基于半监督学习的矿产勘探方法、装置和存储介质
CN117574258A (zh) * 2024-01-15 2024-02-20 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 一种基于文本噪声标签和协同训练策略的文本分类方法
CN117574258B (zh) * 2024-01-15 2024-04-26 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 一种基于文本噪声标签和协同训练策略的文本分类方法

Similar Documents

Publication Publication Date Title
CN112308158B (zh) 一种基于部分特征对齐的多源领域自适应模型及方法
CN113378632B (zh) 一种基于伪标签优化的无监督域适应行人重识别方法
CN109190524B (zh) 一种基于生成对抗网络的人体动作识别方法
CN109376242B (zh) 基于循环神经网络变体和卷积神经网络的文本分类方法
US11816183B2 (en) Methods and systems for mining minority-class data samples for training a neural network
CN111126488B (zh) 一种基于双重注意力的图像识别方法
CN113326731B (zh) 一种基于动量网络指导的跨域行人重识别方法
CN106778796B (zh) 基于混合式协同训练的人体动作识别方法及***
CN114492574A (zh) 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法
CN112085055B (zh) 一种基于迁移模型雅克比阵特征向量扰动的黑盒攻击方法
CN110929848B (zh) 基于多挑战感知学习模型的训练、跟踪方法
CN107945210B (zh) 基于深度学习和环境自适应的目标跟踪方法
CN110097060B (zh) 一种面向树干图像的开集识别方法
CN111564179B (zh) 一种基于三元组神经网络的物种生物学分类方法及***
CN114417975A (zh) 基于深度pu学习与类别先验估计的数据分类方法及***
CN109840595B (zh) 一种基于群体学习行为特征的知识追踪方法
CN110728694A (zh) 一种基于持续学习的长时视觉目标跟踪方法
CN112232395B (zh) 一种基于联合训练生成对抗网络的半监督图像分类方法
CN112784921A (zh) 任务注意力引导的小样本图像互补学习分类算法
CN117611932B (zh) 基于双重伪标签细化和样本重加权的图像分类方法及***
CN113743474A (zh) 基于协同半监督卷积神经网络的数字图片分类方法与***
Demirel et al. Meta-tuning loss functions and data augmentation for few-shot object detection
Qiao et al. A multi-level thresholding image segmentation method using hybrid Arithmetic Optimization and Harris Hawks Optimizer algorithms
Xia et al. Detecting smiles of young children via deep transfer learning
CN105809200B (zh) 一种生物启发式自主抽取图像语义信息的方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination