CN110826688B - 一种保障gan模型最大最小损失函数平稳收敛的训练方法 - Google Patents

一种保障gan模型最大最小损失函数平稳收敛的训练方法 Download PDF

Info

Publication number
CN110826688B
CN110826688B CN201910896955.2A CN201910896955A CN110826688B CN 110826688 B CN110826688 B CN 110826688B CN 201910896955 A CN201910896955 A CN 201910896955A CN 110826688 B CN110826688 B CN 110826688B
Authority
CN
China
Prior art keywords
generator
gen
follows
updating
maximum
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN201910896955.2A
Other languages
English (en)
Other versions
CN110826688A (zh
Inventor
陈旋
吕成云
林善冬
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Jiangsu Aijia Household Products Co Ltd
Original Assignee
Jiangsu Aijia Household Products Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Jiangsu Aijia Household Products Co Ltd filed Critical Jiangsu Aijia Household Products Co Ltd
Priority to CN201910896955.2A priority Critical patent/CN110826688B/zh
Publication of CN110826688A publication Critical patent/CN110826688A/zh
Application granted granted Critical
Publication of CN110826688B publication Critical patent/CN110826688B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/217Validation; Performance evaluation; Active pattern learning techniques
    • G06F18/2193Validation; Performance evaluation; Active pattern learning techniques based on specific statistical tests
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种保障GAN模型最大最小损失函数平稳收敛的训练方法,涉及GAN模型训练方法的深度学习领域,本方法通过合理设定生成器和对抗器的参数更新条件和频率,解决上述的GAN模型训练过程中生成器与对抗器博弈失衡问题,所谓博弈失衡问题指的是,生成器和对抗器的其中一方在训练过程中快速收敛,使得另一方的优化曲面近乎不可导,从而无法顺利训练的现象。

Description

一种保障GAN模型最大最小损失函数平稳收敛的训练方法
技术领域
本发明涉及GAN模型训练方法的深度学***稳收敛的训练方法。
背景技术
GAN模型在图像生成、语音生成、文字生成等领域都有广泛应用。对GAN模型的训练,一般包括两个步骤,一步是根据生成器的损失函数梯度更新生成器的参数,一步是根据对抗器的损失函数梯度更新对抗器的参数。这两步一般交替进行,直至生成器生成的数据分布逼近真实数据分布,对抗器无法判别生成器的输出和真实数据为止。但在实际操作过程中,这种不顾生成器和对抗器的收敛情况一味地交替训练的方法会造成生成器和对抗器博弈失衡。生成器的损失函数包含了对抗器对于生成器输出的计算,当对抗器快速收敛时,如果生成器的收敛速度跟不上,大概率情况下会造成生成器的损失函数梯度不断增大直至最后梯度***,在数值计算领域不可导,从而使得优化器无法更新生成器参数,最终GAN模型训练失败。
发明内容
本发明的目的是针对背景技术的不足提供了一种保障GAN模型最大最小损失函数平稳收敛的训练方法,其通过合理设定生成器和对抗器的参数更新条件和频率,解决上述的GAN模型训练过程中生成器与对抗器博弈失衡问题。
本发明为解决上述技术问题采用以下技术方案:
一种保障GAN模型最大最小损失函数平稳收敛的训练方法,包括如下步骤:
步骤1,准备MNIST数据集;
步骤2,随机生成一个n维向量z;
步骤3,构建生成器G(.),生成器G(.)的具体结构选用反卷积结构,过程式具体如下:
G(.)=Tranpose_CNN(.);
步骤4,将步骤2生成的n维向量z传入步骤构建的生成器G(.)中,输出一个尺寸和MNIST数据集中的图片一样的矩阵Igen,过程式如下:
Igen=G(z)
步骤5,构建对抗器D(.),对抗器D(.)的具体结构选用卷积结构,过程式具体如下:
D(.)=CNN(.);
步骤6,计算生成结果的对抗分数Paen,过程式如下:
Pgen=D(Igen);
步骤7,计算生成器损失函数LG,采用交叉熵形式,过程式如下:
LG=-log(Pgen);
步骤8,从步骤1中准备的MNIST集中随机取出一张图Ireal,传入对抗器D(.)中,计算出真实图片的对抗分数Preal,过程式如下:
Preal=D(Ireal)
步骤9,计算对抗器损失函数LD,采用Wasserstein形式,过程式如下:
LD=log(Pgen)-log(Preal)
步骤10,更新生成器损失值LG动量均值
Figure BDA0002210538130000021
对抗器损失值LD的动量均值
Figure BDA0002210538130000022
若是第一次迭代,则
Figure BDA0002210538130000023
直接取值LG
Figure BDA0002210538130000024
直接取值LD;若不是第一次迭代,则更新过程式如下:
Figure BDA0002210538130000025
Figure BDA0002210538130000026
其中,γ为动量系数;
步骤11,比较两损失值的相对值LG_r、LD_r,计算过程式如下:
Figure BDA0002210538130000027
Figure BDA0002210538130000028
进而比较LG_r和LD_r的大小;根据比较结果,更新模型权重;若LG_r>LD_r,则更新生成器的参数;反之,则更新对抗器的参数;
步骤12,重复步骤2至步骤11,直到生成器输出满意的结果。
作为本发明一种保障GAN模型最大最小损失函数平稳收敛的训练方法的进一步优选方案,生成器和对抗器的参数并不是交替更新,而是根据各自的相对变化幅度,调整更新频率,使得收敛较慢的一方优先更新权重,有效避免发生博弈失衡的局面。
作为本发明一种保障GAN模型最大最小损失函数平稳收敛的训练方法的进一步优选方案,在步骤4中,输出一个尺寸和MNIST数据集中的图片一样的矩阵Igen,其尺寸为(28,28,1)。
作为本发明一种保障GAN模型最大最小损失函数平稳收敛的训练方法的进一步优选方案,在步骤10中,所述步骤10中的动量系数γ,取值范围为[0,1),参考经验值为0.9。
有益效果
本发明采用以上技术方案与现有技术相比,具有以下技术效果:
1、本发通过合理设定生成器和对抗器的参数更新条件和频率,解决上述的GAN模型训练过程中生成器与对抗器博弈失衡问题,所谓博弈失衡问题指的是,生成器和对抗器的其中一方在训练过程中快速收敛,使得另一方的优化曲面近乎不可导,从而无法顺利训练的现象
2、本发明根据生成器和对抗器的训练情况动态调整生成器和对抗器的参数更新频率,能够明显提高GAN模型的训练效果。
附图说明
图1是本发明的方法图;
图2是一张MNIST图片例子。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
一种保障GAN模型最大最小损失函数平稳收敛的训练方法,如图1所示,包括如下步骤:
步骤1,准备MNIST数据集;
步骤2,随机生成一个n维向量z;
步骤3,构建生成器G(.),生成器G(.)的具体结构选用反卷积结构,也可选择其它结构,这里生成器G(.)的具体结构选用反卷积结构,过程式具体如下:
G(.)=Tranpose_CNN(.);
步骤4,将步骤2生成的n维向量z传入步骤构建的生成器G(.)中,输出一个尺寸和MNIST数据集中的图片一样的矩阵Igen,其尺寸为(28,28,1);过程式如下:
Igen=G(z)
步骤5,构建对抗器D(.),对抗器D(.)的具体结构选用卷积结构,也可采用其它结构,对抗器D(.)的具体结构选用卷积结构过程式具体如下:
D(.)=CNN(.);
步骤6,计算生成结果的对抗分数Pgen,过程式如下:
Pgen=D(Igen);
步骤7,计算生成器损失函数LG,采用交叉熵形式,过程式如下:
LG=-log(Pgen);
步骤8,从步骤1中准备的MNIST集中随机取出一张图Ireal,传入对抗器D(.)中,计算出真实图片的对抗分数Preal,过程式如下:
Preal=D(Ireal)
步骤9,计算对抗器损失函数LD,采用Wasserstein形式,过程式如下:
LD=log(Pgen)-log(Preal)
步骤10,更新生成器损失值LG动量均值
Figure BDA0002210538130000041
对抗器损失值LD的动量均值
Figure BDA0002210538130000042
若是第一次迭代,则
Figure BDA0002210538130000043
直接取值LG
Figure BDA0002210538130000044
直接取值LD;若不是第一次迭代,则更新过程式如下:
Figure BDA0002210538130000045
Figure BDA0002210538130000046
其中,γ为动量系数;取值范围为[0,1),参考经验值为0.9
步骤11,比较两损失值的相对值LG_r、LD_r,计算过程式如下:
Figure BDA0002210538130000047
Figure BDA0002210538130000048
进而比较LG_r和LD_r的大小;根据比较结果,更新模型权重;若LG_r>LD_r,则更新生成器的参数;反之,则更新对抗器的参数;
步骤12,重复步骤2至步骤11,直到生成器输出满意的结果。
生成器和对抗器的参数并不是交替更新,而是根据各自的相对变化幅度,调整更新频率,使得收敛较慢的一方优先更新权重,有效避免发生博弈失衡的局面。
具体实施例如下:
在GAN模型(生成对抗网络模型)中,存在两个相互博弈的模型,一个是生成器(generator),一个是对抗器(discriminator)。简记生成器为函数G(.),对抗器为函数D(.)。
具体实施的时候,我们选用tensorflow机器学习平台进行算法开发。
为了方法容易复现,我们使用开源的MNIST数据为例。注意,利用DCGAN生成MNIST数据图片非本发明特征,这里只贴出关键步骤主要是为了方便复现。
1.准备数据
用以下命令从互联网上下载MNIST数据集:
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()
其中,变量train_images中存储的是MNIST图片信息,变量train_labels中存储的是MNIST标签信息。图2是随机抽取的一张标签为“1”的MNIST图片。MNIST图片是大小为(28,28)的灰度图片。
2.定义tensorflow计算图
定义输入节点:
Figure BDA0002210538130000051
定义生成器:
Figure BDA0002210538130000052
定义对抗器:
Figure BDA0002210538130000053
超参配置如下:
真实图像的大小:img_size=train_images[0].shape[0]
传入给generator的噪声大小:noise_size=100
生成器隐层参数:g_units=128
判别器隐层参数:d_units=128
leaky ReLU的参数:alpha=0.01
学习率:learning_rate=0.001
均值动量系数:gama=0.9
定义对抗器的损失:
识别真实图片的损失:
Figure BDA0002210538130000061
识别生成的图片的损失:
Figure BDA0002210538130000062
总体对抗器损失:
d_loss=tf.add(d_loss_real,d_loss_fake)
定义生成器的损失:
Figure BDA0002210538130000063
3.定义优化器
对抗器的优化器:
d_train_opt=tf.train.AdamOptimizer(learning_rate).minimize(d_loss,var_list=d_vars)
生成器的优化器:
g_train_opt=tf.train.AdamOptimizer(learning_rate).minimize(g_loss,var_list=g_vars)
4.训练过程,注意,此处含有本发明的特征
以下代码的关键步骤都作了注释。本发明的特征在于,在对生成器和对抗器进行参数更新之前,先计算生成器和对抗器的相对损失值,然后,根据相对损失值的比较结果,决定是更新生成器的参数还是更新对抗器的参数。由此来动态地控制生成器和对抗器的参数更新频率,从而防止两者博弈失衡。
Figure BDA0002210538130000071
Figure BDA0002210538130000081
以上这段代码,具体实现了本发明的特征,代码用的是tensorflow平台进行开发,但开发平台与本发明无关,亦非本发明特征。
由此可见,本发明根据生成器和对抗器的训练情况动态调整生成器和对抗器的参数更新频率,能够明显提高GAN模型的训练效果;在两种实验中都得到验证。一种是MNIST数据生成实验,另一种是家具户型图生成实验。实验采用定性和定量两种指标同时检测。定性指标用的是百张通过率,指生成器生成100张图片,人工审核通过的平均比率。定量指标用的是FID,即Fréchet Inception Distance。在MNIST数据生成实验中,对比测试结果如表1所示:
表1
Figure BDA0002210538130000082
可以看出,本发明所述方法,由于根据生成器和对抗器的训练情况动态调整生成器和对抗器的参数更新频率,能够明显提高GAN模型的训练效果。
在家具户型图生成实验中,对比测试结果如表2所示:
表2
Figure BDA0002210538130000083
基准方法在迭代150K次时,生成器已经无法训练。
最后应说明的几点是:首先,在本申请的描述中,需要说明的是,除非另有规定和限定,术语“安装”、“相连”、“连接”应做广义理解,可以是机械连接或电连接,也可以是两个元件内部的连通,可以是直接相连,“上”、“下”、“左”、“右”等仅用于表示相对位置关系,当被描述对象的绝对位置改变,则相对位置关系可能发生改变;
其次:本发明公开实施例附图中,只涉及到与本公开实施例涉及到的结构,其他结构可参考通常设计,在不冲突情况下,本发明同一实施例及不同实施例可以相互组合;
最后:以上所述仅为本发明的优选实施例而已,并不用于限制本发明,凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (5)

1.一种保障GAN模型最大最小损失函数平稳收敛的训练方法,其特征在于,包括如下步骤:
步骤1,准备MNIST数据集;
步骤2,随机生成一个n维向量z;
步骤3,构建生成器G(.),生成器G(.)的具体结构选用反卷积结构,过程式具体如下:
G(.)=Tranpose_CNN(.);
步骤4,将步骤2生成的n维向量z传入步骤构建的生成器G(.)中,输出一个尺寸和MNIST数据集中的图片一样的矩阵Igen,过程式如下:
Igen=G(z)
步骤5,构建对抗器D(.),对抗器D(.)的具体结构选用卷积结构,过程式具体如下:
D(.)=CNN(.);
步骤6,计算生成结果的对抗分数Pgen,过程式如下:
Pgen=D(Igen);
步骤7,计算生成器损失函数LG,采用交叉熵形式,过程式如下:
LG=-log(Pgen);
步骤8,从步骤1中准备的MNIST集中随机取出一张图Ireal,传入对抗器D(.)中,计算出真实图片的对抗分数Preal,过程式如下:
Preal=D(Ireal)
步骤9,计算对抗器损失函数LD,采用Wasserstein形式,过程式如下:
LD=log(Pgen)-log(Preal)
步骤10,更新生成器损失值LG动量均值
Figure FDA0003581474320000011
对抗器损失值LD的动量均值
Figure FDA0003581474320000012
若是第一次迭代,则
Figure FDA0003581474320000013
直接取值LG
Figure FDA0003581474320000014
直接取值LD;若不是第一次迭代,则更新过程式如下:
Figure FDA0003581474320000015
Figure FDA0003581474320000016
其中,γ为动量系数;
步骤11,比较两损失值的相对值LG_r、LD_r,计算过程式如下:
Figure FDA0003581474320000017
Figure FDA0003581474320000018
进而比较LG_r和LD_r的大小;根据比较结果,更新模型权重;若LG_r>LD_r,则更新生成器的参数;反之,则更新对抗器的参数;
步骤12,重复步骤2至步骤11,直到生成器输出满意的结果。
2.根据权利要求1所述的一种保障GAN模型最大最小损失函数平稳收敛的训练方法,其特征在于:生成器和对抗器的参数并不是交替更新,而是根据各自的相对变化幅度,调整更新频率,使得收敛较慢的一方优先更新权重,有效避免发生博弈失衡的局面。
3.根据权利要求1所述的一种保障GAN模型最大最小损失函数平稳收敛的训练方法,其特征在于:在步骤4中,输出一个尺寸和MNIST数据集中的图片一样的矩阵Igen,其尺寸为(28,28,1)。
4.根据权利要求1所述的一种保障GAN模型最大最小损失函数平稳收敛的训练方法,其特征在于:在步骤10中,所述步骤10中的动量系数γ,取值范围为[0,1)。
5.根据权利要求1所述的一种保障GAN模型最大最小损失函数平稳收敛的训练方法,其特征在于:动量系数γ为0.9。
CN201910896955.2A 2019-09-23 2019-09-23 一种保障gan模型最大最小损失函数平稳收敛的训练方法 Active CN110826688B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910896955.2A CN110826688B (zh) 2019-09-23 2019-09-23 一种保障gan模型最大最小损失函数平稳收敛的训练方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910896955.2A CN110826688B (zh) 2019-09-23 2019-09-23 一种保障gan模型最大最小损失函数平稳收敛的训练方法

Publications (2)

Publication Number Publication Date
CN110826688A CN110826688A (zh) 2020-02-21
CN110826688B true CN110826688B (zh) 2022-07-29

Family

ID=69548166

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910896955.2A Active CN110826688B (zh) 2019-09-23 2019-09-23 一种保障gan模型最大最小损失函数平稳收敛的训练方法

Country Status (1)

Country Link
CN (1) CN110826688B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112884640B (zh) * 2021-03-01 2024-04-09 深圳追一科技有限公司 模型训练方法及相关装置、可读存储介质

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107767384A (zh) * 2017-11-03 2018-03-06 电子科技大学 一种基于对抗训练的图像语义分割方法
CN108665058A (zh) * 2018-04-11 2018-10-16 徐州工程学院 一种基于分段损失的生成对抗网络方法

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN107767384A (zh) * 2017-11-03 2018-03-06 电子科技大学 一种基于对抗训练的图像语义分割方法
CN108665058A (zh) * 2018-04-11 2018-10-16 徐州工程学院 一种基于分段损失的生成对抗网络方法

Also Published As

Publication number Publication date
CN110826688A (zh) 2020-02-21

Similar Documents

Publication Publication Date Title
CN107392255B (zh) 少数类图片样本的生成方法、装置、计算设备及存储介质
CN111563841B (zh) 一种基于生成对抗网络的高分辨率图像生成方法
CN110460600B (zh) 可抵御生成对抗网络攻击的联合深度学习方法
WO2020259502A1 (zh) 神经网络模型的生成方法及装置、计算机可读存储介质
CN110070174A (zh) 一种生成对抗网络的稳定训练方法
CN111353582A (zh) 一种基于粒子群算法的分布式深度学习参数更新方法
CN106897662A (zh) 基于多任务学习的人脸关键特征点的定位方法
CN107229966A (zh) 一种模型数据更新方法、装置及***
CN109146061A (zh) 神经网络模型的处理方法和装置
CN109983480A (zh) 使用聚类损失训练神经网络
CN107609506A (zh) 用于生成图像的方法和装置
WO2020259504A1 (zh) 一种强化学习的高效探索方法
CN110826688B (zh) 一种保障gan模型最大最小损失函数平稳收敛的训练方法
CN112580728B (zh) 一种基于强化学习的动态链路预测模型鲁棒性增强方法
Lin et al. Evolutionary architectural search for generative adversarial networks
CN113033822A (zh) 基于预测校正和随机步长优化的对抗性攻击与防御方法及***
CN107590538B (zh) 一种基于在线序列学习机的危险源识别方法
CN114781654A (zh) 联邦迁移学习方法、装置、计算机设备及介质
KR20210060146A (ko) 딥 뉴럴 네트워크 모델을 이용한 데이터 처리 방법 및 장치, 딥 뉴럴 네트워크 모델을 학습시키는 학습 방법 및 장치
CN114169527A (zh) 一种基于量子计算的生成对抗网络
CN111144574A (zh) 使用指导者模型训练学习者模型的人工智能***和方法
CN107566051B (zh) 一种mimo ota最大三维测试区域大小的确定方法及装置
Hiraoka et al. Generation of stripe-patchwork images by entropy and inverse filter
JP6961527B2 (ja) 情報処理装置、学習方法、及びプログラム
CN116017476A (zh) 无线传感器网络覆盖设计方法、装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
CB02 Change of applicant information

Address after: 211100 floor 5, block a, China Merchants high speed rail Plaza project, No. 9, Jiangnan Road, Jiangning District, Nanjing, Jiangsu (South Station area)

Applicant after: JIANGSU AIJIA HOUSEHOLD PRODUCTS Co.,Ltd.

Address before: 211100 No. 18 Zhilan Road, Science Park, Jiangning District, Nanjing City, Jiangsu Province

Applicant before: JIANGSU AIJIA HOUSEHOLD PRODUCTS Co.,Ltd.

CB02 Change of applicant information
GR01 Patent grant
GR01 Patent grant