CN113392967A - 领域对抗神经网络的训练方法 - Google Patents

领域对抗神经网络的训练方法 Download PDF

Info

Publication number
CN113392967A
CN113392967A CN202010165937.XA CN202010165937A CN113392967A CN 113392967 A CN113392967 A CN 113392967A CN 202010165937 A CN202010165937 A CN 202010165937A CN 113392967 A CN113392967 A CN 113392967A
Authority
CN
China
Prior art keywords
loss function
feature
data
unit
label
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010165937.XA
Other languages
English (en)
Inventor
钟朝亮
夏文升
石自强
孙俊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Fujitsu Ltd
Original Assignee
Fujitsu Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Fujitsu Ltd filed Critical Fujitsu Ltd
Priority to CN202010165937.XA priority Critical patent/CN113392967A/zh
Priority to JP2021020084A priority patent/JP2021144703A/ja
Publication of CN113392967A publication Critical patent/CN113392967A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

公开了领域对抗神经网络的训练方法。该领域对抗神经网络包括:特征提取单元,其针对已标注的源数据提取第一特征,并且针对未标注的目标数据提取第二特征;标签预测单元,其基于第一特征来预测源数据的标签,并且基于第二特征来预测目标数据的标签;判别单元,其基于第一特征和第二特征来判别输入的数据是源数据还是目标数据。该训练方法包括:基于标签预测单元的输出来构建第一损失函数,其中,该第一损失函数是与源数据有关的预测损失;通过利用源数据和目标数据之间的联合概率密度比对第一损失函数加权而获得第二损失函数;利用第一损失函数和第二损失函数来训练标签预测单元和特征提取单元。

Description

领域对抗神经网络的训练方法
技术领域
本发明总体上涉及领域自适应(domain adaptation),更具体地,涉及领域对抗神经网络的训练方法。
背景技术
在神经网络的训练和应用中通常涉及已标注的源数据集和未标注的目标数据集。由于源数据集和目标数据集之间的差异,诸如分布的差异,将利用源数据集训练得到的神经网络模型直接应用在目标数据集上往往性能不好。领域自适应的任务是通过训练来得到在目标数据集上有良好性能表现的模型。在此方面,目前已经提出了一些技术,例如领域对抗神经网络(DANN),其通过学习领域不变的特征(domain-invariant features)来解决问题。
图1示意性地示出了DANN的架构以及训练过程。如图1所示,DANN包括特征提取器110、标签预测器120以及领域判别器130。特征提取器110和标签预测器120共同组成标准的前向网络结构。通过添加领域判别器130来实现无监督的领域自适应。领域判别器130通过梯度反转层140与特征提取器110相连。在反向传播过程中,梯度反转层140通过将梯度乘以一个负的常数(例如,在图1中是“-1”)来实现梯度反转,从而使得领域判别器130与特征提取器110以相互对抗的方式操作。如果不存在梯度反转层140,则训练按照传统过程进行,即,以使得标签预测器120的损失(只针对源数据集样本)和领域判别器130的损失(针对源数据集样本和目标数据集样本)最小化的方式进行训练。然而,在存在梯度反转层140的情况下,特征提取器110和领域判别器130被以相互对抗的形式进行训练。具体来说,特征提取器110尽量提取一些特征,使得领域判别器130无法基于这些特征而识别样本是来自于源域还是目标域;另一方面,领域判别器130尽最大努力根据输入的特征来识别样本是来自于源域还是目标域。通过这样的对抗训练,特征提取器110最终能够学会提取领域不变的特征,使得领域判别器130无法判别样本来自于哪个域。在此过程中,梯度反转层140能够确保在源域和目标域所学习到的特征的分布的相似性,从而能够学习到领域不变的特征。
虽然DANN能够在许多领域自适应的任务中取得良好效果,但是其仍然存在一些问题。首先,在标签预测器120的优化过程中,仅使得针对源数据集样本的预测损失最小化,而没有考虑针对目标数据集样本的预测损失。其次,DANN通过特征空间的对齐来实现领域自适应,而不能实现类别级别的对齐。
发明内容
为了解决这些问题,本发明提出了对DANN的改进的训练方法。概括来说,根据本发明的方法利用目标数据集和源数据集之间的联合概率密度比来对源数据集上的损失函数进行加权,加权的损失函数可以近似目标数据集上的损失函数。因此,在训练标签预测器的过程中使用该加权的损失函数,可以使得训练好的预测器在被应用于目标数据集时具有更好的预测性能。
根据本发明的一个方面,提供了一种用于训练领域对抗神经网络模型的方法。所述领域对抗神经网络模型包括:特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据。所述方法包括:基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而获得第二损失函数;利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。
根据本发明的另一个方面,提供了一种用于训练领域对抗神经网络模型的装置。所述领域对抗神经网络模型包括:特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据。所述装置包括:存储有程序的存储器;以及一个或多个处理器。所述处理器通过执行所述程序而执行以下操作:基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而获得第二损失函数;利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。
根据本发明的另一个方面,提供了一种存储有用于训练领域对抗神经网络模型的程序的存储介质。所述领域对抗神经网络模型包括:特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据。所述程序在被计算机执行时使得所述计算机执行包括以下步骤的方法:基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而获得第二损失函数;利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。
附图说明
图1示意性地示出了DANN的架构以及训练过程。
图2示意性地示出了根据本发明的一个实施例的领域自适应框架。
图3示出了参数λC的曲线。
图4示出了参数λp的曲线。
图5示出了根据该实施例的领域对抗神经网络的训练方法的流程图。
图6示出了根据该实施例的领域对抗神经网络的训练装置的模块化框架。
图7示意性地示出了根据本发明的另一个实施例的领域自适应框架。
图8示出了参数λTC的曲线。
图9示出了语义分割的一个示例。
图10示出了实现本发明的计算机硬件的示例性配置框图。
具体实施方式
本发明在DANN的基础上进行改进,提出了基于联合概率密度比估计的领域自适应方法(joint density ratio estimation-based domain adaptation,JDA)。图2示意性地示出了根据本发明的一个实施例的领域自适应框架。
如图2所示,根据本发明的领域对抗神经网络包括特征提取器G、标签预测器C和领域判别器D。此外,XP表示已标注的源样本,XQ表示未标注的目标样本,G(XP)表示源样本的特征,G(XQ)表示目标样本的特征,YP表示源样本的标签,YQ表示目标样本的标签。由于目标样本是未标注的数据,因此YQ是伪标签。
特征提取器G针对已标注的源样本XP提取特征G(XP),并且针对未标注的目标样本XQ提取特征G(XQ)。标签预测器C基于所提取的特征G(XP)来预测源样本XP的标签YP,并且基于所提取的特征G(XQ)来预测目标样本XQ的标签YQ。领域判别器D基于所提取的特征G(XP)和特征G(XQ)来判别输入的样本是源样本还是目标样本。
在对标签预测器C的训练中,利用标签预测损失函数
Figure BDA0002407455090000041
以及加权标签预测损失函数
Figure BDA0002407455090000042
以下数学式(1)和数学式(2)分别示出了这两个损失函数:
Figure BDA0002407455090000043
Figure BDA0002407455090000044
其中,E表示数学期望,P表示源数据的分布,l表示交叉熵损失函数,C表示分类器,即标签预测器。r(x,y)表示目标数据集和源数据集之间的联合概率密度比,并且可以表示为以下数学式(3):
Figure BDA0002407455090000045
其中,p(x,y)表示源域的联合概率密度函数,q(x,y)表示目标域的联合概率密度函数。
在加权标签预测损失函数
Figure BDA0002407455090000051
中,每个源样本的预测损失被联合概率密度比r(x,y)加权,因此损失函数
Figure BDA0002407455090000052
可以近似于目标数据集上的损失函数。因此,在标签预测器C的训练中使用损失函数
Figure BDA0002407455090000053
可以提高模型在目标数据集上的性能。
在本发明中,可以基于领域判别器D的输出来计算联合概率密度比r(x,y)。因此,领域判别器D不仅用于分辨源样本和目标样本,而且用于估计联合概率密度比。以下将进行详细说明。
首先,由数学式(4)表示领域判别器D的判别损失函数
Figure BDA0002407455090000054
Figure BDA0002407455090000055
其中,P表示源数据的分布,Q表示目标数据的分布。
损失函数
Figure BDA0002407455090000056
以D为自变量,因此需要计算使得损失函数
Figure BDA0002407455090000057
为最小值时的最优解D*(G(x),y)。
根据数学式(4)可以得到:
Figure BDA0002407455090000058
代入D*(G(x),y),进一步得到:
D*(G(x),y)=argmaxD(G(x),y)p(x,y)log(D(G(x),y))+q(x,y)log(1-D(G(x),y))--(6)
由于函数f(d)=p log(d)+q log(1-d)在区间(0,1)上取得最大值时,变量
Figure BDA0002407455090000059
因此根据数学式(6)可以得出:
Figure BDA00024074550900000510
其中D(G(x),y)∈(0,1)--(7)
然后,结合数学式(3)的定义,可以将联合概率密度比r(x,y)表示为以下数学式(8):
Figure BDA0002407455090000061
根据数学式(8),可以利用领域判别器D的损失函数
Figure BDA0002407455090000062
为最小值时的输出D*(G(x),y)来计算联合概率密度比r(x,y)。
由于领域判别器D的输出被用于估计联合概率密度比,因此对领域判别器D不仅要输入源样本和目标样本的特征G(XP)、G(XQ),而且要输入源样本和目标样本的标签YP、YQ。由于目标样本的标签是未知的,因此本发明中使用由标签预测器C对于目标样本的标签预测结果作为伪标签YQ输入领域判别器D中。
在本发明中,利用估计得到的联合概率密度比对针对源数据集的预测损失进行加权,然后使加权后的损失函数
Figure BDA0002407455090000063
最小化,由此可以实现类别级别的对齐,而不仅仅是特征空间的对齐。作为对比,仅实现特征空间的对齐会产生以下问题:虽然在特征空间中拉近了源域和目标域的样本的特征,但是不同类别的样本的特征可能会混杂在一起,从而不能很好地区分各个类别的样本。本发明由于可实现类别级别的对齐,因此能够解决这一问题。
在本发明中,使用损失函数的加权和来训练图2所示的模型。例如,如数学式(9)所示,基于以上描述的损失函数
Figure BDA0002407455090000064
的加权和来训练模型。
Figure BDA0002407455090000065
在数学式(9)中,权值λC用于控制在优化标签预测器C和特征提取器G的过程中损失函数
Figure BDA0002407455090000066
起作用的程度。在训练的开始阶段,由于对联合概率密度比的估计不准确,优选的是将λC的值设置得比较小,从而主要根据针对源样本的预测损失来优化模型。随着训练的进行,对联合概率密度比的估计逐渐变得准确,因此可以逐渐增大λC的值,使得可以根据目标样本的预测损失来优化模型。例如,在训练过程中,λC可以从初始值“0”逐渐增大到“1”。可以通过以下数学式(10)来表示λC
λC=α·min((2p)n,1)--(10)
其中,p表示训练进度,其线性地从0增大到1,α和n表示超参数。可以根据经验来设置α和n,例如,都设置为10。图3示出了权值λC的曲线,其中α被设置为1,n被设置为10。
另一方面,如图2所示,在反向传播过程中,梯度反转层240通过将来自领域判别器D的梯度乘以负常数“-λp”来实现梯度反转,这与DANN中的梯度反转层相似。在根据本发明的训练过程中,与参数λC类似地,逐渐地增大参数λp的值,例如,从初始值“0”逐渐增大到“1”。可以通过以下数学式(11)来表示λp
Figure BDA0002407455090000071
其中,p表示训练进度,其线性地从0增大到1,γ表示超参数。可以根据经验将其设置为例如10。图4示出了参数λp的曲线。需要说明的是,由于λp只被用于优化特征提取器G,因此其没有出现在数学式(9)中。
图5示出了根据本实施例的领域对抗神经网络的训练方法的流程图,图6示出了根据本实施例的领域对抗神经网络的训练装置的模块化框图。参考图5和图6,在步骤S510,根据数学式(1),基于标签预测器C的输出来构建损失函数
Figure BDA0002407455090000072
损失函数
Figure BDA0002407455090000073
是与源数据集有关的预测损失。可以由图6中的第一损失函数生成单元610执行此步骤。
在步骤S520,根据数学式(8),基于领域判别器D的输出来确定源数据集与目标数据集之间的联合概率密度比r(x,y)。可以由图6中的联合概率密度比确定单元640执行此步骤。
在步骤S530,通过利用联合概率密度比r(x,y)对损失函数
Figure BDA0002407455090000074
进行加权来获得损失函数
Figure BDA0002407455090000075
如数学式(2)所示。损失函数
Figure BDA0002407455090000076
能够近似与目标数据集有关的预测损失。可以由图6中的第二损失函数生成单元620执行此步骤。
在步骤S540,根据数学式(4),基于领域判别器D的输出来构建损失函数
Figure BDA0002407455090000081
可以由图6中的第三损失函数生成单元630执行此步骤。
在步骤S550,根据数学式(9),基于损失函数
Figure BDA0002407455090000082
损失函数
Figure BDA0002407455090000083
和损失函数
Figure BDA0002407455090000084
的加权组合来训练领域对抗神经网络模型。特别地,随着训练的进行,可以逐渐地增大损失函数
Figure BDA0002407455090000085
的权值λC。可以由图6中的训练单元650执行此步骤。以下将结合图7来描述根据本发明的另一个实施例的领域自适应框架。在本实施例中,将上文描述的JDA与自集成(self-ensembling)进行结合(JDA-SE)。如图7所示,根据本发明的领域对抗神经网络模型包括作为学生(student)网络的特征提取器G和标签预测器C,以及作为教师(teacher)网络的教师-特征提取器G和教师-标签预测器C。教师网络与相应的学生网络具有相同的网络结构。不同的是,教师网络的网络参数是不可训练的。此外,教师-特征提取器G的网络参数是特征提取器G的参数的指数移动平均值,教师-标签预测器C的网络参数是标签预测器C的参数的指数移动平均值。
本实施例与图2所示实施例的不同在于引入了目标一致性损失函数
Figure BDA0002407455090000086
如以下数学式(12)所示,其用于保证教师网络和学生网络之间预测结果的一致性。
Figure BDA0002407455090000087
在本实施例中,基于损失函数
Figure BDA0002407455090000088
的加权和来训练模型,如数学式(13)所示。
Figure BDA0002407455090000089
与λC类似地,数学式(13)中的λTC用于控制在优化标签预测器C和特征提取器G的过程中损失函数
Figure BDA00024074550900000810
起作用的程度。随着训练的进行,可以逐渐增大λTC的值,例如从初始值“0”逐渐增大到“1”。可以通过以下数学式(14)来表示λTC
λTC=α·pn+β--(14)
其中,p表示训练进度,其线性地从0增大到1。参数α、n和β可以根据经验预先设置,例如,α被设置为100,n被设置为10,β被设置为0。图8示出了参数λTC的曲线。
以下表1示出了根据本发明的JDA和JDA-SE方案与现有方案(如PFAN、ADDA、DANN等)的性能对比。基于MNIST<—>USPS数据集(公知的手写字符数据集)进行该对比。表1中的数值表示分类准确率,准确率越高,方案的性能越好。
特别地,表1中的Vanilla(source only)表示只利用源数据集的已标注数据、而不利用目标数据集的数据进行训练的方案,这是最简单的方案,作为比较的基准。此外,MDA方案是JDA的变体,在其中将联合概率密度比替换为边缘概率密度比。在表1中列出MDA方案的性能作为对比。
[表1]
方案 MNIST—>USPS USPS—>MNIST
Vanilla(Source only) 75.2±1.6 57.1±1.7
自集成(2018) 88.3±0.8 -
DANN(JMLR 2016) 88.6±2.1 87.3±5.7
ADDA(CVPR 2016) 89.4±0.2 90.1±0.8
PFAN(CVPR 2019) 95.0±1.3 -
MDA 93.9±1.4 94.5±0.6
JDA 94.1±0.9 94.8±0.8
JDA-SE 95.2±0.8 95.3±0.5
根据表1可以看出根据本发明的JDA和JDA-SE方法具有更好的性能。此外,相比于采用边缘概率密度比的移位补偿网络(Shift Compensation Network,表1中未示出),根据本发明的方法的优点在于不需要依赖于任何类型的移位(如协变量移位,标签移位等)的成立,而移位补偿网络需要依赖于协变量移位的成立。
根据本发明的领域自适应方法能够应用于广泛的领域,以下仅以举例方式给出有代表性的应用场景。
[应用场景一]语义分割(semantic segmentation)
语义分割是指将图像中表示不同物体的部分用不同颜色标识出来。图9示出了语义分割的一个示例。图9中最左边的两幅图像是原始图像,从左起第2列的两幅图像是对两幅原始图像的分割结果的真值(ground truth),其它列的图像是采用不同语义分割方法的分割结果。
在语义分割的应用场景中,由于对真实世界的图像进行人工标注的代价非常高,因此真实世界的图像很少是带有标签的。在此情况下,一种替代方法是利用仿真环境(如3D游戏)中的场景的图像来进行训练。由于在仿真环境中很容易通过编程来实现对物体的自动标注,因此很容易得到有标签的数据。这样,利用仿真环境中生成的有标签的数据来训练模型,然后利用经训练的模型来处理真实环境的图像。但是,由于仿真环境不可能与真实环境完全一致,因此利用仿真环境的数据所训练的模型在处理真实环境的图像时性能会大打折扣。
在此情况下,使用根据本发明的领域自适应方法,可以基于有标签的仿真环境数据和无标签的真实环境数据进行训练,从而提高模型处理真实环境图像的性能。
[应用场景二]手写字符的识别
手写字符通常包括手写的数字、文字(如中文、日文)等。在手写字符的识别中,常用的有标签的字符集包括MNIST、USPS、SVHN等,通常利用这些有标签的字符数据来训练模型。然而,在将经训练的模型应用于实际(无标签)的手写字符的识别时,其准确率可能会降低。
在此情况下,使用根据本发明的领域自适应方法,可以基于有标签的源数据和无标签的目标数据进行训练,从而提高模型处理目标数据的性能。
[应用场景三]时间序列数据的分类和预测
时间序列数据的预测例如包括空气污染指数预测、ICU病人住院时长(LOS)的预测、股票行情预测等等。以下将以细颗粒物PM 2.5指数的时间序列数据为例进行描述。
PM 2.5指数的时间序列数据集中的每一条数据记录了一定时间范围内(如1小时内)某个地区的PM 2.5指数、温度、气压、风速、风向、累计降雨量、累计降雪量等信息。假定需要预测三天后该地区的PM 2.5指数的范围。为此,构建预测模型,并且选取指定长度的数据来构建样本。例如,选择数据集中特定24小时的数据作为一个样本,并且每个数据包括8个维度的特征,由此该样本包括24*8维的特征。然后,根据数据集中的三天后的PM 2.5指数所处的范围来给该样本分配标签。以此方式,可以构建训练样本集,从而利用该训练样本集来训练预测模型。训练完成后,可以将训练好的模型应用于实际预测中,例如,基于当前时刻之前24个小时的数据(无标签数据)来预测三天后的PM 2.5指数的范围。
在此场景中,通过使用根据本发明的领域自适应方法,可以基于有标签的数据和无标签的数据来训练模型,从而提高模型的预测准确度。
[应用场景四]表格型数据的分类和预测
表格型数据可以包括金融数据,例如网络借贷数据。在此示例中,为了预测贷款者是否存在逾期还款的可能性,可以构建预测模型,并且使用根据本发明的方法来训练模型。
[应用场景五]图像识别
图像识别或图像分类是深度(卷积)神经网络比较擅长的领域。与语义分割场景类似,在此应用场景中,也存在着对于真实世界的图像数据集进行标注的代价高昂的问题。因此,可以使用根据本发明的领域自适应方法,选择一个已标注的数据集(如ImageNet)作为源数据集,基于该源数据集和未标注的目标数据集进行训练,从而获得性能满足要求的模型。
在上述实施例中描述的方法可以由软件、硬件或者软件和硬件的组合来实现。包括在软件中的程序可以事先存储在设备的内部或外部所设置的存储介质中。作为一个示例,在执行期间,这些程序被写入随机存取存储器(RAM)并且由处理器(例如CPU)来执行,从而实现在本文中描述的各种方法和处理。
图10示出了根据程序执行本发明的方法的计算机硬件的示例配置框图,该计算机硬件是根据本发明的用于训练领域对抗神经网络模型的装置的一个示例。
如图10所示,在计算机1000中,中央处理单元(CPU)1001、只读存储器(ROM)1002以及随机存取存储器(RAM)1003通过总线1004彼此连接。
输入/输出接口1005进一步与总线1004连接。输入/输出接口1005连接有以下组件:以键盘、鼠标、麦克风等形成的输入单元1006;以显示器、扬声器等形成的输出单元1007;以硬盘、非易失性存储器等形成的存储单元1008;以网络接口卡(诸如局域网(LAN)卡、调制解调器等)形成的通信单元1009;以及驱动移动介质1011的驱动器1010,该移动介质1011例如是磁盘、光盘、磁光盘或半导体存储器。
在具有上述结构的计算机中,CPU 1001将存储在存储单元1008中的程序经由输入/输出接口1005和总线1004加载到RAM 1003中,并且执行该程序,以便执行上文中描述的方法。
要由计算机(CPU 1001)执行的程序可以被记录在作为封装介质的移动介质1011上,该封装介质以例如磁盘(包括软盘)、光盘(包括压缩光盘-只读存储器(CD-ROM))、数字多功能光盘(DVD)等)、磁光盘、或半导体存储器来形成。此外,要由计算机(CPU 1001)执行的程序也可以经由诸如局域网、因特网、或数字卫星广播的有线或无线传输介质来提供。
当移动介质1011安装在驱动器1010中时,可以将程序经由输入/输出接口1005安装在存储单元1008中。另外,可以经由有线或无线传输介质由通信单元1009来接收程序,并且将程序安装在存储单元1008中。可替选地,可以将程序预先安装在ROM 1002或存储单元1008中。
由计算机执行的程序可以是根据本说明书中描述的顺序来执行处理的程序,或者可以是并行地执行处理或当需要时(诸如,当调用时)执行处理的程序。
本文中所描述的单元或装置仅是逻辑意义上的,并不严格对应于物理设备或实体。例如,本文所描述的每个单元的功能可能由多个物理实体来实现,或者,本文所描述的多个单元的功能可能由单个物理实体来实现。此外,在一个实施例中描述的特征、部件、元素、步骤等并不局限于该实施例,而是也可以应用于其它实施例,例如替代其它实施例中的特定特征、部件、元素、步骤等,或者与其相结合。
本发明的范围不限于在本文中描述的具体实施例。本领域普通技术人员应该理解的是,取决于设计要求和其他因素,在不偏离本发明的原理和精神的情况下,可以对本文中的实施例进行各种修改或变化。本发明的范围由所附权利要求及其等同方案来限定。
附记:
1.一种用于训练领域对抗神经网络模型的方法,所述领域对抗神经网络模型包括:
特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;
标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;
判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据;
所述方法包括:
基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;
通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而获得第二损失函数;
利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。
2.根据1所述的方法,其中,所述领域对抗神经网络模型用于执行图像识别,并且所述源数据和所述目标数据是图像数据,或者
所述领域对抗神经网络模型用于处理金融数据,并且所述源数据和所述目标数据是表格类型数据,或者
所述领域对抗神经网络模型用于处理环境气象数据或医疗数据,并且所述源数据和所述目标数据是时间序列数据。
3.根据1所述的方法,其中,所述第二损失函数能够近似与所述目标数据有关的预测损失。
4.根据1所述的方法,还包括:基于所述判别单元的输出来确定所述联合概率密度比。
5.根据4所述的方法,其中,通过以下等式来计算所述联合概率密度比r:
Figure BDA0002407455090000141
其中,D表示所述判别单元的输出。
6.根据4所述的方法,其中,所述判别单元还被提供以所述源数据的标签以及由所述标签预测单元预测的所述目标数据的标签。
7.根据1所述的方法,还包括:
基于所述判别单元的输出来构建第三损失函数
基于所述第一损失函数、所述第二损失函数和所述第三损失函数的加权组合来训练所述领域对抗神经网络模型;以及
随着训练的进行,逐渐增大用于对所述第二损失函数加权的权值。
8.根据1所述的方法,其中,所述判别单元经由梯度反转单元与所述特征提取单元连接,并且所述判别单元与所述特征提取单元以相互对抗的方式操作。
9.根据1所述的方法,其中,所述领域对抗神经网络还包括另一特征提取单元和另一标签预测单元,其中,所述另一特征提取单元的参数是所述特征提取单元的参数的指数移动平均,所述另一标签预测单元的参数是所述标签预测单元的参数的指数移动平均,
所述方法还包括:基于所述特征提取单元、所述另一特征提取单元、所述标签预测单元和所述另一标签预测单元的输出来构建第四损失函数。
10.根据9所述的方法,还包括:
基于所述第一损失函数、所述第二损失函数、所述第三损失函数和所述第四损失函数的加权组合来训练所述领域对抗神经网络模型;以及随着训练的进行,逐渐增大用于对所述第四损失函数加权的权值。
11.一种用于训练领域对抗神经网络模型的装置,所述领域对抗神经网络模型包括:
特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;
标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;
判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据;
所述装置包括:
存储有程序的存储器;以及
一个或多个处理器,所述处理器通过执行所述程序而执行以下操作:
基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;
通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而获得第二损失函数;
利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。
12.一种用于训练领域对抗神经网络模型的装置,所述领域对抗神经网络模型包括:
特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;
标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;
判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据;
所述装置包括:
第一损失函数生成单元,其基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;
第二损失函数生成单元,其通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而生成第二损失函数;
训练单元,其利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。
13.一种存储有用于训练领域对抗神经网络模型的程序的存储介质,所述领域对抗神经网络模型包括:
特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;
标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;
判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据;
所述程序在被计算机执行时使得所述计算机执行包括以下步骤的方法:
基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;
通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而获得第二损失函数;
利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。

Claims (10)

1.一种用于训练领域对抗神经网络模型的方法,所述领域对抗神经网络模型包括:
特征提取单元,其用于针对输入的已标注的源数据提取第一特征,并且针对输入的未标注的目标数据提取第二特征;
标签预测单元,其基于所提取的第一特征来预测源数据的标签,并且基于所提取的第二特征来预测目标数据的标签;
判别单元,其基于所提取的第一特征和第二特征来判别输入的数据是源数据还是目标数据;
所述方法包括:
基于所述标签预测单元的输出来构建第一损失函数,其中,所述第一损失函数是与所述源数据有关的预测损失;
通过利用所述源数据和所述目标数据之间的联合概率密度比对所述第一损失函数加权而获得第二损失函数;
利用所述第一损失函数和所述第二损失函数来训练所述标签预测单元和所述特征提取单元。
2.根据权利要求1所述的方法,其中,所述领域对抗神经网络模型用于执行图像识别,并且所述源数据和所述目标数据是图像数据,或者
所述领域对抗神经网络模型用于处理金融数据,并且所述源数据和所述目标数据是表格类型数据,或者
所述领域对抗神经网络模型用于处理环境气象数据或医疗数据,并且所述源数据和所述目标数据是时间序列数据。
3.根据权利要求1所述的方法,其中,所述第二损失函数能够近似与所述目标数据有关的预测损失。
4.根据权利要求1所述的方法,还包括:基于所述判别单元的输出来确定所述联合概率密度比。
5.根据权利要求4所述的方法,其中,通过以下等式来计算所述联合概率密度比r:
Figure FDA0002407455080000021
其中,D表示所述判别单元的输出。
6.根据权利要求4所述的方法,其中,所述判别单元还被提供以所述源数据的标签以及由所述标签预测单元预测的所述目标数据的标签
7.根据权利要求1所述的方法,还包括:
基于所述判别单元的输出来构建第三损失函数
基于所述第一损失函数、所述第二损失函数和所述第三损失函数的加权组合来训练所述领域对抗神经网络模型;以及
随着训练的进行,逐渐增大用于对所述第二损失函数加权的权值。
8.根据权利要求1所述的方法,其中,所述判别单元经由梯度反转单元与所述特征提取单元连接,并且所述判别单元与所述特征提取单元以相互对抗的方式操作。
9.根据权利要求1所述的方法,其中,所述领域对抗神经网络还包括另一特征提取单元和另一标签预测单元,其中,所述另一特征提取单元的参数是所述特征提取单元的参数的指数移动平均,所述另一标签预测单元的参数是所述标签预测单元的参数的指数移动平均,
所述方法还包括:基于所述特征提取单元、所述另一特征提取单元、所述标签预测单元和所述另一标签预测单元的输出来构建第四损失函数。
10.根据权利要求9所述的方法,还包括:
基于所述第一损失函数、所述第二损失函数、所述第三损失函数和所述第四损失函数的加权组合来训练所述领域对抗神经网络模型;以及
随着训练的进行,逐渐增大用于对所述第四损失函数加权的权值。
CN202010165937.XA 2020-03-11 2020-03-11 领域对抗神经网络的训练方法 Pending CN113392967A (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202010165937.XA CN113392967A (zh) 2020-03-11 2020-03-11 领域对抗神经网络的训练方法
JP2021020084A JP2021144703A (ja) 2020-03-11 2021-02-10 ドメイン敵対的ニューラルネットワークの訓練方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010165937.XA CN113392967A (zh) 2020-03-11 2020-03-11 领域对抗神经网络的训练方法

Publications (1)

Publication Number Publication Date
CN113392967A true CN113392967A (zh) 2021-09-14

Family

ID=77615398

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010165937.XA Pending CN113392967A (zh) 2020-03-11 2020-03-11 领域对抗神经网络的训练方法

Country Status (2)

Country Link
JP (1) JP2021144703A (zh)
CN (1) CN113392967A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113902898A (zh) * 2021-09-29 2022-01-07 北京百度网讯科技有限公司 目标检测模型的训练、目标检测方法、装置、设备和介质
CN114358283A (zh) * 2022-01-12 2022-04-15 深圳大学 气体识别神经网络模型的优化方法及相关设备
CN117911852A (zh) * 2024-03-20 2024-04-19 西北工业大学 基于部分无监督领域自适应的水下目标距离预测方法

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114947792B (zh) * 2022-05-19 2024-05-03 北京航空航天大学 一种基于视频的生理信号测量与增强方法
CN114821282B (zh) * 2022-06-28 2022-11-04 苏州立创致恒电子科技有限公司 一种基于域对抗神经网络的图像检测装置及方法

Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170220951A1 (en) * 2016-02-02 2017-08-03 Xerox Corporation Adapting multiple source classifiers in a target domain
US20180101768A1 (en) * 2016-10-07 2018-04-12 Nvidia Corporation Temporal ensembling for semi-supervised learning
CN107944410A (zh) * 2017-12-01 2018-04-20 中国科学院重庆绿色智能技术研究院 一种基于卷积神经网络的跨领域面部特征解析方法
US20180260957A1 (en) * 2017-03-08 2018-09-13 Siemens Healthcare Gmbh Automatic Liver Segmentation Using Adversarial Image-to-Image Network
CN108694443A (zh) * 2017-04-05 2018-10-23 富士通株式会社 基于神经网络的语言模型训练方法和装置
CN109580215A (zh) * 2018-11-30 2019-04-05 湖南科技大学 一种基于深度生成对抗网络的风电传动***故障诊断方法
CN109635280A (zh) * 2018-11-22 2019-04-16 园宝科技(武汉)有限公司 一种基于标注的事件抽取方法
US20190130220A1 (en) * 2017-10-27 2019-05-02 GM Global Technology Operations LLC Domain adaptation via class-balanced self-training with spatial priors
CN110222690A (zh) * 2019-04-29 2019-09-10 浙江大学 一种基于最大二乘损失的无监督域适应语义分割方法
US20190354807A1 (en) * 2018-05-16 2019-11-21 Nec Laboratories America, Inc. Domain adaptation for structured output via disentangled representations
CN110750665A (zh) * 2019-10-12 2020-02-04 南京邮电大学 基于熵最小化的开集域适应方法及***
CN110837850A (zh) * 2019-10-23 2020-02-25 浙江大学 一种基于对抗学习损失函数的无监督域适应方法

Patent Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170220951A1 (en) * 2016-02-02 2017-08-03 Xerox Corporation Adapting multiple source classifiers in a target domain
US20180101768A1 (en) * 2016-10-07 2018-04-12 Nvidia Corporation Temporal ensembling for semi-supervised learning
US20180260957A1 (en) * 2017-03-08 2018-09-13 Siemens Healthcare Gmbh Automatic Liver Segmentation Using Adversarial Image-to-Image Network
CN108694443A (zh) * 2017-04-05 2018-10-23 富士通株式会社 基于神经网络的语言模型训练方法和装置
US20190130220A1 (en) * 2017-10-27 2019-05-02 GM Global Technology Operations LLC Domain adaptation via class-balanced self-training with spatial priors
CN107944410A (zh) * 2017-12-01 2018-04-20 中国科学院重庆绿色智能技术研究院 一种基于卷积神经网络的跨领域面部特征解析方法
US20190354807A1 (en) * 2018-05-16 2019-11-21 Nec Laboratories America, Inc. Domain adaptation for structured output via disentangled representations
CN109635280A (zh) * 2018-11-22 2019-04-16 园宝科技(武汉)有限公司 一种基于标注的事件抽取方法
CN109580215A (zh) * 2018-11-30 2019-04-05 湖南科技大学 一种基于深度生成对抗网络的风电传动***故障诊断方法
CN110222690A (zh) * 2019-04-29 2019-09-10 浙江大学 一种基于最大二乘损失的无监督域适应语义分割方法
CN110750665A (zh) * 2019-10-12 2020-02-04 南京邮电大学 基于熵最小化的开集域适应方法及***
CN110837850A (zh) * 2019-10-23 2020-02-25 浙江大学 一种基于对抗学习损失函数的无监督域适应方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
蔡兴泉等: "基于CNN网络和多任务损失函数的实时叶片识别", ***仿真学报, no. 07 *

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113902898A (zh) * 2021-09-29 2022-01-07 北京百度网讯科技有限公司 目标检测模型的训练、目标检测方法、装置、设备和介质
CN114358283A (zh) * 2022-01-12 2022-04-15 深圳大学 气体识别神经网络模型的优化方法及相关设备
CN117911852A (zh) * 2024-03-20 2024-04-19 西北工业大学 基于部分无监督领域自适应的水下目标距离预测方法

Also Published As

Publication number Publication date
JP2021144703A (ja) 2021-09-24

Similar Documents

Publication Publication Date Title
CN113392967A (zh) 领域对抗神经网络的训练方法
CN110188358B (zh) 自然语言处理模型的训练方法及装置
CN109583501B (zh) 图片分类、分类识别模型的生成方法、装置、设备及介质
JP2022042487A (ja) ドメイン適応型ニューラルネットワークの訓練方法
CN110852447A (zh) 元学习方法和装置、初始化方法、计算设备和存储介质
CN113469186B (zh) 一种基于少量点标注的跨域迁移图像分割方法
CN110826609B (zh) 一种基于强化学习的双流特征融合图像识别方法
CN110929640B (zh) 一种基于目标检测的宽幅遥感描述生成方法
CN117611932B (zh) 基于双重伪标签细化和样本重加权的图像分类方法及***
EP4060548A1 (en) Method and device for presenting prompt information and storage medium
CN116432655B (zh) 基于语用知识学习的少样本命名实体识别方法和装置
CN115690534A (zh) 一种基于迁移学习的图像分类模型的训练方法
CN115482418B (zh) 基于伪负标签的半监督模型训练方法、***及应用
CN116450813B (zh) 文本关键信息提取方法、装置、设备以及计算机存储介质
CN111507406A (zh) 一种用于优化神经网络文本识别模型的方法与设备
JP2010282276A (ja) 映像認識理解装置、映像認識理解方法、及びプログラム
CN112926631A (zh) 金融文本的分类方法、装置及计算机设备
CN117437461A (zh) 一种面向开放世界的图像描述生成方法
CN116541507A (zh) 一种基于动态语义图神经网络的视觉问答方法及***
CN116630694A (zh) 一种偏多标记图像的目标分类方法、***及电子设备
CN116433909A (zh) 基于相似度加权多教师网络模型的半监督图像语义分割方法
CN113379037B (zh) 一种基于补标记协同训练的偏多标记学习方法
CN114973350A (zh) 一种源域数据无关的跨域人脸表情识别方法
CN114997175A (zh) 一种基于领域对抗训练的情感分析方法
CN117036790B (zh) 一种小样本条件下的实例分割多分类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination