CN109583485B - 一种基于反馈训练的有监督深度学习方法 - Google Patents
一种基于反馈训练的有监督深度学习方法 Download PDFInfo
- Publication number
- CN109583485B CN109583485B CN201811367393.4A CN201811367393A CN109583485B CN 109583485 B CN109583485 B CN 109583485B CN 201811367393 A CN201811367393 A CN 201811367393A CN 109583485 B CN109583485 B CN 109583485B
- Authority
- CN
- China
- Prior art keywords
- sample
- training
- deep learning
- loss value
- supervised deep
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 39
- 238000000034 method Methods 0.000 title claims abstract description 34
- 238000013135 deep learning Methods 0.000 title claims abstract description 18
- 238000005070 sampling Methods 0.000 claims abstract description 31
- 238000013136 deep learning model Methods 0.000 claims abstract description 6
- 230000006870 function Effects 0.000 claims description 12
- 238000012937 correction Methods 0.000 claims description 5
- 238000004364 calculation method Methods 0.000 claims description 4
- 238000012545 processing Methods 0.000 claims description 2
- 230000000694 effects Effects 0.000 abstract description 7
- 239000000523 sample Substances 0.000 description 40
- 239000011159 matrix material Substances 0.000 description 6
- 238000012952 Resampling Methods 0.000 description 4
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Artificial Intelligence (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及一种基于反馈训练的有监督深度学习方法,该方法在训练有监督深度学习模型的过程中,在每次迭代开始时,以一采样概率对训练集中的各样本进行抽样,所述采样概率随各样本的预测损失值动态调整。与现有技术相比,本发明通过将有监督学习训练过程中各样本的预测损失值与其自身被采样频率相关联,利用反馈训练的方式实现了调整离群样本被采样到的概率,具有提高训练效果等优点。
Description
技术领域
本发明涉及深度学习领域,尤其是涉及一种基于反馈训练的有监督深度学习方法。
背景技术
现在的有监督深度学习方法在使用过程中,需要使用大量的样本数据进行学习,为了降低有监督深度学习模型训练时对硬件的需求,通常采用小批量采样或单样本输入的方式训练模型。通常的采样方式为均匀采样或采用按序输入。
在该种情况下,大量的常规样本与小量的离群样本会有相等的概率被送入模型训练,导致模型难以学到小量离群样本的空间分布。当模型的训练目标需要检测或识别小量离群样本时,通过常规采样方式的有监督训练不仅降低了模型的准确率,而且降低了模型的训练速度。
为了解决上述问题,现有的解决方式通常为数据重采样、类别均衡采样、代价敏感矩阵与代价敏感向量的方法的方式进行训练。重采样和类别均衡采样的方式均为将不同类别的样本采样相同的个数进行训练。该方法在解决类间差异大、类内差异小的不同种类的样本数量不平衡问题上效果较好。然而当类内差异大,即出现少量离群样本时,模型极难学到其样本分布。代价敏感矩阵或代价敏感向量的方法可以通过构建混淆矩阵或代价敏感矩阵,对被错分的类别增大学习率,从而加速模型对离群样本的学习。但当离群样本存在于大样本量类别中时,由于离群样本被抽到的概率微乎其微,该方法的效果几乎可以忽略不计。
因此,为了提升离群样本的学***衡的问题,而且要解决类内样本数量不平衡的问题。而现有技术难以解决上述问题。
发明内容
本发明的目的就是为了克服上述现有技术存在的缺陷而提供一种基于反馈训练的有监督深度学习方法。
本发明的目的可以通过以下技术方案来实现:
一种基于反馈训练的有监督深度学习方法,该方法在训练有监督深度学习模型的过程中,在每次迭代开始时,以一采样概率对训练集中的各样本进行抽样,所述采样概率随各样本的预测损失值动态调整。
进一步地,所述采样概率动态调整的过程具体包括:
1)初始化各样本权重参数;
2)根据各样本当前的权重参数计算对应的采样概率:
其中,P(i)是样本i的采样概率,α为优先级系数,pi为样本i的权重参数;
3)进行一次迭代后,获得各样本的预测损失值,基于所述预测损失值更新权重参数;
4)在下一次迭代开始时,令pi=p(i),返回步骤2)。
进一步地,所述初始化各样本权重参数时,令各样本权重参数均为1。
进一步地,所述基于所述预测损失值更新权重参数具体为:
p(i)=|δ(i)|+ε
其中,p(i)为更新后的样本i的权重参数,δ(i)为样本i的预测损失值,ε为修正因子。
进一步地,所述修正因子ε为一大于0的正数。
进一步地,所述预测损失值δ(i)的表达式为:
δ(i)=L(yi,f(xi))
其中,xi为输入,yi为xi对应的真值标签,函数f为通过输入xi预测标签的函数,函数L为计算真值标签yi与预测标签f(xi)差异的损失函数。
进一步地,所述基于所述预测损失值更新权重参数时,权重参数与预测损失值的倒数成正比。
与现有技术相比,本发明具有以如下有益效果:
第一,本发明首次提出在有监督深度学习中采用动态调整采样频率的方法,通过增加离群样本被学习的概率以使模型更快地学到整体样本空间分布,从而减少了模型训练时间并可提高模型训练效果。
第二,本发明可以结合其他采样方式(重采样、类别均衡采样、代价敏感矩阵等方式)以达到效果更好的训练效果。
第三,本发明可以逆向使用,通过降低离群样本被采样的概率增大模型学习常规样本特征的能力。
附图说明
图1为本发明训练有监督深度学习模型的流程示意图。
具体实施方式
下面结合附图和具体实施例对本发明进行详细说明。本实施例以本发明技术方案为前提进行实施,给出了详细的实施方式和具体的操作过程,但本发明的保护范围不限于下述的实施例。
本发明提供一种基于反馈训练的有监督深度学习方法,运行于GPU中,应用于图像处理过程,该方法在训练有监督深度学习模型的过程中,在每次迭代开始时,以一采样概率对训练集中的各样本进行抽样,所述采样概率随各样本的预测损失值动态调整。
采样概率动态调整的过程具体包括:
1)初始化各样本权重参数pi=1;
2)根据各样本当前的权重参数计算对应的采样概率:
其中,P(i)是样本i的采样概率,pi为样本i的权重参数,α为优先级系数,值越大则代表优先级越大,当α取0时为均匀采样;
3)进行一次迭代后,获得各样本的预测损失值,更新权重参数:
p(i)=|δ(i)|+ε
其中,p(i)为更新后的样本i的权重参数,δ(i)为样本i的预测损失值,ε为修正因子,可以取10-5等很小的正常数,以防止δ(i)=0时x0将不会再被抽样;
4)在下一次迭代开始时,令pi=p(i),返回步骤2)。
预测损失值δ(i)的表达式为:
δ(i)=L(yi,f(xi))
其中,xi为输入,yi为xi对应的真值标签,函数f为通过输入xi预测标签的函数,函数L为计算真值标签yi与预测标签f(xi)差异的损失函数。
上述方法可以结合其他采样方式(如重采样、类别均衡采样、代价敏感矩阵等方式)以达到效果更好的训练效果。以与类别均衡采样结合为例,从大量的样本类与小量的样本类中分别采集同量的样本,类内采集概率均按权重值计算。
上述方法可以逆向使用,通过降低离群样本被采样的概率增大模型学习常规样本特征的能力。例如在使用自动编码器(Auto-encoder)时,其需要学习更多正常标准样本的特征,此时需要采样更多的正常样本,通过将损失值的倒数作为自身权重计算概率时,离群样本会被更少地采样。
如图1所示,基于上述采样概率动态调整的有监督深度学习模型训练过程具体为:
在步骤401中,预先读入所有图片样本与其对应分类标签的信息;
在步骤402中,对所有读入的图像样本信息进行采集权值初始化,初始化值为1;
在步骤403中,计算各图像样本的采集概率;
在步骤404中,根据各图像样本的采集概率采集图像与其对应的分类标签;
在步骤405中,将采集到的图像送入有监督深度学习网络模型训练,并与其损失值;
在步骤406中,判断该有监督深度学习网络模型是否达到训练迭代次数上限,若达到上限则终止训练,否则执行步骤407;
在步骤407中,利用步骤405计算所得各图像样本的损失值;
在步骤408中,更新各样本的权重,完成后执行步骤403。
以上详细描述了本发明的较佳具体实施例。应当理解,本领域的普通技术人员无需创造性劳动就可以根据本发明的构思作出诸多修改和变化。因此,凡本技术领域中技术人员依本发明的构思在现有技术的基础上通过逻辑分析、推理或者有限的实验可以得到的技术方案,皆应在由权利要求书所确定的保护范围内。
Claims (4)
1.一种基于反馈训练的有监督深度学习方法,其特征在于,该方法应用于图像处理过程,在训练有监督深度学习模型的过程中,在每次迭代开始时,以一采样概率对训练集中的各样本进行抽样,所述采样概率随各样本的预测损失值动态调整,该方法包括以下步骤:
步骤401,预先读入所有图片样本与其对应分类标签的信息;
步骤402,对所有读入的图像样本信息进行采集权值初始化,初始化值为1;
步骤403,计算各图像样本的采集概率;
步骤404,根据各图像样本的采集概率采集图像与其对应的分类标签;
步骤405,将采集到的图像送入有监督深度学习网络模型训练,并计算其损失值;
步骤406,判断该有监督深度学习网络模型是否达到训练迭代次数上限,若达到上限则终止训练,否则执行步骤407;
步骤407,获取步骤405计算所得各图像样本的损失值;
步骤408,更新各样本的权重,完成后执行步骤403;
所述采样概率动态调整的过程具体包括:
1)初始化各样本权重参数;
2)根据各样本当前的权重参数计算对应的采样概率:
其中,P(i)是样本i的采样概率,α为优先级系数,pi为样本i的权重参数;
3)进行一次迭代后,获得各样本的预测损失值,基于所述预测损失值更新权重参数;
4)在下一次迭代开始时,令pi=p(i),返回步骤2);
所述基于所述预测损失值更新权重参数具体为:
p(i)=|δ(i)|+ε
其中,p(i)为更新后的样本i的权重参数,δ(i)为样本i的预测损失值,ε为修正因子。
2.根据权利要求1所述的基于反馈训练的有监督深度学习方法,其特征在于,所述修正因子ε为一大于0的正数。
3.根据权利要求1所述的基于反馈训练的有监督深度学习方法,其特征在于,所述预测损失值δ(i)的表达式为:
δ(i)=L(yi,f(xi))
其中,xi为输入,yi为xi对应的真值标签,函数f为通过输入xi预测标签的函数,函数L为计算真值标签yi与预测标签f(xi)差异的损失函数。
4.根据权利要求1所述的基于反馈训练的有监督深度学习方法,其特征在于,所述基于所述预测损失值更新权重参数时,权重参数与预测损失值的倒数成正比。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811367393.4A CN109583485B (zh) | 2018-11-16 | 2018-11-16 | 一种基于反馈训练的有监督深度学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811367393.4A CN109583485B (zh) | 2018-11-16 | 2018-11-16 | 一种基于反馈训练的有监督深度学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN109583485A CN109583485A (zh) | 2019-04-05 |
CN109583485B true CN109583485B (zh) | 2023-12-08 |
Family
ID=65922667
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201811367393.4A Active CN109583485B (zh) | 2018-11-16 | 2018-11-16 | 一种基于反馈训练的有监督深度学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN109583485B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112052900B (zh) * | 2020-09-04 | 2024-05-24 | 京东科技控股股份有限公司 | 机器学习样本权重调整方法和装置、存储介质 |
CN113420792A (zh) * | 2021-06-03 | 2021-09-21 | 阿波罗智联(北京)科技有限公司 | 图像模型的训练方法、电子设备、路侧设备及云控平台 |
CN116484744B (zh) * | 2023-05-12 | 2024-01-16 | 北京百度网讯科技有限公司 | 物体仿真方法、模型训练方法、装置、设备及存储介质 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104102700A (zh) * | 2014-07-04 | 2014-10-15 | 华南理工大学 | 一种面向因特网不平衡应用流的分类方法 |
CN105096375A (zh) * | 2014-05-09 | 2015-11-25 | 三星电子株式会社 | 图像处理方法和设备 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20150332169A1 (en) * | 2014-05-15 | 2015-11-19 | International Business Machines Corporation | Introducing user trustworthiness in implicit feedback based search result ranking |
-
2018
- 2018-11-16 CN CN201811367393.4A patent/CN109583485B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN105096375A (zh) * | 2014-05-09 | 2015-11-25 | 三星电子株式会社 | 图像处理方法和设备 |
CN104102700A (zh) * | 2014-07-04 | 2014-10-15 | 华南理工大学 | 一种面向因特网不平衡应用流的分类方法 |
Non-Patent Citations (1)
Title |
---|
中心损失与Softmax损失联合监督下的人脸识别;余成波;田桐;熊递恩;许琳英;;重庆大学学报(第05期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN109583485A (zh) | 2019-04-05 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112633384B (zh) | 基于图像识别模型的对象识别方法、装置和电子设备 | |
CN109583485B (zh) | 一种基于反馈训练的有监督深度学习方法 | |
CN113688933B (zh) | 分类网络的训练方法及分类方法和装置、电子设备 | |
CN114283287B (zh) | 基于自训练噪声标签纠正的鲁棒领域自适应图像学习方法 | |
CN112470160A (zh) | 个性化自然语言理解的装置和方法 | |
KR20200145827A (ko) | 얼굴 특징 추출 모델 학습 방법, 얼굴 특징 추출 방법, 장치, 디바이스 및 저장 매체 | |
WO2020186887A1 (zh) | 一种连续小样本图像的目标检测方法、装置及设备 | |
CN108229673B (zh) | 卷积神经网络的处理方法、装置和电子设备 | |
CN108320306B (zh) | 融合tld和kcf的视频目标跟踪方法 | |
US20220092407A1 (en) | Transfer learning with machine learning systems | |
CN110458022B (zh) | 一种基于域适应的可自主学习目标检测方法 | |
CN113469186A (zh) | 一种基于少量点标注的跨域迁移图像分割方法 | |
CN116894985B (zh) | 半监督图像分类方法及半监督图像分类*** | |
CN116740384B (zh) | 洗地机的智能控制方法及*** | |
CN114937025A (zh) | 图像分割方法、模型训练方法、装置、设备及介质 | |
CN115861462A (zh) | 图像生成模型的训练方法、装置、电子设备及存储介质 | |
CN113642635B (zh) | 模型训练方法及装置、电子设备和介质 | |
CN111291902A (zh) | 后门样本的检测方法、装置和电子设备 | |
CN111239137A (zh) | 基于迁移学习与自适应深度卷积神经网络的谷物质量检测方法 | |
CN110768864B (zh) | 一种网络流量批量生成图像的方法及装置 | |
CN117408959A (zh) | 模型的训练方法、缺陷检测方法、装置、电子设备及介质 | |
CN115511012B (zh) | 一种最大熵约束的类别软标签识别训练方法 | |
CN116958548A (zh) | 基于类别统计驱动的伪标签自蒸馏语义分割方法 | |
CN114758130B (zh) | 图像处理及模型训练方法、装置、设备和存储介质 | |
CN116630745A (zh) | 用于图像的端到端半监督目标检测方法、装置和可读介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |