CN114005075B - 一种光流估算模型的构建方法、装置及光流估算方法 - Google Patents
一种光流估算模型的构建方法、装置及光流估算方法 Download PDFInfo
- Publication number
- CN114005075B CN114005075B CN202111635874.0A CN202111635874A CN114005075B CN 114005075 B CN114005075 B CN 114005075B CN 202111635874 A CN202111635874 A CN 202111635874A CN 114005075 B CN114005075 B CN 114005075B
- Authority
- CN
- China
- Prior art keywords
- image pair
- domain image
- optical flow
- network
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种光流估算模型的构建方法、装置及光流估算方法,所述方法包括:将仿真域图像对、仿真域图像对的光流真值及真实域图像对输入至初始神经网络模型中进行迭代训练,得到光流估算模型;在对初始神经网络模型进行训练时,以仿真域图像对及真实域图像对为输入对生成对抗网络进行对抗训练,以使生成对抗网络生成第一转换域图像对及第二转换域图像对;以仿真域图像对、第一转换域图像对及仿真域图像对的光流真值为输入对光流计算网络进行有监督训练,以第二转换域图像对为输入,对光流计算网络进行无监督训练。通过实施本发明能降低构建光流估算模型的人力成本,并提高对真实域图像的光流值的估算准确性。
Description
技术领域
本发明涉及计算机技术领域,尤其涉及一种光流估算模型的训练方法、装置及光流估算方法。
背景技术
目前的深度学习模型极度依赖于有任务相关的真值数据进行监督训练,但对于光流计算任务,获取真实域的光流真值是极其困难的。因此在现有技术中对于光流计算模型一般依赖于仿真域数据的预训练和极少量的真实域的光流真值进行微调训练。这里,域通常是指数据来源不同,例如通过游戏仿真的自动驾驶场景图像和摄像头直接采集的真实世界的自动驾驶图像,可归属于不同的“域”。
但是采用上述方法存在以下问题:1、在实际训练过程中,即使是获取极少的真实域的光流真值数据也需要耗费大量的人力成本;2、通常情况下基于一个域上的数据所训练得到的模型,在另一个域的数据上测试的结果较差,而由于上述方法所训练的模型其训练数据大部分是基于仿真域的数据进行训练的,这就导致了采用上述方法所训练出的光流计算模型,在对真实域的图像进行光流计算时,模型的泛化能力较弱,所计算出的光流值准确性较低。
发明内容
本发明实施例提供一种光流估算模型的构建方法、装置以及光流计算方法,能够降低构建光流估算模型的人力成本,并提高对真实域图像的光流值的估算准确性。
本发明一实施例提供了一种光流估算模型的构建方法,包括:获取仿真域训练集以及真实域训练集;其中,所述仿真域训练集中的每一仿真训练样本包括:相邻帧的仿真域图像对以及仿真域图像对的光流真值;所述真实域训练集中的每一真实训练样本包括相邻帧的真实域图像对;
将所述仿真域训练集以及真实域训练集输入至初始神经网络模型中进行迭代训练,直至达到预设训练次数或所述初始神经网络模型的总损失函数值达到预设值,得到光流估算模型;
其中,所述初始神经网络模型包括:生成对抗网络以及光流计算网络;
在对所述初始神经网络模型进行迭代训练时,以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,以使所述生成对抗网络将仿真域图像对以及真实域图像对转换至同一数据域,生成仿真域图像对所对应的第一转换域图像对,以及真实域图像对所对应的第二转换域图像对;
以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练;各重组图像对根据所述仿真域图像对以及第一转换域图像对生成。
进一步的,所述生成对抗网络包括:生成网络以及判别网络;所述以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,具体包括:
以所述仿真域图像对以及真实域图像对为输入并根据生成网络损失函数以及判别网络损失函数对所述生成对抗网络进行训练;
其中,所述生成网络损失函数为:
所述判别网络损失函数为:
G为生成网络,D为判别网络,S~p(S)表示来自仿真域图像对的仿真域图像,T~p(T) 表示来自真实域图像对的真实域图像,D(G(S))表示判别网络D对于生成网络G所编码的仿真域图像S的特征的分类分数,D(G(T))代表判别网络D对生成网络G所编码的真实域图像T的特征的分类分数,E为期望,c为判别网络D判定生成网络G编码的真实域图像T的特征和仿真域图像S的特征属于同一转换域的目标值,a为真实域图像T的特征所对应的判别网络输出目标值,b为仿真域图像S的特征所对应的判别网络输出目标值。
进一步的,所述以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,具体包括:
对所述仿真域图像对以及第一转换域图像对中的图像进行图像对重组,生成若干重组图像对;
根据所述仿真域图像对的光流真值确定各重组图像对的光流真值;
以各重组图像对以及各重组图像对的光流真值为输入,以各重组图像对的光流估算值为输出,并根据有监督训练损失函数对光流计算网络进行有监督训练;
其中,所述有监督训练损失函数为:L1=|F’-F|;F’为重组图像对的光流估算值,F为重组图像对的光流真值。
进一步的,所述以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练,具体包括:
以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,并根据无监督训练损失函数对光流计算网络进行无监督训练;
L为无监督训练损失函数,α和β为预设的平衡参数,ρ为预设的惩罚函数,T1*和T2*为第二转换域图像对中的两帧相邻图像,(x, y)为图像中像素点的坐标,(u, v)为像素点的光流估算值,∇为预设梯度算子。
在上述方法项实施例的基础上,本发明对应提供了装置实施例;
本发明一实施例提供了一种光流估算模型的构建装置,包括:数据获取模块以及模型训练模块;其中,所述模型训练模块包括第一训练模块和第二训练模块;
所述数据获取模块,用于获取仿真域训练集以及真实域训练集;其中,所述仿真域训练集中的每一仿真训练样本包括:相邻帧的仿真域图像对以及仿真域图像对的光流真值;所述真实域训练集中的每一真实训练样本包括相邻帧的真实域图像对;
所述模型训练模块,用于将所述仿真域训练集以及真实域训练集输入至初始神经网络模型中进行迭代训练,直至达到预设训练次数或所述初始神经网络模型的总损失函数值达到预设值,得到光流估算模型;其中,所述初始神经网络模型包括:生成对抗网络以及光流计算网络;
在对所述初始神经网络模型进行迭代训练时,所述第一训练模块以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,以使所述生成对抗网络将仿真域图像对以及真实域图像对转换至同一数据域,生成仿真域图像对所对应的第一转换域图像对,以及真实域图像对所对应的第二转换域图像对;
所述第二训练模块以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练;各重组图像对根据所述仿真域图像对以及第一转换域图像对生成。
进一步的,所述生成对抗网络包括:生成网络以及判别网络;第一训练模块,以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,具体包括:
第一训练模块以所述仿真域图像对以及真实域图像对为输入并根据生成网络损失函数以及判别网络损失函数对所述生成对抗网络进行训练;
其中,所述生成网络损失函数为:
所述判别网络损失函数为:
G为生成网络,D为判别网络,S~p(S)表示来自仿真域图像对的仿真域图像,T~p(T) 表示来自真实域图像对的真实域图像,D(G(S))表示判别网络D对于生成网络G所编码的仿真域图像S的特征的分类分数,D(G(T))代表判别网络D对生成网络G所编码的真实域图像T的特征的分类分数,E为期望,c为判别网络D判定生成网络G编码的真实域图像T的特征和仿真域图像S的特征属于同一转换域的目标值,a为真实域图像T的特征所对应的判别网络输出目标值,b为仿真域图像S的特征所对应的判别网络输出目标值。
进一步的,第二训练模块以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,具体包括:
第二训练模块对所述仿真域图像对以及第一转换域图像对中的图像进行图像对重组,生成若干重组图像对;
根据所述仿真域图像对的光流真值确定各重组图像对的光流真值;
以各重组图像对以及各重组图像对的光流真值为输入,以各重组图像对的光流估算值为输出,并根据有监督训练损失函数对光流计算网络进行有监督训练;
其中,所述有监督训练损失函数为:L1=|F’-F|;F’为重组图像对的光流估算值,F为重组图像对的光流真值。
进一步的,第二训练模块以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练,具体包括:
以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,并根据无监督训练损失函数对光流计算网络进行无监督训练;
L为无监督训练损失函数,α和β为预设的平衡参数,ρ为预设的惩罚函数,T1*和T2*为第二转换域图像对中的两帧相邻图像,(x, y)为图像中像素点的坐标,(u, v)为像素点的光流估算值,∇为预设梯度算子。
在上述方法项实施例的基础上,本发明另一实施例提供了一种光流估算方法,所述光流估算方法包括:获取待估算真实域图像对,并将所述待估算真实域图像对输入通过上述任意一项所述的光流估算模型的构建方法所构建的光流估算模型中,以使所述光流估算模型输出所述待估算真实域图像对的光流估算值。
通过实施本发明具有如下有益效果:
本发明实施例提供了一种光流估算模型的构建方法、装置及光流估算方法,在构建光流估算模型时,基于以仿真域图像对以及真实域图像对为输入,训练出用于将仿真域图像对及真实域图像对转换至同一数据域的转换域图像对,生成第一转换域图像对和第二转换域图像对,紧接着,以仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对以及第一转换域图像对生成所生成各重组图像对的光流估算值为输出,对光流估算模型中的光流计算网络进行有监督训练,以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练;与现有技术相比,当使用本发明所提供的光流估算模型进行光流值估算时,光流估算模型的生成对抗网络能够将待估算的真实域图像对,转换至转换域,生成转换域图像对,再由光流计算网络对转换域图像对的光流值进行估算,间接得到真实域图像对的光流值,虽然模型的输入是真实域图像对,但是最终光流计算网络是基于转换域的图像对进行训练的,因此在计算过程中不会存在基于一个域上的数据所训练得到的模型,在另一个域的数据上测试的结果较差的问题,使得所生成的光流估算模型在对真实域的图像进行光流计算时,能生成准确的光流估算值,提高了光流估算模型的泛化能力。此外,在训练光流估算模型过程中,采用仿真域图像对、转换域图像对以及仿真域图像对的光流值对光流计算网络做有监督训练,同时仅利用真实域图像对来对光流计算网络做无监督训练,从而达到利用仿真域学习的知识来提升模型在真实域的精度,而整个训练过程无需采用真实域图像对的光流真值数据;由于在实际操作过程中仿真域图像对、仿真域图像对的光流真值以及真实域图像对都是比较容易获取的,因此能够降低模型训练时所耗费的人力成本。
附图说明
图1是本发明一实施例提供的一种光流估算模型的构建方法的流程示意图。
图2是本发明一实施例提供的一种光流估算模型的结构示意图。
图3是本发明一实施例提供的一种光流估算模型的构建装置的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
如图1所示,本发明一实施例提供了一种光流估算模型的构建方法,至少包括如下步骤:
步骤S101:获取仿真域训练集以及真实域训练集;其中,所述仿真域训练集中的每一仿真训练样本包括:相邻帧的仿真域图像对以及仿真域图像对的光流真值;所述真实域训练集中的每一真实训练样本包括相邻帧的真实域图像对。
步骤S102:将所述仿真域训练集以及真实域训练集输入至初始神经网络模型中进行迭代训练,直至达到预设训练次数或所述初始神经网络模型的总损失函数值达到预设值,得到光流估算模型;其中,所述初始神经网络模型包括:生成对抗网络以及光流计算网络;在对所述初始神经网络模型进行迭代训练时,以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,以使所述生成对抗网络将仿真域图像对以及真实域图像对转换至同一数据域,生成仿真域图像对所对应的第一转换域图像对,以及真实域图像对所对应的第二转换域图像对;以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练;各重组图像对根据所述仿真域图像对以及第一转换域图像对生成。
对于步骤S101、在本发明中在构建光流估算模型时需要两个训练集数据,一个是基于仿真所获取的仿真域训练集,一个基于真实场景进行采集的真实域训练集,其中仿真域训练集的一个样本包括相邻帧的图像对(即上述仿真域图像对)和两帧图像之间的光流真值(即上述仿真域图像对的光流真值),真实域训练集的样本只包括相邻帧的图像对(即上述真实域图像对)。每一训练集中包含多个训练样本。示意性的,在训练阶段,获取仿真域图像对S1和S2,以及S1到S2之间的光流真值F1-2和S2到S1之间的光流真值F2-1,获取真实域图像对T1和T2。
对于步骤S102、在一个优选的实施例中,所述生成对抗网络包括:生成网络以及判别网络;所述以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,具体包括:以所述仿真域图像对以及真实域图像对为输入并根据生成网络损失函数以及判别网络损失函数对所述生成对抗网络进行训练;
其中,所述生成网络损失函数为:
所述判别网络损失函数为:
G为生成网络,D为判别网络,S~p(S)表示来自仿真域图像对的仿真域图像,T~p(T) 表示来自真实域图像对的真实域图像,D(G(S))表示判别网络D对于生成网络G所编码的仿真域图像S的特征的分类分数,D(G(T))代表判别网络D对生成网络G所编码的真实域图像T的特征的分类分数,E为期望,c为判别网络D判定生成网络G编码的真实域图像T的特征和仿真域图像S的特征属于同一转换域的目标值,a为真实域图像T的特征所对应的判别网络输出目标值,b为仿真域图像S的特征所对应的判别网络输出目标值。
在一个优选的实施例中,所述以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,具体包括:对所述仿真域图像对以及第一转换域图像对中的图像进行图像对重组,生成若干重组图像对;根据所述仿真域图像对的光流真值确定各重组图像对的光流真值;以各重组图像对以及各重组图像对的光流真值为输入,以各重组图像对的光流估算值为输出,并根据有监督训练损失函数对光流计算网络进行有监督训练;其中,所述有监督训练损失函数为:L1=|F’-F|;F’为重组图像对的光流估算值,F为重组图像对的光流真值。
在一个优选的实施例中,所述以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练,具体包括:以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,并根据无监督训练损失函数对光流计算网络进行无监督训练;
L为无监督训练损失函数,α和β为预设的平衡参数,ρ为预设的惩罚函数,T1*和T2*为第二转换域图像对中的两帧相邻图像,(x, y)为图像中像素点的坐标,(u, v)为像素点的光流估算值,∇为预设梯度算子。
示意性的,如图2所示,在步骤S101中获取了获取仿真域图像对S1和S2,以及S1到S2之间的光流真值F1-2和S2到S1之间的光流真值F2-1,获取真实域图像对T1和T2后,通过生成对抗网络的生成网络G进行编码及解码,在域转换的过程中,本发明通过一个编码器将仿真域图像对S1和S2,以及真实域图像对T1和T2,编码到一个共同的转换域空间,即通过编码器将不同域的图像编码到共同的特征空间,然后再将编码的特征通过解码器解码为转换域风格的图像S1*和S2*(即上述第一转换域图像对),T1*和T2*(即上述第二转换域图像对),从而实现将仿真域图像对以及真实域图像对转换至同一数据域。同时,在训练过程中通过判别网络D进行域转换判断,具体的,在训练的过程中,将仿真域图像对S1和S2在转换域的特征和真实域图像对T1和T2在转换域的特征,同时输入判别网络D进行对抗训练,即通过判别网络D来判断生成网络G是否将不同域的图像转换到相同的特征空间中。判别网络D和生成网络G共同构成对抗生成网络,进行对抗训练。
对于对抗训练,在本发明中根据最小二乘对抗生成网络,采用交叉熵损失函数训练对抗生成网络存在梯度消失的问题,因此训练生成网络G的损失函数采用最小二乘函数:
其中,c为判别网络D判定生成网络G编码的真实域图像T的特征和仿真域图像S的特征属于同一转换域的目标值。
训练判别网络D目的是尽可能区分仿真域样本和真实域样本的编码特征,因此损失函数的最小化形式如下:
其中,a为真实域图像T的特征所对应的判别网络输出目标值,b为仿真域图像S的特征所对应的判别网络输出目标值。通过最小化这一损失函数,使判别网络D能够清晰地区分来自仿真域和真实域的不同数据类别。
整个对抗训练过程采用训练一步生成网络G,然后训练一步判别网络D,然后再训练一步G的循环训练模式。
由生成对抗网络所生成转换域风格的图像S1*和S2*(即上述第一转换域图像对),T1*和T2*(即上述第二转换域图像对),原始的仿真域图像对S1和S2,以及仿真域图像对S1和S2的光流真值F1-2、F2-1 流一起作为输入对光流计算模块(即上述光流计算网络)进行训练,在对图像进行域转换的过程中,保持图像的原始结构对最终的任务结果至关重要。在光流计算的过程中,对图像结构的保持要求更为苛刻,因为光流的计算精度要精确到像素级甚至亚像素级,来自同域的相邻两张图像在域转换的过程中结构如果发生错位,直接影响后续的光流计算精度。为此本发明在训练光流计算网络时提出交叉一致性有监督训练的方法,能够训练过程中对图像结构保持加以约束。对于仿真域图像对S1和S2,经过生成网络G转换为转换域风格的图像S1*和S2*,其中S1到S2之间的光流真值为F1-2,S2到S1之间的光流真值为F2-1。因此可以得到,S1*到S2之间的光流真值为F1-2,S2*到S1之间的光流真值为F2-1,S1*到S2*之间的光流真值为F1-2,这些有监督的训练真值限制了在域转换过程中,S1*到S2*需要保持对光流计算至关重要的图像结构不发生改变。其中,光流的有监督训练损失函数采用L1=|F’-F|损失函数,其中F’表示预测光流值,F表示真实光流值。如上所述,本发明将仿真域图像对S1和S2和第一转换域图像对S1*和S2*进行图像对重组,得到S1和S2、S1*和S2*、S1*和S2、S1和S2*一共四对重组图像对,每组图像对对应两个图像对光流真值,采用S1至S2、S1至S2*、S1*至S2,S1*至S2*、S2至S1、S2至S1*、S2*至S1,S2*至S1*之间共八组有监督训练损失函数来达到增加交叉一致性约束的目的。
在真实域的光流计算需要克服的不利因素包括光照变化、阴影等,因此可以假设经过域转换,在转换域,图像将会保留光流计算所需要的结构而去除上述不利因素。因此,发明在真实域图像对的转换图像对T1*和T2*之间增加无监督训练损失函数,对光流计算网络进行无监督训练,无监督训练损失函数如下:
L为无监督训练损失函数,α和β为预设的平衡参数,ρ为预设的惩罚函数,T1*和T2*为第二转换域图像对中的两帧相邻图像,(x, y)为图像中像素点的坐标,(u, v)为像素点的光流估算值,∇为预设梯度算子。
本发明所述的总损失函数即为上述生成网络损失函数、判别网络损失函数、有监督训练损失函数以及无监督损失函数之和,通过上述方法对整个初始神经模型进行训练,使得总损失函数的函数值收敛至预设值,或训练次数达到预设次数,即可获得光流估算模型。
通过本发明上述构建方法所构建的光流估算模型,在训练光流估算模型过程中,采用仿真域图像对、转换域图像对以及仿真域图像对的光流值对光流计算网络做有监督训练,同时仅利用真实域图像对来对光流计算网络做无监督训练,从而达到利用仿真域学习的知识来提升模型在真实域的精度,而整个训练过程无需采用真实域图像对的光流真值数据;由于在实际操作过程中仿真域图像对、仿真域图像对的光流真值以及真实域图像对都是比较容易获取的,因此能够降低模型训练时所耗费的人力成本。此外,当使用本发明所提供的光流估算模型进行光流值估算时,光流估算模型的生成对抗网络能够将待估算的真实域图像对,转换至转换域,生成转换域图像对,再由光流计算网络对转换域图像对的光流值进行估算,间接得到真实域图像对的光流值,虽然模型的输入是真实域图像对,但是最终光流计算网络是基于转换域的图像对进行训练的,因此在计算过程中不会存在基于一个域上的数据所训练得到的模型,在另一个域的数据上测试的结果较差的问题,使得所生成的光流估算模型在对真实域的图像进行光流计算时,能生成准确的光流估算值,提高了光流估算模型的泛化能力。
如图3所示,在上述实施例的基础上,本发明对应提供了一种种光流估算模型的构建装置,包括:数据获取模块以及模型训练模块;其中,所述模型训练模块包括第一训练模块和第二训练模块;
所述数据获取模块,用于获取仿真域训练集以及真实域训练集;其中,所述仿真域训练集中的每一仿真训练样本包括:相邻帧的仿真域图像对以及仿真域图像对的光流真值;所述真实域训练集中的每一真实训练样本包括相邻帧的真实域图像对;
所述模型训练模块,用于将所述仿真域训练集以及真实域训练集输入至初始神经网络模型中进行迭代训练,直至达到预设训练次数或所述初始神经网络模型的总损失函数值达到预设值,得到光流估算模型;其中,所述初始神经网络模型包括:生成对抗网络以及光流计算网络;
在对所述初始神经网络模型进行迭代训练时,所述第一训练模块以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,以使所述生成对抗网络将仿真域图像对以及真实域图像对转换至同一数据域,生成仿真域图像对所对应的第一转换域图像对,以及真实域图像对所对应的第二转换域图像对;
所述第二训练模块以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练;各重组图像对根据所述仿真域图像对以及第一转换域图像对生成。
在一个优选的实施例中,所述生成对抗网络包括:生成网络以及判别网络;第一训练模块,以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,具体包括:第一训练模块以所述仿真域图像对以及真实域图像对为输入并根据生成网络损失函数以及判别网络损失函数对所述生成对抗网络进行训练;
其中,所述生成网络损失函数为:
所述判别网络损失函数为:
G为生成网络,D为判别网络,S~p(S)表示来自仿真域图像对的仿真域图像,T~p(T) 表示来自真实域图像对的真实域图像,D(G(S))表示判别网络D对于生成网络G所编码的仿真域图像S的特征的分类分数,D(G(T))代表判别网络D对生成网络G所编码的真实域图像T的特征的分类分数,E为期望,c为判别网络D判定生成网络G编码的真实域图像T的特征和仿真域图像S的特征属于同一转换域的目标值,a为真实域图像T的特征所对应的判别网络输出目标值,b为仿真域图像S的特征所对应的判别网络输出目标值。
在一个优选的实施例中,第二训练模块以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,具体包括:第二训练模块对所述仿真域图像对以及第一转换域图像对中的图像进行图像对重组,生成若干重组图像对;根据所述仿真域图像对的光流真值确定各重组图像对的光流真值;以各重组图像对以及各重组图像对的光流真值为输入,以各重组图像对的光流估算值为输出,并根据有监督训练损失函数对光流计算网络进行有监督训练;其中,所述有监督训练损失函数为:L1=|F’-F|;F’为重组图像对的光流估算值,F为重组图像对的光流真值。
在一个优选的实施例中,第二训练模块以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练,具体包括:
以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,并根据无监督训练损失函数对光流计算网络进行无监督训练;
L为无监督训练损失函数,α和β为预设的平衡参数,ρ为预设的惩罚函数,T1*和T2*为第二转换域图像对中的两帧相邻图像,(x, y)为图像中像素点的坐标,(u, v)为像素点的光流估算值,∇为预设梯度算子。
在上述实施例的基础上,本发明一实施例提供了一种光流估算方法,所述光流估算方法包括:获取待估算真实域图像对,并将所述待估算真实域图像对输入通过上述任意一项所述的光流估算模型的构建方法所构建的光流估算模型中,以使所述光流估算模型输出所述待估算真实域图像对的光流估算值。
当待估算真实域图像对输入至光流估算模型后,光流估算模型先通过生成对抗网络中的生成网络将待估算真实域图像对转换为待估算转换域图像对,然后将待估算转换域图像对输入至光流计算网络中,由光流计算网络计算出待估算转换域图像对的光流值,继而将待估算转换域图像对的光流值作为待估算真实域图像对的光流值。
需说明的是,以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。另外,本发明提供的装置实施例附图中,模块之间的连接关系表示它们之间具有通信连接,具体可以实现为一条或多条通信总线或信号线。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
以上所述是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也视为本发明的保护范围。
Claims (9)
1.一种光流估算模型的构建方法,其特征在于,包括:
获取仿真域训练集以及真实域训练集;其中,所述仿真域训练集中的每一仿真训练样本包括:相邻帧的仿真域图像对以及仿真域图像对的光流真值;所述真实域训练集中的每一真实训练样本包括相邻帧的真实域图像对;
将所述仿真域训练集以及真实域训练集输入至初始神经网络模型中进行迭代训练,直至达到预设训练次数或所述初始神经网络模型的总损失函数值达到预设值,得到光流估算模型;
其中,所述初始神经网络模型包括:生成对抗网络以及光流计算网络;
在对所述初始神经网络模型进行迭代训练时,以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,以使所述生成对抗网络将仿真域图像对以及真实域图像对转换至同一数据域,生成仿真域图像对所对应的第一转换域图像对,以及真实域图像对所对应的第二转换域图像对;
以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练。
2.如权利要求1所述的光流估算模型的构建方法,其特征在于,所述生成对抗网络包括:生成网络以及判别网络;所述以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,具体包括:
以所述仿真域图像对以及真实域图像对为输入并根据生成网络损失函数以及判别网络损失函数对所述生成对抗网络进行训练;
其中,所述生成网络损失函数为:
所述判别网络损失函数为:
G为生成网络,D为判别网络,S~p(S)表示来自仿真域图像对的仿真域图像,T~p(T)表示来自真实域图像对的真实域图像,D(G(S))表示判别网络D对于生成网络G所编码的仿真域图像S的特征的分类分数,D(G(T))代表判别网络D对生成网络G所编码的真实域图像T的特征的分类分数,E为期望,c为判别网络D判定生成网络G编码的真实域图像T的特征和仿真域图像S的特征属于同一转换域的目标值,a为真实域图像T的特征所对应的判别网络输出目标值,b为仿真域图像S的特征所对应的判别网络输出目标值。
3.如权利要求1所述的光流估算模型的构建方法,其特征在于,所述以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,具体包括:
对所述仿真域图像对以及第一转换域图像对中的图像进行图像对重组,生成若干重组图像对;
根据所述仿真域图像对的光流真值确定各重组图像对的光流真值;
以各重组图像对以及各重组图像对的光流真值为输入,以各重组图像对的光流估算值为输出,并根据有监督训练损失函数对光流计算网络进行有监督训练;
其中,所述有监督训练损失函数为:L1=|F’-F|;F’为重组图像对的光流估算值,F为重组图像对的光流真值。
5.一种光流估算模型的构建装置,其特征在于,包括数据获取模块以及模型训练模块;其中,所述模型训练模块包括第一训练模块和第二训练模块;
所述数据获取模块,用于获取仿真域训练集以及真实域训练集;其中,所述仿真域训练集中的每一仿真训练样本包括:相邻帧的仿真域图像对以及仿真域图像对的光流真值;所述真实域训练集中的每一真实训练样本包括相邻帧的真实域图像对;
所述模型训练模块,用于将所述仿真域训练集以及真实域训练集输入至初始神经网络模型中进行迭代训练,直至达到预设训练次数或所述初始神经网络模型的总损失函数值达到预设值,得到光流估算模型;其中,所述初始神经网络模型包括:生成对抗网络以及光流计算网络;
在对所述初始神经网络模型进行迭代训练时,所述第一训练模块以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,以使所述生成对抗网络将仿真域图像对以及真实域图像对转换至同一数据域,生成仿真域图像对所对应的第一转换域图像对,以及真实域图像对所对应的第二转换域图像对;
所述第二训练模块以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,以第二转换域图像对为输入,以第二转换域图像对的光流估算值为输出,对所述光流计算网络进行无监督训练;各重组图像对根据所述仿真域图像对以及第一转换域图像对生成。
6.如权利要求5所述的光流估算模型的构建装置,其特征在于,所述生成对抗网络包括:生成网络以及判别网络;第一训练模块,以仿真域图像对以及真实域图像对为输入对所述生成对抗网络进行对抗训练,具体包括:
以所述仿真域图像对以及真实域图像对为输入并根据生成网络损失函数以及判别网络损失函数对所述生成对抗网络进行训练;
其中,所述生成网络损失函数为:
所述判别网络损失函数为:
G为生成网络,D为判别网络,S~p(S)表示来自仿真域图像对的仿真域图像,T~p(T)表示来自真实域图像对的真实域图像,D(G(S))表示判别网络D对于生成网络G所编码的仿真域图像S的特征的分类分数,D(G(T))代表判别网络D对生成网络G所编码的真实域图像T的特征的分类分数,E为期望,c为判别网络D判定生成网络G编码的真实域图像T的特征和仿真域图像S的特征属于同一转换域的目标值,a为真实域图像T的特征所对应的判别网络输出目标值,b为仿真域图像S的特征所对应的判别网络输出目标值。
7.如权利要求5所述的光流估算模型的构建装置,其特征在于,第二训练模块以所述仿真域图像对、第一转换域图像对以及仿真域图像对的光流真值为输入,以根据所述仿真域图像对及第一转换域图像对生成的各重组图像对的光流估算值为输出,对光流计算网络进行有监督训练,具体包括:
第二训练模块对所述仿真域图像对以及第一转换域图像对中的图像进行图像对重组,生成若干重组图像对;
根据所述仿真域图像对的光流真值确定各重组图像对的光流真值;
以各重组图像对以及各重组图像对的光流真值为输入,以各重组图像对的光流估算值为输出,并根据有监督训练损失函数对光流计算网络进行有监督训练;
其中,所述有监督训练损失函数为:L1=|F’-F|;F’为重组图像对的光流估算值,F为重组图像对的光流真值。
9.一种光流估算方法,其特征在于,包括:获取待估算真实域图像对,并将所述待估算真实域图像对输入通过权利要求1-4任意一项所述的光流估算模型的构建方法所构建的光流估算模型中,以使所述光流估算模型输出所述待估算真实域图像对的光流估算值。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111635874.0A CN114005075B (zh) | 2021-12-30 | 2021-12-30 | 一种光流估算模型的构建方法、装置及光流估算方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111635874.0A CN114005075B (zh) | 2021-12-30 | 2021-12-30 | 一种光流估算模型的构建方法、装置及光流估算方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114005075A CN114005075A (zh) | 2022-02-01 |
CN114005075B true CN114005075B (zh) | 2022-04-05 |
Family
ID=79932143
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111635874.0A Active CN114005075B (zh) | 2021-12-30 | 2021-12-30 | 一种光流估算模型的构建方法、装置及光流估算方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114005075B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116563169B (zh) * | 2023-07-07 | 2023-09-05 | 成都理工大学 | 基于混合监督学习的探地雷达图像异常区域增强方法 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106022229A (zh) * | 2016-05-11 | 2016-10-12 | 北京航空航天大学 | 基于视频运动信息特征提取与自适应增强算法的误差反向传播网络的异常行为识别方法 |
CN111369595A (zh) * | 2019-10-15 | 2020-07-03 | 西北工业大学 | 基于自适应相关卷积神经网络的光流计算方法 |
CN112396074A (zh) * | 2019-08-15 | 2021-02-23 | 广州虎牙科技有限公司 | 基于单目图像的模型训练方法、装置及数据处理设备 |
CN113920581A (zh) * | 2021-09-29 | 2022-01-11 | 江西理工大学 | 一种时空卷积注意力网络用于视频中动作识别的方法 |
CN113947732A (zh) * | 2021-12-21 | 2022-01-18 | 北京航空航天大学杭州创新研究院 | 基于强化学习图像亮度调节的空中视角人群计数方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018128741A1 (en) * | 2017-01-06 | 2018-07-12 | Board Of Regents, The University Of Texas System | Segmenting generic foreground objects in images and videos |
-
2021
- 2021-12-30 CN CN202111635874.0A patent/CN114005075B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106022229A (zh) * | 2016-05-11 | 2016-10-12 | 北京航空航天大学 | 基于视频运动信息特征提取与自适应增强算法的误差反向传播网络的异常行为识别方法 |
CN112396074A (zh) * | 2019-08-15 | 2021-02-23 | 广州虎牙科技有限公司 | 基于单目图像的模型训练方法、装置及数据处理设备 |
CN111369595A (zh) * | 2019-10-15 | 2020-07-03 | 西北工业大学 | 基于自适应相关卷积神经网络的光流计算方法 |
CN113920581A (zh) * | 2021-09-29 | 2022-01-11 | 江西理工大学 | 一种时空卷积注意力网络用于视频中动作识别的方法 |
CN113947732A (zh) * | 2021-12-21 | 2022-01-18 | 北京航空航天大学杭州创新研究院 | 基于强化学习图像亮度调节的空中视角人群计数方法 |
Also Published As
Publication number | Publication date |
---|---|
CN114005075A (zh) | 2022-02-01 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113658051B (zh) | 一种基于循环生成对抗网络的图像去雾方法及*** | |
Golts et al. | Unsupervised single image dehazing using dark channel prior loss | |
CN109271933B (zh) | 基于视频流进行三维人体姿态估计的方法 | |
WO2020037965A1 (zh) | 一种用于视频预测的多运动流深度卷积网络模型方法 | |
CN110533044B (zh) | 一种基于gan的域适应图像语义分割方法 | |
Aliakbarian et al. | Flag: Flow-based 3d avatar generation from sparse observations | |
CN109636721B (zh) | 基于对抗学习和注意力机制的视频超分辨率方法 | |
CN110689599A (zh) | 基于非局部增强的生成对抗网络的3d视觉显著性预测方法 | |
CN112258625B (zh) | 基于注意力机制的单幅图像到三维点云模型重建方法及*** | |
CN110599468A (zh) | 无参考视频质量评估方法及装置 | |
CN114972085B (zh) | 一种基于对比学习的细粒度噪声估计方法和*** | |
CN112884758B (zh) | 一种基于风格迁移方法的缺陷绝缘子样本生成方法及*** | |
CN114005075B (zh) | 一种光流估算模型的构建方法、装置及光流估算方法 | |
CN111898482A (zh) | 基于渐进型生成对抗网络的人脸预测方法 | |
CN113283577A (zh) | 一种基于元学***行数据生成方法 | |
CN116229106A (zh) | 一种基于双u结构的视频显著性预测方法 | |
CN117291232A (zh) | 一种基于扩散模型的图像生成方法与装置 | |
CN115272423B (zh) | 一种训练光流估计模型的方法、装置和可读存储介质 | |
CN111275751A (zh) | 一种无监督绝对尺度计算方法及*** | |
CN104320659B (zh) | 背景建模方法、装置及设备 | |
CN115630612A (zh) | 一种基于vae与wgan的软件度量缺陷数据增广方法 | |
CN113077383B (zh) | 一种模型训练方法及模型训练装置 | |
Li et al. | Context convolution dehazing network with channel attention | |
CN113315995A (zh) | 提高视频质量的方法、装置、可读存储介质及电子设备 | |
Sivaanpu et al. | Underwater Image Enhancement Using Dual Convolutional Neural Network with Skip Connections |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant | ||
CP03 | Change of name, title or address |
Address after: Floor 25, Block A, Zhongzhou Binhai Commercial Center Phase II, No. 9285, Binhe Boulevard, Shangsha Community, Shatou Street, Futian District, Shenzhen, Guangdong 518000 Patentee after: Shenzhen Youjia Innovation Technology Co.,Ltd. Address before: 518051 401, building 1, Shenzhen new generation industrial park, No. 136, Zhongkang Road, Meidu community, Meilin street, Futian District, Shenzhen, Guangdong Province Patentee before: SHENZHEN MINIEYE INNOVATION TECHNOLOGY Co.,Ltd. |
|
CP03 | Change of name, title or address |