CN113240113B - 一种增强网络预测鲁棒性的方法 - Google Patents

一种增强网络预测鲁棒性的方法 Download PDF

Info

Publication number
CN113240113B
CN113240113B CN202110623241.1A CN202110623241A CN113240113B CN 113240113 B CN113240113 B CN 113240113B CN 202110623241 A CN202110623241 A CN 202110623241A CN 113240113 B CN113240113 B CN 113240113B
Authority
CN
China
Prior art keywords
model
deep learning
data
sub
models
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110623241.1A
Other languages
English (en)
Other versions
CN113240113A (zh
Inventor
李瑞瑞
刘嘉润
赵伟
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Futong Oriental Technology Co ltd
Original Assignee
Beijing Futong Oriental Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Futong Oriental Technology Co ltd filed Critical Beijing Futong Oriental Technology Co ltd
Priority to CN202110623241.1A priority Critical patent/CN113240113B/zh
Publication of CN113240113A publication Critical patent/CN113240113A/zh
Application granted granted Critical
Publication of CN113240113B publication Critical patent/CN113240113B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Physics (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种增强网络预测鲁棒性的方法,包括选取深度学习子模型,设置深度学习子模型的训练超参数及损失函数,对深度学习子模型进行初始化,对深度学习子模型进行改进的互学习训练,保存经过训练的增强深度学习子模型,对经过训练的增强深度学习子模型进行性能测试。通过上述方式,本发明能够在数据分布不均衡、数据标注存在错误、数据量不足等导致深度网络模型难以克服认知偏差的情况下,在互学习的过程中就对网络进行引导整合,最终通过所有模型学习到的知识共同做出正确的预测,有效提升模型的准确率。

Description

一种增强网络预测鲁棒性的方法
技术领域
本发明涉及数据预测技术领域,特别是涉及一种增强网络预测鲁棒性的方法。
背景技术
深度模型的精度不足和结果不确定性在深度学习模型的落地过程中造成了很大的阻碍。
用于深度学习的数据集对训练出模型的好坏有直接的影响,诸如数据不足、数据不均衡、数据不准确都可以导致模型产生严重的认知偏差从而做出错误的预测。一般来说,通过提升数据的数量和质量可以有效的提高深度模型的准确性和泛化性,然而这会造成非常大的数据收集成本。
提升网络性能的有效方法之一是集成学习,其中心思想是分别训练多个具有不同或相同特性的预测模型,整合多个模型的预测结果进行预测。主流的模型整合方法有Bagging和Boosting两类。Bagging是对所有子模型预测结果取均值并做出预测的方法,Boosting则是根据子模型的特性和好坏,对模型的性能做出评估,根据评估结果给各个模型的预测概率赋予不同的权重,并做出最终预测。但是,集成学习的方法存在以下缺陷:
(1)不同子模型之间关联性不强,预测结果不准确的模型可能会拉低模型的整体水平;
(2)模型之间协同度不足,很难确认模型之间集成后是增益还是减益的效果;
(3)最终预测概率难以协调,不管是采用Bagging或者Boosting的方法,都无法做出最合适的预测概率整合。
互学***衡数据学习中发挥着巨大的作用。但是互学习的多个网络训练往往只保留一个,这造成了资源的浪费。
随着元学习技术的发展,元学习技术也被用于鲁棒性方法中并且取得了不错的效果,但其往往需要一个无偏的验证集作为训练目标在训练中进行指导,事实上在真实数据中我们很难获取到和应用环境无偏的验证集。
发明内容
有鉴于此,本申请提出一种增强网络预测鲁棒性的方法,用于解决数据集存在类别不均衡、数据错误、数量不足、数据噪声等情况下提高网络预测准确性,降低模型产生等认知偏差,提高网络预测鲁棒性。
本申请提供的一种增强网络预测鲁棒性的方法,包括:
步骤1:选取深度学习子模型;
进一步的,所述选取深度学习子模型,是在深度学习模型中选取n项模型作为深度学习子模型;
进一步的,所述n项模型具有相同或不同的网络结构;
进一步的,所述n为正整数。
进一步的,所述在深度学习模型中选取n项模型作为深度学习子模型,包括:根据预测任务的数据类型和任务类型,在深度学习模型中进行选取。
步骤2:设置深度学习子模型的训练超参数及损失函数;
进一步的,所述超参数,包括:学习率、批次容量、优化器、迭代次数。
步骤3:对深度学习子模型进行初始化;
步骤4:对深度学习子模型进行互学习训练,得到经过训练的增强深度学习子模型;
进一步的,所述互学习训练,包括:
S41:从数据集D中采样一批样本 ,/>,并对每一个样本按序设置索引号/>;
S42:将X输入到每一个子模型中,得到/>对/>的预测值/>:
S43:根据和/>计算子模型的预测概率对每一个样本的损失函数/>
S44:对每一个中的损失函数值进行排序并得到损失函数值最小的/>个样本的索引号/>,并根据/>选出相应的样本的损失函数得到/>
S45:加权集成每个模型的损失函数
S46:反向传播并计算每一个对模型产生的梯度;
S47:交换模型间的梯度:
S48:通过设定的优化器对模型参数进行更新;
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转至S41;
所述,为每个批次样本中保留的比例;
所述,计算结果向下取整;
其中,、/>为数据集中真实值,/>为模型预测值,/>为样本索引号,/>为根据真实值与模型预测值计算得到的损失函数,/>为模型梯度。
步骤5:保存经过训练的增强深度学习子模型;
步骤6:对经过训练的增强深度学习子模型进行性能测试。
进一步的,所述性能测试,包括:
S61:加载训练好的增强深度学习子模型;
S62:输入数据并将数据转化为增强深度学习子模型输入数据的格式:
S63:将数据输入到各个增强深度学习子模型中,得到所有增强深度学习子模型的预测概率;
S64:对增强深度学习子模型的预测结果进行加权集成:
S65:重复S61-S63,对所有测试数据进行测试并得到
S66:对结果进行验证得到经过训练的增强深度学习子模型准确度:
其中,、/>为数据集中真实值,/>为模型预测值,/>为数据数量,/>为准确度。
本发明的有益效果:在数据分布不均衡、数据标注存在错误、数据量不足导致深度网络模型难以克服认知偏差的情况下,本发明的技术方案通过在互学习的过程中就对网络进行引导整合,最终通过所有模型学习到的知识共同做出正确的预测,有效提升模型预测的准确率。
附图说明
图1是本发明一种增强网络预测鲁棒性的方法一较佳实施例的流程图。
图2是本发明一种增强网络预测鲁棒性的方法一较佳实施例的流程图。
图3是本发明一种增强网络预测鲁棒性的方法一较佳实施例的流程图。
具体实施方式
下面结合附图对本发明的较佳实施例进行详细阐述,以使本发明的优点和特征能更易于被本领域技术人员理解,从而对本发明的保护范围做出更为清楚明确的界定。
本发明的技术方案是:
步骤1:选取深度学习子模型;
进一步的,所述选取深度学习子模型,是在深度学习模型中选取n项模型作为深度学习子模型;
进一步的,所述n项模型具有相同或不同的网络结构;
进一步的,所述n为正整数。
进一步的,所述在深度学习模型中选取n项模型作为深度学习子模型,包括:根据预测任务的数据类型和任务类型,在深度学习模型中进行选取。
步骤2:设置深度学习子模型的训练超参数及损失函数;
进一步的,所述超参数,包括:学习率、批次容量、优化器、迭代次数。
步骤3:对深度学习子模型进行初始化;
步骤4:对深度学习子模型进行互学习训练,得到经过训练的增强深度学习子模型;
进一步的,所述互学习训练,包括:
S41:从数据集D中采样一批样本,/>,并对每一个样本按序设置索引号/>;
S42:将X输入到每一个子模型中,得到/>对/>的预测值/>:
S43:根据和/>计算子模型的预测概率对每一个样本的损失函数/>
S44:对每一个中的损失函数值进行排序并得到损失函数值最小的/>个样本的索引号/>,并根据/>选出相应的样本的损失函数得到/>
S45:加权集成每个模型的损失函数
S46:反向传播并计算每一个对模型产生的梯度;
S47:交换模型间的梯度:
S48:通过设定的优化器对模型参数进行更新;
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转至S41;
所述,为每个批次样本中保留的比例;
所述,计算结果向下取整;
其中,、/>为数据集中真实值,/>为模型预测值,/>为样本索引号,/>为根据真实值与模型预测值计算得到的损失函数,/>为模型梯度。
步骤5:保存经过训练的增强深度学习子模型;
步骤6:对经过训练的增强深度学习子模型进行性能测试。
进一步的,所述性能测试,包括:
S61:加载训练好的增强深度学习子模型;
S62:输入数据并将数据转化为增强深度学习子模型输入数据的格式:
S63:将数据输入到各个增强深度学习子模型中,得到所有增强深度学习子模型的预测概率;
S64:对增强深度学习子模型的预测结果进行加权集成:
S65:重复S61-S63,对所有测试数据进行测试并得到
S66:对结果进行验证得到经过训练的增强深度学习子模型准确度:
其中,、/>为数据集中真实值,/>为模型预测值,/>为数据数量,/>为准确度。
基于上述技术方案,本发明对一种增强网络预测鲁棒性的方法举一个实例说明。在本实例中,以含有20%噪声标注的皮肤显微镜图像分类为例进行对比说明。此时预测正确率为准确率。
步骤1:根据计算任务,选择使用的模型,此处综合考量预测任务的数据类型和任务类型,选择两个ResNet-50作为子模型;
步骤2:设置合适的超参数,根据硬件条件设置超参数,如下表:
步骤3:使用随机初始化算法对模型参数进行随机初始化,得到初始化后的两个ResNet-50子模型;
步骤4:对两个ResNet-50子模型进行改进的互学习训练,直至网络收敛或到达迭代次数:
S41:从数据集中按照设定的批次大小采样出一批数据d(32组)。
S42:将数据中的图像输入到网络中,得到网络对所有数据的类别预测概率。
S43:根据模型预测概率和数据标注使用CELoss计算每个模型预测概率对每个样本的损失函数。
S44:对损失函数进行排序,计算
S45:加权集成得到两个模型的损失函数
S46:反向传播损失函数,计算损失函数对网络的梯度。
S47:交换模型间的梯度。
S48:根据网络的梯度及设定的优化器对网络参数进行更新。
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转至S41;
步骤5:保存经过训练的增强深度学习子模型;
步骤6:对经过训练的增强深度学习子模型进行性能测试:
S61:加载训练好的两个ResNet-50子模型。
S62:加载测试数据并转换为网络输入数据格式。
S63:将数据分别输入子模型中,使用子模型对数据分别做出预测,得到所有模型对样本的预测结果。
S64:对预测结果进行集成。
S65:重复上述S63、S64,得到所有数据的预测结果。
S66:使用准确率对结果进行验证。
名词解释:
SGD(Stochastic Gradient Descent):随机梯度下降;
CE(Cross Entropy) Loss:交叉熵损失。
本实施例将收集的3324个皮肤镜图像诊断黑色素瘤按照6:3:1的比例划分为训练集、测试集、验证集。这些数据总共有两个类别,分别是黑色素瘤阴性和黑色素瘤阳性。
为了验证模型的鲁棒性,分别按照{0,0.05,0.1,0.2,0.3,0.4}的比例翻转数据的标注。
测试结果取最后10个Epoch的测试准确率的平均值,各方法准确率如下表:
由此可知,本申请的技术方案相比现有的方法,测试准确率明显较高,在增强网络预测鲁棒性上取得了更好的效果。
其中,Standard: 指标准的单模型神经网络训练方式;
Joint Optim: 一种基于标注修正的鲁棒性方法;
Co-teaching: 一种基于互学习的鲁棒性方法;
DivideMix: 一种以半监督方法为基础的鲁棒性方法;
Reweight: 一种基于元学习的鲁棒性方法;
Ours: 本申请中的方法。
以上所述仅为本发明的实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。

Claims (4)

1.一种增强网络预测鲁棒性的方法,其特征在于,包括以下步骤:
步骤1:对标注的皮肤显微镜图像进行分类预测,根据皮肤显微镜图像分类预测任务的数据类型和任务类型,选取深度学习子模型;
步骤2:根据硬件条件设置深度学习子模型的训练超参数及损失函数;
步骤3:使用随机初始化算法对模型参数进行随机初始化,对深度学习子模型进行初始化;将收集的皮肤显微镜图像诊断黑色素瘤数据按比例划分为训练集、测试集、验证集;
步骤4:依据所述训练集的数据对深度学习子模型进行互学习训练,得到经过训练的增强深度学习子模型;
步骤5:保存经过训练的增强深度学习子模型;
步骤6:依据所述测试集的数据对经过训练的增强深度学习子模型进行性能测试:
所述选取深度学习子模型,是在深度学习模型中选取n项模型作为深度学习子模型;
所述n项模型具有相同或不同的网络结构;
所述n为正整数:
所述互学习训练,包括:
S41:从所述训练集的数据D中采样一批样本d={(x0,y0),(x1,y1),...,(xn,yn)},X={x0,x1,...,xn},Y={y0,y1,...,yn},并对每一个样本按序设置索引号I;
S42:将X输入到每一个子模型fkk)中,得到fkk)对X的预测值
其中/>是第k个子模型中第i个样本的预测值,i,k∈n;
S43:根据和Y={y0,y1,...,yn}计算子模型的预测概率对每一个样本的损失函数lk={l0,l1,...,ln}:
其中li是第i样本的损失函数;
S44:对每一个lk中的损失函数值进行排序并得到损失函数值最小的λ×n个样本的索引号Ik,并根据Ik选出相应的样本的损失函数得到
S45:加权集成每个模型的损失函数:
S46:反向传播并计算每一个lk对模型产生的梯度δ={δ0,δ1,...,δn};
S47:交换模型间的梯度:
其中,δ0是指第0个损失函数对应梯度,δn为第n个损失函数对应梯度;
S48:通过设定的优化器对模型参数进行更新;
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转S41;
所述λ,为每个批次样本中保留的比例;
所述λ×n,计算结果向下取整;
其中,xi、yi为数据集中真实值,pk为模型预测值,Ik为样本索引号,lk为根据真实值与模型预测值计算得到的损失函数,δ为模型梯度。
2.如权利要求1所述的一种增强网络预测鲁棒性的方法,其特征在于,所述在深度学习模型中选取n项模型作为深度学习子模型,包括:根据预测任务的数据类型和任务类型,在深度学习模型中进行选取。
3.如权利要求1所述的一种增强网络预测鲁棒性的方法,其特征在于,所述超参数,包括:学习率、批次容量、优化器、迭代次数。
4.如权利要求1所述的一种增强网络预测鲁棒性的方法,其特征在于,所述性能测试,包括:
S61:加载训练好的增强深度学习子模型;
S62:输入数据并将数据转化为增强深度学习子模型输入数据的格式:dtes={(x0,y0),(x1,y1),...,(xm,ym)},其中dtes为测试集数据,m为测试集数据量;
S63:将数据输入到各个增强深度学习子模型中,得到所有增强深度学习子模型的预测概率p={p0,p1,...,pn};
S64:对增强深度学习子模型的预测结果进行加权集成:pfinal=mean(p);
S65:重复S61-S63,对所有测试数据进行测试并得到
S66:对结果进行验证得到经过训练的增强深度学习子模型准确度:
其中,xi、yi为数据集中真实值,为第i个测试集数据预测结果,m为测试集数据量,acc为准确度。
CN202110623241.1A 2021-06-04 2021-06-04 一种增强网络预测鲁棒性的方法 Active CN113240113B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110623241.1A CN113240113B (zh) 2021-06-04 2021-06-04 一种增强网络预测鲁棒性的方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110623241.1A CN113240113B (zh) 2021-06-04 2021-06-04 一种增强网络预测鲁棒性的方法

Publications (2)

Publication Number Publication Date
CN113240113A CN113240113A (zh) 2021-08-10
CN113240113B true CN113240113B (zh) 2024-05-28

Family

ID=77136814

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110623241.1A Active CN113240113B (zh) 2021-06-04 2021-06-04 一种增强网络预测鲁棒性的方法

Country Status (1)

Country Link
CN (1) CN113240113B (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113907710A (zh) * 2021-09-29 2022-01-11 山东师范大学 基于模型无关的图像增强元学习的皮肤病变分类***
CN114998613B (zh) * 2022-06-24 2024-04-26 安徽工业大学 一种基于深度互学习的多标记零样本学习方法
CN115937617B (zh) * 2023-03-06 2023-05-30 支付宝(杭州)信息技术有限公司 一种风险识别模型训练、风险控制方法、装置和设备

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108241650A (zh) * 2016-12-23 2018-07-03 北京国双科技有限公司 训练分类标准的训练方法和装置
CN110533610A (zh) * 2019-08-20 2019-12-03 东软医疗***股份有限公司 图像增强模型的生成方法及装置、应用方法及装置
CN112149556A (zh) * 2020-09-22 2020-12-29 南京航空航天大学 一种基于深度互学习和知识传递的人脸属性识别方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108040073A (zh) * 2018-01-23 2018-05-15 杭州电子科技大学 信息物理交通***中基于深度学习的恶意攻击检测方法

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108241650A (zh) * 2016-12-23 2018-07-03 北京国双科技有限公司 训练分类标准的训练方法和装置
CN110533610A (zh) * 2019-08-20 2019-12-03 东软医疗***股份有限公司 图像增强模型的生成方法及装置、应用方法及装置
CN112149556A (zh) * 2020-09-22 2020-12-29 南京航空航天大学 一种基于深度互学习和知识传递的人脸属性识别方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
刘威 ; 刘尚 ; 白润才 ; 周璇 ; 周定宁.《互学习神经网络训练方法研究》.《互学习神经网络训练方法研究》.2017,第1-18页. *

Also Published As

Publication number Publication date
CN113240113A (zh) 2021-08-10

Similar Documents

Publication Publication Date Title
CN113240113B (zh) 一种增强网络预测鲁棒性的方法
CN109587713B (zh) 一种基于arima模型的网络指标预测方法、装置及存储介质
CN111542843A (zh) 利用协作生成器积极开发
CN111860982A (zh) 一种基于vmd-fcm-gru的风电场短期风电功率预测方法
CN112465040B (zh) 一种基于类不平衡学习算法的软件缺陷预测方法
RU2517286C2 (ru) Классификация данных выборок
CN110046706A (zh) 模型生成方法、装置及服务器
CN113591988B (zh) 知识认知结构分析方法、***、计算机设备、介质、终端
CN108647772B (zh) 一种用于边坡监测数据粗差剔除的方法
CN113902129A (zh) 多模态的统一智能学习诊断建模方法、***、介质、终端
CN112307536A (zh) 一种大坝渗流参数反演方法
JP2016194914A (ja) 混合モデル選択の方法及び装置
CN111522743A (zh) 一种基于梯度提升树支持向量机的软件缺陷预测方法
CN114692507A (zh) 基于堆叠泊松自编码器网络的计数数据软测量建模方法
CN114519508A (zh) 基于时序深度学习和法律文书信息的信用风险评估方法
CN114169460A (zh) 样本筛选方法、装置、计算机设备和存储介质
CN112488188B (zh) 一种基于深度强化学习的特征选择方法
US20210103807A1 (en) Computer implemented method and system for running inference queries with a generative model
CN112597687A (zh) 一种基于少样本学习的涡轮盘结构混合可靠性分析方法
CN112200271A (zh) 一种训练样本确定方法、装置、计算机设备及存储介质
CN111652264A (zh) 基于最大均值差异的负迁移样本筛选方法
CN116523136A (zh) 基于多模型集成学习的矿产资源空间智能预测方法及装置
CN116956171A (zh) 基于ai模型的分类方法、装置、设备及存储介质
CN113128556B (zh) 基于变异分析的深度学习测试用例排序方法
CN115017125B (zh) 改进knn方法的数据处理方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant