CN116310545A - 一种基于深度层次化最优传输的跨域舌头图像分类方法 - Google Patents

一种基于深度层次化最优传输的跨域舌头图像分类方法 Download PDF

Info

Publication number
CN116310545A
CN116310545A CN202310252527.2A CN202310252527A CN116310545A CN 116310545 A CN116310545 A CN 116310545A CN 202310252527 A CN202310252527 A CN 202310252527A CN 116310545 A CN116310545 A CN 116310545A
Authority
CN
China
Prior art keywords
domain
tongue
tongue image
representing
image
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310252527.2A
Other languages
English (en)
Inventor
文贵华
徐映雪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
South China University of Technology SCUT
Original Assignee
South China University of Technology SCUT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by South China University of Technology SCUT filed Critical South China University of Technology SCUT
Priority to CN202310252527.2A priority Critical patent/CN116310545A/zh
Publication of CN116310545A publication Critical patent/CN116310545A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/0002Inspection of images, e.g. flaw detection
    • G06T7/0012Biomedical image inspection
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20081Training; Learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20084Artificial neural networks [ANN]

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Software Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Health & Medical Sciences (AREA)
  • Multimedia (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Databases & Information Systems (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
  • Radiology & Medical Imaging (AREA)
  • Quality & Reliability (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于深度层次化最优传输的跨域舌头图像分类方法,包括采集舌头图像;构建机器学习分类模型,所述机器学习分类模型利用深度层次化最优传输模型实现不同领域舌头图像特征的对齐,所述深度层次化最优传输模型包括两层网络结构,其中,第一层网络结构用于实现不同领域间的最优传输,第二层网络结构用于实现不同样本之间的最优传输;根据深度层次化最优传输模型对舌头图像进行分类,输出舌头图像类别。本发明的有益之处在于对齐不同分布的舌头图像的同时,增强舌头图像的分类能力。

Description

一种基于深度层次化最优传输的跨域舌头图像分类方法
技术领域
本发明涉及辅助中医诊疗的舌头图像分类技术领域,更具体的说是涉及一种基于深度层次化最优传输的跨域舌头图像分类方法。
背景技术
现有基于机器学习的舌头图像分类方法大多是基于监督学习的。监督学习方法通常假设训练集与测试集服从同一个分布,因此在训练集上训练出来的模型在测试集同样可以表现的很好。然而,在现实应用中这样的假设是难以成立的,解决这类问题的主要思路是假设两个数据分布通过一个非线性映射到一个领域间共同的隐空间,以减少分布之间的漂移,让两个分布在经过非线性映射转换之后更加相似。这个非线性映射的过程称之为领域自适应。
现有基于机器学习的舌头图像分类就面临这样的问题。首先是不同人的舌头图像具有差异性,包括舌头图像的边缘纹理、颜色等。其次,不同医院的舌头图像采集设备可能是不同的,采集的舌头图像数据还受采集环境的影响,例如角度、光照等。另外,不同医院的地理位置不同,采集舌头图像的个体也有地域差异。这些因素导致了不同医院的舌头图像数据会出现比较大的分布差异。若每个医院的舌头图像数据为领域数据,则不同领域的数据分布有差异,这些差异将导致在已采集的数据集上训练出来的模型,部署到其他医院时会出现严重的性能退化。同时,由于医疗数据的标注代价较高,在目标域没有标签的数据时,更加困难,即只有源域的标签是可用的。
为了解决这个问题,不同领域分布需要对齐。目前解决不同领域分布对齐的主流方法主要包括两个步骤:首先,通过非线性转换让两个分布更加靠近;然后,在变换后的分布上利用源域的标签信息为目标域训练一个分类器,使得模型能够泛化到目标域,这也是域间知识迁移的过程。可见,如何找到这个非线性转换对于解决领域自适应问题是关键。近年来,最优传输方法在领域自适应问题上显示出来较大的优势,它可以直接在边缘分布上衡量两个分布的距离,而无需标签信息。在视觉领域,这个基于最优化传输的距离称之为EMD(Earth Mover’s Distance)距离。一方面,它可以直接在离散的经验分布(域)上计算两个分布之间的距离。另一方面,当两个域的支撑集不明显重叠时也能提供有意义的梯度,因此不容易导致训练失败。此外,它具有良好的可解释性,能够显式地建模领域之间的耦合。
通过优化两个领域的特征分布之间的最小传输代价,让源域与目标域的分布都能以最小代价变换到一个共同的隐空间里,而在这个隐空间里的特征都具有领域不变性。这个过程称之为领域对齐。在这样的特征上训练得到的分类器,具备迁移到目标域的能力。而领域对齐并不是最终目标,分类才是。然而,在最优传输中,代价矩阵的计算方式通常是计算两两样本的欧氏距离(L2距离)。在这样的度量空间里,当两个样本的支撑集不重叠,是无法提供有意义的距离的。这在视觉问题里表现为,当两个样本的背景过于杂乱或具有较大的类内外观变化,可能会使同一类别的图像在这样的度量空间中相去甚远。换句话说,这时的L2距离受背景变化的影响较大。这虽然可以通过神经网络的建模缓解,但需要充足的训练数据,而这在实际场景中(特别是医学场景)是很难做到的,需要强调其目标区域的局部特征,同时,L2距离作为一种全局表示,破坏了图像特征的空间结构,丢失了局部信息。而局部信息是可以提供区分性且可迁移的信息的,这对分类任务来说非常重要。尤其是在中医体征图像数据集,其采集过程的标准化程度较差,背景或者环境因素变化较大。基于上述原因,导致在现有的领域对齐的过程中,获得领域不变性特征的同时,也会模糊特征的类别区分性,即出现过度对齐。
因此,如何避免舌头图像分类过程中产生过度对齐的现象,提高舌头分类图像的准确度是本领域技术人员亟需解决的技术问题。
发明内容
为了解决这些问题,本发明公开了一种基于深度层次化最优传输的跨域舌头图像分类方法,从而让机器学习模型能够学习到对环境噪音更加鲁棒的不变性特征,使得不同分布的舌头图像数据具有自适应的能力,提高分类的准确率。
为了实现上述目的,本发明采用如下技术方案:
一种基于深度层次化最优传输的跨域舌头图像分类方法,包括:
S1、采集多个不同领域的舌头图像样本作为训练集;
S2、利用深度神经网络对训练集中源域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图;
利用深度神经网络对训练集中目标域舌头图像样本进行特征提取,获取对应的目标域舌头图像样本特征构成的目标域图像特征图;
S3、对源域图像特征图中的源域舌头图像样本特征进行分块,获取源域舌头图像样本对应的源域图像特征集;
对目标域图像特征图中的目标域舌头图像样本特征进行分块,获取目标域舌头图像样本对应的目标域图像特征集;
S4、计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,作为源域舌头图像样本和目标域图像样本之间的样本最优化传输距离;
S5、以所述样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
S6、根据步骤S2提取的源域舌头图像样本特征值计算softmax交叉熵损失,作为损失函数的一部分;将域间最优化传输距离作为损失函数里的另一部分,构建分类损失函数,利用所述分类损失函数训练分类器;
S7、利用训练好的分类器对待验证的舌头图像样本进行分类。
所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,包括以下方法:
联合使用EMD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
Figure BDA0004128320780000041
其中,g表示深度神经网络的特征提取器;
Figure BDA0004128320780000042
表示第i个源域舌头图像样本提取的源域图像特征图,/>
Figure BDA0004128320780000043
表示第i个源域舌头图像样本,/>
Figure BDA0004128320780000044
Hi及Wi分别表示第i个源域舌头图像样本提取的源域图像特征图的宽和高;
Figure BDA0004128320780000045
表示第j个目标域舌头图像样本提取的目标域图像特征图,/>
Figure BDA0004128320780000046
表示第j个目标域舌头图像样本,/>
Figure BDA0004128320780000047
Hj及Wj分别表示第j个目标域舌头图像样本提取的目标域图像特征图的宽和高;
Figure BDA0004128320780000048
表示元域图像特征图和目标域图像特征图的联合图像特征图;
γin表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的最优传输方案,Cin表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的代价矩阵;<γin,Cin>F表示γin和Cin的Frobenius点乘;
Figure BDA0004128320780000049
表示第i个源域舌头图像样本提取的源域图像特征图沿着空间维度进行全局平均池化结果,/>
Figure BDA00041283207800000410
Figure BDA00041283207800000411
表示第j个目标域舌头图像样本提取的目标域图像特征图沿着空间维度进行全局平均池化结果,/>
Figure BDA00041283207800000412
ch表示通道数。
优选的,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,还包括以下方法:
联合使用SWD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
Figure BDA0004128320780000051
P表示置换矩阵,
Figure BDA0004128320780000052
表示所有置换矩阵的集合,Ui表示将源域舌头图像样本/>
Figure BDA0004128320780000053
对应的特征转换到一个共同的高维隐层空间,Uj表示将目标域舌头图像样本/>
Figure BDA0004128320780000054
对应的特征转换到一个共同的高维隐层空间,T为矩阵转置符号。
优选的,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,还包括以下方法:
采用SWD距离、L2距离和类条件分布差异的交叉熵作为两个特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
Figure BDA0004128320780000055
λswd表示SWD距离的平衡系数,λl2表示L2距离的平衡系数,λcond表示类条件分布差异的交叉熵的平衡系数,
Figure BDA0004128320780000056
表示源域舌头图像样本/>
Figure BDA0004128320780000057
的标签,/>
Figure BDA0004128320780000058
表示是类条件分布的差异,M表示投影矩阵总个数,Zi表示将源域舌头图像样本/>
Figure BDA0004128320780000059
映射到隐层空间Z而得到的特征矩阵,Zj表示将目标域样本/>
Figure BDA0004128320780000061
映射到隐层空间Z而得到的特征矩阵,/>
Figure BDA0004128320780000062
表示将源域舌头图像样本/>
Figure BDA0004128320780000063
或目标域舌头图像样本/>
Figure BDA0004128320780000064
投影到隐层空间Z形成对应的第m个投影矩阵。
优选的,所述步骤S5中,包括采用mini-batch策略,具体包括每次随机从每个源域舌头图像样本和目标域舌头图像样本中分别抽取大小为n的mini-batch,计算这两个mini-batch之间的最优传输,作为领域之间的最优化传输距离:
Figure BDA0004128320780000065
其中,/>
Figure BDA0004128320780000066
OTn表示域间最优化传输距离的矩阵,/>
Figure BDA0004128320780000067
表示源域舌头图像样本分布组成的矩阵,/>
Figure BDA0004128320780000068
表示目标域舌头样本分布组成的矩阵,/>
Figure BDA0004128320780000069
表示/>
Figure BDA00041283207800000610
和/>
Figure BDA00041283207800000611
的联合分布,γn表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的最优传输方案组成的n*n矩阵,Cn表示任意一个源域舌头图像样本和任意一个目标域图像样本之间的样本最优化传输距离组成的n*n矩阵,<γn,Cn>F表示γn和Cn的Frobenius点乘。
优选的,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输。
优选的,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输损失再增加源域的分类交叉熵损失函数。
经由上述的技术方案可知,与现有技术相比,本发明公开提供了基于深度层次化最优传输的跨域舌头图像分类方法,具有以下有益效果:
通过深度层次化的最优传输对齐不同分布的舌头图像的同时,增强分类能力;在第一层的最优传输中采用了非平衡的最优传输进行领域对齐,放松了最优传输的边缘约束,从而能够对小批量的训练提供一个更鲁棒的优化性能;对于第二层的最优传输,使用SWD代替EMD距离,增强样本的区分性特征,SWD是EMD距离的近似,但其计算代价更低。提高了舌头图像分类的准确率。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1为本发明提供的深度层次化图像分类方法流程示意图;
图2为本发明提供的深度层次化最优传输模型结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例公开了一种基于深度层次化最优传输的跨域舌头图像分类方法,包括:
一种基于深度层次化最优传输的跨域舌头图像分类方法,包括:
S1、采集多个不同领域的舌头图像样本作为训练集;
S2、利用深度神经网络分别对训练集中源域舌头图像样本及目标域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图及目标域舌头图像样本特征构成的目标域图像特征图;
即:利用深度神经网络对训练集中源域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图;
利用深度神经网络对训练集中目标域舌头图像样本进行特征提取,获取对应的目标域舌头图像样本特征构成的目标域图像特征图;
S3、对源域图像特征图进行分块,获取源域舌头图像样本对应的源域图像特征集;对目标域图像特征图进行分块,获取目标域舌头图像样本对应的目标域图像特征集;
S4、计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,作为源域舌头图像样本和目标域图像样本之间的样本最优化传输距离;
本发明中样本最优化传输距离中两个样本分别来自源域舌头图像样本和目标域图像样本,从而实现在领域对齐的同时,引入局部信息,保持其特征的区分性;
S5、以所述样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
S6、根据步骤S2提取的源域舌头图像样本特征值计算softmax交叉熵损失,作为损失函数的一部分;将域间最优化传输距离作为损失函数里的另一部分,构建分类损失函数,利用所述分类损失函数训练分类器;
S7、利用训练好的分类器对待验证的舌头图像样本进行分类。
假定
Figure BDA0004128320780000081
和/>
Figure BDA0004128320780000082
是分别来自源域分布μs和目标域分布μt的两个样本,Π(μst)是源域分布μs和目标域分布μt的联合概率分布。假设两个域的样本数分别为Ns和Nt,C≥0且
Figure BDA0004128320780000083
是μst之间的代价矩阵,其中每个元素由/>
Figure BDA0004128320780000084
计算而来是两个样本之间的代价成本,用来衡量两个样本之间的差异,c是衡量两个样本距离的代价函数,通常采用L2距离。以/>
Figure BDA0004128320780000085
作为成本度量,可以计算领域之间的最优化传输距离。/>
Figure BDA0004128320780000086
的计算有以下方法。
在一个实施例中,步骤S4中源域舌头图像样本和目标域图像样本之间的样本最优化传输距离可以联合使用EMD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数:
首先设计一个基深度神经网络的特征提取器g:x→z,它可以将输入映射到一个隐层空间Z。同时,设计一个分类器f:z→y,它可以将隐层空间映射到标签空间。图像x通过特征提取器g可以得到
Figure BDA0004128320780000091
那么,源域图像特征集和目标域图像特征集之间的代价函数可以变成:
Figure BDA0004128320780000092
其中,g表示深度神经网络的特征提取器;
Figure BDA0004128320780000093
表示源域图像特征图,/>
Figure BDA0004128320780000094
表示源域舌头图像样本,/>
Figure BDA0004128320780000095
Hi及Wi分别表示源域图像特征图的宽和高;
Figure BDA0004128320780000096
表示目标域图像特征图,/>
Figure BDA0004128320780000097
表示目标域舌头图像样本,/>
Figure BDA0004128320780000098
Hj及Wj分别表示目标域图像特征图的宽和高;
γin表示两个样本之间关于图像特征集的最优传输方案,Cin表示两个样本之间关于图像特征集的代价矩阵,γin∈RHiWi×HjWj;Cin∈RHiWi×HjWj;<γin,Cin>F表示γin和Cin的Frobenius点乘;
Figure BDA0004128320780000099
表示源域图像特征图沿着空间维度进行全局平均池化结果,
Figure BDA00041283207800000910
Figure BDA00041283207800000911
表示目标域图像特征图沿着空间维度进行全局平均池化结果,
Figure BDA00041283207800000912
ch表示通道数。
特征提取器可以采用卷积神经网络的卷积层实现。
为了进一步优化上述技术方案,另外的一个实施例中,步骤S4计算源域舌头图像样本和目标域图像样本之间的样本最优化传输距离,采用如下方式:
联合使用SWD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
Figure BDA00041283207800000913
P表示置换矩阵,
Figure BDA0004128320780000101
表示所有置换矩阵的集合,Ui表示将样本/>
Figure BDA0004128320780000102
对应的特征转换到一个共同的高维隐层空间,Uj表示将样本/>
Figure BDA0004128320780000103
对应的特征转换到一个共同的高维隐层空间,T为矩阵转置符号。
在一个大小为n的mini-batch内,源域舌头图像样本和目标域图像样本之间的两两样本之间要计算一次公式(1),而这样的计算代价还是太高,公式(2)采用SWD距离(Sliced Wasserstein Distance)来代替EMD距离近似计算源域图像特征集和目标域图像特征集之间的代价函数:
公式(2)引入了置换矩阵P来匹配图像的不同区域,P包含了两张图像各区域之间的关联,而
Figure BDA0004128320780000104
表示所有置换矩阵的集合。/>
Figure BDA0004128320780000105
表示将样本/>
Figure BDA0004128320780000106
对应的特征转换到一个d维的共同隐层空间里。
为了进一步优化上述技术方案,另外的一个实施例中,步骤S4另一种计算源域舌头图像样本和目标域图像样本之间的样本最优化传输距离,采用如下方式:
公式(2)中的匹配问题是NP难问题,时间复杂度太高,为了更高效地解决这个问题,因此公式(2)将用以下的算法来近似:
Figure BDA0004128320780000107
其中,
Figure BDA0004128320780000108
包含M个投影,这里无需显式求出置换矩阵P,而只需对投影后的每个区域排序,然后计算它们对应的距离。这样通过公式(3)来近似计算源域舌头图像样本和目标域图像样本之间的样本最优化传输距离,每次计算/>
Figure BDA0004128320780000109
的计算复杂度可以由O(N3)(如果用线性规划求解)或者O(JN2)(如果用Sinkhorn scaling算法求解,J是Sinkhorn scaling算法迭代次数)降低到O(MN),其中N是问题复杂度。由于这里是基于深度特征提取器g提取的特征图,得到的特征图大小不大,能分割的区域有限,因此N会比较小,计算代价不会增加太多。
这里的边缘分布的差异由内层SWD距离与L2距离衡量,类条件分布由样本标签之间的熵衡量。由于目标域样本
Figure BDA0004128320780000111
无标签信息,这里使用模型预测的标签/>
Figure BDA0004128320780000112
作为代理。通过联合对齐边缘分布与类条件分布,能够引入更多的类别信息,提高类别之间的区分性。因此公式(3)将转换成:
Figure BDA0004128320780000113
其中,λswd表示SWD距离的平衡系数,λl2表示L2距离的平衡系数,λcond表示类条件分布差异的交叉熵的平衡系数,M表示投影矩阵总个数,Zi表示将源域舌头图像样本
Figure BDA0004128320780000114
映射到隐层空间Z而得到的特征矩阵,Zj表示将目标域舌头图像样本/>
Figure BDA0004128320780000115
映射到隐层空间Z而得到的特征矩阵,/>
Figure BDA0004128320780000116
表示将源域舌头图像样本/>
Figure BDA0004128320780000117
或目标域舌头图像样本/>
Figure BDA0004128320780000118
投影到隐层空间Z形成对应的第m个投影矩阵,/>
Figure BDA0004128320780000119
表示源域舌头图像样本/>
Figure BDA00041283207800001110
的标签,/>
Figure BDA00041283207800001111
表示是类条件分布的差异。
公式(4)中包含三项内容,第一项是源域舌头图像样本
Figure BDA00041283207800001112
和目标域图像样本/>
Figure BDA00041283207800001113
的SWD距离,具体来说,我们将/>
Figure BDA00041283207800001114
的特征Zi和/>
Figure BDA00041283207800001115
的特征Zj投影到多个Zi和Zj共享的空间,在每个这样的共享空间里可以分别对Zi和Zj的特征子集进行排序后,直接计算其欧式距离,得到Zi和Zj在该共享空间的距离。最后,对在多个共享空间里计算出来的Zi和Zj的距离取平均值,作为Zi和Zi的SWD距离。第二项的/>
Figure BDA00041283207800001116
函数比g函数增加了全局平均池化操作,等价于将Zi或Zj进行全局平均池化,因此第二项内容是对Zi对应的全局平均池化结果和Zj对应的全局平均池化结果计算L2距离。第三项是计算/>
Figure BDA0004128320780000121
对应的标签与对/>
Figure BDA0004128320780000122
进行分类预测的标签之间的交叉熵,表示类条件分布的差异。最后,公式(4)使用了三个超参数作为这三项的平衡系数对其加权并求和。
值得注意的是,公式(4)的三项都是互补的:SWD距离与L2距离形成局部与全局的互补信息;SWD距离与L2距离计算的都是边缘分布的差异;
Figure BDA0004128320780000123
衡量的是类条件分布的差异。通过这三项的互相补充,可以在进行领域对齐的同时,保持其类别的区分性,从而提高分类性能。
本实施案例中,特征提取器的网络结构采用ResNet-50,特征提取器g提取的特征图保持其空间结构,用以计算SWD距离。特征提取器g先在ImageNet上预训练,分类器f从头开始训练,因此分类器的学***衡系数将设置为λswd=0.001,λl2=0.001和λcond=1.0。
本实施案例的优化器采用了SGD优化器,动量设置为0.9。学习率的变化策略的设置进行线性变化。小批量的大小为65,并迭代10000次。
步骤S5,用于根据样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
最优化传输(Optimal Transport,OT)是一种衡量两个概率分布之间的距离的方法,它可以利用分布的几何结构。通常来说,OT会搜索两个分布μs和μt可能的耦合方式γ∈Π(μst),找到传输代价最小的耦合方案:
Figure BDA0004128320780000124
其中
Figure BDA0004128320780000125
和/>
Figure BDA0004128320780000126
是分别来自源域分布μs和目标域分布μt的任意两个样本;/>
Figure BDA0004128320780000127
是两个样本之间的代价成本,用来衡量两个样本之间的差异;Π(μst)是边缘分布μs和μt的联合概率分布。而在经验分布上的离散形式的OT可以定义为:
Figure BDA0004128320780000131
其中,μst为正向量,<·,·>F是Frobenius点乘。假设两个域的样本数分别为Ns和Nt,C≥0且
Figure BDA0004128320780000132
是μst之间的代价矩阵,其中每个元素由
Figure BDA0004128320780000133
计算而来。c是衡量两个样本距离的代价函数,通常采用L2距离。通过优化公式(6),即最小化传输代价,可以得到最优传输流/>
Figure BDA0004128320780000134
公式(6)可以通过线性规划求解而得。
在一个实施例中,步骤S5,计算源域和目标域之间的域间最优化传输距离,包括采用mini-batch策略,具体包括每次随机从每个域抽取大小为n的mini-batch,计算这两个mini-batch之间的最优传输,作为领域之间的代理最优传输:
Figure BDA0004128320780000135
其中,
Figure BDA0004128320780000136
考虑到公式(6)的计算代价,本发明每次随机从每个域抽取大小为n的mini-batch,计算这两个mini-batch之间的最优传输,作为领域之间的代理最优传输,即将公式(6)转变成公式(7)Cn中的每个元素由等式(6)计算而来,从而构成层次化的最优传输模型。
作为一种改进的技术方案,在另外一个实施例中,步骤S5,另一种计算源域和目标域之间的域间最优化传输距离,采用基于mini-batch的非平衡的最优传输方式代替公式(7):
Figure BDA0004128320780000137
其中,Dφ是Csiszar Divergences,KL是Kullback-Leibler散度,
Figure BDA0004128320780000141
和/>
Figure BDA0004128320780000142
是γn的边缘分布。这里的τ是边际惩罚系数(Marginal Penalization),ε,是正则化系数(Regularization Coefficient),ε≥0,具体可以设置为ε=0.01和τ=0.5。
这样在每个mini-batch里,每个领域是样本的集合。等式(8)作为第一层源域和目标域之间的最优传输,而等式(8)中的代价矩阵Cn中的每个元素是由对应源域舌头图像样本和目标域图像两个样本计算(4)而得的,因此,第二层是源域舌头图像样本和目标域图像样本之间的最优传输,这里每个样本是图像特征图各空间区域的集合。这样两层最优传输构成了深度层次化最优传输模型(Deep Hierarchical Optimal Transport,DeepHOT),如图2所示。这样,对于给定的mini-batch,DeepHOT的目标问题是:
Figure BDA0004128320780000143
其中,
Figure BDA0004128320780000144
Figure BDA0004128320780000145
作为一种改进的技术方案,在另外一个实施例中,步骤S5,另一种计算源域和目标域之间的域间最优化传输距离,采用非平衡的最优传输损失再增加源域的分类交叉熵损失函数:
Figure BDA0004128320780000146
目的是避免源域上的“灾难性遗忘”(Catastrophic Forgetting)问题,最终的优化目标包括了源域的分类交叉熵损失L。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
对所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本发明。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。

Claims (7)

1.一种基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述方法包括以下步骤:
S1、采集多个不同领域的舌头图像样本作为训练集;
S2、利用深度神经网络对训练集中源域舌头图像样本进行特征提取,获取对应的源域舌头图像样本特征构成的源域图像特征图;
利用深度神经网络对训练集中目标域舌头图像样本进行特征提取,获取对应的目标域舌头图像样本特征构成的目标域图像特征图;
S3、对源域图像特征图中的源域舌头图像样本特征进行分块,获取源域舌头图像样本对应的源域图像特征集;
对目标域图像特征图中的目标域舌头图像样本特征进行分块,获取目标域舌头图像样本对应的目标域图像特征集;
S4、计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,作为源域舌头图像样本和目标域图像样本之间的样本最优化传输距离;
S5、以所述样本最优化传输距离作为源域和目标域之间的成本度量,计算源域和目标域之间的域间最优化传输距离;
S6、根据步骤S2提取的源域舌头图像样本特征值计算softmax交叉熵损失,作为损失函数的一部分;将域间最优化传输距离作为损失函数里的另一部分,构建分类损失函数,利用所述分类损失函数训练分类器;
S7、利用训练好的分类器对待验证的舌头图像样本进行分类。
2.根据权利要求1所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,包括以下方法:
联合使用EMD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
Figure FDA0004128320770000021
其中,g表示深度神经网络的特征提取器;
Figure FDA0004128320770000022
表示第i个源域舌头图像样本提取的源域图像特征图,/>
Figure FDA0004128320770000023
表示第i个源域舌头图像样本,/>
Figure FDA0004128320770000024
Hi及Wi分别表示第i个源域舌头图像样本提取的源域图像特征图的宽和高;
Figure FDA0004128320770000025
表示第j个目标域舌头图像样本提取的目标域图像特征图,/>
Figure FDA0004128320770000026
表示第j个目标域舌头图像样本,/>
Figure FDA0004128320770000027
Hj及Wj分别表示第j个目标域舌头图像样本提取的目标域图像特征图的宽和高;
Figure FDA0004128320770000028
表示元域图像特征图和目标域图像特征图的联合图像特征图;
γin表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的最优传输方案,Cin表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的代价矩阵;<γin,Cin>F表示γin和Cin的Frobenius点乘;
Figure FDA0004128320770000029
表示第i个源域舌头图像样本提取的源域图像特征图沿着空间维度进行全局平均池化结果,/>
Figure FDA00041283207700000210
Figure FDA00041283207700000211
表示第j个目标域舌头图像样本提取的目标域图像特征图沿着空间维度进行全局平均池化结果,/>
Figure FDA00041283207700000212
ch表示通道数。
3.根据权利要求2所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,还包括以下方法:
联合使用SWD距离和L2距离作为源域图像特征集和目标域图像特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
Figure FDA0004128320770000031
P表示置换矩阵,
Figure FDA0004128320770000032
表示所有置换矩阵的集合,Ui表示将源域舌头图像样本/>
Figure FDA0004128320770000033
对应的特征转换到一个共同的高维隐层空间,Uj表示将目标域舌头图像样本/>
Figure FDA0004128320770000034
对应的特征转换到一个共同的高维隐层空间,T为矩阵转置符号。
4.根据权利要求3所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S4中计算每一个源域舌头图像样本对应的源域图像特征集和目标域舌头图像样本对应的目标域图像特征集之间的最优化传输距离,还包括以下方法:
采用SWD距离、L2距离和类条件分布差异的交叉熵作为两个特征集之间的代价函数,计算两个特征集之间的最优化传输距离,所述代价函数具体包括:
Figure FDA0004128320770000035
λswd表示SWD距离的平衡系数,λl2表示L2距离的平衡系数,λcond表示类条件分布差异的交叉熵的平衡系数,
Figure FDA0004128320770000036
表示源域舌头图像样本/>
Figure FDA0004128320770000037
的标签,/>
Figure FDA0004128320770000038
表示类条件分布的差异,M表示投影矩阵总个数,Zi表示将源域舌头图像样本/>
Figure FDA0004128320770000039
映射到隐层空间Z而得到的特征矩阵,Zj表示将目标域舌头图像样本/>
Figure FDA00041283207700000310
映射到隐层空间Z而得到的特征矩阵,/>
Figure FDA00041283207700000311
表示将源域舌头图像样本/>
Figure FDA00041283207700000312
或目标域舌头图像样本/>
Figure FDA00041283207700000313
投影到隐层空间Z形成对应的第m个投影矩阵。
5.根据权利要求4所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离,包括采用mini-batch策略,具体包括,
每次随机从每个源域舌头图像样本和目标域舌头图像样本中分别抽取大小为n的mini-batch,计算这两个mini-batch之间的最优传输,作为领域之间的最优化传输距离:
Figure FDA0004128320770000041
其中,
Figure FDA0004128320770000042
OTn表示域间最优化传输距离的矩阵,/>
Figure FDA0004128320770000043
表示源域舌头图像样本分布组成的矩阵,/>
Figure FDA0004128320770000044
表示目标域舌头样本分布组成的矩阵,/>
Figure FDA0004128320770000045
表示/>
Figure FDA0004128320770000046
Figure FDA0004128320770000047
的联合分布,γn表示任意一个源域舌头图像样本和任意一个目标域舌头图像样本之间关于对应的图像特征集的最优传输方案组成的n*n矩阵,Cn表示任意一个源域舌头图像样本和任意一个目标域图像样本之间的样本最优化传输距离组成的n*n矩阵,<γn,Cn>F表示γn和Cn的Frobenius点乘。
6.根据权利要求5所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输。
7.根据权利要求1所述的基于深度层次化最优传输的跨域舌头图像分类方法,其特征在于,所述步骤S5中,计算源域和目标域之间的域间最优化传输距离还包括,采用非平衡的最优传输损失再增加源域的分类交叉熵损失函数。
CN202310252527.2A 2023-03-16 2023-03-16 一种基于深度层次化最优传输的跨域舌头图像分类方法 Pending CN116310545A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310252527.2A CN116310545A (zh) 2023-03-16 2023-03-16 一种基于深度层次化最优传输的跨域舌头图像分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310252527.2A CN116310545A (zh) 2023-03-16 2023-03-16 一种基于深度层次化最优传输的跨域舌头图像分类方法

Publications (1)

Publication Number Publication Date
CN116310545A true CN116310545A (zh) 2023-06-23

Family

ID=86781093

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310252527.2A Pending CN116310545A (zh) 2023-03-16 2023-03-16 一种基于深度层次化最优传输的跨域舌头图像分类方法

Country Status (1)

Country Link
CN (1) CN116310545A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116566743A (zh) * 2023-07-05 2023-08-08 北京理工大学 一种账户对齐方法、设备及存储介质

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116566743A (zh) * 2023-07-05 2023-08-08 北京理工大学 一种账户对齐方法、设备及存储介质
CN116566743B (zh) * 2023-07-05 2023-09-08 北京理工大学 一种账户对齐方法、设备及存储介质

Similar Documents

Publication Publication Date Title
Lu et al. Class-agnostic counting
Krebs et al. Unsupervised probabilistic deformation modeling for robust diffeomorphic registration
CN107704877B (zh) 一种基于深度学习的图像隐私感知方法
CN109886121B (zh) 一种遮挡鲁棒的人脸关键点定位方法
Xu et al. Ask, attend and answer: Exploring question-guided spatial attention for visual question answering
Boyda et al. Deploying a quantum annealing processor to detect tree cover in aerial imagery of California
Papa et al. Efficient supervised optimum-path forest classification for large datasets
US11494616B2 (en) Decoupling category-wise independence and relevance with self-attention for multi-label image classification
Wang Online Learning Behavior Analysis Based on Image Emotion Recognition.
Gao et al. Small sample classification of hyperspectral image using model-agnostic meta-learning algorithm and convolutional neural network
Gong et al. A coupling translation network for change detection in heterogeneous images
Liu et al. Generative self-training for cross-domain unsupervised tagged-to-cine mri synthesis
Shu et al. LVC-Net: Medical image segmentation with noisy label based on local visual cues
CN113298129B (zh) 基于超像素和图卷积网络的极化sar图像分类方法
CN111126464A (zh) 一种基于无监督域对抗领域适应的图像分类方法
Ning et al. Conditional generative adversarial networks based on the principle of homologycontinuity for face aging
Alshehri A content-based image retrieval method using neural network-based prediction technique
CN116310545A (zh) 一种基于深度层次化最优传输的跨域舌头图像分类方法
Franchi et al. Latent discriminant deterministic uncertainty
Chen et al. A robust automatic clustering algorithm for probability density functions with application to categorizing color images
Huang et al. An evidential combination method with multi-color spaces for remote sensing image scene classification
Fu et al. Personality trait detection based on ASM localization and deep learning
CN114612658A (zh) 基于双重类别级对抗网络的图像语义分割方法
Yun et al. Land cover classification based on tolerant rough set
CN112819154B (zh) 一种应用于图学习领域的预训练模型的生成方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination