CN110598806A - 一种基于参数优化生成对抗网络的手写数字生成方法 - Google Patents

一种基于参数优化生成对抗网络的手写数字生成方法 Download PDF

Info

Publication number
CN110598806A
CN110598806A CN201910692092.7A CN201910692092A CN110598806A CN 110598806 A CN110598806 A CN 110598806A CN 201910692092 A CN201910692092 A CN 201910692092A CN 110598806 A CN110598806 A CN 110598806A
Authority
CN
China
Prior art keywords
network
data
discriminator
generator
weight parameter
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201910692092.7A
Other languages
English (en)
Inventor
朱敏
方超
储昭碧
董学平
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hefei University of Technology
Hefei Polytechnic University
Original Assignee
Hefei Polytechnic University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hefei Polytechnic University filed Critical Hefei Polytechnic University
Priority to CN201910692092.7A priority Critical patent/CN110598806A/zh
Publication of CN110598806A publication Critical patent/CN110598806A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T11/002D [Two Dimensional] image generation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V30/00Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
    • G06V30/10Character recognition
    • G06V30/24Character recognition characterised by the processing or recognition method
    • G06V30/242Division of the character sequences into groups prior to recognition; Selection of dictionaries
    • G06V30/244Division of the character sequences into groups prior to recognition; Selection of dictionaries using graphical properties, e.g. alphabet type or font
    • G06V30/2455Discrimination between machine-print, hand-print and cursive writing
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V30/00Character recognition; Recognising digital ink; Document-oriented image-based pattern recognition
    • G06V30/10Character recognition

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

本发明提供了一种基于参数优化生成对抗网络的手写数字生成方法。包括:准备手写数字数据集作为样本训练数据集,采样得到真实数据并初始化随机噪声数据;建立生成对抗网络,初始化生成器网络权值参数θ和判别器网络权值参数ω;通过搬土距离W建立生成器损失函数和判别器损失函数,并且为判别器损失函数额外添加梯度惩罚损失项;迭代训练生成器网络和判别器网络,优化生成器网络权值参数θ和判别器网络权值参数ω。本发明实例解决了原始生成对抗网络收敛缓慢,训练不稳定,计算开销大等问题。实现了对生成对抗网络的优化,充分提高了生成对抗网络的网络性能,并且生成器能够生成质量更高的手写数字图像。

Description

一种基于参数优化生成对抗网络的手写数字生成方法
技术领域
本发明实施例涉及深度学习神经网络,尤其涉及一种基于参数优化生成对抗网络的手写数字生成方法。
背景技术
生成对抗网络(GAN)是于2014年提出的一种训练模型的方法,该方法通过两个模型(生成器网络G和判别器网络D)之间的对抗训练,参考了博弈论里面的极小极大问题的思路,最终使得两个模型的效果均会有所提高。生成对抗网络的目标,给定一个真实样本分布的集合,根据该集合不断迭代训练生成器网络G和判别器网络D,最终使得生成器网络G可以从噪声信号生成尽可能符合真实样本分布的样本,而判别器网络D可以从样本的分布判别该样本是否符合真实样本的分布。
利普希茨连续性条件,以德国数学家利普希茨命名,是一个比通常连续更强的光滑性条件。利普希茨连续函数限制了函数改变的速度,符合利普希茨条件的函数的斜率,必小于一个称为利普希茨常数的实数(该常数依函数而定)。其定义为:对于函数f(x),在其任意定义域中的x1,x2,都存在K>0,使得|f(x1)-f(x2)|≤K|x1-x2|。
随着生成对抗网络的不断发展,其被广泛应用于人工智能的各个领域,但原始生成对抗网络在训练过程中容易导致梯度消失,训练不稳定,收敛速度慢等问题。由于原始生成对抗网络采用JS散度(Jensen-Shannon divergence)衡量生成样本与真实样本的距离,但在的训练过程中,采用JS散度训练到底,无论生成器的输出质量如何,JS散度始终是一个固定值,判别器的损失就不能用来衡量生成样本的质量,判别器的输出判断意义不大。
Arjovsky Martin,Chintala Soumith,Bottou Leon.Wasserstein generativeadversarial networks[C]//International Conference on Machine Learning.2017:214-223.(马丁·阿约夫斯基,苏米什·金塔拉,里昂·巴顿,沃瑟斯坦生成对抗网络,国际机器学习会议,2017:214-223),提出了一种生成对抗网络的优化方法,该方法有以下不足:采用梯度剪切方法来实现利普希茨连续性条件,即将判别器的所有参数约束在一个阈值范围内,容易导致判别器的参数集中在该阈值范围的最值上,另外该方法对阈值的设置要求较高,阈值设置偏小容易导致梯度消失,阈值设置偏大容易导致梯度***,从而导致优化效果不佳,优化难度偏大。
中国发明专利文献(CN 108470196A)于2018年8月31日公开的《一种基于深度卷积对抗网络模型生成手写数字的方法》采用深度卷积对抗网络模型实现手写数字生成,该方法有以下不足:未对生成对抗网络的生成器损失函数和判别器损失函数进行单独优化设计,不能充分发挥生成对抗网络的性能;未对生成器网络权值参数和判别器网络权值参数进行优化,容易导致判别器网络收敛较快,过早的失去对生成器网络的影响,生成的手写数字图像质量不高。
因此,一种基于参数优化生成对抗网络的手写数字生成方法具有重要的研究意义和应用价值。
发明内容
为了克服上述深度卷积对抗网络模型在手写数字生成上的不足,本发明实施例提供一种基于参数优化生成对抗网络的手写数字生成方法,该方法能大幅提高生成对抗网络的收敛速度,训练稳定性,有利于生成质量更高的手写数字图像。
本发明的目的是这样实现的。本发明提供一种基于参数优化生成对抗网络的手写数字生成方法,建立生成对抗网络,所述生成对抗网络包括生成器网络G和判别器网络D,通过迭代训练优化生成器网络权值参数θ和判别器网络权值参数ω,所述迭代训练分批次进行,每批次训练的样本数据为m个,将m记为批尺寸m,并将批尺寸m中的任意一个样本数据记为当前样本数据i,i=1,2,...,m,具体的,一个批次训练的具体步骤包括:
步骤1,准备手写数字数据集作为样本训练数据集,通过离散二维滤波器对样本训练数据集采样得到真实数据xr,设置生成对抗网络,包括生成器网络G和判别器网络D,初始化随机噪声数据z,设随机噪声数据z满足随机噪声分布Pz;设置生成对抗网络相关训练参数;
所述生成对抗网络相关训练参数包括:神经网络的学习率α、梯度惩罚系数λ、生成器网络G期望误差e、生成器网络G最大迭代训练次数ng、判别器网络D单轮最大迭代训练次数nd,所述判别器网络D单轮最大迭代训练次数nd指的是生成器网络G每迭代训练一次判别器网络需要迭代训练nd次;
步骤2,设置生成对抗网络中生成器网络G权值参数θ和判别器网络D权值参数ω并进行初始化处理;
步骤3,设当前生成器网络G迭代训练次数为t1,t1=1,2,...,ng,将当前生成器网络G迭代训练次数t1置1;
步骤4,设当前判别器网络D迭代训练次数为t2,t2=1,2,...,nd,将当前判别器网络D迭代训练次数t2置1;
步骤5,训练判别器网络D
步骤5.1,当前样本数据i置1;
步骤5.2,获取生成器网络G输入和判别器网络D输入,所述生成器网络G的输入是随机噪声数据z,随机噪声数据z满足随机噪声分布Pz;所述判别器网络D的输入包括真实数据xr和生成数据xg两种形式的输入,真实数据xr来源于所学习的样本训练数据集,生成数据xg是随机噪声数据z经过生成器网络G重构得到的数据;
步骤5.3,获取梯度惩罚损失项中的数据分布,通过在真实数据xr和生成数据xg之间随机插值取样得到梯度惩罚数据梯度惩罚数据满足梯度惩罚数据分布
步骤5.4,通过搬土距离W建立判别器损失函数LD,所述搬土距离W为生成数据xg和真实数据xr之间的距离,表达式为:
其中,Pr表示真实数据分布,Pg表示生成数据分布,sup表示对所有满足||f||Lip≤K的f函数取到的上界,||f||Lip≤K表示f函数需要满足利普希茨连续性条件,其中K为f函数的利普希茨常数,xr~Pr表示真实数据xr满足真实数据分布Pr,xg~Pg表示生成数据xg满足生成数据分布Pg,E表示期望,f(x)表示神经网络,表示在真实数据分布Pr下真实数据xr对f(x)的期望,表示在生成数据分布Pg下生成数据xg对f(x)的期望;
判别器损失函数LD表达式为:
其中,D(xg)表示采用判别器网络D对生成数据xg进行评价打分,D(xr)表示采用判别器网络D对真实数据xr进行评价打分,得分为0.1表示判别器网路D对生成数据xg的评价,得分为0.9表示判别器网路D对真实数据xr的评价;表示在生成数据分布Pg下生成数据xg对D(xg)的期望,表示在真实数据分布Pr下真实数据xr对D(xr)的期望;
步骤5.5,为通过搬土距离W建立的判别器损失函数LD额外添加梯度惩罚损失项LP,并得到更新后的判别器损失函数L'D
梯度惩罚损失项LP表达式为:
其中,表示梯度惩罚数据满足梯度惩罚数据分布 表示判别器网络D梯度的p范数,H表示判别器网络D梯度的利普希茨常数;
更新后的判别器损失函数L'D表达式为:
步骤5.6,采用更新后的判别器损失函数L'D对判别器网络D权值参数ω进行反向传播,得到更新后的判别器网络D权值参数ω';
步骤5.7,通过谱归一化对更新后的判别器网络D权值参数ω'进行归一化,得到谱归一化后的判别器网络D权值参数ω”,其中σ(ω')表示更新后的判别器网络D权值参数ω'的谱范数;
步骤5.8,判断当前样本数据i是否大于批尺寸m,若i大于m,转到步骤5.10;若i不大于m,转到步骤5.9;
步骤5.9,将当前样本数据i加1,并用谱归一化后的判别器网络D权值参数ω”更新判别器网络D权值参数ω,返回步骤5.2;
步骤5.10,通过自适应矩估计优化器优化判别器网络D权值参数ω;
步骤6,将当前判别器网络D迭代训练次数t2加1,得到更新后的判别器网络D迭代训练次数t2',用更新后的判别器网络D迭代训练次数t2'更新当前判别器网络D迭代训练次数t2,判断当前判别器网络D迭代训练次数t2是否大于判别器网络D单轮最大迭代训练次数nd,若t2大于nd,转到步骤7;若t2不大于nd,返回步骤5;
步骤7,训练生成器网络G
步骤7.1,当前样本数据i置1;
步骤7.2,通过搬土距离W建立生成器损失函数LG,生成器损失函数LG表达式为:
步骤7.3,采用生成器损失函数LG进行反向传播,得到更新后的生成器网络G权值参数θ';
步骤7.4,判断当前样本数据i是否大于批尺寸m,若i大于m,转到步骤7.6;若i不大于m,转到步骤7.5;
步骤7.5,将当前样本数据i加1,并用更新后的生成器网络G权值参数θ'更新生成器网络G权值参数θ,返回步骤7.2;
步骤7.6,通过自适应矩估计优化器优化生成器网络G权值参数θ;
步骤8,将当前生成器网络G迭代训练次数t1加1,得到更新后的生成器网络G迭代训练次数t1',用更新后的生成器网络G迭代训练次数t1'更新当前生成器网络G迭代训练次数t1,判断是否满足条件:当前生成器网络G迭代训练次数t1大于生成器网络G最大迭代训练次数ng或生成器损失函数LG小于等于生成器网络G期望误差e,如果满足以上任意条件,转到步骤9;若不满足以上任意条件,返回步骤4;
步骤9,本批次训练结束。
相对于现有技术,本发明的有益效果为:
1、由于本方法采用搬土距离W建立生成器损失函数LG和判别器损失函数LD,相较于JS散度更有利于生成高质量的生成样本。
2、通过对判别器损失函数LD额外添加梯度惩罚损失项LP得到更新后的判别器损失函数L'D,使判别器网络梯度满足利普希茨连续性,因此判别器损失函数不会由于变化过于剧烈导致失去对生成器的影响作用,并且可以大幅提升生成对抗网络的收敛速度,使训练过程更加稳定,生成样本质量更高,因此生成器最终能够生成高质量的手写数字图像。
附图说明
图1为本发明一种基于参数优化生成对抗网络的手写数字生成方法示意图。
图2为本发明一种基于参数优化生成对抗网络的手写数字生成方法训练流程图。
图3为本发明实施例提供的生成对抗网络示意图。
图4为本发明实施例提供的判别器网络D训练流程图。
图5为本发明实施例提供的生成器网络G训练流程图。
图6为本发明实施例提供的原始生成对抗网络生成图像效果图。
图7为本发明实施例提供的参数优化生成对抗网络生成图像效果图。
具体实施方式
下面结合附图对本实施例进行具体的描述。
图1是本发明一种基于参数优化生成对抗网络的手写数字生成方法示意图,图2为本发明一种基于参数优化生成对抗网络的手写数字生成方法训练流程图,由图1、图2可见,在本发明中,建立生成对抗网络,所述生成对抗网络包括生成器网络G和判别器网络D,通过迭代训练优化生成器网络权值参数θ和判别器网络权值参数ω。在训练过程中,需要通过生成器网络G和判别器网络D对每一个样本数据进行训练。所述迭代训练分批次进行,每批次训练的样本数据为m个,将m记为批尺寸m,并将批尺寸m中的任意一个样本数据记为当前样本数据i,i=1,2,...,m。本批次训练结束后,下一轮训练新批次的样本数据。
具体的,一个批次训练的具体步骤包括:
步骤1,准备手写数字数据集作为样本训练数据集,通过离散二维滤波器对样本训练数据集采样得到真实数据xr,设置生成对抗网络,包括生成器网络G和判别器网络D,初始化随机噪声数据z,设随机噪声数据z满足随机噪声分布Pz;设置生成对抗网络相关训练参数。
所述生成对抗网络相关训练参数包括:神经网络的学习率α、梯度惩罚系数λ、生成器网络G期望误差e、生成器网络G最大迭代训练次数ng、判别器网络D单轮最大迭代训练次数nd,所述判别器网络D单轮最大迭代训练次数nd指的是生成器网络G每迭代训练一次判别器网络需要迭代训练nd次。
步骤2,设置生成对抗网络中生成器网络G权值参数θ和判别器网络D权值参数ω并进行初始化处理。
图3为本发明实施例提供的生成对抗网络示意图。由该图可见:
设生成器网络G和判别器网络D均为多层卷积神经网络,总共N层网络,n表示当前层网络,n-1表示上一层网络。
生成器网络G的输入输出关系表示为:
Gθ(z)=aNN(aN-1N-1(…a11z)…))))
其中,Gθ(z)表示生成器网络G的输出,a1,...,aN-1,aN表示生成器网络G第1,...,N-1,N层的激活函数,θ1,...,θN-1N表示生成器网络G第1,...,N-1,N层的权值参数。
判别器网络D的输入输出关系表示为:
Dω(x)=bNN(bN-1N-1(…b11x)…))))
其中,Dω(x)表示判别器网络D的输出,b1,...,bN-1,bN表示判别器网络D第1,...,N-1,N层的激活函数ω1,...,ωN-1N表示判别器网络D第1,...,N-1,N层的权值参数,x表示判别器网络D的输入数据,包括真实数据xr和生成数据xg两种输入。
生成器网络G和判别器网络D都采用Leaky RELU激活函数,采用何氏初始化对生成对抗网络中生成器网络G权值参数θ和判别器网络D权值络参数ω进行初始化处理,根据当前层网络维度dims[n]和上一层网络维度dims[n-1]采用随机数组生成函数生成n×(n-1)维的数组array(n,n-1),将其乘上实现初始化,因此生成器网络G权值参数θ初始化为判别器网络D权值参数ω初始化为
步骤3,设当前生成器网络G迭代训练次数为t1,t1=1,2,...,ng,将当前生成器网络G迭代训练次数t1置1。
步骤4,设当前判别器网络D迭代训练次数为t2,t2=1,2,...,nd,将当前判别器网络D迭代训练次数t2置1。
步骤5,训练判别器网络D
图4为本发明实施例提供的判别器网络D训练流程图。由该图可见判别器网络D具体训练步骤如下:
步骤5.1,当前样本数据i置1。
步骤5.2,获取生成器网络G输入和判别器网络D输入,所述生成器网络G的输入是随机噪声数据z,随机噪声数据z满足随机噪声分布Pz;所述判别器网络D的输入包括真实数据xr和生成数据xg两种形式的输入,真实数据xr来源于所学习的样本训练数据集,生成数据xg是随机噪声数据z经过生成器网络G重构得到的数据。
步骤5.3,获取梯度惩罚损失项中的数据分布,通过在真实数据xr和生成数据xg之间随机插值取样得到梯度惩罚数据x,梯度惩罚数据x满足梯度惩罚数据分布
步骤5.4,通过搬土距离W建立判别器损失函数LD,所述搬土距离W为生成数据xg和真实数据xr之间的距离,表达式为:
其中,Pr表示真实数据分布,Pg表示生成数据分布,sup表示对所有满足||f||Lip≤K的f函数取到的上界,||f||Lip≤K表示f函数需要满足利普希茨连续性条件,其中K为f函数的利普希茨常数,xr~Pr表示真实数据xr满足真实数据分布Pr,xg~Pg表示生成数据xg满足生成数据分布Pg,E表示期望,f(x)表示神经网络,表示在真实数据分布Pr下真实数据xr对f(x)的期望,表示在生成数据分布Pg下生成数据xg对f(x)的期望。
判别器损失函数LD表达式为:
其中,D(xg)表示采用判别器网络D对生成数据xg进行评价打分,D(xr)表示采用判别器网络D对真实数据xr进行评价打分,得分为0.1表示判别器网路D对生成数据xg的评价,得分为0.9表示判别器网路D对真实数据xr的评价;表示在生成数据分布Pg下生成数据xg对D(xg)的期望,表示在真实数据分布Pr下真实数据xr对D(xr)的期望。
步骤5.5,为通过搬土距离W建立的判别器损失函数LD额外添加梯度惩罚损失项LP,并得到更新后的判别器损失函数L'D
梯度惩罚损失项LP表达式为:
其中,表示梯度惩罚数据满足梯度惩罚数据分布 表示判别器网络D梯度的p范数,H表示判别器网络D梯度的利普希茨常数。
更新后的判别器损失函数L'D表达式为:
步骤5.6,采用更新后的判别器损失函数L'D对判别器网络D权值参数ω进行反向传播,得到更新后的判别器网络D权值参数ω'。
步骤5.7,通过谱归一化对更新后的判别器网络D权值参数ω'进行归一化,得到谱归一化后的判别器网络D权值参数ω”,其中σ(ω')表示更新后的判别器网络D权值参数ω'的谱范数。
步骤5.8,判断当前样本数据i是否大于批尺寸m,若i大于m,转到步骤5.10;若i不大于m,转到步骤5.9。
步骤5.9,将当前样本数据i加1,并用谱归一化后的判别器网络D权值参数ω”更新判别器网络D权值参数ω,返回步骤5.2。
步骤5.10,通过自适应矩估计优化器优化判别器网络D权值参数ω。
步骤6,将当前判别器网络D迭代训练次数t2加1,得到更新后的判别器网络D迭代训练次数t2',用更新后的判别器网络D迭代训练次数t2'更新当前判别器网络D迭代训练次数t2,判断当前判别器网络D迭代训练次数t2是否大于判别器网络D单轮最大迭代训练次数nd,若t2大于nd,转到步骤7;若t2不大于nd,返回步骤5。
步骤7,训练生成器网络G
图5为本发明实施例提供的生成器网络G训练流程图。由该图可见生成器网络G具体训练步骤如下:
步骤7.1,当前样本数据i置1。
步骤7.2,通过搬土距离W建立生成器损失函数LG,生成器损失函数LG表达式为:
步骤7.3,采用生成器损失函数LG进行反向传播,得到更新后的生成器网络G权值参数θ'。
步骤7.4,判断当前样本数据i是否大于批尺寸m,若i大于m,转到步骤7.6;若i不大于m,转到步骤7.5。
步骤7.5,将当前样本数据i加1,并用更新后的生成器网络G权值参数θ'更新生成器网络G权值参数θ,返回步骤7.2。
步骤7.6,通过自适应矩估计优化器优化生成器网络G权值参数θ。
步骤8,将当前生成器网络G迭代训练次数t1加1,得到更新后的生成器网络G迭代训练次数t1',用更新后的生成器网络G迭代训练次数t1'更新当前生成器网络G迭代训练次数t1,判断是否满足条件:当前生成器网络G迭代训练次数t1大于生成器网络G最大迭代训练次数ng或生成器损失函数LG小于等于生成器网络G期望误差e,如果满足以上任意条件,转到步骤9;若不满足以上任意条件,返回步骤4。
步骤9,本批次训练结束。
下面结合仿真实验对发明的效果做进一步的描述。
1.仿真实验条件。
本发明的仿真实验平台为:处理器为Inter core i7-6700HQ,操作***为64位Windows 10,显卡为NVIDIA GTX 1080Ti,使用pycharm编辑器,使用python3.5版本,使用pytorch深度学习框架。手写数字数据集为:手写体数字图像(MNIST),该数据集包含70000张28×28手写数字的灰度图像,其中包括60000张训练数据集和10000张测试数据集两部分。本实验仅用到60000张训练数据集。
2.仿真实验内容。
分别在原始生成对抗网络和基于参数优化生成对抗网络上进行手写数字图像生成,除自身算法外,网络结构基本保持一致,迭代训练相同次数,对图像生成结果进行对比,结果如图6,图7所示,其中:
图6是采用原始生成对抗网络中的生成器网络G在输入随机噪声数据z后得到的生成图像效果图。
图7是采用参数优化生成对抗网络中的生成器网络G在输入随机噪声数据z后得到的生成图像效果图。
3.仿真结果分析。
通过比较图6和图7,能够明显看到,经过相同的迭代训练次数,原始生成对抗网络生成的图像细节明显,但是存在噪点和乱点;参数优化生成对抗网络生成的图像则更加清晰,且噪点和乱点相比原始生成对抗网络更少,图像质量更高。

Claims (1)

1.一种基于参数优化生成对抗网络的手写数字生成方法,其特征在于,建立生成对抗网络,所述生成对抗网络包括生成器网络G和判别器网络D,通过迭代训练优化生成器网络权值参数θ和判别器网络权值参数ω,所述迭代训练分批次进行,每批次训练的样本数据为m个,将m记为批尺寸m,并将批尺寸m中的任意一个样本数据记为当前样本数据i,i=1,2,...,m,具体的,一个批次训练的具体步骤包括:
步骤1,准备手写数字数据集作为样本训练数据集,通过离散二维滤波器对样本训练数据集采样得到真实数据xr,设置生成对抗网络,包括生成器网络G和判别器网络D,初始化随机噪声数据z,设随机噪声数据z满足随机噪声分布Pz;设置生成对抗网络相关训练参数;
所述生成对抗网络相关训练参数包括:神经网络的学习率α、梯度惩罚系数λ、生成器网络G期望误差e、生成器网络G最大迭代训练次数ng、判别器网络D单轮最大迭代训练次数nd,所述判别器网络D单轮最大迭代训练次数nd指的是生成器网络G每迭代训练一次判别器网络需要迭代训练nd次;
步骤2,设置生成对抗网络中生成器网络G权值参数θ和判别器网络D权值参数ω并进行初始化处理;
步骤3,设当前生成器网络G迭代训练次数为t1,t1=1,2,...,ng,将当前生成器网络G迭代训练次数t1置1;
步骤4,设当前判别器网络D迭代训练次数为t2,t2=1,2,...,nd,将当前判别器网络D迭代训练次数t2置1;
步骤5,训练判别器网络D
步骤5.1,当前样本数据i置1;
步骤5.2,获取生成器网络G输入和判别器网络D输入,所述生成器网络G的输入是随机噪声数据z,随机噪声数据z满足随机噪声分布Pz;所述判别器网络D的输入包括真实数据xr和生成数据xg两种形式的输入,真实数据xr来源于所学习的样本训练数据集,生成数据xg是随机噪声数据z经过生成器网络G重构得到的数据;
步骤5.3,获取梯度惩罚损失项中的数据分布,通过在真实数据xr和生成数据xg之间随机插值取样得到梯度惩罚数据梯度惩罚数据满足梯度惩罚数据分布
步骤5.4,通过搬土距离W建立判别器损失函数LD,所述搬土距离W为生成数据xg和真实数据xr之间的距离,表达式为:
其中,Pr表示真实数据分布,Pg表示生成数据分布,sup表示对所有满足||f||Lip≤K的f函数取到的上界,||f||Lip≤K表示f函数需要满足利普希茨连续性条件,其中K为f函数的利普希茨常数,xr~Pr表示真实数据xr满足真实数据分布Pr,xg~Pg表示生成数据xg满足生成数据分布Pg,E表示期望,f(x)表示神经网络,表示在真实数据分布Pr下真实数据xr对f(x)的期望,表示在生成数据分布Pg下生成数据xg对f(x)的期望;
判别器损失函数LD表达式为:
其中,D(xg)表示采用判别器网络D对生成数据xg进行评价打分,D(xr)表示采用判别器网络D对真实数据xr进行评价打分,得分为0.1表示判别器网路D对生成数据xg的评价,得分为0.9表示判别器网路D对真实数据xr的评价;表示在生成数据分布Pg下生成数据xg对D(xg)的期望,表示在真实数据分布Pr下真实数据xr对D(xr)的期望;
步骤5.5,为通过搬土距离W建立的判别器损失函数LD额外添加梯度惩罚损失项LP,并得到更新后的判别器损失函数L'D
梯度惩罚损失项LP表达式为:
其中,表示梯度惩罚数据满足梯度惩罚数据分布表示判别器网络D梯度的p范数,H表示判别器网络D梯度的利普希茨常数;
更新后的判别器损失函数L'D表达式为:
步骤5.6,采用更新后的判别器损失函数L'D对判别器网络D权值参数ω进行反向传播,得到更新后的判别器网络D权值参数ω';
步骤5.7,通过谱归一化对更新后的判别器网络D权值参数ω'进行归一化,得到谱归一化后的判别器网络D权值参数ω”,其中σ(ω')表示更新后的判别器网络D权值参数ω'的谱范数;
步骤5.8,判断当前样本数据i是否大于批尺寸m,若i大于m,转到步骤5.10;若i不大于m,转到步骤5.9;
步骤5.9,将当前样本数据i加1,并用谱归一化后的判别器网络D权值参数ω”更新判别器网络D权值参数ω,返回步骤5.2;
步骤5.10,通过自适应矩估计优化器优化判别器网络D权值参数ω;
步骤6,将当前判别器网络D迭代训练次数t2加1,得到更新后的判别器网络D迭代训练次数t2',用更新后的判别器网络D迭代训练次数t2'更新当前判别器网络D迭代训练次数t2,判断当前判别器网络D迭代训练次数t2是否大于判别器网络D单轮最大迭代训练次数nd,若t2大于nd,转到步骤7;若t2不大于nd,返回步骤5;
步骤7,训练生成器网络G
步骤7.1,当前样本数据i置1;
步骤7.2,通过搬土距离W建立生成器损失函数LG,生成器损失函数LG表达式为:
步骤7.3,采用生成器损失函数LG进行反向传播,得到更新后的生成器网络G权值参数θ';
步骤7.4,判断当前样本数据i是否大于批尺寸m,若i大于m,转到步骤7.6;若i不大于m,转到步骤7.5;
步骤7.5,将当前样本数据i加1,并用更新后的生成器网络G权值参数θ'更新生成器网络G权值参数θ,返回步骤7.2;
步骤7.6,通过自适应矩估计优化器优化生成器网络G权值参数θ;
步骤8,将当前生成器网络G迭代训练次数t1加1,得到更新后的生成器网络G迭代训练次数t1',用更新后的生成器网络G迭代训练次数t1'更新当前生成器网络G迭代训练次数t1,判断是否满足条件:当前生成器网络G迭代训练次数t1大于生成器网络G最大迭代训练次数ng或生成器损失函数LG小于等于生成器网络G期望误差e,如果满足以上任意条件,转到步骤9;若不满足以上任意条件,返回步骤4;
步骤9,本批次训练结束。
CN201910692092.7A 2019-07-29 2019-07-29 一种基于参数优化生成对抗网络的手写数字生成方法 Pending CN110598806A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910692092.7A CN110598806A (zh) 2019-07-29 2019-07-29 一种基于参数优化生成对抗网络的手写数字生成方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910692092.7A CN110598806A (zh) 2019-07-29 2019-07-29 一种基于参数优化生成对抗网络的手写数字生成方法

Publications (1)

Publication Number Publication Date
CN110598806A true CN110598806A (zh) 2019-12-20

Family

ID=68853081

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910692092.7A Pending CN110598806A (zh) 2019-07-29 2019-07-29 一种基于参数优化生成对抗网络的手写数字生成方法

Country Status (1)

Country Link
CN (1) CN110598806A (zh)

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111419213A (zh) * 2020-03-11 2020-07-17 哈尔滨工业大学 一种基于深度学习的ecg心电信号生成方法
CN111476200A (zh) * 2020-04-27 2020-07-31 华东师范大学 基于生成对抗网络的人脸去识别化生成方法
CN111582348A (zh) * 2020-04-29 2020-08-25 武汉轻工大学 条件生成式对抗网络的训练方法、装置、设备及存储介质
CN111898373A (zh) * 2020-08-21 2020-11-06 中国工商银行股份有限公司 手写日期样本生成方法及装置
CN111966997A (zh) * 2020-07-20 2020-11-20 华南理工大学 基于梯度惩罚的生成式对抗网络的密码破解方法及***
CN112488294A (zh) * 2020-11-20 2021-03-12 北京邮电大学 一种基于生成对抗网络的数据增强***、方法和介质
CN112598125A (zh) * 2020-11-25 2021-04-02 西安科技大学 一种基于双判别器加权生成对抗网络的手写数字生成方法
CN112766489A (zh) * 2021-01-12 2021-05-07 合肥黎曼信息科技有限公司 一种基于对偶距离损失的生成对抗网络训练方法
CN113269356A (zh) * 2021-05-18 2021-08-17 中国人民解放***箭军工程大学 一种面向缺失数据的设备剩余寿命预测方法及***
CN114301667A (zh) * 2021-12-27 2022-04-08 杭州电子科技大学 基于wgan动态惩罚的网络安全不平衡数据集分析方法
CN116031894A (zh) * 2023-03-29 2023-04-28 武汉新能源接入装备与技术研究院有限公司 一种有源滤波器的控制方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108711138A (zh) * 2018-06-06 2018-10-26 北京印刷学院 一种基于生成对抗网络的灰度图片彩色化方法
CN109002686A (zh) * 2018-04-26 2018-12-14 浙江工业大学 一种自动生成样本的多牌号化工过程软测量建模方法
CN109191402A (zh) * 2018-09-03 2019-01-11 武汉大学 基于对抗生成神经网络的图像修复方法和***
CN109598279A (zh) * 2018-09-27 2019-04-09 天津大学 基于自编码对抗生成网络的零样本学习方法
US20190147582A1 (en) * 2017-11-15 2019-05-16 Toyota Research Institute, Inc. Adversarial learning of photorealistic post-processing of simulation with privileged information
US20190147321A1 (en) * 2017-10-26 2019-05-16 Preferred Networks, Inc. Image generation method, image generation apparatus, and image generation program

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190147321A1 (en) * 2017-10-26 2019-05-16 Preferred Networks, Inc. Image generation method, image generation apparatus, and image generation program
US20190147582A1 (en) * 2017-11-15 2019-05-16 Toyota Research Institute, Inc. Adversarial learning of photorealistic post-processing of simulation with privileged information
CN109002686A (zh) * 2018-04-26 2018-12-14 浙江工业大学 一种自动生成样本的多牌号化工过程软测量建模方法
CN108711138A (zh) * 2018-06-06 2018-10-26 北京印刷学院 一种基于生成对抗网络的灰度图片彩色化方法
CN109191402A (zh) * 2018-09-03 2019-01-11 武汉大学 基于对抗生成神经网络的图像修复方法和***
CN109598279A (zh) * 2018-09-27 2019-04-09 天津大学 基于自编码对抗生成网络的零样本学习方法

Non-Patent Citations (4)

* Cited by examiner, † Cited by third party
Title
MARTIN ARJOVSKY等: "Wasserstein GAN", 《ARXIV:1701.07875V3》 *
TAKERU MIYATO等: "Spectral normalization for generative adversarial networks", 《ARXIV PREPRINT ARXIV: 1802.05957》 *
冯永等: "GP-WIRGAN:梯度惩罚优化的Wasserstein图像循环生成对抗网络模型", 《计算机学报》 *
林懿伦等: "人工智能研究的新前线:生成式对抗网络", 《自动化学报》 *

Cited By (18)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111419213A (zh) * 2020-03-11 2020-07-17 哈尔滨工业大学 一种基于深度学习的ecg心电信号生成方法
CN111476200A (zh) * 2020-04-27 2020-07-31 华东师范大学 基于生成对抗网络的人脸去识别化生成方法
CN111476200B (zh) * 2020-04-27 2022-04-19 华东师范大学 基于生成对抗网络的人脸去识别化生成方法
CN111582348A (zh) * 2020-04-29 2020-08-25 武汉轻工大学 条件生成式对抗网络的训练方法、装置、设备及存储介质
CN111582348B (zh) * 2020-04-29 2024-02-27 武汉轻工大学 条件生成式对抗网络的训练方法、装置、设备及存储介质
CN111966997A (zh) * 2020-07-20 2020-11-20 华南理工大学 基于梯度惩罚的生成式对抗网络的密码破解方法及***
CN111898373B (zh) * 2020-08-21 2023-09-26 中国工商银行股份有限公司 手写日期样本生成方法及装置
CN111898373A (zh) * 2020-08-21 2020-11-06 中国工商银行股份有限公司 手写日期样本生成方法及装置
CN112488294A (zh) * 2020-11-20 2021-03-12 北京邮电大学 一种基于生成对抗网络的数据增强***、方法和介质
CN112598125A (zh) * 2020-11-25 2021-04-02 西安科技大学 一种基于双判别器加权生成对抗网络的手写数字生成方法
CN112598125B (zh) * 2020-11-25 2024-04-30 西安科技大学 一种基于双判别器加权生成对抗网络的手写数字生成方法
CN112766489A (zh) * 2021-01-12 2021-05-07 合肥黎曼信息科技有限公司 一种基于对偶距离损失的生成对抗网络训练方法
CN113269356A (zh) * 2021-05-18 2021-08-17 中国人民解放***箭军工程大学 一种面向缺失数据的设备剩余寿命预测方法及***
CN113269356B (zh) * 2021-05-18 2024-03-15 中国人民解放***箭军工程大学 一种面向缺失数据的设备剩余寿命预测方法及***
CN114301667B (zh) * 2021-12-27 2024-01-30 杭州电子科技大学 基于wgan动态惩罚的网络安全不平衡数据集分析方法
CN114301667A (zh) * 2021-12-27 2022-04-08 杭州电子科技大学 基于wgan动态惩罚的网络安全不平衡数据集分析方法
CN116031894B (zh) * 2023-03-29 2023-06-02 武汉新能源接入装备与技术研究院有限公司 一种有源滤波器的控制方法
CN116031894A (zh) * 2023-03-29 2023-04-28 武汉新能源接入装备与技术研究院有限公司 一种有源滤波器的控制方法

Similar Documents

Publication Publication Date Title
CN110598806A (zh) 一种基于参数优化生成对抗网络的手写数字生成方法
CN111724478B (zh) 一种基于深度学习的点云上采样方法
CN109035142B (zh) 一种对抗网络结合航拍图像先验的卫星图像超分辨方法
CN109190684B (zh) 基于素描及结构生成对抗网络的sar图像样本生成方法
CN109949255B (zh) 图像重建方法及设备
CN108711141B (zh) 利用改进的生成式对抗网络的运动模糊图像盲复原方法
CN111861906B (zh) 一种路面裂缝图像虚拟增广模型建立及图像虚拟增广方法
CN110675321A (zh) 一种基于渐进式的深度残差网络的超分辨率图像重建方法
CN109003234B (zh) 针对运动模糊图像复原的模糊核计算方法
CN115131347B (zh) 一种用于锌合金零件加工的智能控制方法
CN114881861B (zh) 基于双采样纹理感知蒸馏学习的不均衡图像超分方法
CN114493995A (zh) 图像渲染模型训练、图像渲染方法及装置
CN113256508A (zh) 一种改进的小波变换与卷积神经网络图像去噪声的方法
CN113743474A (zh) 基于协同半监督卷积神经网络的数字图片分类方法与***
CN111861924A (zh) 一种基于进化gan的心脏磁共振图像数据增强方法
CN116563682A (zh) 一种基于深度霍夫网络的注意力方案和条带卷积语义线检测的方法
Wang et al. An adaptive learning image denoising algorithm based on eigenvalue extraction and the GAN model
Li et al. A novelty harmony search algorithm of image segmentation for multilevel thresholding using learning experience and search space constraints
CN116843544A (zh) 一种高超声速流场导入卷积神经网络进行超分辨率重建的方法、***及设备
CN108428226B (zh) 一种基于ica稀疏表示与som的失真图像质量评价方法
CN114897884A (zh) 基于多尺度边缘特征融合的无参考屏幕内容图像质量评估方法
CN111598839A (zh) 一种基于孪生网络的手腕骨等级分类方法
CN112686807A (zh) 一种图像超分辨率重构方法及***
CN113205159B (zh) 一种知识迁移方法、无线网络设备个体识别方法及***
CN111489381A (zh) 一种婴幼儿脑部核磁共振图像群配准方法及其装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WD01 Invention patent application deemed withdrawn after publication
WD01 Invention patent application deemed withdrawn after publication

Application publication date: 20191220