CN112633407B - 分类模型的训练方法、装置、电子设备及存储介质 - Google Patents
分类模型的训练方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN112633407B CN112633407B CN202011637604.9A CN202011637604A CN112633407B CN 112633407 B CN112633407 B CN 112633407B CN 202011637604 A CN202011637604 A CN 202011637604A CN 112633407 B CN112633407 B CN 112633407B
- Authority
- CN
- China
- Prior art keywords
- classification model
- training
- loss function
- classification
- loss
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本发明实施例提供一种分类模型的训练方法,方法包括:获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签;在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;根据所述动态损失函数,对所述分类模型进行训练。可以提高分类模型的训练效率,并根据各个类别对应的动态损失函数,提高分类模型的分类识别精确度。
Description
技术领域
本发明涉及人工智能领域,尤其涉及一种分类模型的训练方法、装置、电子设备及存储介质。
背景技术
在分类模型的训练过程中,需要使用到样本数据作为输入,在有监督的情况下使分类模型能够学习到对样本数据的分类识别。为了使分类模型具有更高的分类精度,可以增大各个类别之间的距离,减小同一类别中各个样本之间的距离,通常的做法是在损失函数中加入一个超参数margin对同一个类别的样本进行框定,将超出框定范围的类内样本进行惩罚,使后续训练时,类内样本向框定范围内靠近。但是在实际训练任务中很难设置一个适应这批数据的、通用的margin值,或者说需要大量的调参实验和权衡各个类别之间的精度才能够调试出一组使用于该数据集的margin值,这样不仅浪费了大量精力和时间去“试错”,最终导致能够找到合适的margin值的概率非常小。因此,现有超参数margin的获取方式复杂且难度高,使得分类模型的训练效率较低。
发明内容
本发明实施例提供一种分类模型的训练方法,能够在分类模型的训练过程中通过各个类别的样本分布紧凑度对类别框参数(超参数margin)进行更新,不需要进行大量的调参实验和权衡各个类别之间的精度就可以得到适合各个类别的类别框参数,从而得到各个类别对应的动态损失函数,可以提高分类模型的训练效率,并根据各个类别对应的动态损失函数,提高分类识别精确度。
第一方面,本发明实施例提供一种分类模型的训练方法,所述方法包括:
获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签,所述分类模型为行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型中的任意一个,所述训练数据为行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与所述分类模型对应的一项;
在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及
根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;
在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;
根据所述动态损失函数,对所述分类模型进行训练。
可选的,在所述根据分类层参数计算各个类别的样本分布紧凑度之前,所述方法还包括:
对所述分类模型进行初始化,得到初始化分类模型,所述初始化分类模型中的损失函数为固定损失函数;
通过所述训练数据,结合使用所述固定损失函数对所述初始化分类模型进行预设次数的预训练,以更新初始化分类模型中的分类层参数。
可选的,所述在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度,包括:
获取第n次迭代时的分类层参数以及各个类别的样本特征,n为大于0的整数;
根据所述第n次迭代时的分类层参数以及各个类别的样本特征,计算第n次迭代时的各个类别的样本分布紧凑度。
可选的,所述根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失,包括:
获取第n-1次迭代时的样本分类结果;
基于第n-2次迭代时的损失函数对第n-1次迭代时的损失函数进行预设,并使用预设的损失函数计算第n-1次迭代时对应样本的误差损失。
可选的,所述在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数,包括:
在所述第n-1次迭代时对应样本的误差损失满足预设条件时,基于所述第n次迭代时的各个类别的样本分布紧凑度,对第n-1次迭代时的损失函数中的类别框参数进行更新,得到第n次迭代时的各个类别对应的动态损失函数。
可选的,在所述在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数之前,所述方法还包括:
维护一个损失条件集合,所述损失条件集合中包括离散的损失条件值,所述离散的损失条件值按排列顺序递减;
当所述第n-1次迭代时对应样本的误差损失较所述第n-2次迭代时对应样本的误差损失为减小,且达到所述损失条件集合中损失条件值时,则确定所述第n-1次迭代时对应样本的误差损失满足预设条件。
第二方面,本发明实施例还提供一种分类模型的训练装置,所述装置包括:
获取模块,用于获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签,所述分类模型为行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型中的任意一个,所述训练数据为行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与所述分类模型对应的一项;
第一计算模块,用于在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及
第二计算模块,用于根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;
更新模块,用于在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;
训练模块,用于根据所述动态损失函数,对所述分类模型进行训练。
可选的,所述装置还包括:
初始化模块,用于对所述分类模型进行初始化,得到初始化分类模型,所述初始化分类模型中的损失函数为固定损失函数;
预训练模块,用于通过所述训练数据,结合使用所述固定损失函数对所述初始化分类模型进行预设次数的预训练,以更新初始化分类模型中的分类层参数。
第三方面,本发明实施例提供一种电子设备,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现本发明实施例提供的分类模型的训练方法中的步骤。
第四方面,本发明实施例提供一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现发明实施例提供的分类模型的训练方法中的步骤。
本发明实施例中,获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签,所述分类模型为行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型中的任意一个,所述训练数据为行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与所述分类模型对应的一项;在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;根据所述动态损失函数,对所述分类模型进行训练。能够在分类模型的训练过程中通过各个类别的样本分布紧凑度对类别框参数(超参数margin)进行更新,不需要进行大量的调参实验和权衡各个类别之间的精度就可以得到适合各个类别的类别框参数,从而得到各个类别对应的动态损失函数,可以提高分类模型的训练效率,并根据各个类别对应的动态损失函数,提高分类模型的分类识别精确度。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的一种分类模型的训练方法的流程图;
图2是本发明实施例提供的另一种分类模型的训练方法的流程图;
图3是本发明实施例提供的一种分类模型的训练装置的结构示意图;
图4是本发明实施例提供的另一种分类模型的训练装置的结构示意图;
图5是本发明实施例提供的一种第一计算模块的结构示意图;
图6是本发明实施例提供的一种第二计算模块的结构示意图;
图7是本发明实施例提供的另一种分类模型的训练装置的结构示意图;
图8是本发明实施例提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
请参见图1,图1是本发明实施例提供的一种分类模型的训练方法的流程图,如图1所示,包括以下步骤:
101、获取不同类别的训练数据对分类模型进行训练。
在本发明实施例中,上述训练数据包括不同类别的样本及类别标签。上述分类模型可以是需要对目标进行分类识别的模型,比如可以是行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型等,上述分类模型可以通过训练数据进行训练,训练数据中包括各个类别的样本与对应的类别标签。上述样本可以是对应各个类别的图像、文本、音频流等形式中的一种形式,比如,当分类模型为行人识别模型、车辆识别模型、物体检测模型、场景图像分类模型时,上述样本的形式可以是图像形式;当分类模型为文章分类模型时,上述样本的形式可以是文本形式。
上述训练数据可以是行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与分类模型对应的一项。
上述类别可以根据实际的模型需要进行确定,比如,行人识别模型中,训练数据为行人图像数据,行人图像数据中样本的类别可以是行人、车辆、背景等类别;在车辆识别模型中,训练数据为车辆图像数据,车辆图像数据中样本的类别可以是机动车、非机动车、交通信号灯、背景等类别;在物体检测模型中,训练数据为物体图像数据,物体图像数据中样本的类别可以是猫、狗、包、帽子等类别;在文章分类模型中,训练数据为文本数据,文本数据中样本的类别可以是说明文、散文、诗歌等类别;在音乐分类模型中,训练数据为音频数据,音频数据中样本的类别可以是流行音乐、说唱音乐、轻音乐等类别;在视频分类模型中,训练数据为视频数据,视频数据中样本的类别可以是记录片、爱情片、动作片等类别;在场景图像分类模型中,训练数据为场景图像数据,场景图像数据中样本的类别可以是室内监控图像,室外监控图像,仰角拍摄图像等不同场景图像。
可选的,由于训练数据中包括较多样本,即样本的数据量很大,将整个训练数据一次性输入分类模型中对分类模型进行训练的话,会存在训练速度很慢的问题。因此,在训练过程中,需要对训练数据中的样本进行批处理,将训练数据采样为若干个批数据来对分类模型进行训练,一个批数据的训练过程可以称为一次迭代过程,迭代次数与批数据的个数相同。比如,训练数据中存在10000个样本,将这10000个样本进行批处理,得到5个批数据,每个批数据中包含2000个样本,分类模型的迭代次数为5。
进一步的,可以通过样本重采样的策略使每个批数据中尽可能包含各个类别的样本。
102、在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度。
在本发明实施例中,上述训练过程可以是将批数据依次输入分类模型进行计算,进一步可以是将样本依次输入分类模型进行计算,在计算得到样本分类结果后,将样本分类结果与类别标签进行误差损失计算,得到样本分类结果与类别标签之间的误差损失进行反向传播,在反向传播过程中,通过梯度下降法来调整分类模型的模型参数,直到误差损失最小,完成对分类模型的训练。
在训练过程中,上述分类模型可以包括计算层、分类层、以及损失函数,其中,计算层的输入为样本,分类层的输入为计算层的输出,损失函数的输入为分类层的输出和类别标签。在分类层中,包括分类层参数,上述分类层参数用于对计算层的输出进行分类计算,具体计算计算层的输出与每个类别的距离,进而对计算层的输出进行分类。
上述样本分布紧凑度可以说明同一类别对应的所有样本的分布情况,样本分布紧凑度越小,则说明该类别下所有样本的分布越分散,样本分布紧凑度越大,则说明该类别下所有样本的分布越集中。对于分散的样本,则与其他类别的距离较小,更容易被误分类到其他类别;对于集中的样本,则与本类别其他样本的距离更近,与其他类别的距离较大,更容易分类到本类别。
具体的,上述样本分布紧凑度可以通过下述式子进行表示:
其中,上述IC(w)为一个类别的样本分布紧凑度(也可以称为类内紧凑度),上述w是分类层参数(也可以称为类中心向量),wk是第k个样本的向量embeding(计算层的输出),K是该个类别中样本总数,上述的s为一个预设的参数。在本发明实施例中,样本分布紧凑度IC越大,则代表了该类别的类内样本越紧凑,反之样本分布紧凑度IC越小,则代表了该类别的类内样本越松散。
103、根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失。
在本发明实施例中,上述样本分类结果可以是分类层的输出经过归一化后的分类结果,具体可以是通过归一化函数进行归一化的。上述预设的损失函数可以是softmax交叉熵损失函数,根据如下述式子所示:
其中,上述zy表示第i个样本对应的类别标签,zi表示第i个样本对应的样本分类结果,C表示类别的数量。
在本发明实施例中,上述预设的损失函数可以是添加了类别框参数margin的交叉熵损失函数,根据如下述式子所示:
其中,上述sn表示当前样本距本类别中心的距离,m表示类别框参数margin,γ是预设的超参数,N为类别的数量。
104、在误差损失满足预设条件时,基于各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数。
在本发明实施例中,上述误差损失所满足的预设条件可以是误差损失背叛到一个或多个预设值,或者误差损失为第j次迭代的误差损失。
可选的,上述误差损失可以是上一次迭代时的误差损失,上述各个类别的样本分布紧凑度可以是当前次迭代时的各个类别的样本分布紧凑度。上述的更新可以是:
其中,ICN为第n类别的样本分布紧凑度,上述类别框参数margin为默认的预设值。通过对添加了类别框参数margin的交叉熵损失函数进行更新,得到各个类别对应的动态损失函数。进而可以通过各个类别对应的动态损失函数,对分类模型进行训练。
在本发明实施例中,获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签;在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;根据所述动态损失函数,对所述分类模型进行训练。能够在分类模型的训练过程中通过各个类别的样本分布紧凑度对类别框参数(超参数margin)进行更新,不需要进行大量的调参实验和权衡各个类别之间的精度就可以得到适合各个类别的类别框参数,从而得到各个类别对应的动态损失函数,可以提高分类模型的训练效率,并根据各个类别对应的动态损失函数,提高分类模型的分类识别精确度。
需要说明的是,本发明实施例提供的分类模型的训练方法可以应用于可以进行分类模型的训练的手机、监控器、计算机、服务器等设备。
可选的,请参见图2,图2是本发明实施例提供的另一种分类模型的训练方法的流程图,如图2所示,在图1实施例的基础上,分类模型的训练方法还包括以下步骤:
201、对分类模型进行初始化,得到初始化分类模型。
在本发明实施例中,上述初始化分类模型中的损失函数为固定损失函数。比如,上述固定损失函数可以是softmax交叉熵损失函数,根据如下述式子所示:
其中,上述zy表示第i个样本对应的类别标签,zi表示第i个样本对应的样本分类结果,C表示类别的数量。
在一种可能的实施例中,上述固定损失函数可以是添加了类别框参数margin的交叉熵损失函数,根据如下述式子所示:
其中,上述sn表示当前样本距本类别中心的距离,m表示类别框参数margin,γ是预设的超参数,N为类别的数量。需要说明的是,在预训练过程中,上述类别框参数margin是固定的。
上述的初始化可以是默认值初始化或随机初始化,上述默认值初始化可以理解为将分类模型的参数按用户预先设置的参数进行初始化,上述随机初始化可以理解为将分类模型的参数按随机值进行初始化。
202、通过训练数据,结合使用固定损失函数对初始化分类模型进行预设次数的预训练,以更新初始化分类模型中的分类层参数。
上述预训练可以理解为训练过程开始的前几次迭代,此时,由于分类模型中的分类层参数还没有作为一个类别的中心的条件,因此,可以使用固定损失函数对初始化分类模型进行预设次数的预训练。上述预设训练次数可以是由用户进行指定,在预训练过程,也会不断调整更新分类模型中的分类层参数和计算层参数。
可以理解的是,通过预训练的分类模型,已经具有一定的分类能力。
可选的,在预训练到预设次数后,分类层参数可以初步做为一个类别的类别中心,此时,则可以取分类层参数与计算层输出来计算各个类别的样本分布紧凑度。
进一步的,可以假设当前迭代次数为第n次迭代,可以获取第n-1次迭代时的样本分类结果;基于第n-2次迭代时的损失函数对第n-1次迭代时的损失函数进行预设,并使用预设的损失函数计算第n-1次迭代时对应样本的误差损失。可以理解为的是,当前迭代时的损失函数可以在上一次迭代时的损失函数的基础上进行确定。在预训练阶段,当前迭代时的损失函数与上一次迭代时的损失函数可以是相同的固定损失函数。在预训练阶段之后,当前迭代时的损失函数则为上一次迭代时的损失函数,通过当前迭代时的各个类别的样本分布紧凑度进行更新来得到。
进一步的,在训练过程中,可以获取第n次迭代时的分类层参数以及各个类别的样本特征,上述样本特征为计算层的输出;根据第n次迭代时的分类层参数以及各个类别的样本特征,计算第n次迭代时的各个类别的样本分布紧凑度。
举例来说,假设分类模型训练所使用的训练数据为自然场景数据,自然场景数据一般都会出现样本分布不均衡的现象,具体表现为:某一些场景的样本数据量非常大,而其他部分场景的样本数据量非常小。进一步假设训练任务所用的训练数据可以包含N个不同场景:{D1,D2,…,DN},例如室内监控图像,室外监控图像,仰角拍摄图像等不同场景。每个场景的类别可以是{M_1,M_2,M_N}。在对分类模型进行预训练后,得到一个具有一定分类能力的分类模型,可以取到该分类模型分类层参数w,此时,该分类层参数w在一定程度上可以代表对应类别的类别中心,然后根据当前分类模型得到所有场景中所有样本的向量wk,计算每一个类别的样本分布紧凑度为{IC1,IC2,...,ICi},其中,上述ICi=avg(ICi1,ICi2…,ICiM_i),ICi表示第i个类别的样本分布紧凑度,ICiM_i表示第M_i个场景的样本分布紧凑度。
进一步的,可以在第n-1次迭代时对应样本的误差损失满足预设条件时,基于第n次迭代时的各个类别的样本分布紧凑度,对第n-1次迭代时的损失函数中的类别框参数进行更新,得到第n次迭代时的各个类别对应的动态损失函数。上述第n次迭代可以理解为当前次迭代,在当前次迭代过程中,最后一步便是通过动态损失函数,计算分类模型的误差损失,并根据该误差损失反向传播,通过梯度下降方法来调整该分类模型的参数,因此,动态损失函数是基于上一次迭代时的损失函数进行更新的。
具体的,可以维护一个损失条件集合,所述损失条件集合中包括离散的损失条件值,所述离散的损失条件值按排列顺序递减,上述的维护可以理解为创建一个损失条件集合,并保持该损失条件集合在内存中不会消亡;当第n-1次迭代时对应样本的误差损失较第n-2次迭代时对应样本的误差损失为减小,且达到损失条件集合中损失条件值时,则确定第n-1次迭代时对应样本的误差损失满足预设条件。可以理解的是,更新样本分布紧凑度IC可以是是根据训练过程中误差损失来判定的,具体的,可以根据经验设置一个Loss值集{loss_01,loss_02,…,loss_T}等一系列递减损失条件值,每当当前迭代时的误差损失减小到集合中的第i个损失条件值时,就进行一次样本分布紧凑度IC值更新,进而对损失函数进行动态更新。
在本发明实施例中,可以根据不同场景或类别数据margin设置规则,为margin超参数调整给出了明确可行的设置规则,同时引入各个类别对应的样本分布紧凑度来衡量各个类别的margin设置是否合适,从而可以直接提高模型在样本分布紧凑度较小的类别上的分类精度。
请参见图3,图3是本发明实施例提供的一种分类模型的训练装置的结构示意图,如图3所示,所述装置包括:
获取模块301,用于获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签,所述分类模型为行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型中的任意一个,所述训练数据为行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与所述分类模型对应的一项;
第一计算模块302,用于在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及
第二计算模块303,用于根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;
更新模块304,用于在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;
训练模块305,用于根据所述动态损失函数,对所述分类模型进行训练。
可选的,如图4所示,所述装置还包括:
初始化模块306,用于对所述分类模型进行初始化,得到初始化分类模型,所述初始化分类模型中的损失函数为固定损失函数;
预训练模块307,用于通过所述训练数据,结合使用所述固定损失函数对所述初始化分类模型进行预设次数的预训练,以更新初始化分类模型中的分类层参数。
可选的,如图5所示,第一计算模块302,包括:
第一获取单元3021,用于获取第n次迭代时的分类层参数以及各个类别的样本特征;
第一计算单元3022,用于根据所述第n次迭代时的分类层参数以及各个类别的样本特征,计算第n次迭代时的各个类别的样本分布紧凑度。
可选的,如图6所示,所述第二计算模块303,包括:
第二获取单元3031,用于获取第n-1次迭代时的样本分类结果;
第二计算单元3032,用于基于第n-2次迭代时的损失函数对第n-1次迭代时的损失函数进行预设,并使用预设的损失函数计算第n-1次迭代时对应样本的误差损失。
可选的,更新模块304还用于在所述第n-1次迭代时对应样本的误差损失满足预设条件时,基于所述第n次迭代时的各个类别的样本分布紧凑度,对第n-1次迭代时的损失函数中的类别框参数进行更新,得到第n次迭代时的各个类别对应的动态损失函数。
可选的,如图7所示,所述装置还包括:
维护模块308,用于维护一个损失条件集合,所述损失条件集合中包括离散的损失条件值,所述离散的损失条件值按排列顺序递减;
确定模块309,用于当所述第n-1次迭代时对应样本的误差损失较所述第n-2次迭代时对应样本的误差损失为减小,且达到所述损失条件集合中损失条件值时,则确定所述第n-1次迭代时对应样本的误差损失满足预设条件。
需要说明的是,本发明实施例提供的分类模型的训练装置可以应用于可以进行分类模型的训练的手机、监控器、计算机、服务器等设备。
本发明实施例提供的分类模型的训练装置能够实现上述方法实施例中分类模型的训练方法实现的各个过程,且可以达到相同的有益效果。为避免重复,这里不再赘述。
参见图8,图8是本发明实施例提供的一种电子设备的结构示意图,如图8所示,包括:存储器802、处理器801及存储在所述存储器802上并可在所述处理器801上运行的计算机程序,其中:
处理器801用于调用存储器802存储的计算机程序,执行如下步骤:
获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签,所述分类模型为行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型中的任意一个,所述训练数据为行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与所述分类模型对应的一项;
在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及
根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;
在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;
根据所述动态损失函数,对所述分类模型进行训练。
可选的,在所述根据分类层参数计算各个类别的样本分布紧凑度之前,处所述理器801执行的方法还包括:
对所述分类模型进行初始化,得到初始化分类模型,所述初始化分类模型中的损失函数为固定损失函数;
通过所述训练数据,结合使用所述固定损失函数对所述初始化分类模型进行预设次数的预训练,以更新初始化分类模型中的分类层参数。
可选的,处理器801执行的所述在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度,包括:
获取第n次迭代时的分类层参数以及各个类别的样本特征;
根据所述第n次迭代时的分类层参数以及各个类别的样本特征,计算第n次迭代时的各个类别的样本分布紧凑度。
可选的,处理器801执行的所述根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失,包括:
获取第n-1次迭代时的样本分类结果;
基于第n-2次迭代时的损失函数对第n-1次迭代时的损失函数进行预设,并使用预设的损失函数计算第n-1次迭代时对应样本的误差损失。
可选的,处理器801执行的所述在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数,包括:
在所述第n-1次迭代时对应样本的误差损失满足预设条件时,基于所述第n次迭代时的各个类别的样本分布紧凑度,对第n-1次迭代时的损失函数中的类别框参数进行更新,得到第n次迭代时的各个类别对应的动态损失函数。
可选的,在所述在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数之前,所述处理器801执行的方法还包括:
维护一个损失条件集合,所述损失条件集合中包括离散的损失条件值,所述离散的损失条件值按排列顺序递减;
当所述第n-1次迭代时对应样本的误差损失较所述第n-2次迭代时对应样本的误差损失为减小,且达到所述损失条件集合中损失条件值时,则确定所述第n-1次迭代时对应样本的误差损失满足预设条件。
需要说明的是,上述电子设备可以是可以应用于可以进行分类模型的训练的手机、监控器、计算机、服务器等设备。
本发明实施例提供的电子设备能够实现上述方法实施例中分类模型的训练方法实现的各个过程,且可以达到相同的有益效果,为避免重复,这里不再赘述。
本发明实施例还提供一种计算机可读存储介质,计算机可读存储介质上存储有计算机程序,该计算机程序被处理器执行时实现本发明实施例提供的分类模型的训练方法的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,所述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)或随机存取存储器(Random AccessMemory,简称RAM)等。
以上所揭露的仅为本发明较佳实施例而已,当然不能以此来限定本发明之权利范围,因此依本发明权利要求所作的等同变化,仍属本发明所涵盖的范围。
Claims (10)
1.一种分类模型的训练方法,其特征在于,包括以下步骤:
获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签,所述分类模型为行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型中的任意一个,所述训练数据为行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与所述分类模型对应的一项;
在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及
根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;
在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;
根据所述动态损失函数,对所述分类模型进行训练。
2.如权利要求1所述的方法,其特征在于,在所述根据分类层参数计算各个类别的样本分布紧凑度之前,所述方法还包括:
对所述分类模型进行初始化,得到初始化分类模型,所述初始化分类模型中的损失函数为固定损失函数;
通过所述训练数据,结合使用所述固定损失函数对所述初始化分类模型进行预设次数的预训练,以更新初始化分类模型中的分类层参数。
3.如权利要求2所述的方法,其特征在于,所述在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度,包括:
获取第n次迭代时的分类层参数以及各个类别的样本特征,n为大于0的整数;
根据所述第n次迭代时的分类层参数以及各个类别的样本特征,计算第n次迭代时的各个类别的样本分布紧凑度。
4.如权利要求3所述的方法,其特征在于,所述根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失,包括:
获取第n-1次迭代时的样本分类结果;
基于第n-2次迭代时的损失函数对第n-1次迭代时的损失函数进行预设,并使用预设的损失函数计算第n-1次迭代时对应样本的误差损失。
5.如权利要求4所述的方法,其特征在于,所述在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数,包括:
在所述第n-1次迭代时对应样本的误差损失满足预设条件时,基于所述第n次迭代时的各个类别的样本分布紧凑度,对第n-1次迭代时的损失函数中的类别框参数进行更新,得到第n次迭代时的各个类别对应的动态损失函数。
6.如权利要求5所述的方法,其特征在于,在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数之前,所述方法还包括:
维护一个损失条件集合,所述损失条件集合中包括离散的损失条件值,所述离散的损失条件值按排列顺序递减;
当所述第n-1次迭代时对应样本的误差损失较所述第n-2次迭代时对应样本的误差损失为减小,且达到所述损失条件集合中损失条件值时,则确定所述第n-1次迭代时对应样本的误差损失满足预设条件。
7.一种分类模型的训练装置,其特征在于,所述装置包括:
获取模块,用于获取不同类别的训练数据对分类模型进行训练,所述训练数据包括不同类别的样本及类别标签,所述分类模型为行人识别模型、车辆识别模型、物体检测模型、文章分类模型、音乐分类模型、视频分类模型、场景图像分类模型中的任意一个,所述训练数据为行人图像数据、车辆图像数据、物体图像数据、文本数据、音频数据、视频数据、场景图像数据中与所述分类模型对应的一项;
第一计算模块,用于在训练过程中,根据分类层参数计算各个类别的样本分布紧凑度;以及
第二计算模块,用于根据样本分类结果与类别标签,使用预设的损失函数计算对应样本的误差损失;
更新模块,用于在所述误差损失满足预设条件时,基于所述各个类别的样本分布紧凑度,对损失函数中的类别框参数进行更新,得到各个类别对应的动态损失函数;
训练模块,用于根据所述动态损失函数,对所述分类模型进行训练。
8.如权利要求7所述的装置,其特征在于,所述装置还包括:
初始化模块,用于对所述分类模型进行初始化,得到初始化分类模型,所述初始化分类模型中的损失函数为固定损失函数;
预训练模块,用于通过所述训练数据,结合使用所述固定损失函数对所述初始化分类模型进行预设次数的预训练,以更新初始化分类模型中的分类层参数。
9.一种电子设备,其特征在于,包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1至6中任一项所述的分类模型的训练方法中的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至6中任一项所述的分类模型的训练方法中的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011637604.9A CN112633407B (zh) | 2020-12-31 | 2020-12-31 | 分类模型的训练方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011637604.9A CN112633407B (zh) | 2020-12-31 | 2020-12-31 | 分类模型的训练方法、装置、电子设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112633407A CN112633407A (zh) | 2021-04-09 |
CN112633407B true CN112633407B (zh) | 2023-10-13 |
Family
ID=75290482
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011637604.9A Active CN112633407B (zh) | 2020-12-31 | 2020-12-31 | 分类模型的训练方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112633407B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113326889A (zh) * | 2021-06-16 | 2021-08-31 | 北京百度网讯科技有限公司 | 用于训练模型的方法和装置 |
Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
JP2011141674A (ja) * | 2010-01-06 | 2011-07-21 | Hitachi Ltd | ソフトウェア品質指標値管理システム、ソフトウェア品質指標値の真値を推定する推定方法及び推定プログラム |
CN108304859A (zh) * | 2017-12-29 | 2018-07-20 | 达闼科技(北京)有限公司 | 图像识别方法及云端*** |
CN110321965A (zh) * | 2019-07-10 | 2019-10-11 | 腾讯科技(深圳)有限公司 | 物体重识别模型的训练方法、物体重识别的方法及装置 |
CN110705489A (zh) * | 2019-10-09 | 2020-01-17 | 北京迈格威科技有限公司 | 目标识别网络的训练方法、装置、计算机设备和存储介质 |
CN110751197A (zh) * | 2019-10-14 | 2020-02-04 | 上海眼控科技股份有限公司 | 图片分类方法、图片模型训练方法及设备 |
CN111079790A (zh) * | 2019-11-18 | 2020-04-28 | 清华大学深圳国际研究生院 | 一种构建类别中心的图像分类方法 |
CN111144566A (zh) * | 2019-12-30 | 2020-05-12 | 深圳云天励飞技术有限公司 | 神经网络权重参数的训练方法、特征分类方法及对应装置 |
CN111160538A (zh) * | 2020-04-02 | 2020-05-15 | 北京精诊医疗科技有限公司 | 一种损失函数中margin参数值的更新方法和*** |
CN111553399A (zh) * | 2020-04-21 | 2020-08-18 | 佳都新太科技股份有限公司 | 特征模型训练方法、装置、设备及存储介质 |
WO2020221278A1 (zh) * | 2019-04-29 | 2020-11-05 | 北京金山云网络技术有限公司 | 视频分类方法及其模型的训练方法、装置和电子设备 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108615044A (zh) * | 2016-12-12 | 2018-10-02 | 腾讯科技(深圳)有限公司 | 一种分类模型训练的方法、数据分类的方法及装置 |
CN111950279B (zh) * | 2019-05-17 | 2023-06-23 | 百度在线网络技术(北京)有限公司 | 实体关系的处理方法、装置、设备及计算机可读存储介质 |
-
2020
- 2020-12-31 CN CN202011637604.9A patent/CN112633407B/zh active Active
Patent Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
JP2011141674A (ja) * | 2010-01-06 | 2011-07-21 | Hitachi Ltd | ソフトウェア品質指標値管理システム、ソフトウェア品質指標値の真値を推定する推定方法及び推定プログラム |
CN108304859A (zh) * | 2017-12-29 | 2018-07-20 | 达闼科技(北京)有限公司 | 图像识别方法及云端*** |
WO2020221278A1 (zh) * | 2019-04-29 | 2020-11-05 | 北京金山云网络技术有限公司 | 视频分类方法及其模型的训练方法、装置和电子设备 |
CN110321965A (zh) * | 2019-07-10 | 2019-10-11 | 腾讯科技(深圳)有限公司 | 物体重识别模型的训练方法、物体重识别的方法及装置 |
CN110705489A (zh) * | 2019-10-09 | 2020-01-17 | 北京迈格威科技有限公司 | 目标识别网络的训练方法、装置、计算机设备和存储介质 |
CN110751197A (zh) * | 2019-10-14 | 2020-02-04 | 上海眼控科技股份有限公司 | 图片分类方法、图片模型训练方法及设备 |
CN111079790A (zh) * | 2019-11-18 | 2020-04-28 | 清华大学深圳国际研究生院 | 一种构建类别中心的图像分类方法 |
CN111144566A (zh) * | 2019-12-30 | 2020-05-12 | 深圳云天励飞技术有限公司 | 神经网络权重参数的训练方法、特征分类方法及对应装置 |
CN111160538A (zh) * | 2020-04-02 | 2020-05-15 | 北京精诊医疗科技有限公司 | 一种损失函数中margin参数值的更新方法和*** |
CN111553399A (zh) * | 2020-04-21 | 2020-08-18 | 佳都新太科技股份有限公司 | 特征模型训练方法、装置、设备及存储介质 |
Non-Patent Citations (2)
Title |
---|
基于多任务学习的深层人脸识别算法;杨恢先;激光与光电子学进展;第56卷(第18期);全文 * |
基于自适应角度损失函数的深度人脸识别算法研究;姬东飞;丁学明;;计算机应用研究(第10期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN112633407A (zh) | 2021-04-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111401516B (zh) | 一种神经网络通道参数的搜索方法及相关设备 | |
Zhang et al. | Dynamic R-CNN: Towards high quality object detection via dynamic training | |
CN109583501B (zh) | 图片分类、分类识别模型的生成方法、装置、设备及介质 | |
CN110633745B (zh) | 一种基于人工智能的图像分类训练方法、装置及存储介质 | |
CN110991652A (zh) | 神经网络模型训练方法、装置及电子设备 | |
CN113570029A (zh) | 获取神经网络模型的方法、图像处理方法及装置 | |
CN112101544A (zh) | 适用于长尾分布数据集的神经网络的训练方法和装置 | |
CN110930996B (zh) | 模型训练方法、语音识别方法、装置、存储介质及设备 | |
US10810464B2 (en) | Information processing apparatus, information processing method, and storage medium | |
KR20210155824A (ko) | 적응적 하이퍼파라미터 세트를 이용한 멀티스테이지 학습을 통해 자율 주행 자동차의 머신 러닝 네트워크를 온디바이스 학습시키는 방법 및 이를 이용한 온디바이스 학습 장치 | |
CN113962965A (zh) | 图像质量评价方法、装置、设备以及存储介质 | |
CN114842343A (zh) | 一种基于ViT的航空图像识别方法 | |
CN114972850A (zh) | 多分支网络的分发推理方法、装置、电子设备及存储介质 | |
CN112633407B (zh) | 分类模型的训练方法、装置、电子设备及存储介质 | |
KR20230088714A (ko) | 개인화된 뉴럴 네트워크 프루닝 | |
CN115082752A (zh) | 基于弱监督的目标检测模型训练方法、装置、设备及介质 | |
KR20240034804A (ko) | 자동 회귀 언어 모델 신경망을 사용하여 출력 시퀀스 평가 | |
CN113449840A (zh) | 神经网络训练方法及装置、图像分类的方法及装置 | |
Makwe et al. | An empirical study of neural network hyperparameters | |
CN113361384A (zh) | 人脸识别模型压缩方法、设备、介质及计算机程序产品 | |
CN116128044A (zh) | 一种模型剪枝方法、图像处理方法及相关装置 | |
CN114566184A (zh) | 音频识别方法及相关装置 | |
CN115205573A (zh) | 图像处理方法、装置及设备 | |
CN111898465B (zh) | 一种人脸识别模型的获取方法和装置 | |
CN111310823B (zh) | 目标分类方法、装置和电子*** |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |