CN116310545A - 一种基于深度层次化最优传输的跨域舌头图像分类方法 - Google Patents
一种基于深度层次化最优传输的跨域舌头图像分类方法 Download PDFInfo
- Publication number
- CN116310545A CN116310545A CN202310252527.2A CN202310252527A CN116310545A CN 116310545 A CN116310545 A CN 116310545A CN 202310252527 A CN202310252527 A CN 202310252527A CN 116310545 A CN116310545 A CN 116310545A
- Authority
- CN
- China
- Prior art keywords
- domain
- tongue
- tongue image
- representing
- image
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 230000005540 biological transmission Effects 0.000 title claims abstract description 110
- 238000000034 method Methods 0.000 title claims abstract description 38
- 238000009826 distribution Methods 0.000 claims abstract description 52
- 239000011159 matrix material Substances 0.000 claims description 34
- 238000012549 training Methods 0.000 claims description 18
- 238000013528 artificial neural network Methods 0.000 claims description 12
- 238000011176 pooling Methods 0.000 claims description 10
- 238000000605 extraction Methods 0.000 claims description 9
- 239000000203 mixture Substances 0.000 claims description 8
- 238000013507 mapping Methods 0.000 claims description 6
- 238000000638 solvent extraction Methods 0.000 claims description 4
- 238000010586 diagram Methods 0.000 claims description 3
- 238000010801 machine learning Methods 0.000 abstract description 5
- 230000008901 benefit Effects 0.000 abstract description 3
- 238000013145 classification model Methods 0.000 abstract 2
- 230000006870 function Effects 0.000 description 25
- 230000008569 process Effects 0.000 description 6
- 238000004364 calculation method Methods 0.000 description 3
- 230000008859 change Effects 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000006978 adaptation Effects 0.000 description 2
- 238000006243 chemical reaction Methods 0.000 description 2
- 230000000295 complement effect Effects 0.000 description 2
- 239000003814 drug Substances 0.000 description 2
- 230000007613 environmental effect Effects 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 238000005457 optimization Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 239000003795 chemical substances by application Substances 0.000 description 1
- 239000003086 colorant Substances 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 238000003745 diagnosis Methods 0.000 description 1
- 230000004069 differentiation Effects 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000005286 illumination Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/0002—Inspection of images, e.g. flaw detection
- G06T7/0012—Biomedical image inspection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20081—Training; Learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20084—Artificial neural networks [ANN]
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Multimedia (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Databases & Information Systems (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
- Radiology & Medical Imaging (AREA)
- Quality & Reliability (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于深度层次化最优传输的跨域舌头图像分类方法,包括采集舌头图像;构建机器学习分类模型,所述机器学习分类模型利用深度层次化最优传输模型实现不同领域舌头图像特征的对齐,所述深度层次化最优传输模型包括两层网络结构,其中,第一层网络结构用于实现不同领域间的最优传输,第二层网络结构用于实现不同样本之间的最优传输;根据深度层次化最优传输模型对舌头图像进行分类,输出舌头图像类别。本发明的有益之处在于对齐不同分布的舌头图像的同时,增强舌头图像的分类能力。
Description
技术领域
本发明涉及辅助中医诊疗的舌头图像分类技术领域,更具体的说是涉及一种基于深度层次化最优传输的跨域舌头图像分类方法。
背景技术
现有基于机器学习的舌头图像分类方法大多是基于监督学习的。监督学习方法通常假设训练集与测试集服从同一个分布,因此在训练集上训练出来的模型在测试集同样可以表现的很好。然而,在现实应用中这样的假设是难以成立的,解决这类问题的主要思路是假设两个数据分布通过一个非线性映射到一个领域间共同的隐空间,以减少分布之间的漂移,让两个分布在经过非线性映射转换之后更加相似。这个非线性映射的过程称之为领域自适应。
现有基于机器学习的舌头图像分类就面临这样的问题。首先是不同人的舌头图像具有差异性,包括舌头图像的边缘纹理、颜色等。其次,不同医院的舌头图像采集设备可能是不同的,采集的舌头图像数据还受采集环境的影响,例如角度、光照等。另外,不同医院的地理位置不同,采集舌头图像的个体也有地域差异。这些因素导致了不同医院的舌头图像数据会出现比较大的分布差异。若每个医院的舌头图像数据为领域数据,则不同领域的数据分布有差异,这些差异将导致在已采集的数据集上训练出来的模型,部署到其他医院时会出现严重的性能退化。同时,由于医疗数据的标注代价较高,在目标域没有标签的数据时,更加困难,即只有源域的标签是可用的。
为了解决这个问题,不同领域分布需要对齐。目前解决不同领域分布对齐的主流方法主要包括两个步骤:首先,通过非线性转换让两个分布更加靠近;然后,在变换后的分布上利用源域的标签信息为目标域训练一个分类器,使得模型能够泛化到目标域,这也是域间知识迁移的过程。可见,如何找到这个非线性转换对于解决领域自适应问题是关键。近年来,最优传输方法在领域自适应问题上显示出来较大的优势,它可以直接在边缘分布上衡量两个分布的距离,而无需标签信息。在视觉领域,这个基于最优化传输的距离称之为EMD(Earth Mover’s Distance)距离。一方面,它可以直接在离散的经验分布(域)上计算两个分布之间的距离。另一方面,当两个域的支撑集不明显重叠时也能提供有意义的梯度,因此不容易导致训练失败。此外,它具有良好的可解释性,能够显式地建模领域之间的耦合。
通过优化两个领域的特征分布之间的最小传输代价,让源域与目标域的分布都能以最小代价变换到一个共同的隐空间里,而在这个隐空间里的特征都具有领域不变性。这个过程称之为领域对齐。在这样的特征上训练得到的分类器,具备迁移到目标域的能力。而领域对齐并不是最终目标,分类才是。然而,在最优传输中,代价矩阵的计算方式通常是计算两两样本的欧氏距离(L2距离)。在这样的度量空间里,当两个样本的支撑集不重叠,是无法提供有意义的距离的。这在视觉问题里表现为,当两个样本的背景过于杂乱或具有较大的类内外观变化,可能会使同一类别的图像在这样的度量空间中相去甚远。换句话说,这时的L2距离受背景变化的影响较大。这虽然可以通过神经网络的建模缓解,但需要充足的训练数据,而这在实际场景中(特别是医学场景)是很难做到的,需要强调其目标区域的局部特征,同时,L2距离作为一种全局表示,破坏了图像特征的空间结构,丢失了局部信息。而局部信息是可以提供区分性且可迁移的信息的,这对分类任务来说非常重要。尤其是在中医体征图像数据集,其采集过程的标准化程度较差,背景或者环境因素变化较大。基于上述原因,导致在现有的领域对齐的过程中,获得领域不变性特征的同时,也会模糊特征的类别区分性,即出现过度对齐。
因此,如何避免舌头图像分类过程中产生过度对齐的现象,提高舌头分类图像的准确度是本领域技术人员亟需解决的技术问题。
发明内容
为了解决这些问题,本发明公开了一种基于深度层次化最优传输的跨域舌头图像分类方法,从而让机器学习模型能够学习到对环境噪音更加鲁棒的不变性特征,使得不同分布的舌头图像数据具有自适应的能力,提高分类的准确率。
为了实现上述目的,本发明采用如下技术方案:
一种基于深度层次化最优传输的跨域舌头图像分类方法,包括:
S1、采集多个不同领域的舌头图像样本作为训练集;
S2、利用深度神经网络对训练集中源域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图;
利用深度神经网络对训练集中目标域舌头图像样本进行特征提取,获取对应的目标域舌头图像样本特征构成的目标域图像特征图;
S3、对源域图像特征图中的源域舌头图像样本特征进行分块,获取源域舌头图像样本对应的源域图像特征集;
对目标域图像特征图中的目标域舌头图像样本特征进行分块,获取目标域舌头图像样本对应的目标域图像特征集;
S4、计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,作为源域舌头图像样本和目标域图像样本之间的样本最优化传输距离;
S5、以所述样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
S6、根据步骤S2提取的源域舌头图像样本特征值计算softmax交叉熵损失,作为损失函数的一部分;将域间最优化传输距离作为损失函数里的另一部分,构建分类损失函数,利用所述分类损失函数训练分类器;
S7、利用训练好的分类器对待验证的舌头图像样本进行分类。
所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,包括以下方法:
联合使用EMD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
其中,g表示深度神经网络的特征提取器;
γin表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的最优传输方案,Cin表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的代价矩阵;<γin,Cin>F表示γin和Cin的Frobenius点乘;表示第i个源域舌头图像样本提取的源域图像特征图沿着空间维度进行全局平均池化结果,/> 表示第j个目标域舌头图像样本提取的目标域图像特征图沿着空间维度进行全局平均池化结果,/>ch表示通道数。
优选的,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,还包括以下方法:
联合使用SWD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
P表示置换矩阵,表示所有置换矩阵的集合,Ui表示将源域舌头图像样本/>对应的特征转换到一个共同的高维隐层空间,Uj表示将目标域舌头图像样本/>对应的特征转换到一个共同的高维隐层空间,T为矩阵转置符号。
优选的,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,还包括以下方法:
采用SWD距离、L2距离和类条件分布差异的交叉熵作为两个特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
λswd表示SWD距离的平衡系数,λl2表示L2距离的平衡系数,λcond表示类条件分布差异的交叉熵的平衡系数,表示源域舌头图像样本/>的标签,/>表示是类条件分布的差异,M表示投影矩阵总个数,Zi表示将源域舌头图像样本/>映射到隐层空间Z而得到的特征矩阵,Zj表示将目标域样本/>映射到隐层空间Z而得到的特征矩阵,/>表示将源域舌头图像样本/>或目标域舌头图像样本/>投影到隐层空间Z形成对应的第m个投影矩阵。
优选的,所述步骤S5中,包括采用mini-batch策略,具体包括每次随机从每个源域舌头图像样本和目标域舌头图像样本中分别抽取大小为n的mini-batch,计算这两个mini-batch之间的最优传输,作为领域之间的最优化传输距离:其中,/>OTn表示域间最优化传输距离的矩阵,/>表示源域舌头图像样本分布组成的矩阵,/>表示目标域舌头样本分布组成的矩阵,/>表示/>和/>的联合分布,γn表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的最优传输方案组成的n*n矩阵,Cn表示任意一个源域舌头图像样本和任意一个目标域图像样本之间的样本最优化传输距离组成的n*n矩阵,<γn,Cn>F表示γn和Cn的Frobenius点乘。
优选的,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输。
优选的,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输损失再增加源域的分类交叉熵损失函数。
经由上述的技术方案可知,与现有技术相比,本发明公开提供了基于深度层次化最优传输的跨域舌头图像分类方法,具有以下有益效果:
通过深度层次化的最优传输对齐不同分布的舌头图像的同时,增强分类能力;在第一层的最优传输中采用了非平衡的最优传输进行领域对齐,放松了最优传输的边缘约束,从而能够对小批量的训练提供一个更鲁棒的优化性能;对于第二层的最优传输,使用SWD代替EMD距离,增强样本的区分性特征,SWD是EMD距离的近似,但其计算代价更低。提高了舌头图像分类的准确率。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1为本发明提供的深度层次化图像分类方法流程示意图;
图2为本发明提供的深度层次化最优传输模型结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例公开了一种基于深度层次化最优传输的跨域舌头图像分类方法,包括:
一种基于深度层次化最优传输的跨域舌头图像分类方法,包括:
S1、采集多个不同领域的舌头图像样本作为训练集;
S2、利用深度神经网络分别对训练集中源域舌头图像样本及目标域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图及目标域舌头图像样本特征构成的目标域图像特征图;
即:利用深度神经网络对训练集中源域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图;
利用深度神经网络对训练集中目标域舌头图像样本进行特征提取,获取对应的目标域舌头图像样本特征构成的目标域图像特征图;
S3、对源域图像特征图进行分块,获取源域舌头图像样本对应的源域图像特征集;对目标域图像特征图进行分块,获取目标域舌头图像样本对应的目标域图像特征集;
S4、计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,作为源域舌头图像样本和目标域图像样本之间的样本最优化传输距离;
本发明中样本最优化传输距离中两个样本分别来自源域舌头图像样本和目标域图像样本,从而实现在领域对齐的同时,引入局部信息,保持其特征的区分性;
S5、以所述样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
S6、根据步骤S2提取的源域舌头图像样本特征值计算softmax交叉熵损失,作为损失函数的一部分;将域间最优化传输距离作为损失函数里的另一部分,构建分类损失函数,利用所述分类损失函数训练分类器;
S7、利用训练好的分类器对待验证的舌头图像样本进行分类。
假定和/>是分别来自源域分布μs和目标域分布μt的两个样本,Π(μs,μt)是源域分布μs和目标域分布μt的联合概率分布。假设两个域的样本数分别为Ns和Nt,C≥0且是μs,μt之间的代价矩阵,其中每个元素由/>计算而来是两个样本之间的代价成本,用来衡量两个样本之间的差异,c是衡量两个样本距离的代价函数,通常采用L2距离。以/>作为成本度量,可以计算领域之间的最优化传输距离。/>的计算有以下方法。
在一个实施例中,步骤S4中源域舌头图像样本和目标域图像样本之间的样本最优化传输距离可以联合使用EMD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数:
首先设计一个基深度神经网络的特征提取器g:x→z,它可以将输入映射到一个隐层空间Z。同时,设计一个分类器f:z→y,它可以将隐层空间映射到标签空间。图像x通过特征提取器g可以得到那么,源域图像特征集和目标域图像特征集之间的代价函数可以变成:
其中,g表示深度神经网络的特征提取器;
γin表示两个样本之间关于图像特征集的最优传输方案,Cin表示两个样本之间关于图像特征集的代价矩阵,γin∈RHiWi×HjWj;Cin∈RHiWi×HjWj;<γin,Cin>F表示γin和Cin的Frobenius点乘;表示源域图像特征图沿着空间维度进行全局平均池化结果, 表示目标域图像特征图沿着空间维度进行全局平均池化结果,ch表示通道数。
特征提取器可以采用卷积神经网络的卷积层实现。
为了进一步优化上述技术方案,另外的一个实施例中,步骤S4计算源域舌头图像样本和目标域图像样本之间的样本最优化传输距离,采用如下方式:
联合使用SWD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
在一个大小为n的mini-batch内,源域舌头图像样本和目标域图像样本之间的两两样本之间要计算一次公式(1),而这样的计算代价还是太高,公式(2)采用SWD距离(Sliced Wasserstein Distance)来代替EMD距离近似计算源域图像特征集和目标域图像特征集之间的代价函数:
为了进一步优化上述技术方案,另外的一个实施例中,步骤S4另一种计算源域舌头图像样本和目标域图像样本之间的样本最优化传输距离,采用如下方式:
公式(2)中的匹配问题是NP难问题,时间复杂度太高,为了更高效地解决这个问题,因此公式(2)将用以下的算法来近似:
其中,包含M个投影,这里无需显式求出置换矩阵P,而只需对投影后的每个区域排序,然后计算它们对应的距离。这样通过公式(3)来近似计算源域舌头图像样本和目标域图像样本之间的样本最优化传输距离,每次计算/>的计算复杂度可以由O(N3)(如果用线性规划求解)或者O(JN2)(如果用Sinkhorn scaling算法求解,J是Sinkhorn scaling算法迭代次数)降低到O(MN),其中N是问题复杂度。由于这里是基于深度特征提取器g提取的特征图,得到的特征图大小不大,能分割的区域有限,因此N会比较小,计算代价不会增加太多。
这里的边缘分布的差异由内层SWD距离与L2距离衡量,类条件分布由样本标签之间的熵衡量。由于目标域样本无标签信息,这里使用模型预测的标签/>作为代理。通过联合对齐边缘分布与类条件分布,能够引入更多的类别信息,提高类别之间的区分性。因此公式(3)将转换成:
其中,λswd表示SWD距离的平衡系数,λl2表示L2距离的平衡系数,λcond表示类条件分布差异的交叉熵的平衡系数,M表示投影矩阵总个数,Zi表示将源域舌头图像样本映射到隐层空间Z而得到的特征矩阵,Zj表示将目标域舌头图像样本/>映射到隐层空间Z而得到的特征矩阵,/>表示将源域舌头图像样本/>或目标域舌头图像样本/>投影到隐层空间Z形成对应的第m个投影矩阵,/>表示源域舌头图像样本/>的标签,/>表示是类条件分布的差异。
公式(4)中包含三项内容,第一项是源域舌头图像样本和目标域图像样本/>的SWD距离,具体来说,我们将/>的特征Zi和/>的特征Zj投影到多个Zi和Zj共享的空间,在每个这样的共享空间里可以分别对Zi和Zj的特征子集进行排序后,直接计算其欧式距离,得到Zi和Zj在该共享空间的距离。最后,对在多个共享空间里计算出来的Zi和Zj的距离取平均值,作为Zi和Zi的SWD距离。第二项的/>函数比g函数增加了全局平均池化操作,等价于将Zi或Zj进行全局平均池化,因此第二项内容是对Zi对应的全局平均池化结果和Zj对应的全局平均池化结果计算L2距离。第三项是计算/>对应的标签与对/>进行分类预测的标签之间的交叉熵,表示类条件分布的差异。最后,公式(4)使用了三个超参数作为这三项的平衡系数对其加权并求和。
值得注意的是,公式(4)的三项都是互补的:SWD距离与L2距离形成局部与全局的互补信息;SWD距离与L2距离计算的都是边缘分布的差异;衡量的是类条件分布的差异。通过这三项的互相补充,可以在进行领域对齐的同时,保持其类别的区分性,从而提高分类性能。
本实施案例中,特征提取器的网络结构采用ResNet-50,特征提取器g提取的特征图保持其空间结构,用以计算SWD距离。特征提取器g先在ImageNet上预训练,分类器f从头开始训练,因此分类器的学***衡系数将设置为λswd=0.001,λl2=0.001和λcond=1.0。
本实施案例的优化器采用了SGD优化器,动量设置为0.9。学习率的变化策略的设置进行线性变化。小批量的大小为65,并迭代10000次。
步骤S5,用于根据样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
最优化传输(Optimal Transport,OT)是一种衡量两个概率分布之间的距离的方法,它可以利用分布的几何结构。通常来说,OT会搜索两个分布μs和μt可能的耦合方式γ∈Π(μs,μt),找到传输代价最小的耦合方案:
其中和/>是分别来自源域分布μs和目标域分布μt的任意两个样本;/>是两个样本之间的代价成本,用来衡量两个样本之间的差异;Π(μs,μt)是边缘分布μs和μt的联合概率分布。而在经验分布上的离散形式的OT可以定义为:
其中,μs,μt为正向量,<·,·>F是Frobenius点乘。假设两个域的样本数分别为Ns和Nt,C≥0且是μs,μt之间的代价矩阵,其中每个元素由计算而来。c是衡量两个样本距离的代价函数,通常采用L2距离。通过优化公式(6),即最小化传输代价,可以得到最优传输流/>公式(6)可以通过线性规划求解而得。
在一个实施例中,步骤S5,计算源域和目标域之间的域间最优化传输距离,包括采用mini-batch策略,具体包括每次随机从每个域抽取大小为n的mini-batch,计算这两个mini-batch之间的最优传输,作为领域之间的代理最优传输:
其中,考虑到公式(6)的计算代价,本发明每次随机从每个域抽取大小为n的mini-batch,计算这两个mini-batch之间的最优传输,作为领域之间的代理最优传输,即将公式(6)转变成公式(7)Cn中的每个元素由等式(6)计算而来,从而构成层次化的最优传输模型。
作为一种改进的技术方案,在另外一个实施例中,步骤S5,另一种计算源域和目标域之间的域间最优化传输距离,采用基于mini-batch的非平衡的最优传输方式代替公式(7):
其中,Dφ是Csiszar Divergences,KL是Kullback-Leibler散度,和/>是γn的边缘分布。这里的τ是边际惩罚系数(Marginal Penalization),ε,是正则化系数(Regularization Coefficient),ε≥0,具体可以设置为ε=0.01和τ=0.5。
这样在每个mini-batch里,每个领域是样本的集合。等式(8)作为第一层源域和目标域之间的最优传输,而等式(8)中的代价矩阵Cn中的每个元素是由对应源域舌头图像样本和目标域图像两个样本计算(4)而得的,因此,第二层是源域舌头图像样本和目标域图像样本之间的最优传输,这里每个样本是图像特征图各空间区域的集合。这样两层最优传输构成了深度层次化最优传输模型(Deep Hierarchical Optimal Transport,DeepHOT),如图2所示。这样,对于给定的mini-batch,DeepHOT的目标问题是:
作为一种改进的技术方案,在另外一个实施例中,步骤S5,另一种计算源域和目标域之间的域间最优化传输距离,采用非平衡的最优传输损失再增加源域的分类交叉熵损失函数:
目的是避免源域上的“灾难性遗忘”(Catastrophic Forgetting)问题,最终的优化目标包括了源域的分类交叉熵损失L。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
对所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本发明。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。
Claims (7)
1.一种基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述方法包括以下步骤:
S1、采集多个不同领域的舌头图像样本作为训练集;
S2、利用深度神经网络对训练集中源域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图;
利用深度神经网络对训练集中目标域舌头图像样本进行特征提取,获取对应的目标域舌头图像样本特征构成的目标域图像特征图;
S3、对源域图像特征图中的源域舌头图像样本特征进行分块,获取源域舌头图像样本对应的源域图像特征集;
对目标域图像特征图中的目标域舌头图像样本特征进行分块,获取目标域舌头图像样本对应的目标域图像特征集;
S4、计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,作为源域舌头图像样本和目标域图像样本之间的样本最优化传输距离;
S5、以所述样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
S6、根据步骤S2提取的源域舌头图像样本特征值计算softmax交叉熵损失,作为损失函数的一部分;将域间最优化传输距离作为损失函数里的另一部分,构建分类损失函数,利用所述分类损失函数训练分类器;
S7、利用训练好的分类器对待验证的舌头图像样本进行分类。
2.根据权利要求1所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,包括以下方法:
联合使用EMD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
其中,g表示深度神经网络的特征提取器;
4.根据权利要求3所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,还包括以下方法:
采用SWD距离、L2距离和类条件分布差异的交叉熵作为两个特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
5.根据权利要求4所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离,包括采用mini-batch策略,具体包括,
6.根据权利要求5所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输。
7.根据权利要求1所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输损失再增加源域的分类交叉熵损失函数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310252527.2A CN116310545A (zh) | 2023-03-16 | 2023-03-16 | 一种基于深度层次化最优传输的跨域舌头图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310252527.2A CN116310545A (zh) | 2023-03-16 | 2023-03-16 | 一种基于深度层次化最优传输的跨域舌头图像分类方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116310545A true CN116310545A (zh) | 2023-06-23 |
Family
ID=86781093
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310252527.2A Pending CN116310545A (zh) | 2023-03-16 | 2023-03-16 | 一种基于深度层次化最优传输的跨域舌头图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116310545A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116566743A (zh) * | 2023-07-05 | 2023-08-08 | 北京理工大学 | 一种账户对齐方法、设备及存储介质 |
-
2023
- 2023-03-16 CN CN202310252527.2A patent/CN116310545A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116566743A (zh) * | 2023-07-05 | 2023-08-08 | 北京理工大学 | 一种账户对齐方法、设备及存储介质 |
CN116566743B (zh) * | 2023-07-05 | 2023-09-08 | 北京理工大学 | 一种账户对齐方法、设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Lu et al. | Class-agnostic counting | |
Krebs et al. | Unsupervised probabilistic deformation modeling for robust diffeomorphic registration | |
CN107704877B (zh) | 一种基于深度学习的图像隐私感知方法 | |
CN109886121B (zh) | 一种遮挡鲁棒的人脸关键点定位方法 | |
Xu et al. | Ask, attend and answer: Exploring question-guided spatial attention for visual question answering | |
Boyda et al. | Deploying a quantum annealing processor to detect tree cover in aerial imagery of California | |
Papa et al. | Efficient supervised optimum-path forest classification for large datasets | |
US11494616B2 (en) | Decoupling category-wise independence and relevance with self-attention for multi-label image classification | |
Wang | Online Learning Behavior Analysis Based on Image Emotion Recognition. | |
Gao et al. | Small sample classification of hyperspectral image using model-agnostic meta-learning algorithm and convolutional neural network | |
Gong et al. | A coupling translation network for change detection in heterogeneous images | |
Liu et al. | Generative self-training for cross-domain unsupervised tagged-to-cine mri synthesis | |
Shu et al. | LVC-Net: Medical image segmentation with noisy label based on local visual cues | |
CN113298129B (zh) | 基于超像素和图卷积网络的极化sar图像分类方法 | |
CN111126464A (zh) | 一种基于无监督域对抗领域适应的图像分类方法 | |
Ning et al. | Conditional generative adversarial networks based on the principle of homologycontinuity for face aging | |
Alshehri | A content-based image retrieval method using neural network-based prediction technique | |
CN116310545A (zh) | 一种基于深度层次化最优传输的跨域舌头图像分类方法 | |
Franchi et al. | Latent discriminant deterministic uncertainty | |
Chen et al. | A robust automatic clustering algorithm for probability density functions with application to categorizing color images | |
Huang et al. | An evidential combination method with multi-color spaces for remote sensing image scene classification | |
Fu et al. | Personality trait detection based on ASM localization and deep learning | |
CN114612658A (zh) | 基于双重类别级对抗网络的图像语义分割方法 | |
Yun et al. | Land cover classification based on tolerant rough set | |
CN112819154B (zh) | 一种应用于图学习领域的预训练模型的生成方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |