CN113240113A - 一种增强网络预测鲁棒性的方法 - Google Patents
一种增强网络预测鲁棒性的方法 Download PDFInfo
- Publication number
- CN113240113A CN113240113A CN202110623241.1A CN202110623241A CN113240113A CN 113240113 A CN113240113 A CN 113240113A CN 202110623241 A CN202110623241 A CN 202110623241A CN 113240113 A CN113240113 A CN 113240113A
- Authority
- CN
- China
- Prior art keywords
- deep learning
- model
- submodel
- data
- enhancing
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 36
- 230000002708 enhancing effect Effects 0.000 title claims abstract description 34
- 238000013135 deep learning Methods 0.000 claims abstract description 60
- 230000006870 function Effects 0.000 claims abstract description 27
- 238000012549 training Methods 0.000 claims abstract description 24
- 238000011056 performance test Methods 0.000 claims abstract description 8
- 238000012360 testing method Methods 0.000 claims description 12
- 230000010354 integration Effects 0.000 claims description 10
- 238000013136 deep learning model Methods 0.000 claims description 6
- 239000000126 substance Substances 0.000 claims description 6
- 238000004364 calculation method Methods 0.000 claims description 4
- 230000000717 retained effect Effects 0.000 claims description 3
- 238000005070 sampling Methods 0.000 claims description 3
- 230000008569 process Effects 0.000 abstract description 7
- 230000001149 cognitive effect Effects 0.000 abstract description 4
- 238000002372 labelling Methods 0.000 abstract description 3
- 230000000694 effects Effects 0.000 description 3
- 201000001441 melanoma Diseases 0.000 description 3
- 238000012795 verification Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000000052 comparative effect Effects 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000003745 diagnosis Methods 0.000 description 1
- 238000007636 ensemble learning method Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000001000 micrograph Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000001902 propagating effect Effects 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000005728 strengthening Methods 0.000 description 1
- 239000002699 waste material Substances 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种增强网络预测鲁棒性的方法,包括选取深度学习子模型,设置深度学习子模型的训练超参数及损失函数,对深度学习子模型进行初始化,对深度学习子模型进行改进的互学习训练,保存经过训练的增强深度学习子模型,对经过训练的增强深度学习子模型进行性能测试。通过上述方式,本发明能够在数据分布不均衡、数据标注存在错误、数据量不足等导致深度网络模型难以克服认知偏差的情况下,在互学习的过程中就对网络进行引导整合,最终通过所有模型学习到的知识共同做出正确的预测,有效提升模型的准确率。
Description
技术领域
本发明涉及数据预测技术领域,特别是涉及一种增强网络预测鲁棒性的方法。
背景技术
深度模型的精度不足和结果不确定性在深度学习模型的落地过程中造成了很大的阻碍。
用于深度学习的数据集对训练出模型的好坏有直接的影响,诸如数据不足、数据不均衡、数据不准确都可以导致模型产生严重的认知偏差从而做出错误的预测。一般来说,通过提升数据的数量和质量可以有效的提高深度模型的准确性和泛化性,然而这会造成非常大的数据收集成本。
提升网络性能的有效方法之一是集成学习,其中心思想是分别训练多个具有不同或相同特性的预测模型,整合多个模型的预测结果进行预测。主流的模型整合方法有Bagging和Boosting两类。Bagging是对所有子模型预测结果取均值并做出预测的方法,Boosting则是根据子模型的特性和好坏,对模型的性能做出评估,根据评估结果给各个模型的预测概率赋予不同的权重,并做出最终预测。但是,集成学习的方法存在以下缺陷:
(1)不同子模型之间关联性不强,预测结果不准确的模型可能会拉低模型的整体水平;
(2)模型之间协同度不足,很难确认模型之间集成后是增益还是减益的效果;
(3)最终预测概率难以协调,不管是采用Bagging或者Boosting的方法,都无法做出最合适的预测概率整合。
互学***衡数据学习中发挥着巨大的作用。但是互学习的多个网络训练往往只保留一个,这造成了资源的浪费。
随着元学习技术的发展,元学习技术也被用于鲁棒性方法中并且取得了不错的效果,但其往往需要一个无偏的验证集作为训练目标在训练中进行指导,事实上在真实数据中我们很难获取到和应用环境无偏的验证集。
发明内容
有鉴于此,本申请提出一种增强网络预测鲁棒性的方法,用于解决数据集存在类别不均衡、数据错误、数量不足、数据噪声等情况下提高网络预测准确性,降低模型产生等认知偏差,提高网络预测鲁棒性。
本申请提供的一种增强网络预测鲁棒性的方法,包括:
步骤1:选取深度学习子模型;
进一步的,所述选取深度学习子模型,是在深度学习模型中选取n项模型作为深度学习子模型;
进一步的,所述n项模型具有相同或不同的网络结构;
进一步的,所述n为正整数。
进一步的,所述在深度学习模型中选取n项模型作为深度学习子模型,包括:根据预测任务的数据类型和任务类型,在深度学习模型中进行选取。
步骤2:设置深度学习子模型的训练超参数及损失函数;
进一步的,所述超参数,包括:学习率、批次容量、优化器、迭代次数。
步骤3:对深度学习子模型进行初始化;
步骤4:对深度学习子模型进行互学习训练,得到经过训练的增强深度学习子模型;
进一步的,所述互学习训练,包括:
S47:交换模型间的梯度:
S48:通过设定的优化器对模型参数进行更新;
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转至S41;
步骤5:保存经过训练的增强深度学习子模型;
步骤6:对经过训练的增强深度学习子模型进行性能测试。
进一步的,所述性能测试,包括:
S61:加载训练好的增强深度学习子模型;
S62:输入数据并将数据转化为增强深度学习子模型输入数据的格式:
S66:对结果进行验证得到经过训练的增强深度学习子模型准确度:
本发明的有益效果:在数据分布不均衡、数据标注存在错误、数据量不足导致深度网络模型难以克服认知偏差的情况下,本发明的技术方案通过在互学习的过程中就对网络进行引导整合,最终通过所有模型学习到的知识共同做出正确的预测,有效提升模型预测的准确率。
附图说明
图1是本发明一种增强网络预测鲁棒性的方法一较佳实施例的流程图。
图2是本发明一种增强网络预测鲁棒性的方法一较佳实施例的流程图。
图3是本发明一种增强网络预测鲁棒性的方法一较佳实施例的流程图。
具体实施方式
下面结合附图对本发明的较佳实施例进行详细阐述,以使本发明的优点和特征能更易于被本领域技术人员理解,从而对本发明的保护范围做出更为清楚明确的界定。
本发明的技术方案是:
步骤1:选取深度学习子模型;
进一步的,所述选取深度学习子模型,是在深度学习模型中选取n项模型作为深度学习子模型;
进一步的,所述n项模型具有相同或不同的网络结构;
进一步的,所述n为正整数。
进一步的,所述在深度学习模型中选取n项模型作为深度学习子模型,包括:根据预测任务的数据类型和任务类型,在深度学习模型中进行选取。
步骤2:设置深度学习子模型的训练超参数及损失函数;
进一步的,所述超参数,包括:学习率、批次容量、优化器、迭代次数。
步骤3:对深度学习子模型进行初始化;
步骤4:对深度学习子模型进行互学习训练,得到经过训练的增强深度学习子模型;
进一步的,所述互学习训练,包括:
S47:交换模型间的梯度:
S48:通过设定的优化器对模型参数进行更新;
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转至S41;
步骤5:保存经过训练的增强深度学习子模型;
步骤6:对经过训练的增强深度学习子模型进行性能测试。
进一步的,所述性能测试,包括:
S61:加载训练好的增强深度学习子模型;
S62:输入数据并将数据转化为增强深度学习子模型输入数据的格式:
S66:对结果进行验证得到经过训练的增强深度学习子模型准确度:
基于上述技术方案,本发明对一种增强网络预测鲁棒性的方法举一个实例说明。在本实例中,以含有20%噪声标注的皮肤显微镜图像分类为例进行对比说明。此时预测正确率为准确率。
步骤1:根据计算任务,选择使用的模型,此处综合考量预测任务的数据类型和任务类型,选择两个ResNet-50作为子模型;
步骤2:设置合适的超参数,根据硬件条件设置超参数,如下表:
步骤3:使用随机初始化算法对模型参数进行随机初始化,得到初始化后的两个ResNet-50子模型;
步骤4:对两个ResNet-50子模型进行改进的互学习训练,直至网络收敛或到达迭代次数:
S41:从数据集中按照设定的批次大小采样出一批数据d(32组)。
S42:将数据中的图像输入到网络中,得到网络对所有数据的类别预测概率。
S43:根据模型预测概率和数据标注使用CELoss计算每个模型预测概率对每个样本的损失函数。
S46:反向传播损失函数,计算损失函数对网络的梯度。
S47:交换模型间的梯度。
S48:根据网络的梯度及设定的优化器对网络参数进行更新。
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转至S41;
步骤5:保存经过训练的增强深度学习子模型;
步骤6:对经过训练的增强深度学习子模型进行性能测试:
S61:加载训练好的两个ResNet-50子模型。
S62:加载测试数据并转换为网络输入数据格式。
S63:将数据分别输入子模型中,使用子模型对数据分别做出预测,得到所有模型对样本的预测结果。
S64:对预测结果进行集成。
S65:重复上述S63、S64,得到所有数据的预测结果。
S66:使用准确率对结果进行验证。
名词解释:
SGD(Stochastic Gradient Descent):随机梯度下降;
CE(Cross Entropy) Loss:交叉熵损失。
本实施例将收集的3324个皮肤镜图像诊断黑色素瘤按照6:3:1的比例划分为训练集、测试集、验证集。这些数据总共有两个类别,分别是黑色素瘤阴性和黑色素瘤阳性。
为了验证模型的鲁棒性,分别按照{0,0.05,0.1,0.2,0.3,0.4}的比例翻转数据的标注。
测试结果取最后10个Epoch的测试准确率的平均值,各方法准确率如下表:
由此可知,本申请的技术方案相比现有的方法,测试准确率明显较高,在增强网络预测鲁棒性上取得了更好的效果。
其中,Standard: 指标准的单模型神经网络训练方式;
Joint Optim: 一种基于标注修正的鲁棒性方法;
Co-teaching: 一种基于互学习的鲁棒性方法;
DivideMix: 一种以半监督方法为基础的鲁棒性方法;
Reweight: 一种基于元学习的鲁棒性方法;
Ours: 本申请中的方法。
以上所述仅为本发明的实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。
Claims (6)
1.一种增强网络预测鲁棒性的方法,其特征在于,包括以下步骤:
步骤1:选取深度学习子模型;
步骤2:设置深度学习子模型的训练超参数及损失函数;
步骤3:对深度学习子模型进行初始化;
步骤4:对深度学习子模型进行互学习训练,得到经过训练的增强深度学习子模型;
步骤5:保存经过训练的增强深度学习子模型;
步骤6:对经过训练的增强深度学习子模型进行性能测试。
2.如权利要求1所述的一种增强网络预测鲁棒性的方法,其特征在于,所述选取深度学习子模型,是在深度学习模型中选取n项模型作为深度学习子模型;
所述n项模型具有相同或不同的网络结构;
所述n为正整数。
3.如权利要求2所述的一种增强网络预测鲁棒性的方法,其特征在于,
所述在深度学习模型中选取n项模型作为深度学习子模型,包括:根据预测任务的数据类型和任务类型,在深度学习模型中进行选取。
4.如权利要求1所述的一种增强网络预测鲁棒性的方法,其特征在于,所述超参数,包括:学习率、批次容量、优化器、迭代次数。
5.如权利要求1所述的一种增强网络预测鲁棒性的方法,其特征在于,所述互学习训练,包括:
S47:交换模型间的梯度:
S48:通过设定的优化器对模型参数进行更新;
S49:若模型收敛或到达预设迭代次数,则结束训练,否则转至S41;
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110623241.1A CN113240113B (zh) | 2021-06-04 | 2021-06-04 | 一种增强网络预测鲁棒性的方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110623241.1A CN113240113B (zh) | 2021-06-04 | 2021-06-04 | 一种增强网络预测鲁棒性的方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113240113A true CN113240113A (zh) | 2021-08-10 |
CN113240113B CN113240113B (zh) | 2024-05-28 |
Family
ID=77136814
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110623241.1A Active CN113240113B (zh) | 2021-06-04 | 2021-06-04 | 一种增强网络预测鲁棒性的方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113240113B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113907710A (zh) * | 2021-09-29 | 2022-01-11 | 山东师范大学 | 基于模型无关的图像增强元学习的皮肤病变分类*** |
CN114998613A (zh) * | 2022-06-24 | 2022-09-02 | 安徽工业大学 | 一种基于深度互学习的多标记零样本学习方法 |
CN115937617A (zh) * | 2023-03-06 | 2023-04-07 | 支付宝(杭州)信息技术有限公司 | 一种风险识别模型训练、风险控制方法、装置和设备 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108241650A (zh) * | 2016-12-23 | 2018-07-03 | 北京国双科技有限公司 | 训练分类标准的训练方法和装置 |
CN110533610A (zh) * | 2019-08-20 | 2019-12-03 | 东软医疗***股份有限公司 | 图像增强模型的生成方法及装置、应用方法及装置 |
US20200106788A1 (en) * | 2018-01-23 | 2020-04-02 | Hangzhou Dianzi University | Method for detecting malicious attacks based on deep learning in traffic cyber physical system |
CN112149556A (zh) * | 2020-09-22 | 2020-12-29 | 南京航空航天大学 | 一种基于深度互学习和知识传递的人脸属性识别方法 |
-
2021
- 2021-06-04 CN CN202110623241.1A patent/CN113240113B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108241650A (zh) * | 2016-12-23 | 2018-07-03 | 北京国双科技有限公司 | 训练分类标准的训练方法和装置 |
US20200106788A1 (en) * | 2018-01-23 | 2020-04-02 | Hangzhou Dianzi University | Method for detecting malicious attacks based on deep learning in traffic cyber physical system |
CN110533610A (zh) * | 2019-08-20 | 2019-12-03 | 东软医疗***股份有限公司 | 图像增强模型的生成方法及装置、应用方法及装置 |
CN112149556A (zh) * | 2020-09-22 | 2020-12-29 | 南京航空航天大学 | 一种基于深度互学习和知识传递的人脸属性识别方法 |
Non-Patent Citations (1)
Title |
---|
刘威;刘尚;白润才;周璇;周定宁: "《互学习神经网络训练方法研究》", 《互学习神经网络训练方法研究》, 31 March 2017 (2017-03-31), pages 1 - 18 * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113907710A (zh) * | 2021-09-29 | 2022-01-11 | 山东师范大学 | 基于模型无关的图像增强元学习的皮肤病变分类*** |
CN114998613A (zh) * | 2022-06-24 | 2022-09-02 | 安徽工业大学 | 一种基于深度互学习的多标记零样本学习方法 |
CN114998613B (zh) * | 2022-06-24 | 2024-04-26 | 安徽工业大学 | 一种基于深度互学习的多标记零样本学习方法 |
CN115937617A (zh) * | 2023-03-06 | 2023-04-07 | 支付宝(杭州)信息技术有限公司 | 一种风险识别模型训练、风险控制方法、装置和设备 |
Also Published As
Publication number | Publication date |
---|---|
CN113240113B (zh) | 2024-05-28 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113240113A (zh) | 一种增强网络预测鲁棒性的方法 | |
CN110084271B (zh) | 一种图片类别的识别方法和装置 | |
CN112508334B (zh) | 融合认知特性及试题文本信息的个性化组卷方法及*** | |
CN111241243A (zh) | 面向知识测量的试题、知识、能力张量构建与标注方法 | |
CN108536784B (zh) | 评论信息情感分析方法、装置、计算机存储介质和服务器 | |
CN111343147B (zh) | 一种基于深度学习的网络攻击检测装置及方法 | |
CN111368920A (zh) | 基于量子孪生神经网络的二分类方法及其人脸识别方法 | |
CN110110610B (zh) | 一种用于短视频的事件检测方法 | |
CN111046961B (zh) | 基于双向长短时记忆单元和胶囊网络的故障分类方法 | |
CN109102002A (zh) | 结合卷积神经网络和概念机递归神经网络的图像分类方法 | |
CN112487193B (zh) | 一种基于自编码器的零样本图片分类方法 | |
CN112784031B (zh) | 一种基于小样本学习的客服对话文本的分类方法和*** | |
CN113591988B (zh) | 知识认知结构分析方法、***、计算机设备、介质、终端 | |
CN113902129A (zh) | 多模态的统一智能学习诊断建模方法、***、介质、终端 | |
CN112150304A (zh) | 电网运行状态轨迹稳定性预判方法、***及存储介质 | |
CN114663002A (zh) | 一种自动化匹配绩效考核指标的方法及设备 | |
CN110688484B (zh) | 一种基于不平衡贝叶斯分类的微博敏感事件言论检测方法 | |
CN112307536A (zh) | 一种大坝渗流参数反演方法 | |
CN112163106A (zh) | 二阶相似感知的图像哈希码提取模型建立方法及其应用 | |
CN111144462A (zh) | 一种雷达信号的未知个体识别方法及装置 | |
CN113283467B (zh) | 一种基于平均损失和逐类选择的弱监督图片分类方法 | |
CN117591961A (zh) | 基于自归一化分类模型的脉冲星候选体识别方法及*** | |
CN112579777A (zh) | 一种未标注文本的半监督分类方法 | |
CN116956171A (zh) | 基于ai模型的分类方法、装置、设备及存储介质 | |
Basheer | Stress-strain behavior of geomaterials in loading reversal simulated by time-delay neural networks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |