CN111738351A - 模型训练方法、装置、存储介质及电子设备 - Google Patents
模型训练方法、装置、存储介质及电子设备 Download PDFInfo
- Publication number
- CN111738351A CN111738351A CN202010623929.5A CN202010623929A CN111738351A CN 111738351 A CN111738351 A CN 111738351A CN 202010623929 A CN202010623929 A CN 202010623929A CN 111738351 A CN111738351 A CN 111738351A
- Authority
- CN
- China
- Prior art keywords
- distribution
- image
- training
- category
- encoder
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 178
- 238000000034 method Methods 0.000 title claims abstract description 61
- 238000009826 distribution Methods 0.000 claims abstract description 191
- 239000013598 vector Substances 0.000 claims abstract description 128
- 238000005070 sampling Methods 0.000 claims abstract description 32
- 238000013528 artificial neural network Methods 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 10
- 238000004364 calculation method Methods 0.000 claims description 9
- 230000000694 effects Effects 0.000 abstract description 6
- 239000000284 extract Substances 0.000 abstract description 5
- 230000008569 process Effects 0.000 description 12
- 238000004891 communication Methods 0.000 description 8
- 238000010586 diagram Methods 0.000 description 7
- 238000011156 evaluation Methods 0.000 description 6
- 230000006870 function Effects 0.000 description 5
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 238000013519 translation Methods 0.000 description 3
- 230000008859 change Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000011176 pooling Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000000052 comparative effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000007774 longterm Effects 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
- G06F18/232—Non-hierarchical techniques
- G06F18/2321—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Molecular Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本申请涉及图像聚类技术领域,提供一种模型训练方法、装置、存储介质及电子设备。其中,模型训练方法包括:获取训练图像并将其输入至编码器,获得编码器输出的特征向量的分布参数;根据特征向量的分布参数确定训练图像的类别;从训练图像的类别的分布中采样获得采样向量;将采样向量输入至解码器,获得解码器输出的重构图像;将训练图像以及重构图像分别输入至判别器,获得判别器输出的判别结果;重复获取训练图像至获得判别结果的步骤,以训练包括编码器、解码器以及判别器在内的图像聚类模型。该方法通过设置判别器并对模型进行对抗训练,使得编码器能够有效提取图像特征,因此后续利用训练好的编码器执行图像聚类任务能够取得较好的效果。
Description
技术领域
本发明涉及图像聚类技术领域,具体而言,涉及一种模型训练方法、装置、存储介质及电子设备。
背景技术
将对象的集合划分为由类似的对象组成的多个类别的过程被称为聚类。目前的无监督聚类方法主要利用已经提取出对象的特征进行聚类,但对于一些非结构化的数据,例如图像,并不容易提取出较好的特征,导致聚类的效果较差。
发明内容
本申请实施例的目的在于提供一种模型训练方法、装置、存储介质及电子设备,以改善上述技术问题。
为实现上述目的,本申请提供如下技术方案:
第一方面,本申请实施例提供一种模型训练方法,包括:获取训练图像并将所述训练图像输入至编码器,获得所述编码器输出的特征向量的分布参数;根据所述特征向量的分布参数确定所述训练图像的类别;从所述训练图像的类别的分布中采样获得采样向量;将所述采样向量输入至解码器,获得所述解码器输出的重构图像;将所述训练图像以及所述重构图像分别输入至判别器,获得所述判别器输出的判别结果,所述判别结果包括输入图像的真实性以及类别同一性;重复获取训练图像至获得判别结果的步骤,以训练包括所述编码器、所述解码器以及所述判别器在内的图像聚类模型;其中,训练的方式为对抗训练,训练的目标包括根据所述判别器输出的所述判别结果无法区分所述训练图像和所述重构图像的真实性以及类别。
上述方法是一种基于变分自编码器的无监督聚类方法,在该方法中,通过对包括编码器、解码器以及判别器在内的图像聚类模型进行对抗训练,使得在训练结束后,判别器难以区分由解码器输出的重构图像和真实的训练图像,而解码器进行图像重构是基于编码器提取到的特征向量的分布参数进行的,这说明此时编码器能够有效提取图像的特征,因此后续利用训练好的编码器执行图像聚类任务能够取得较好的效果。
另外,在上述方法中,通过设置判别器(例如,可以是一个神经网络)来评估训练图像和重构图像的区分度,代替了采用图像差分的方式(如,求训练图像和重构图像的L2距离)进行损失计算,避免了损失难以收敛、模型训练困难的问题。
在第一方面的一种实现方式中,所述根据所述特征向量的分布参数确定所述训练图像的类别,包括:根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离;将计算得到的距离最小的已有类别确定为所述训练图像的类别;利用所述特征向量的分布参数更新所述已有类别的分布。
在上述实现方式中,聚类的类别数目是固定的,通过计算特征向量的分布与每个已有类别的分布之间的距离,必然会将当前的训练图像划分至某个已有类别。此种实现方式下的已有类别可以是在训练开始前预设的若干个类别。
在第一方面的一种实现方式中,所述根据所述特征向量的分布参数确定所述训练图像的类别,包括:根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离;判断计算得到的距离中是否存在小于预设阈值的距离;若存在小于所述预设阈值的距离,则将小于所述预设阈值的距离中的最小值对应的已有类别确定为所述训练图像的类别,并利用所述特征向量的分布参数更新所述已有类别的分布;若不存在小于所述预设阈值的距离,则为所述训练图像分配一个新类别,并根据所述特征向量的分布参数确定所述新类别的分布。
在上述实现方式中,聚类的类别数目是不固定的,不仅要计算特征向量的分布与每个已有类别的分布之间的距离,还要考量计算出的距离与预设阈值之间的大小关系,根据该大小关系可能会将当前的训练图像划分至某个类别,但也可能为其创建新的类别。在此种方式下,于训练开始之前,可以预设若干个类别作为已有类别,也可以不预设任何类别。
在第一方面的一种实现方式中,所述根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离,包括:计算所述特征向量的分布参数与每个已有类别的分布参数之间的距离作为所述特征向量的分布与每个已有类别的分布之间的距离;或者,根据所述特征向量的分布参数确定所述特征向量的分布,并计算所述特征向量的分布与每个已有类别的分布之间的KL散度作为所述特征向量的分布与每个已有类别的分布之间的距离。
在上述实现方式中,提供了两种计算特征向量的分布与每个已有类别的分布之间的距离的方法:其一,计算分布参数之间的距离(因为分布参数也可以采用向量的形式,相当于计算向量间的距离);其二,计算KL散度,KL散度又称相对熵,用于评估两个分布之间的差异程度。当然,也不排除还有其他计算方法。
在第一方面的一种实现方式中,所述分布参数包括均值和方差。
一些分布的概率密度函数只需根据均值和方差就能够完全确定下来,例如,高斯分布。在可选的方案中,特征向量的分布、各类别的分布都可以假定其遵从高斯分布。
在第一方面的一种实现方式中,所述编码器、所述解码器以及所述判别器均采用神经网络。
神经网络具有良好的学***移,如果利用L2距离评估两幅图像之间区分度,由于L2距离仅仅表征了图像在像素值层面的差异,所以图像平移可能导致算出的L2距离很大,进而使得评估结果失准;但对于采用神经网络实现的判别器,则会基于图像深层次的特征(表征图像的具体内容)进行区分度评估,由于图像平移基本不改变图像的内容,所以得到的评估结果较为准确。
在第一方面的一种实现方式中,所述方法还包括:利用训练好的图像聚类模型中的编码器确定待处理图像的类别。
第二方面,本申请实施例提供一种模型训练装置,包括:编码模块,用于获取训练图像并将所述训练图像输入至编码器,获得所述编码器输出的特征向量的分布参数;聚类模块,用于根据所述特征向量的分布参数确定所述训练图像的类别;采样模块,用于从所述训练图像的类别的分布中采样获得采样向量;解码模块,用于将所述采样向量输入至解码器,获得所述解码器输出的重构图像;判别模块,用于将所述训练图像以及所述重构图像分别输入至判别器,获得所述判别器输出的判别结果,所述判别结果包括输入图像的真实性以及类别同一性;迭代模块,用于重复获取训练图像至获得判别结果的步骤,以训练包括所述编码器、所述解码器以及所述判别器在内的图像聚类模型;其中,训练的方式为对抗训练,训练的目标包括根据所述判别器输出的所述判别结果无法区分所述训练图像和所述重构图像的真实性以及类别。
第三方面,本申请实施例提供一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序指令,所述计算机程序指令被处理器读取并运行时,执行第一方面或第一方面的任意一种可能的实现方式提供的方法。
第四方面,本申请实施例提供一种电子设备,包括:存储器以及处理器,所述存储器中存储有计算机程序指令,所述计算机程序指令被所述处理器读取并运行时,执行第一方面或第一方面的任意一种可能的实现方式提供的方法。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1示出了变分自编码器的结构示意图;
图2示出了本申请实施例提供的一种模型训练方法的工作原理图;
图3示出了本申请实施例提供的一种模型训练方法的流程图;
图4示出了本申请实施例提供的一种模型训练装置的功能模块图;
图5示出了本申请实施例提供的一种电子设备的结构图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述。应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。
术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本申请实施例提供的模型训练方法是一种基于变分自编码器(VariationalAuto-Encoder,简称VAE)的无监督聚类方法(指训练好的模型可用于无监督聚类图像)。因此,在介绍本申请的方案之前,先简单介绍变分自编码器的概念,对于变分自编码器的实现细节,若未提及的,可以参考现有技术。
图1示出了变分自编码器的结构示意图。参照图1,变分自编码器的工作流程主要包括三个阶段:
编码阶段:由编码器对输入数据进行编码,得到隐变量的均值以及隐变量的方差(指隐变量所服从的分布的均值和方差)。
采样阶段:基于编码阶段得到的均值和方差从隐变量服从的分布中进行采样,得到采样变量。
解码阶段:基于采样变量解码出一个和输入数据维度一致的输出数据。
通过训练变分自编码器可以使得输出数据尽可能接近于输入数据,即实现了对输入数据的重构。因此,在一些应用场景中,可以将训练好的变分自编码器用于数据生成。特别地,在图像处理领域,可将变分自编码器用于图像生成。
图2示出了本申请实施例提供的一种模型训练方法的工作原理图,该方法用于训练一个可执行图像聚类任务的图像聚类模型(实际中可能只有编码器会用于执行图像聚类任务,模型的其余组件只在训练时使用,详见后文)。参照图2,不难看出,该图像聚类模型中也包含一个由编码器和解码器构成的变分自编码器,同时该模型与变分自编码器的主要区别在于加入了一个判别器。关于图2中的其他内容,将在阐述图3时一并进行介绍。
图3示出了本申请实施例提供的一种模型训练方法的流程图。该方法可由一电子设备执行,图5示出了该电子设备的一种可能的结构,具体可以参考后文关于图5的阐述。参照图3,该方法包括:
步骤S100:获取训练图像并将训练图像输入至编码器,获得编码器输出的特征向量的分布参数。
获取训练图像的方式不限定,例如,可以是从网络上搜集的无标注的图像,可以是某个数据集中的图像,可以是实时采集的图像等。编码器用于对训练图像进行编码,输出特征向量的分布参数。此处的特征向量对应于上文介绍变分自编码器时提到的隐变量,对于自编码器(Auto-Encoder,简称AE),其编码器直接提取输入图像的特征向量,但对于变分自编码器,在此处并不直接提取训练图像的特征向量,而是输出特征向量所服从的分布的参数,所以从这个意义上看该特征向量是隐含的。
编码器可以采用神经网络实现,该神经网络可以包含若干卷积层(当然不排除其他层,如池化层),用于通过特征提取将一幅图像编码为特征向量的分布参数。例如,编码器可以采用ResNet、VGG、LeNet、GoogleNet等网络结构。
步骤S100中需要输出哪些分布参数取决于特征向量的分布形式,该分布形式需要事先确定下来,反过来,若获知了特征向量的分布参数的具体数值,也能够确定其具体的分布。
例如,若假定特征向量服从高斯分布,则分布参数可以包括均值和方差,根据均值和方差能够唯一确定一个具体的高斯分布的概率密度函数;又例如,若假定特征向量服从指数分布,则分布参数可以包括率参数,根据率参数能够唯一确定一个具体的指数分布的概率密度函数。
步骤S110:根据特征向量的分布参数确定训练图像的类别。
将执行步骤S110时已经确定出的聚类类别称为已有类别,每个已有类别都对应一个自己的概率分布,每个已有类别下都包括零个或多个训练图像,这些训练图像可以视为对该类别的分布进行采样得到的样本。注意,即使某个类别下不包括任何训练图像,该类别也可能具有概率分布,这样的分布是预先指定的。聚类类别的分布形式需要事先确定下来,这样每个分布需要采用哪些分布参数来表示也可以确定下来,在对训练图像进行聚类的过程中,只需要根据训练图像的类别划分情况更新这些分布参数的取值就可以维护每个类别的分布。
从而,步骤S110中所说的根据特征向量的分布参数确定训练图像的类别,可以指根据特征向量的分布参数计算特征向量的分布与每个已有类别的分布之间的距离,然后根据计算出的距离确定当前的训练图像应该属于哪个已有类别,或者不属于任何已有类别。这里的距离,泛指某种相似性度量结果,用于度量两个分布之间的相似性,距离的取值和分布的相似性可以为正相关关系。下面列举两种距离计算方式,可以理解,还存在其他距离计算方式:
方式一:计算特征向量的分布参数与已有类别的分布参数之间的距离作为特征向量的分布与已有类别的分布之间的距离。由于分布参数可以决定分布的概率密度函数,从而,若两个分布的参数接近,则其相似性也越高,而相互之间的距离也越小。分布参数可以采用向量的形式表示(例如,均值向量、方差向量,其维度与特征向量相同),从而,计算分布参数的距离就是计算向量之间的距离。
方式二:根据特征向量的分布参数确定特征向量的分布,并计算特征向量的分布与已有类别的分布之间的KL散度作为特征向量的分布与已有类别的分布之间的距离。KL散度又称相对熵,其设计目的就是用于评估两个分布之间的差异程度,KL散度值越大,两个分部之间的差异越大,否则越小。
图像聚类中主要有两种情况:一种是预先指定聚类结果中包含的类别,在聚类过程中只是将待聚类图像划分至这些预设的类别,例如,事先获得了有关待聚类图像的某些先验知识,可以确定出其总共包含的类别;另一种是预先不指定类别,在聚类的过程允许产生新的类别,例如,事先并不知道代聚类图像可能被划分为几类。对于后一种情况,可能会指定一些预设类别作为初始类别,也可能不指定任何类别作为初始类别,在聚类的过程中才生成全部的类别。
对于第一种情况,步骤S110可以这样实现:首先,根据特征向量的分布参数计算特征向量的分布与每个已有类别的分布之间的距离;然后,将计算得到的距离最小的已有类别确定为训练图像的类别;最后,利用特征向量的分布参数更新已有类别的分布。
其中,已有类别是预设好的,其初始分布的分布参数可以随机给出,在有训练图像划分到该类别后,利用其对应的特征向量的分布参数对初始分布的参数进行更新,更新后的分布参数将会具有实际意义。在这种实现方式中,当前的训练图像必然会被划分至某个已有类别。
对于第二种情况,步骤S110可以这样实现:首先,根据特征向量的分布参数计算特征向量的分布与每个已有类别的分布之间的距离;然后,判断计算得到的距离中是否存在小于预设阈值的距离;若存在小于预设阈值的距离,则将小于预设阈值的距离中的最小值对应的已有类别确定为训练图像的类别,并利用特征向量的分布参数更新已有类别的分布;若不存在小于预设阈值的距离,则为训练图像分配一个新类别,并根据特征向量的分布参数确定新类别的分布。
在上述实现方式中,不仅要计算特征向量的分布与每个已有类别的分布之间的距离,还要考量计算出的距离与预设阈值之间的大小关系,根据该大小关系可能会将当前的训练图像划分至某个类别,但也可能为其创建新的类别,具体而言,若某个距离小于预设阈值,表明当前的训练图像很可能是对某个已有类别的分布进行采样的结果,所以应当将其划分至某个已有类别,否则应当为其创建新的类别。在此种方式下,于训练开始之前,可以预设若干个类别作为已有类别,也可以不预设任何类别,任由在聚类过程中再去产生类别。
当然,除了以上两种情况,也不排除一些其他的图像聚类方式,例如,先不指定任何预设类别进行图像聚类,但指定一个最大类别数目阈值,在聚类过程中若已有类别的数目未达到该阈值,则允许产生新的类别,若达到该阈值则不允许产生新的类别,只允许将图像划分至已有类别。
步骤S120:从训练图像的类别的分布中采样获得采样向量。
由于步骤S110中已经确定了训练图像的类别,从而只需要对该类别的分布进行采样就可以获得采样向量,此处的采样向量对应于上文介绍变分自编码器时提到的采样变量。基于一个已知分布进行采样属于现有技术,此处不进行具体介绍,采样产生的向量可以是一个随机向量。
步骤S130:将采样向量输入至解码器,获得解码器输出的重构图像。
解码器可以采用神经网络实现,该神经网络可以包含若干反卷积层(当然不排除其他层,如反池化层),用于将一个向量解码(或称重构)为一幅图像。
步骤S140:将训练图像以及重构图像分别输入至判别器,获得判别器输出的判别结果,判别结果包括输入图像的真实性以及类别同一性。
所谓输入图像的真实性(可以对应一个分数或概率),就是该输入图像到底是原始的训练图像还是重构出的训练图像,所谓输入图像的类别同一性(可以对应一个分数或概率),就是指输入的训练图像以及重构图像是否属于同一类别。
判别器可以采用神经网络实现,当然也不排除采用某些固有规则进行判别,例如计算训练图像以及重构图像的相似程度。但由于神经网络具有良好的学习和泛化能力,所以其经过训练后输出判别结果相较于单纯靠预设规则计算出的判别结果更加可靠。
步骤S150:对抗训练图像聚类模型,直至满足训练结束条件。
图像聚类模型至少包括编码器、解码器以及判别器,但也不排除包含其他结构。其中,编解码器视为一部分,判别器视为另一部分,所谓对抗训练是指这两部分之间的训练目标的对抗。编解码器的训练目标就是使解码器尽可能重构出与训练图像完全相同的图像,足以“欺骗”辨别器;判别器的训练目标就是尽量区分出重构图像与训练图像,不受编解码器的“欺骗”。从总体上看,图像聚类模型的训练目标包括:使得在训练完成后,根据训练好的判别器输出的判别结果无法区分训练图像和重构图像的真实性以及类别,即向判别器输入训练图像或重构图像,根据判别结果难以确定它们到底是重构的还是原始的训练图像,向判别器输入训练图像和重构图像,根据判别结果可以确定二者属于同一类别,也就是说解码器重构的图像已经足够接近真实的训练图像。
有关对抗训练的原理,可以参考现生成式对抗网络(Generative AdversarialNetworks,简称GAN)中的相关内容,其中,编解码器可视为GAN中的生成网络(Generator),判别器可视为GAN中的判别网络(Discriminator)。
步骤S100至步骤S140可视为训练过程中的一轮迭代过程(计算损失、更新网络参数的步骤从略),而步骤S150是一个迭代步骤,每经过一轮训练后,就会判断是否满足训练结束条件,若满足条件则结束训练,否则回到步骤S100开始下一轮训练。训练结束条件可能有多种设置方式,例如,可以是训练了一定轮次后结束,可以是训练了一定时间后结束,可以是判别器收敛后结束,或者也可以是以上多种条件的组合,等等。
图像聚类模型训练好后,可以将其中的编码器用于执行图像聚类任务,其过程类似步骤S100和步骤S110:获取待处理图像(指需要聚类的图像)并将待处理图像输入至训练好的编码器,获得编码器输出的特征向量的分布参数,根据特征向量的分布参数确定待处理图像的类别。其具体实现方式也可以参考步骤S100和步骤S110,此处不再重复。至于图像聚类模型中的其他组件,在执行实际的图像聚类任务时可不使用。
综上所述,在本申请实施例提供的模型训练方法中,通过对图像聚类模型进行对抗训练,使得在训练结束后,判别器难以区分由解码器输出的重构图像和真实的训练图像,由于解码器进行图像重构是基于编码器提取到的特征向量的分布参数进行的,这说明此时编码器能够有效提取图像的特征(只有提取到的特征足够好,判别器才会难以区分重构图像和训练图像),因此后续利用训练好的编码器执行图像聚类任务能够取得较好的效果。
在一些对照实施例中,不采用判别器,只使用变分自编码器进行聚类,并在变分自编码器的损失函数中加入重构损失(如,计算训练图像与重构图像的差分),以此使得重构图像足够接近训练图像。以采用L2距离计算重构损失为例,经发明人长期研究发现,由于L2距离仅仅表征了图像在像素值层面的差异,所以图像在像素层面的一些微小变化(如平移操作),都可能导致算出的L2距离很大,特别是在图像尺寸较大时这一问题更加明显,致使重构损失难以收敛,不利于模型的训练。在本申请的方案中,通过设置判别器来评估训练图像和重构图像的区分度,代替了采用图像差分的方式进行损失计算,从而避免了对照实施例中的上述问题,使得图像聚类模型易于训练,并可用于执行大尺寸图像的聚类任务。
进一步的,在该方法的一些实现方式中,图像聚类网络中的编码器、解码器以及判别器均采用神经网络实现。由于神经网络具有良好的学***移,如果利用L2距离评估两幅图像之间区分度,由于图像平移的影响可能导致算出的L2距离很大,进而使得评估结果失准;但对于采用神经网络实现的判别器,则会基于图像深层次的特征(表征图像的具体内容)进行区分度评估,由于图像平移基本不改变图像的内容,所以得到的评估结果较为准确。
图4示出了本申请实施例提供的模型训练装置200的功能模块图。参照图4,模型训练装置200包括:
编码模块210,用于获取训练图像并将所述训练图像输入至编码器,获得所述编码器输出的特征向量的分布参数;
聚类模块220,用于根据所述特征向量的分布参数确定所述训练图像的类别;
采样模块230,用于从所述训练图像的类别的分布中采样获得采样向量;
解码模块240,用于将所述采样向量输入至解码器,获得所述解码器输出的重构图像;
判别模块250,用于将所述训练图像以及所述重构图像分别输入至判别器,获得所述判别器输出的判别结果,所述判别结果包括输入图像的真实性以及类别同一性;
迭代模块260,用于重复获取训练图像至获得判别结果的步骤,以训练包括所述编码器、所述解码器以及所述判别器在内的图像聚类模型;其中,训练的方式为对抗训练,训练的目标包括根据所述判别器输出的所述判别结果无法区分所述训练图像和所述重构图像的真实性以及类别。
在模型训练装置200的一种实现方式中,聚类模块220根据所述特征向量的分布参数确定所述训练图像的类别,包括:根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离;将计算得到的距离最小的已有类别确定为所述训练图像的类别;利用所述特征向量的分布参数更新所述已有类别的分布。
在模型训练装置200的一种实现方式中,聚类模块220根据所述特征向量的分布参数确定所述训练图像的类别,包括:根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离;判断计算得到的距离中是否存在小于预设阈值的距离;若存在小于所述预设阈值的距离,则将小于所述预设阈值的距离中的最小值对应的已有类别确定为所述训练图像的类别,并利用所述特征向量的分布参数更新所述已有类别的分布;若不存在小于所述预设阈值的距离,则为所述训练图像分配一个新类别,并根据所述特征向量的分布参数确定所述新类别的分布。
在模型训练装置200的一种实现方式中,聚类模块220根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离,包括:计算所述特征向量的分布参数与每个已有类别的分布参数之间的距离作为所述特征向量的分布与每个已有类别的分布之间的距离;或者,根据所述特征向量的分布参数确定所述特征向量的分布,并计算所述特征向量的分布与每个已有类别的分布之间的KL散度作为所述特征向量的分布与每个已有类别的分布之间的距离。
在模型训练装置200的一种实现方式中,所述分布参数包括均值和方差。
在模型训练装置200的一种实现方式中,所述编码器、所述解码器以及所述判别器均采用神经网络。
在模型训练装置200的一种实现方式中,所述装置还包括:应用模块,用于利用训练好的图像聚类模型中的编码器确定待处理图像的类别。
本申请实施例提供的模型训练装置200,其实现原理及产生的技术效果在前述方法实施例中已经介绍,为简要描述,装置实施例部分未提及之处,可参考方法施例中相应内容。
图5示出了本申请实施例提供的电子设备300的一种可能的结构。参照图5,电子设备300包括:处理器310、存储器320以及通信接口330,这些组件通过通信总线340和/或其他形式的连接机构(未示出)互连并相互通讯。
其中,存储器320包括一个或多个(图中仅示出一个),其可以是,但不限于,随机存取存储器(Random Access Memory,简称RAM),只读存储器(Read Only Memory,简称ROM),可编程只读存储器(Programmable Read-Only Memory,简称PROM),可擦除可编程只读存储器(Erasable Programmable Read-Only Memory,简称EPROM),电可擦除可编程只读存储器(Electrically Erasable Programmable Read-Only Memory,简称EEPROM)等。处理器310以及其他可能的组件可对存储器320进行访问,读和/或写其中的数据。
处理器310包括一个或多个(图中仅示出一个),其可以是一种集成电路芯片,具有信号的处理能力。上述的处理器610可以是通用处理器,包括中央处理器(CentralProcessing Unit,简称CPU)、微控制单元(Micro Controller Unit,简称MCU)、网络处理器(Network Processor,简称NP)或者其他常规处理器;还可以是专用处理器,包括数字信号处理器(Digital Signal Processor,简称DSP)、专用集成电路(Application SpecificIntegrated Circuits,简称ASIC)、现场可编程门阵列(Field Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
通信接口330包括一个或多个(图中仅示出一个),可以用于和其他设备进行直接或间接地通信,以便进行数据的交互。通信接口330可以包括进行有线和/或无线通信的接口。
在存储器320中可以存储一个或多个计算机程序指令,处理器310可以读取并运行这些计算机程序指令,以实现本申请实施例提供的模型训练方法及其他期望的功能。
可以理解,图5所示的结构仅为示意,电子设备300还可以包括比图5中所示更多或者更少的组件,或者具有与图5所示不同的配置。图5中所示的各组件可以采用硬件、软件或其组合实现。电子设备300可能是实体设备,例如服务器、PC机、笔记本电脑、平板电脑、手机、可穿戴设备、图像采集设备、车载设备、无人机、机器人等,也可能是虚拟设备,例如虚拟机、虚拟化容器等。并且,电子设备300也不限于单台设备,也可以是多台设备的组合或者大量设备构成的一个或多个集群。
本申请实施例还提供一种计算机可读存储介质,该计算机可读存储介质上存储有计算机程序指令,所述计算机程序指令被计算机的处理器读取并运行时,执行本申请实施例提供的模型训练方法。例如,计算机可读存储介质可以实现为图5中电子设备300中的存储器320。
在本申请所提供的实施例中,应该理解到,所揭露装置和方法,可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,又例如,多个单元或组件可以结合或者可以集成到另一个***,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些通信接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
另外,作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
再者,在本申请各个实施例中的各功能模块可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或两个以上模块集成形成一个独立的部分。
以上所述仅为本申请的实施例而已,并不用于限制本申请的保护范围,对于本领域的技术人员来说,本申请可以有各种更改和变化。凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (10)
1.一种模型训练方法,其特征在于,包括:
获取训练图像并将所述训练图像输入至编码器,获得所述编码器输出的特征向量的分布参数;
根据所述特征向量的分布参数确定所述训练图像的类别;
从所述训练图像的类别的分布中采样获得采样向量;
将所述采样向量输入至解码器,获得所述解码器输出的重构图像;
将所述训练图像以及所述重构图像分别输入至判别器,获得所述判别器输出的判别结果,所述判别结果包括输入图像的真实性以及类别同一性;
重复获取训练图像至获得判别结果的步骤,以训练包括所述编码器、所述解码器以及所述判别器在内的图像聚类模型;其中,训练的方式为对抗训练,训练的目标包括根据所述判别器输出的所述判别结果无法区分所述训练图像和所述重构图像的真实性以及类别。
2.根据权利要求1所述的模型训练方法,其特征在于,所述根据所述特征向量的分布参数确定所述训练图像的类别,包括:
根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离;
将计算得到的距离最小的已有类别确定为所述训练图像的类别;
利用所述特征向量的分布参数更新所述已有类别的分布。
3.根据权利要求1所述的模型训练方法,其特征在于,所述根据所述特征向量的分布参数确定所述训练图像的类别,包括:
根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离;
判断计算得到的距离中是否存在小于预设阈值的距离;
若存在小于所述预设阈值的距离,则将小于所述预设阈值的距离中的最小值对应的已有类别确定为所述训练图像的类别,并利用所述特征向量的分布参数更新所述已有类别的分布;
若不存在小于所述预设阈值的距离,则为所述训练图像分配一个新类别,并根据所述特征向量的分布参数确定所述新类别的分布。
4.根据权利要求2或3所述的模型训练方法,其特征在于,所述根据所述特征向量的分布参数计算所述特征向量的分布与每个已有类别的分布之间的距离,包括:
计算所述特征向量的分布参数与每个已有类别的分布参数之间的距离作为所述特征向量的分布与每个已有类别的分布之间的距离;或者,
根据所述特征向量的分布参数确定所述特征向量的分布,并计算所述特征向量的分布与每个已有类别的分布之间的KL散度作为所述特征向量的分布与每个已有类别的分布之间的距离。
5.根据权利要求1所述的模型训练方法,其特征在于,所述分布参数包括均值和方差。
6.根据权利要求1所述的模型训练方法,其特征在于,所述编码器、所述解码器以及所述判别器均采用神经网络。
7.根据权利要求1所述的模型训练方法,其特征在于,所述方法还包括:
利用训练好的图像聚类模型中的编码器确定待处理图像的类别。
8.一种模型训练装置,其特征在于,包括:
编码模块,用于获取训练图像并将所述训练图像输入至编码器,获得所述编码器输出的特征向量的分布参数;
聚类模块,用于根据所述特征向量的分布参数确定所述训练图像的类别;
采样模块,用于从所述训练图像的类别的分布中采样获得采样向量;
解码模块,用于将所述采样向量输入至解码器,获得所述解码器输出的重构图像;
判别模块,用于将所述训练图像以及所述重构图像分别输入至判别器,获得所述判别器输出的判别结果,所述判别结果包括输入图像的真实性以及类别同一性;
迭代模块,用于重复获取训练图像至获得判别结果的步骤,以训练包括所述编码器、所述解码器以及所述判别器在内的图像聚类模型;其中,训练的方式为对抗训练,训练的目标包括根据所述判别器输出的所述判别结果无法区分所述训练图像和所述重构图像的真实性以及类别。
9.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序指令,所述计算机程序指令被处理器读取并运行时,执行如权利要求1-7中任一项所述的方法。
10.一种电子设备,其特征在于,包括:存储器以及处理器,所述存储器中存储有计算机程序指令,所述计算机程序指令被所述处理器读取并运行时,执行如权利要求1-7中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010623929.5A CN111738351B (zh) | 2020-06-30 | 2020-06-30 | 模型训练方法、装置、存储介质及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010623929.5A CN111738351B (zh) | 2020-06-30 | 2020-06-30 | 模型训练方法、装置、存储介质及电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111738351A true CN111738351A (zh) | 2020-10-02 |
CN111738351B CN111738351B (zh) | 2023-12-19 |
Family
ID=72652358
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010623929.5A Active CN111738351B (zh) | 2020-06-30 | 2020-06-30 | 模型训练方法、装置、存储介质及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111738351B (zh) |
Cited By (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112016638A (zh) * | 2020-10-26 | 2020-12-01 | 广东博智林机器人有限公司 | 一种钢筋簇的识别方法、装置、设备及存储介质 |
CN112328750A (zh) * | 2020-11-26 | 2021-02-05 | 上海天旦网络科技发展有限公司 | 训练文本判别模型的方法及*** |
CN112465020A (zh) * | 2020-11-25 | 2021-03-09 | 创新奇智(合肥)科技有限公司 | 训练数据集的生成方法及装置、电子设备、存储介质 |
CN113361583A (zh) * | 2021-06-01 | 2021-09-07 | 珠海大横琴科技发展有限公司 | 一种对抗样本检测方法和装置 |
CN113362403A (zh) * | 2021-07-20 | 2021-09-07 | 支付宝(杭州)信息技术有限公司 | 图像处理模型的训练方法及装置 |
CN113468820A (zh) * | 2021-07-21 | 2021-10-01 | 上海眼控科技股份有限公司 | 数据训练方法、装置、设备及存储介质 |
CN113936302A (zh) * | 2021-11-03 | 2022-01-14 | 厦门市美亚柏科信息股份有限公司 | 行人重识别模型的训练方法、装置、计算设备及存储介质 |
WO2022089522A1 (zh) * | 2020-10-28 | 2022-05-05 | 华为技术有限公司 | 一种数据传输的方法和装置 |
CN115100717A (zh) * | 2022-06-29 | 2022-09-23 | 腾讯科技(深圳)有限公司 | 特征提取模型的训练方法、卡通对象的识别方法及装置 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109615014A (zh) * | 2018-12-17 | 2019-04-12 | 清华大学 | 一种基于kl散度优化的数据分类***与方法 |
EP3477553A1 (en) * | 2017-10-27 | 2019-05-01 | Robert Bosch GmbH | Method for detecting an anomalous image among a first dataset of images using an adversarial autoencoder |
CN110009013A (zh) * | 2019-03-21 | 2019-07-12 | 腾讯科技(深圳)有限公司 | 编码器训练及表征信息提取方法和装置 |
CN110309853A (zh) * | 2019-05-20 | 2019-10-08 | 湖南大学 | 基于变分自编码器的医学图像聚类方法 |
CN110458904A (zh) * | 2019-08-06 | 2019-11-15 | 苏州瑞派宁科技有限公司 | 胶囊式内窥镜图像的生成方法、装置及计算机存储介质 |
US20200005154A1 (en) * | 2018-02-01 | 2020-01-02 | Siemens Healthcare Limited | Data encoding and classification |
CN111079649A (zh) * | 2019-12-17 | 2020-04-28 | 西安电子科技大学 | 基于轻量化语义分割网络的遥感图像地物分类方法 |
-
2020
- 2020-06-30 CN CN202010623929.5A patent/CN111738351B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
EP3477553A1 (en) * | 2017-10-27 | 2019-05-01 | Robert Bosch GmbH | Method for detecting an anomalous image among a first dataset of images using an adversarial autoencoder |
US20200005154A1 (en) * | 2018-02-01 | 2020-01-02 | Siemens Healthcare Limited | Data encoding and classification |
CN109615014A (zh) * | 2018-12-17 | 2019-04-12 | 清华大学 | 一种基于kl散度优化的数据分类***与方法 |
CN110009013A (zh) * | 2019-03-21 | 2019-07-12 | 腾讯科技(深圳)有限公司 | 编码器训练及表征信息提取方法和装置 |
CN110309853A (zh) * | 2019-05-20 | 2019-10-08 | 湖南大学 | 基于变分自编码器的医学图像聚类方法 |
CN110458904A (zh) * | 2019-08-06 | 2019-11-15 | 苏州瑞派宁科技有限公司 | 胶囊式内窥镜图像的生成方法、装置及计算机存储介质 |
CN111079649A (zh) * | 2019-12-17 | 2020-04-28 | 西安电子科技大学 | 基于轻量化语义分割网络的遥感图像地物分类方法 |
Non-Patent Citations (3)
Title |
---|
徐德荣;陈秀宏;田进;: "基于类编码的判别特征学习", 计算机工程与科学, no. 03 * |
杨晨曦;左?;孙频捷;: "基于自编码器的零样本学习方法研究进展", 现代计算机, no. 01 * |
陈梦雪;刘勇;: "基于对抗图卷积的网络表征学习框架", 模式识别与人工智能, no. 11 * |
Cited By (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112016638A (zh) * | 2020-10-26 | 2020-12-01 | 广东博智林机器人有限公司 | 一种钢筋簇的识别方法、装置、设备及存储介质 |
CN112016638B (zh) * | 2020-10-26 | 2021-04-06 | 广东博智林机器人有限公司 | 一种钢筋簇的识别方法、装置、设备及存储介质 |
WO2022089522A1 (zh) * | 2020-10-28 | 2022-05-05 | 华为技术有限公司 | 一种数据传输的方法和装置 |
CN112465020A (zh) * | 2020-11-25 | 2021-03-09 | 创新奇智(合肥)科技有限公司 | 训练数据集的生成方法及装置、电子设备、存储介质 |
CN112328750A (zh) * | 2020-11-26 | 2021-02-05 | 上海天旦网络科技发展有限公司 | 训练文本判别模型的方法及*** |
CN113361583A (zh) * | 2021-06-01 | 2021-09-07 | 珠海大横琴科技发展有限公司 | 一种对抗样本检测方法和装置 |
CN113362403A (zh) * | 2021-07-20 | 2021-09-07 | 支付宝(杭州)信息技术有限公司 | 图像处理模型的训练方法及装置 |
CN113468820A (zh) * | 2021-07-21 | 2021-10-01 | 上海眼控科技股份有限公司 | 数据训练方法、装置、设备及存储介质 |
CN113936302A (zh) * | 2021-11-03 | 2022-01-14 | 厦门市美亚柏科信息股份有限公司 | 行人重识别模型的训练方法、装置、计算设备及存储介质 |
CN113936302B (zh) * | 2021-11-03 | 2023-04-07 | 厦门市美亚柏科信息股份有限公司 | 行人重识别模型的训练方法、装置、计算设备及存储介质 |
CN115100717A (zh) * | 2022-06-29 | 2022-09-23 | 腾讯科技(深圳)有限公司 | 特征提取模型的训练方法、卡通对象的识别方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN111738351B (zh) | 2023-12-19 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111738351A (zh) | 模型训练方法、装置、存储介质及电子设备 | |
CN110929622B (zh) | 视频分类方法、模型训练方法、装置、设备及存储介质 | |
JP7414901B2 (ja) | 生体検出モデルのトレーニング方法及び装置、生体検出の方法及び装置、電子機器、記憶媒体、並びにコンピュータプログラム | |
CN111291817B (zh) | 图像识别方法、装置、电子设备和计算机可读介质 | |
CN110765860A (zh) | 摔倒判定方法、装置、计算机设备及存储介质 | |
CN112270686B (zh) | 图像分割模型训练、图像分割方法、装置及电子设备 | |
CN110135505B (zh) | 图像分类方法、装置、计算机设备及计算机可读存储介质 | |
CN113095370B (zh) | 图像识别方法、装置、电子设备及存储介质 | |
CN114048468A (zh) | 入侵检测的方法、入侵检测模型训练的方法、装置及介质 | |
CN111488810A (zh) | 人脸识别方法、装置、终端设备及计算机可读介质 | |
US20230206121A1 (en) | Modal information completion method, apparatus, and device | |
CN117011274A (zh) | 自动化玻璃瓶检测***及其方法 | |
CN111382791A (zh) | 深度学习任务处理方法、图像识别任务处理方法和装置 | |
CN108496174B (zh) | 用于面部识别的方法和*** | |
CN111783812A (zh) | 违禁图像识别方法、装置和计算机可读存储介质 | |
CN114004364A (zh) | 采样优化方法、装置、电子设备及存储介质 | |
CN109101984B (zh) | 一种基于卷积神经网络的图像识别方法及装置 | |
CN110866609B (zh) | 解释信息获取方法、装置、服务器和存储介质 | |
CN116805534A (zh) | 基于弱监督学习的疾病分型方法、***、介质及设备 | |
CN115713669B (zh) | 一种基于类间关系的图像分类方法、装置、存储介质及终端 | |
CN111652320A (zh) | 一种样本分类方法、装置、电子设备及存储介质 | |
CN113591969B (zh) | 面部相似度评测方法、装置、设备以及存储介质 | |
CN110929767B (zh) | 一种字形处理方法、***、设备和介质 | |
CN113536859A (zh) | 行为识别模型训练方法、识别方法、装置及存储介质 | |
CN111582404A (zh) | 内容分类方法、装置及可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |