CN114332539A - 针对类别不均衡数据集的网络训练方法 - Google Patents

针对类别不均衡数据集的网络训练方法 Download PDF

Info

Publication number
CN114332539A
CN114332539A CN202111671005.3A CN202111671005A CN114332539A CN 114332539 A CN114332539 A CN 114332539A CN 202111671005 A CN202111671005 A CN 202111671005A CN 114332539 A CN114332539 A CN 114332539A
Authority
CN
China
Prior art keywords
samples
sample
class
data set
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111671005.3A
Other languages
English (en)
Inventor
任维
姚鹏
徐亮
陈明潇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Yousheng Biotechnology Co ltd
University of Science and Technology of China USTC
Original Assignee
Shenzhen Yousheng Biotechnology Co ltd
University of Science and Technology of China USTC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Yousheng Biotechnology Co ltd, University of Science and Technology of China USTC filed Critical Shenzhen Yousheng Biotechnology Co ltd
Priority to CN202111671005.3A priority Critical patent/CN114332539A/zh
Publication of CN114332539A publication Critical patent/CN114332539A/zh
Pending legal-status Critical Current

Links

Images

Landscapes

  • Image Analysis (AREA)

Abstract

本发明公开了一种针对类别不均衡数据集的网络训练方法,包括:获取目标图像数据集,确定类别数目和每类样本量;利用每类样本量计算相应类别的权值,并结合设置的超参数构建限制误差损失函数;利用所述目标图像数据集对神经网络模型进行训练,并将样本的预测结果与真实标签带入所述限制误差损失函数进行误差计算,使用反向传播不断更新神经网络模型的参数,直至网络收敛达到预期目标。构建的限制误差损失函数按类别数量进行加权,并且通过引入超参数对尾部泛化进行正则化的LDAM,可以将训练的关注度更多的偏向于数量较少的尾部类别,防止网络训练对尾部类欠拟合,可以应用于不均衡的图像数据集中,并可以显著提高网络对于不均衡数据集的识别准确度。

Description

针对类别不均衡数据集的网络训练方法
技术领域
本发明涉及深度学习技术领域,尤其涉及一种针对类别不均衡数据集的网络训练方法。
背景技术
随着人工智能的发展,深度学习在各个领域都取得了显著的成果,如今在各个领域下都有广泛的应用。随着高质量、大规模数据集(如ImageNet ILSVRC 2012,MS COCO等数据集)的使用,计算机视觉领域也取得了重大突破。但是与这些人工干预的标签均匀分布的数据集相比,现实场景中的数据集往往呈现严重的类别不均衡现象。
这种类别不均衡现象可以分为两类:一、类别的数量分布不均衡:即少数类别(头部类)的样本占据大部分的样本数据,而多数类别(尾部类)却只有少量的样本。二、类别的难易程度分布不均衡:即由于采集、人为等因素的影响,数据集本身存在某些类别与多数类别差异较大的情况,如像素较差、成像不清晰或者前景与背景占比差距较大等,这些类别往往相较于其他类别训练难度更大。
这种类别数量以及难易程度分布极端不均衡的数据集在传统深度学习方法下往往难以实现出色的图像识别精度。为了使深度学习网络能够适应这种不均衡的数据集,可以选择通过两个方面来提升网络的性能。
在类别不均衡方面:往往需要通过特定的研究来关注尾部类样本的特征,从而鼓励网络在头部类和尾部类中寻找一个最佳权衡。针对该类不均衡问题,使用最小化边缘泛化边界损失函数(LDAM Loss)可以提供给尾部类比头部类更强的正则化,在维持样本数量占据较多的头部类的准确度不下降的前提下,改善了尾部类的泛化误差。
在难易程度不均衡方面:在实际样本中,存在类间样本相似,类内样本变化很大的现象,使得在分类时有些类可以很容易的区别,而有些类却很容易混淆,难以区分。针对该类不均衡问题,焦点损失函数(Focal Loss)通过对容易分类的样本降低损失权重,从而使训练更多聚焦在困难样本的分类上,以缓解样本数据难易不均衡的情况并提高网络性能。
另一方面,随机裁剪作为一种数据增强方法,被广泛的使用于深度学习网络训练,可以大大增强模型的空间鲁棒性。但是当图像的前景区域相对于背景区域很小时,采用随机裁剪的增强方法往往会产生一些可能只包含极少量前景甚至无前景的图像,如图1所示,左侧表示裁剪前的样本,右侧表示裁剪后的样本。对于这样的样本,以及数据集中可能出现的标签错误的样本,被称为“非常困难的样本”或者“异常样本”。这些异常样本在深度学习模型训练的收敛阶段仍然可能存在较大的损失值。如果在训练的过程中,模型被强制更好地对这些异常值进行分类,那么往往会显著降低对其他大量通常样本的分类准确度。
目前还没有一种方法,可以很好的同时处理数据集中存在的样本数量不均衡、样本难易不均衡、含有“异常样本”的问题。因此,需要构建一种方法来提高在不均衡数据集和存在部分“异常样本”时深度学习网络的性能,以提升网络模型分类性能。
发明内容
本发明的目的是提供一种针对类别不均衡数据集的网络训练方法,可以广泛应用于处理现实场景中存在的不均衡问题,并可以显著提高网络训练的准确度,提升网络模型图像分类性能。
本发明的目的是通过以下技术方案实现的:
一种针对类别不均衡数据集的网络训练方法,包括:
获取目标图像数据集,确定类别数目和每类样本量;
利用每类样本量计算相应类别的权值,并结合设置的超参数构建限制误差损失函数;
利用所述目标图像数据集对神经网络模型进行训练,并将样本的预测结果与真实标签带入所述限制误差损失函数进行误差计算,使用反向传播不断更新神经网络模型的参数,直至网络收敛达到预期目标。
由上述本发明提供的技术方案可以看出,构建的限制误差损失函数按类别数量进行加权,并且通过引入超参数对尾部泛化进行正则化的LDAM,可以将训练的关注度更多的偏向于数量较少的尾部类别,防止网络训练对尾部类欠拟合;基于限制误差损失函数进行网络训练,可以广泛应用于各种不均衡的图像数据集中,并可以显著提高网络对于不均衡数据集的识别准确度。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1为本发明背景技术提供的由随机裁剪导致的异常样本出现的示意图;
图2为本发明实施例提供的一种针对类别不均衡数据集的网络训练方法的流程图;
图3为本发明实施例提供的限制误差损失函数计算流程图。
具体实施方式
下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。
首先对本文中可能使用的术语进行如下说明:
术语“包括”、“包含”、“含有”、“具有”或其它类似语义的描述,应被解释为非排它性的包括。例如:包括某技术特征要素(如原料、组分、成分、载体、剂型、材料、尺寸、零件、部件、机构、装置、步骤、工序、方法、反应条件、加工条件、参数、算法、信号、数据、产品或制品等),应被解释为不仅包括明确列出的某技术特征要素,还可以包括未明确列出的本领域公知的其它技术特征要素。
下面对本发明所提供的一种针对类别不均衡数据集的网络训练方法进行详细描述。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。本发明实施例中未注明具体条件者,按照本领域常规条件或制造商建议的条件进行。
本发明实施例提供的一种针对类别不均衡数据集的网络训练方法,主要应用在深度学习模型的训练过程中,主要原理可以描述为:首先获取目标图像数据集,根据目标数据集确定数据样本的类别数目C和每类样本量Ni,设置超参数γ、T和S,利用每类样本量计算相应类别的权值,并结合设置的超参数构建限制误差损失函数LEloss(z,y),利用所述目标图像数据集对神经网络模型进行迭代训练,并将样本的预测结果与真实标签带入所述限制误差损失函数进行误差计算,使用反向传播不断更新神经网络模型的参数,直至网络收敛达到预期目标,最终完成训练。利用此限制误差损失函数不仅可以同时处理不同数据类别的样本数量不均衡问题和分类难度不均衡问题,还可以进一步降低异常样本对训练过程的影响,可以应用于存在类别不均衡问题的数据集,从而有效缓解类不均衡问题的影响。如图2所示,上述方案的主要步骤包括:
步骤1:获取目标图像数据集,并将目标图像数据集进行常用的数据增广(随机裁剪、随机翻转等),根据目标图像数据集确定样本的类别数目C和每类样本数量Ni,Ni为第i类样本的总数。
本发明实施例中,所述目标图像数据集为类别不均衡数据集,有些类别的样本数量很多,而另外一些类别的样本数量很少。并且,同一类别中的样本也分为困难样本和简单样本,所述困难样本是指预测时与真实标签误差超过设定最大阈值的样本,简单样本是指预测时与真实标签误差未超过设定最小阈值的样本;同时,所述目标图像数据集中还包含异常样本,所述异常样本包括:通过数据增广操作产生的前景消失的样本或者只包含前景区域小于设定值的样本,或者是标签错误的样本。
本发明实施例中,所涉及的最大阈值、最小阈值等各类设定值的具体数值大小可以由用户根据实际情况或者经验设定,本发明不对具体的数值大小进行限定。
步骤2:设置限制误差损失函数的超参数γ,T,S。
本发明实施例中,γ,T和S均为超参数,γ用来调节样本(简单样本)权重降低的速率,T是用于判断样本是否是异常样本的阈值,S用来调节头部类和尾部类的分类边界偏向头部类的程度。
本发明实施例中,所述超参数γ取值大于0,取值越大,表示简单样本的权重相对于困难样本的权重降低速率越快,训练过程将更加关注困难样本的分类;所述超参数T取值在0到0.5之间,取值越大,则会将更多的样本认定为异常样本。
本发明实施例中,超参数γ与T预先利用目标图像数据集设定,在训练过程中进一步搜索寻优,进行数值优化。
步骤3:利用每类样本量计算相应类别的权值,并结合设置的超参数构建限制误差损失函数。
本发明实施例中,所述限制误差损失函数表示为:
Figure BDA0003449738270000041
其中,z=[z1,…,zC],zj表示样本在类别j上的预测值,C表示类别数目,y表示真实标签;Ny表示真实标签对应的样本量,
Figure BDA0003449738270000051
即为相应类别的权值wy;Loss(z,y)为结合超参数计算的损失函数。
本发明实施例中,权重wy用来处理不同类别样本数量不均衡问题的影响。当某一类别样本数量较少时,对应权重wy会相应较大,通过上述损失函数计算得到的损失也会增大,使得神经网络训练时更加关注这一数量较少的类别,从而缓解不同类别数量不均衡问题的影响。
本发明实施例中,损失函数Loss(z,y)计算方式表示为:
Figure BDA0003449738270000052
其中,σ是一个由超参数T决定的常值,可以表示为:σ=(1T)γlog(T)。py表示网络预测样本为真实标签的概率。
通过损失函数Loss(z,y)的计算公式可知,当py小于等于阈值T时,表明预测值偏离真实标签很远,认为该样本是异常样本,此时Loss(z,y)的值将被限定为一个常数σ,通过上述损失函数计算得到的损失值相应减少,使得神经网络训练时对该异常样本关注减少,从而使得该异常样本对训练过程造成的影响降低。同时,令w*=(1-py)γ,当py越大时,表明输出预测值越接近真实标签值,样本为简单样本,此时权重w*就会越小。通过上述损失函数计算得到的损失值相应更小,使得神经网络训练时对简单样本关注减少,更加关注困难样本,从而缓解简单样本与困难样本不均衡的影响。
本发明实施例中,参数py是通过预测值集合z归一化后获得的预测样本为真实标签的概率;为了提高分类准确性,采用LDAM方法来计算,表示为:
Figure BDA0003449738270000053
其中,
Figure BDA0003449738270000054
S是一个超参数,用来调节头部类和尾部类的分类边界偏向头部类的程度,e为自然常数。
本发明实施例中,对于py的计算,采用LDAM方法代替传统的交叉熵的方法,可以进一步明晰不同类之间的分类边界,提高分类准确性,尤其是提高少数类的分类准确性。
图3展示了限制误差损失函数计算流程。
步骤4:将限制误差损失函数用于神经网络模型训练的反向传播过程中,并利用该限制误差损失函数对不同类别数据样本和不同难易样本进行损失计算,从而缓解不同类别数量不均衡和分类难度不均衡的问题的影响,直至网络收敛,最终达到网络训练的目的。
本发明实施例中,所涉及网络参数更新流程可参照常规技术实现,本发明不做赘述,所涉及的神经网络模型可以是目前任意结构的形式的图像分类网络。
本发明实施例上述方案,主要获得如下有益效果:
1、通过引入对尾部泛化进行正则化的LDAM以及对损失函数按类别数量进行加权的方法,将训练的关注度更多的偏向于数量较少的尾部类别,防止网络训练对尾部类欠拟合。
2、可以通过降低简单类样本的权重,放大困难类样本对损失函数的贡献,将网络的训练更加集中于学习困难样本的特征。
3、可以有效缓解由于预训练过程或标签错误导致的出现异常值的情况对于神经网络模型训练带来的负面影响。
4、可以广泛应用于各种不均衡数据集中,并可以显著提高网络对于不均衡数据集的识别准确度。
为了验证本发明上述方案的有效性,以现实场景中图像的分类为例,进行了相关实验。
选取的数据集为官方数据集CIFAR10,并通过常用的不均衡数据集转化方法,将均匀的十分类的原数据集按照指数衰减的形式转化为不均衡的样本,如表1所示。
类别 飞机 汽车 鹿 青蛙 卡车
数量 5000 2997 1796 1077 645 387 232 139 83 50
表1不均衡的样本数据分布
为了对比不同损失函数的实际效果,选用Resnet-32神经网络模型,并设置超参数T=0.5与γ=1.5,在此基础上与现有的常见的损失函数(Cross-Entropy Loss(CE),LDAMLoss,FocalLoss)进行对比实验,实验结果如表2所示。
损失函数 平均准确率(%)
CE 70.54
LDAM Loss 73.35
FocalLoss 70.38
LE Loss(本发明提供的限制误差损失函数) 75.84
表2实验结果
可以观察到,以不均衡长尾数据集为训练集,在使用LE Loss进行训练时,总体上基于不均衡数据的限制误差损失函数计算方法相对于现有CE Loss、LDAM Loss、FocalLoss准确率更高。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例可以通过软件实现,也可以借助软件加必要的通用硬件平台的方式来实现。基于这样的理解,上述实施例的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。

Claims (8)

1.一种针对类别不均衡数据集的网络训练方法,其特征在于,包括:
获取目标图像数据集,确定类别数目和每类样本量;
利用每类样本量计算相应类别的权值,并结合设置的超参数构建限制误差损失函数;
利用所述目标图像数据集对神经网络模型进行训练,并将样本的预测结果与真实标签带入所述限制误差损失函数进行误差计算,使用反向传播不断更新神经网络模型的参数,直至网络收敛达到预期目标。
2.根据权利要求1所述的一种针对类别不均衡数据集的网络训练方法,其特征在于,所述限制误差损失函数表示为:
Figure FDA0003449738260000011
其中,z=[z1,...,zC],zj表示样本在类别j上的预测值,C表示类别数目,y表示真实标签;Ny表示真实标签对应的样本量,
Figure FDA0003449738260000012
即为相应类别的权值wy;Loss(z,y)为结合超参数计算的损失函数。
3.根据权利要求2所述的一种针对类别不均衡数据集的网络训练方法,其特征在于,损失函数Loss(z,y)计算方式表示为:
Figure FDA0003449738260000013
其中,γ与T均为超参数,γ用来调节样本权重降低的速率,T是用于判断样本是否是异常样本的阈值;σ是一个由超参数T决定的常值;py表示网络预测样本为真实标签的概率。
4.根据权利要求3所述的一种针对类别不均衡数据集的网络训练方法,其特征在于,参数py采用LDAM方法来计算,表示为:
Figure FDA0003449738260000014
其中,
Figure FDA0003449738260000015
S为超参数,e为自然常数。
5.根据权利要求3所述的一种针对类别不均衡数据集的网络训练方法,其特征在于,常值σ的计算方式表示为:
σ=(1-T)γlog(T)。
6.根据权利要求1所述的一种针对类别不均衡数据集的网络训练方法,其特征在于,所述目标图像数据集为类别不均衡数据集,并且,同一类别中的样本也分为困难样本和简单样本,所述困难样本是指预测时与真实标签误差超过设定最大阈值的样本,简单样本是指预测时与真实标签误差未超过设定最小阈值的样本;同时,所述目标图像数据集中还包含异常样本,所述异常样本包括:通过数据增广操作产生的前景消失的样本或者只包含前景区域小于设定值的样本,或者是标签错误的样本。
7.根据权利要求3所述的一种针对类别不均衡数据集的网络训练方法,其特征在于,所述超参数γ取值大于0,取值越大,表示简单样本的权重相对于困难样本的权重降低速率越快,训练过程将更加关注困难样本的分类;所述超参数T,取值越大,则会将更多的样本认定为异常样本;
其中,所述困难样本是指预测时与真实标签误差超过设定最大阈值的样本,简单样本是指预测时与真实标签误差未超过设定最小阈值的样本;所述异常样本包括:通过数据增广操作产生的前景消失的样本或者只包含前景区域小于设定值的样本,或者是标签错误的样本。
8.根据权利要求1或2或3所述的一种针对类别不均衡数据集的网络训练方法,其特征在于,所述超参数预先利用目标图像数据集设定,在训练过程中进行数值优化。
CN202111671005.3A 2021-12-31 2021-12-31 针对类别不均衡数据集的网络训练方法 Pending CN114332539A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111671005.3A CN114332539A (zh) 2021-12-31 2021-12-31 针对类别不均衡数据集的网络训练方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111671005.3A CN114332539A (zh) 2021-12-31 2021-12-31 针对类别不均衡数据集的网络训练方法

Publications (1)

Publication Number Publication Date
CN114332539A true CN114332539A (zh) 2022-04-12

Family

ID=81020599

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111671005.3A Pending CN114332539A (zh) 2021-12-31 2021-12-31 针对类别不均衡数据集的网络训练方法

Country Status (1)

Country Link
CN (1) CN114332539A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114863193A (zh) * 2022-07-07 2022-08-05 之江实验室 基于混合批归一化的长尾学习图像分类、训练方法及装置
CN115453889A (zh) * 2022-10-12 2022-12-09 安徽机电职业技术学院 一种基于神经网络的数控车床控制信号设置方法及***
CN115677346A (zh) * 2022-11-07 2023-02-03 北京赛乐米克材料科技有限公司 彩色锆宝石陶瓷鼻托的制备方法
CN116660389A (zh) * 2023-07-21 2023-08-29 山东大禹水务建设集团有限公司 一种基于人工智能的河道底泥探测及修复***
CN117743857A (zh) * 2023-12-29 2024-03-22 北京海泰方圆科技股份有限公司 文本纠错模型训练、文本纠错方法、装置、设备和介质

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114863193A (zh) * 2022-07-07 2022-08-05 之江实验室 基于混合批归一化的长尾学习图像分类、训练方法及装置
CN115453889A (zh) * 2022-10-12 2022-12-09 安徽机电职业技术学院 一种基于神经网络的数控车床控制信号设置方法及***
CN115677346A (zh) * 2022-11-07 2023-02-03 北京赛乐米克材料科技有限公司 彩色锆宝石陶瓷鼻托的制备方法
CN115677346B (zh) * 2022-11-07 2023-09-12 北京赛乐米克材料科技有限公司 彩色锆宝石陶瓷鼻托的制备方法
CN116660389A (zh) * 2023-07-21 2023-08-29 山东大禹水务建设集团有限公司 一种基于人工智能的河道底泥探测及修复***
CN116660389B (zh) * 2023-07-21 2023-10-13 山东大禹水务建设集团有限公司 一种基于人工智能的河道底泥探测及修复***
CN117743857A (zh) * 2023-12-29 2024-03-22 北京海泰方圆科技股份有限公司 文本纠错模型训练、文本纠错方法、装置、设备和介质

Similar Documents

Publication Publication Date Title
CN114332539A (zh) 针对类别不均衡数据集的网络训练方法
CN108038859B (zh) 基于pso和综合评价准则的pcnn图分割方法及装置
Zhang et al. Brain tumor segmentation based on hybrid clustering and morphological operations
CN109242878B (zh) 一种基于自适应布谷鸟优化法的图像多阈值分割方法
KR101113006B1 (ko) 클러스터 간 상호정보를 이용한 클러스터링 장치 및 방법
Xie et al. Image de-noising algorithm based on Gaussian mixture model and adaptive threshold modeling
Chou et al. Turbulent-PSO-based fuzzy image filter with no-reference measures for high-density impulse noise
CN109871855B (zh) 一种自适应的深度多核学习方法
CN111696046A (zh) 一种基于生成式对抗网络的水印去除方法和装置
CN114283307B (zh) 一种基于重采样策略的网络训练方法
Yuan et al. Neighborloss: a loss function considering spatial correlation for semantic segmentation of remote sensing image
CN109658378B (zh) 基于土壤ct图像的孔隙辨识方法及***
WO2023088174A1 (zh) 目标检测方法及装置
CN110991554B (zh) 一种基于改进pca的深度网络图像分类方法
Yap et al. A recursive soft-decision approach to blind image deconvolution
CN116629376A (zh) 一种基于无数据蒸馏的联邦学习聚合方法和***
Zhao et al. Underwater fish detection in sonar image based on an improved Faster RCNN
CN106651781B (zh) 一种激光主动成像的图像噪声抑制方法
CN106952287A (zh) 一种基于低秩稀疏表达的视频多目标分割方法
CN113344935B (zh) 基于多尺度难度感知的图像分割方法及***
CN115294424A (zh) 一种基于生成对抗网络的样本数据增强方法
CN110675344B (zh) 一种基于真实彩色图像自相似性的低秩去噪方法及设备
CN106296704B (zh) 通用型图像分割方法
Turan ANN Based Removal for Salt and Pepper Noise
Qian et al. Region-based pixels integration mechanism for weakly supervised semantic segmentation

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20220412

RJ01 Rejection of invention patent application after publication