CN112149825A - 神经网络模型的训练方法及装置、电子设备、存储介质 - Google Patents
神经网络模型的训练方法及装置、电子设备、存储介质 Download PDFInfo
- Publication number
- CN112149825A CN112149825A CN202011019263.9A CN202011019263A CN112149825A CN 112149825 A CN112149825 A CN 112149825A CN 202011019263 A CN202011019263 A CN 202011019263A CN 112149825 A CN112149825 A CN 112149825A
- Authority
- CN
- China
- Prior art keywords
- sample data
- neural network
- network model
- data set
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000003062 neural network model Methods 0.000 title claims abstract description 159
- 238000012549 training Methods 0.000 title claims abstract description 97
- 238000000034 method Methods 0.000 title claims abstract description 66
- 239000000203 mixture Substances 0.000 claims description 28
- 239000013598 vector Substances 0.000 claims description 27
- 238000004364 calculation method Methods 0.000 claims description 25
- 238000004590 computer program Methods 0.000 claims description 4
- 238000000638 solvent extraction Methods 0.000 claims description 4
- 238000012935 Averaging Methods 0.000 claims description 3
- 238000012545 processing Methods 0.000 claims description 3
- 230000006870 function Effects 0.000 description 11
- 238000010586 diagram Methods 0.000 description 10
- 238000013528 artificial neural network Methods 0.000 description 5
- 230000000694 effects Effects 0.000 description 3
- 238000002372 labelling Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 2
- 238000005192 partition Methods 0.000 description 2
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000007689 inspection Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本申请提供一种神经网络模型的训练方法及装置、电子设备、计算机可读存储介质,方法包括:将两个批次的样本数据输入两个神经网络模型,并依据神经网络模型输出的预测类别信息将样本数据划分至干净数据集和不干净数据集;将两个批次的样本数据对调后,分别训练两个神经网络模型,且在训练过程区分干净数据集中的样本数据和不干净数据集中的样本数据。本申请实施例中,两个神经网络模型采用完全不同的样本数据进行训练,使得两个神经网络模型保持了一定的独立性,从而避免训练过程中学习到相同的错误信息。在样本数据存在噪声标签的情况下,这种训练方式可以提高神经网络模型的鲁棒性。
Description
技术领域
本申请涉及深度学习技术领域,特别涉及一种神经网络模型的训练方法及装置、电子设备、计算机可读存储介质。
背景技术
神经网络模型的训练过程需要大量高质量的标注数据。在实际应用时,为样本数据添加标签需要大量时间和人力成本,而且人工标注过程可能出错。为保证神经网络模型的应用效果,通常需对样本数据的标签进行人工核查,这又是个耗时且耗费人力成本的过程。如果可以利用含有噪声标签的标注数据对神经网络模型进行训练、同时避免噪声标签带来的负面影响,可以提高训练效率,降低训练成本。
发明内容
本申请实施例的目的在于提供一种神经网络模型的训练方法及装置、电子设备、计算机可读存储介质,用于利用含有噪声标签的标注数据训练神经网络模型,并避免神经网络模型对噪声标签的过拟合。
一方面,本申请提供了一种神经网络模型的训练方法,包括:
从样本数据集中选择第一批次的样本数据输入第一神经网络模型,获得所述第一神经网络模型输出的预测类别信息;
从所述样本数据集中选择第二批次的样本数据输入第二神经网络模型,获得所述第二神经网络模型输出的预测类别信息;
针对每一样本数据,基于所述预测类别信息拟合得到对应于所述样本数据的高斯混合模型;
针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将所述样本数据划分至干净数据集或不干净数据集;
利用所述第一批次的样本数据对所述第二神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;
利用所述第二批次的样本数据对所述第一神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;
从所述样本数据集中重新选择两个批次的样本数据,重复上述训练过程,直至所述第一神经网络模型和所述第二神经网络模型收敛。
在一实施例中,所述针对每一样本数据,基于所述预测类别信息拟合得到对应于所述样本数据的高斯混合模型,包括:
针对每一样本数据,基于所述样本数据的预测类别信息中对应于各类别的置信度,拟合得到所述高斯混合模型;其中,所述高斯混合模型包括对应于真实预测类别信息的真实高斯模型、对应于虚假预测类别信息的虚假高斯模型。
在一实施例中,所述针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将所述样本数据划分至干净数据集或不干净数据集,包括:
针对每一样本数据,判断拟合得到的与所述真实高斯模型对应的权重参数是否达到预设权重参数阈值;
如果是,将所述样本数据划分至所述干净数据集;
如果否,将所述样本数据划分至所述不干净数据集。
在一实施例中,所述利用所述第一批次的样本数据对所述第二神经网络模型进行训练,包括:
将所述第一批次的样本数据输入所述第二神经网络模型,获得所述第二神经网络模型输出的预测类别信息;
对于所述干净数据集中的每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算交叉熵;
对于所述不干净数据集中每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算向量距离;
对所述交叉熵和所述向量距离加权求和,获得损失参数;
根据所述损失参数对所述第二神经网络模型的网络参数进行调整。
在一实施例中,所述利用所述第二批次的样本数据对所述第一神经网络模型进行训练,包括:
将所述第二批次的样本数据输入所述第一神经网络模型,获得所述第一神经网络模型输出的预测类别信息;
对于所述干净数据集中的每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算交叉熵;
对于所述不干净数据集中每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算向量距离;
对所述交叉熵和所述向量距离加权求和,获得损失参数;
根据所述损失参数对所述第一神经网络模型的网络参数进行调整。
在一实施例中,在训练神经网络模型之前,所述方法还包括:
针对所述干净数据集中的每一样本数据,根据所述样本数据对应所述权重参数、所述样本数据的标签和对应的预测类别信息,更新所述样本数据的标签。
在一实施例中,在训练神经网络模型之前,所述方法还包括:
针对所述不干净数据集中每一样本数据,对所述样本数据进行多次数据增强操作,得到多个增强样本数据;
将多个增强样本数据分别输入所述第一神经网络模型和所述第二神经网络模型,得到多个预测类别信息;
对所述多个预测类别信息进行均值化处理,获得所述样本数据的标签。
另一方面,本申请还提供了一种神经网络模型的训练装置,包括:
输入模块,用于从样本数据集中选择第一批次的样本数据输入第一神经网络模型,获得所述第一神经网络模型输出的预测类别信息;从所述样本数据集中选择第二批次的样本数据输入第二神经网络模型,获得所述第二神经网络模型输出的预测类别信息;
拟合模块,用于针对每一样本数据,基于所述预测类别信息拟合得到对应于所述样本数据的高斯混合模型;
划分模块,用于针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将所述样本数据划分至干净数据集或不干净数据集;
训练模块,用于利用所述第一批次的样本数据对所述第二神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;利用所述第二批次的样本数据对所述第一神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;
选择模块,用于从所述样本数据集中重新选择两个批次的样本数据,重复上述训练过程,直至所述第一神经网络模型和所述第二神经网络模型收敛。
进一步的,本申请还提供了一种电子设备,所述电子设备包括:
处理器;
用于存储处理器可执行指令的存储器;
其中,所述处理器被配置为执行上述神经网络模型的训练方法。
另外,本申请还提供了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序可由处理器执行以完成上述神经网络模型的训练方法。
在本申请实施例中,由于两个神经网络模型采用完全不同的样本数据进行训练,使得两个神经网络模型保持了一定的独立性,从而避免训练过程中学习到相同的错误信息。在样本数据存在噪声标签的情况下,这种训练方式可以提高神经网络模型的鲁棒性。在训练过程中通过区分干净数据集中的样本数据和不干净数据集中的样本数据,可以降低噪声标签对神经网络模型的影响,避免神经网络模型对噪声标签过拟合。通过本实施例的措施,无需对标签进行人工核查,降低了人力成本。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图作简单地介绍。
图1为本申请一实施例提供的电子设备的结构示意图;
图2是本申请一实施例提供的神经网络模型的训练示意图;
图3是本申请一实施例提供的神经网络模型的训练方法的流程示意图;
图4是本申请一实施例提供的训练第一神经网络模型的流程示意图;
图5是本申请一实施例提供的训练第二神经网络模型的流程示意图;
图6是本申请一实施例提供的神经网络模型的训练装置的框图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述。
相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。同时,在本申请的描述中,术语“第一”、“第二”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
如图1所示,本实施例提供一种电子设备1,包括:至少一个处理器11和存储器12,图1中以一个处理器11为例。处理器11和存储器12通过总线10连接,存储器12存储有可被处理器11执行的指令,指令被处理器11执行,以使电子设备1可执行下述的实施例中方法的全部或部分流程。在一实施例中,电子设备1可以是执行神经网络模型训练的服务器或主机,为便于描述方案,下文以主机为执行主体。
存储器12可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,如静态随机存取存储器(Static Random Access Memory,简称SRAM),电可擦除可编程只读存储器(Electrically Erasable Programmable Read-Only Memory,简称EEPROM),可擦除可编程只读存储器(Erasable Programmable Read Only Memory,简称EPROM),可编程只读存储器(Programmable Red-Only Memory,简称PROM),只读存储器(Read-Only Memory,简称ROM),磁存储器,快闪存储器,磁盘或光盘。
本申请还提供了一种计算机可读存储介质,存储介质存储有计算机程序,计算机程序可由处理器11执行以完成本申请提供的神经网络模型的训练方法。
参见图2,为本申请一实施例提供的神经网络模型的训练示意图,如图2所示,本申请同时训练两个神经网络模型。主机选择两个批次的样本数据,分别输入到两个神经网络模型中进行预测,获得预测类别信息。以图2为例,主机将“batch-a”批次的样本数据输入第一神经网络模型,由第一神经网络输出“batch-a”批次的样本数据对应的预测类别信息;主机将“batch-b”批次的样本数据输入第二神经网络模型,由第二神经网络输出“batch-b”批次的样本数据对应的预测类别信息。
主机依据样本数据的预测类别信息,为每一样本数据拟合高斯混合模型(Gaussian Mixture Model,GMM),并以拟合结果将样本数据划分至干净数据集或不干净数据集。不干净数据集中的样本数据可能携带噪声标签。
训练阶段主机可以将两个批次的样本数据对调,进而训练两个神经网络模型。以图2为例,主机利用“batch-a”批次的样本数据训练第二神经网络模型,利用“batch-b”批次的样本数据训练第一神经网络模型。由于样本数据被划分为干净数据和不干净数据,在训练阶段可以对两者区别对待,降低噪声标签带来的过拟合和学习偏差等问题。
后文详细说明使用含有噪声标签的标注数据的训练方法。
参见图3,为本申请一实施例提供的神经网络模型的训练方法的流程示意图,如图3所示,该方法可以包括以下步骤310-步骤370。
步骤310:从样本数据集中选择第一批次的样本数据输入第一神经网络模型,获得第一神经网络模型输出的预测类别信息。
步骤320:从样本数据集中选择第二批次的样本数据输入第二神经网络模型,获得第二神经网络模型输出的预测类别信息。
样本数据集中包括大量已添加标签的样本数据,标签指示样本数据(图像)上目标的类别信息。类别信息可以是n维向量,n是可以识别的类别总数;向量中每一元素对应于一种类别,表示目标属于该类别的置信度。标签指示的类别是确定的,该类别在向量中对应的元素为1,向量中其它元素为0。
第一神经网络模型和第二神经网络模型可以是LeNet、AlexNet、VGG(VisualGeometry Group Network,视觉几何群网络)、GoogleNet、ResNet(Residual NeuralNetwork,残差神经网络)等用于分类的网络中的任意一种。第一神经网络模型和第二神经网络模型的模型结构可以相同,也可以不同。
主机可以从样本数据集中两次选择一定数量的样本数据,作为第一批次的样本数据和第二批次的样本数据。第一批次的样本数据与第二批次的样本数据相互独立,换而言之,一个样本数据不会同时被选到第一批次和第二批次中。
主机将第一批次的样本数据输入第一神经网络模型,获得第一神经网络模型输出的对应于每一样本数据的预测类别信息。预测类别信息可以是与标签上类别信息相同维度的向量,向量中每一元素对应一种类别,表示目标属于该类别的置信度。预测类别信息中每一元素都是0到1之间的数值,其一个向量中各元素相加的和为1。
主机将第二批次的样本数据输入第二神经网络模型,获得第二神经网络模型输出的对应于每一样本数据的预测类别信息。
步骤330:针对每一样本数据,基于预测类别信息拟合得到对应于样本数据的高斯混合模型。
其中,高斯混合模型包括对应于真实预测类别信息的真实高斯模型,以及对应于虚假预测类别信息的虚假高斯模型。
对于任一样本数据的预测类别信息而言,以预测类别信息的每一类别作为横坐标,以置信度作为纵坐标,根据每一类别对应的置信度可以绘制出一条曲线。一般,如果预测类别信息正确,在预测类别信息中最大的置信度明显大于其它置信度,在这种情况下,曲线的峰较为明显,可认为预测类别信息的正确的可能性较大。
将每一类别对应的置信度看作从两个单独的高斯分布中采集得到,因此,针对每一样本数据,主机可以将基于样本数据的预测类别信息中对应于各类别的置信度,拟合得到高斯混合模型。在拟合时,设置高斯混合模型包括两个高斯模型,分别为真实高斯模型和虚假高斯模型。拟合完成后,可以得到对应于真实高斯模型的均值、方差和权重参数,以及对应于虚假高斯模型的均值、方差和权重参数。此时,每一样本数据存在对应的高斯混合模型。主机在拟合得到两组均值、方差和权重参数后,从两组数据中选择权重参数较大的数据,作为与真实高斯模型对应的均值、方差和权重参数。
步骤340:针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将样本数据划分至干净数据集或不干净数据集。
干净数据集中的样本数据是干净数据,在统计概率上,被划分至干净数据集的样本数据的标签正确的可能性较大。
不干净数据集中的样本数据是不干净数据,在统计概率上,被划分至不干净数据集的样本数据的标签正确的可能性较小。换而言之,不干净数据集中的样本数据可能携带噪声标签。
在一实施例中,针对每一样本数据,主机可以判断该样本数据对应的高斯混合模型中真实高斯模型对应的权重参数,是否达到预设权重参数阈值。该权重参数阈值可以是预配置的经验值,用于区分干净数据和不干净数据。一方面,如果权重参数达到权重参数阈值,主机可以将样本数据划分至干净数据集。另一方面,如果权重参数未达到权重参数阈值,主机可以将样本数据划分至不干净数据集。
步骤350:利用第一批次的样本数据对第二神经网络模型进行训练;其中,在训练过程中干净数据集中样本数据对应的损失计算方式,与不干净数据集中样本数据对应的损失计算方式不同。
步骤360:利用第二批次的样本数据对第一神经网络模型进行训练;其中,在训练过程中干净数据集中样本数据对应的损失计算方式,与不干净数据集中样本数据对应的损失计算方式不同。
损失计算方式可以包括损失函数和损失权重,不同的损失权重可用于调节样本数据对训练过程的影响。
主机对调两个批次的样本数据,利用第一批次样本数据训练第二神经网络模型,利用第二批次样本数据训练第二神经网络模型。由于两个神经网络模型采用完全不同的样本数据进行训练,使得两个神经网络模型保持了一定的独立性,从而避免训练过程中学习到相同的错误信息。在样本数据存在噪声标签的情况下,这种训练方式可以提高神经网络模型的鲁棒性。
步骤370:从样本数据集中重新选择两个批次的样本数据,重复上述训练过程,直至第一神经网络模型和第二神经网络模型收敛。
主机通过第一批次的样本数据和第二批次的样本数据,对两个神经网络模型执行一轮训练后,可以从样本数据集中重新选择两个批次的样本数据,并重复迭代上述训练过程,直到评估训练结果的损失函数的函数值小于预设损失阈值,或者,函数值不再变小。此时,可以认为第一神经网络模型和所述第二神经网络模型收敛。
在一实施例中,主机在训练第一神经网络模型时,参见图4,为本申请一实施例提供的训练第一神经网络模型的流程示意图,如图4所示,训练过程可以包括以下步骤351-步骤355。
步骤351:将第一批次的样本数据输入第二神经网络模型,获得第二神经网络模型输出的预测类别信息。
步骤352:对于干净数据集中的每一样本数据,根据样本数据的标签和对应的预测类别信息计算交叉熵。
在通过第二神经网络对第一批次的样本数据计算出预测类别信息后,针对每一样本数据,主机可以判断该样本数据是否被划分至干净数据集。一方面,如果是,说明该样本数据携带的标签可能正确,在这种情况下,主机可以计算样本数据的标签和预测类别信息之间的交叉熵。另一方面,如果该样本数据被划分至不干净数据集,说明该样本数据可能携带噪声标签,此时可以执行步骤353。
步骤353:对于不干净数据集中每一样本数据,根据样本数据的标签和对应的预测类别信息计算向量距离。
其中,向量距离可以是欧式距离、曼哈顿距离等任意一种。
步骤354:对交叉熵和向量距离加权求和,获得损失参数。
步骤355:根据损失参数对第二神经网络模型的网络参数进行调整。
针对干净数据集中的样本数据,主机可以计算交叉熵的均值;针对不干净数据集中的样本数据,主机可以计算向量距离的均值。主机可以对这两个均值进行加权求和,从而获得损失参数。损失参数就是整体损失函数的函数值,用于评估神经网络模型的预测性能。在加权求和时,为提高干净数据集的样本数据对训练过程的影响,交叉熵的均值对应的权重较大。示例性的,交叉熵的均值对应的权重为0.6,向量距离的均值对应的权重为0.4。获得损失参数后,主机可以对第二神经网络模型的网络参数进行调整,从而完成这一轮的训练。
在一实施例中,主机在训练第二神经网络模型时,参见图5,为本申请一实施例提供的训练第二神经网络模型的流程示意图,如图5所示,训练过程可以包括以下步骤361-步骤365。
步骤361:将第二批次的样本数据输入第一神经网络模型,获得第一神经网络模型输出的预测类别信息。
步骤362:对于干净数据集中的每一样本数据,根据样本数据的标签和对应的预测类别信息计算交叉熵。
在通过第一神经网络对第二批次的样本数据计算出预测类别信息后,针对每一样本数据,主机可以判断该样本数据是否被划分至干净数据集。一方面,如果是,说明该样本数据携带的标签可能正确,在这种情况下,主机可以计算样本数据的标签和预测类别信息之间的交叉熵。另一方面,如果该样本数据被划分至不干净数据集,说明该样本数据可能携带噪声标签,此时可以执行步骤363。
步骤363:对于不干净数据集中每一样本数据,根据样本数据的标签和对应的预测类别信息计算向量距离。
步骤364:对交叉熵和向量距离加权求和,获得损失参数。
步骤365:根据损失参数对第一神经网络模型的网络参数进行调整。
针对干净数据集中的样本数据,主机可以计算交叉熵的均值;针对不干净数据集中的样本数据,主机可以计算向量距离的均值。主机可以对这两个均值进行加权求和,从而获得损失参数。在加权求和时,为提高干净数据集的样本数据对训练过程的影响,交叉熵的均值对应的权重较大。获得损失参数后,主机可以对第一神经网络模型的网络参数进行调整,从而完成这一轮的训练。
在对第一神经网络模型和第二神经网络模型,通过区分干净数据集中的样本数据和不干净数据集中的样本数据,可以更充分地利用标签正确性高的样本数据中的信息,降低噪声标签的影响,从而避免对噪声标签的过拟合,提高神经网络模型的泛化能力。
在一实施例中,主机在训练神经网络模型之前,为获得更好的训练效果,可以对样本数据的标签进行调整。
在将第一批次的样本数据和第二批次的样本数据划分至干净数据集和不干净数据集之后,针对干净数据集中每一样本数据,主机可以根据该样本数据的标签、对应的预测类别信息和权重参数,更新该样本数据的标签。
这里,样本数据对应的方差是样本数据对应的高斯混合模型中真实高斯混合模型的权重参数,该权重参数的值在0到1之间,可以认为表示该样本数据的标签正确的概率。更新标签的方式可以通过如下公式(1)来表示:
Y=w*y+(1-w)p (1)
其中,Y表示更新后的标签,w表示权重参数,y表示原来的标签,p表示预测类别信息。
标签和预测类别信息是相同维度的向量,公式(1)表示针对两个向量中每一元素加权求和,从而得到新的标签。
针对不干净数据集中每一样本数据,主机可以对样本数据进行多次数据增强操作,得到多个增强样本数据。其中,数据增强操作可以包括亮度调整、旋转、压缩、翻转等手段。增强样本数据是经过数据增强操作获得的样本数据。
主机可以将任一样本数据对应的多个增强样本数据分别输入第一神经网络模型和第二神经网络模型,得到多个预测类别信息。主机可以对多个预测类别信息进行均值化处理,获得该样本数据的标签。更新标签的方式可以通过如下公式(2)来表示:
其中,Y表示更新后的标签,m表示一个样本数据经过数据增强操作后得到的增强样本数据的数量,p1(i)表示第一神经网络模型对第i个增强样本数据输出的预测类别信息,p2(i)表示第二神经网络模型对第i个增强样本数据输出的预测类别信息。
多个预测类别信息是相同维度的向量,公式(2)表示针对多个向量中每一元素计算均值,从而得到新的标签。
主机对样本数据的标签进行调整后,可以降低噪声标签对神经网络模型的影响,避免对噪声标签过拟合。经过多次迭代训练后,两个神经网络模型趋于稳定,样本数据集中样本数据的标签逐步被调整,从而持续校正噪声标签,提高神经网络模型的训练效果。
参见图6,为本申请一实施例提供的神经网络模型的训练装置的框图,如图6所示,该装置可以包括:输入模块610、拟合模块620、划分模块630、训练模块640、选择模块650。
输入模块610,用于从样本数据集中选择第一批次的样本数据输入第一神经网络模型,获得所述第一神经网络模型输出的预测类别信息;从所述样本数据集中选择第二批次的样本数据输入第二神经网络模型,获得所述第二神经网络模型输出的预测类别信息;
拟合模块620,用于针对每一样本数据,基于所述预测类别信息拟合得到对应于所述样本数据的高斯混合模型;
划分模块630,用于针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将所述样本数据划分至干净数据集或不干净数据集;
训练模块640,用于利用所述第一批次的样本数据对所述第二神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;利用所述第二批次的样本数据对所述第一神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;
选择模块650,用于从所述样本数据集中重新选择两个批次的样本数据,重复上述训练过程,直至所述第一神经网络模型和所述第二神经网络模型收敛。
上述装置中各个模块的功能和作用的实现过程具体详见上述神经网络模型的训练方法中对应步骤的实现过程,在此不再赘述。
在本申请所提供的几个实施例中,所揭露的装置和方法,也可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,附图中的流程图和框图显示了根据本申请的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。在有些作为替换的实现方式中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的***来实现,或者可以用专用硬件与计算机指令的组合来实现。
另外,在本申请各个实施例中的各功能模块可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或两个以上模块集成形成一个独立的部分。
功能如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
Claims (10)
1.一种神经网络模型的训练方法,其特征在于,包括:
从样本数据集中选择第一批次的样本数据输入第一神经网络模型,获得所述第一神经网络模型输出的预测类别信息;
从所述样本数据集中选择第二批次的样本数据输入第二神经网络模型,获得所述第二神经网络模型输出的预测类别信息;
针对每一样本数据,基于所述预测类别信息拟合得到对应于所述样本数据的高斯混合模型;
针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将所述样本数据划分至干净数据集或不干净数据集;
利用所述第一批次的样本数据对所述第二神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;
利用所述第二批次的样本数据对所述第一神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;
从所述样本数据集中重新选择两个批次的样本数据,重复上述训练过程,直至所述第一神经网络模型和所述第二神经网络模型收敛。
2.根据权利要求1所述的方法,其特征在于,所述针对每一样本数据,基于所述预测类别信息拟合得到对应于所述样本数据的高斯混合模型,包括:
针对每一样本数据,基于所述样本数据的预测类别信息中对应于各类别的置信度,拟合得到所述高斯混合模型;其中,所述高斯混合模型包括对应于真实预测类别信息的真实高斯模型、对应于虚假预测类别信息的虚假高斯模型。
3.根据权利要求1所述的方法,其特征在于,所述针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将所述样本数据划分至干净数据集或不干净数据集,包括:
针对每一样本数据,判断拟合得到的与所述真实高斯模型对应的权重参数是否达到预设权重参数阈值;
如果是,将所述样本数据划分至所述干净数据集;
如果否,将所述样本数据划分至所述不干净数据集。
4.根据权利要求1所述的方法,其特征在于,所述利用所述第一批次的样本数据对所述第二神经网络模型进行训练,包括:
将所述第一批次的样本数据输入所述第二神经网络模型,获得所述第二神经网络模型输出的预测类别信息;
对于所述干净数据集中的每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算交叉熵;
对于所述不干净数据集中每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算向量距离;
对所述交叉熵和所述向量距离加权求和,获得损失参数;
根据所述损失参数对所述第二神经网络模型的网络参数进行调整。
5.根据权利要求1所述的方法,其特征在于,所述利用所述第二批次的样本数据对所述第一神经网络模型进行训练,包括:
将所述第二批次的样本数据输入所述第一神经网络模型,获得所述第一神经网络模型输出的预测类别信息;
对于所述干净数据集中的每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算交叉熵;
对于所述不干净数据集中每一样本数据,根据所述样本数据的标签和对应的预测类别信息计算向量距离;
对所述交叉熵和所述向量距离加权求和,获得损失参数;
根据所述损失参数对所述第一神经网络模型的网络参数进行调整。
6.根据权利要求4或5所述的方法,其特征在于,在训练神经网络模型之前,所述方法还包括:
针对所述干净数据集中的每一样本数据,根据所述样本数据对应所述权重参数、所述样本数据的标签和对应的预测类别信息,更新所述样本数据的标签。
7.根据权利要求4或5所述的方法,其特征在于,在训练神经网络模型之前,所述方法还包括:
针对所述不干净数据集中每一样本数据,对所述样本数据进行多次数据增强操作,得到多个增强样本数据;
将多个增强样本数据分别输入所述第一神经网络模型和所述第二神经网络模型,得到多个预测类别信息;
对所述多个预测类别信息进行均值化处理,获得所述样本数据的标签。
8.一种神经网络模型的训练装置,其特征在于,包括:
输入模块,用于从样本数据集中选择第一批次的样本数据输入第一神经网络模型,获得所述第一神经网络模型输出的预测类别信息;从所述样本数据集中选择第二批次的样本数据输入第二神经网络模型,获得所述第二神经网络模型输出的预测类别信息;
拟合模块,用于针对每一样本数据,基于所述预测类别信息拟合得到对应于所述样本数据的高斯混合模型;
划分模块,用于针对每一样本数据,根据拟合得到的与高斯混合模型中真实高斯模型对应的权重参数,将所述样本数据划分至干净数据集或不干净数据集;
训练模块,用于利用所述第一批次的样本数据对所述第二神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;利用所述第二批次的样本数据对所述第一神经网络模型进行训练;其中,在训练过程中所述干净数据集中样本数据对应的损失计算方式,与所述不干净数据集中样本数据对应的损失计算方式不同;
选择模块,用于从所述样本数据集中重新选择两个批次的样本数据,重复上述训练过程,直至所述第一神经网络模型和所述第二神经网络模型收敛。
9.一种电子设备,其特征在于,所述电子设备包括:
处理器;
用于存储处理器可执行指令的存储器;
其中,所述处理器被配置为执行权利要求1-7任意一项所述的神经网络模型的训练方法。
10.一种计算机可读存储介质,其特征在于,所述存储介质存储有计算机程序,所述计算机程序可由处理器执行以完成权利要求1-7任意一项所述的神经网络模型的训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011019263.9A CN112149825A (zh) | 2020-09-24 | 2020-09-24 | 神经网络模型的训练方法及装置、电子设备、存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011019263.9A CN112149825A (zh) | 2020-09-24 | 2020-09-24 | 神经网络模型的训练方法及装置、电子设备、存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112149825A true CN112149825A (zh) | 2020-12-29 |
Family
ID=73896944
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011019263.9A Pending CN112149825A (zh) | 2020-09-24 | 2020-09-24 | 神经网络模型的训练方法及装置、电子设备、存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112149825A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112291122A (zh) * | 2020-12-31 | 2021-01-29 | 迈普通信技术股份有限公司 | 网络流量检测方法、装置、电子设备及可读存储介质 |
CN112988212A (zh) * | 2021-03-24 | 2021-06-18 | 厦门吉比特网络技术股份有限公司 | 神经网络模型之在线增量更新方法、装置、***及存储介质 |
CN114417987A (zh) * | 2022-01-11 | 2022-04-29 | 支付宝(杭州)信息技术有限公司 | 一种模型训练方法、数据识别方法、装置及设备 |
CN114662588A (zh) * | 2022-03-21 | 2022-06-24 | 合肥工业大学 | 一种自动更新模型的方法、***、设备及存储介质 |
-
2020
- 2020-09-24 CN CN202011019263.9A patent/CN112149825A/zh active Pending
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112291122A (zh) * | 2020-12-31 | 2021-01-29 | 迈普通信技术股份有限公司 | 网络流量检测方法、装置、电子设备及可读存储介质 |
CN112291122B (zh) * | 2020-12-31 | 2021-03-16 | 迈普通信技术股份有限公司 | 网络流量检测方法、装置、电子设备及可读存储介质 |
CN112988212A (zh) * | 2021-03-24 | 2021-06-18 | 厦门吉比特网络技术股份有限公司 | 神经网络模型之在线增量更新方法、装置、***及存储介质 |
CN114417987A (zh) * | 2022-01-11 | 2022-04-29 | 支付宝(杭州)信息技术有限公司 | 一种模型训练方法、数据识别方法、装置及设备 |
CN114662588A (zh) * | 2022-03-21 | 2022-06-24 | 合肥工业大学 | 一种自动更新模型的方法、***、设备及存储介质 |
CN114662588B (zh) * | 2022-03-21 | 2023-11-07 | 合肥工业大学 | 一种自动更新模型的方法、***、设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112149825A (zh) | 神经网络模型的训练方法及装置、电子设备、存储介质 | |
CN111523621B (zh) | 图像识别方法、装置、计算机设备和存储介质 | |
US11468262B2 (en) | Deep network embedding with adversarial regularization | |
US11816183B2 (en) | Methods and systems for mining minority-class data samples for training a neural network | |
CN111507469B (zh) | 对自动标注装置的超参数进行优化的方法和装置 | |
CN109271958B (zh) | 人脸年龄识别方法及装置 | |
CN110799995A (zh) | 数据识别器训练方法、数据识别器训练装置、程序及训练方法 | |
CN109840413B (zh) | 一种钓鱼网站检测方法及装置 | |
JPH07296117A (ja) | 減少された要素特徴部分集合を用いたパターン認識システム用の分類重みマトリックスを構成する方法 | |
CN114492279A (zh) | 一种模拟集成电路的参数优化方法及*** | |
CN117155706B (zh) | 网络异常行为检测方法及其*** | |
CN108985442B (zh) | 手写模型训练方法、手写字识别方法、装置、设备及介质 | |
CN108154186B (zh) | 一种模式识别方法和装置 | |
CN110991247B (zh) | 一种基于深度学习与nca融合的电子元器件识别方法 | |
CN115812210A (zh) | 用于增强机器学习分类任务的性能的方法和设备 | |
US20210365719A1 (en) | System and method for few-shot learning | |
CN114091597A (zh) | 基于自适应组样本扰动约束的对抗训练方法、装置及设备 | |
CN111666991A (zh) | 基于卷积神经网络的模式识别方法、装置和计算机设备 | |
JP2002251592A (ja) | パターン認識辞書学習方法 | |
Jang et al. | Improving classifier confidence using lossy label-invariant transformations | |
CN117522586A (zh) | 金融异常行为检测方法及装置 | |
CN112613550A (zh) | 一种数据分类方法、装置及相关设备 | |
US11609936B2 (en) | Graph data processing method, device, and computer program product | |
CN115936104A (zh) | 用于训练机器学习模型的方法和装置 | |
CN110866527A (zh) | 一种图像分割方法、装置、电子设备及可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20201229 |
|
RJ01 | Rejection of invention patent application after publication |