CN112784677A - 模型训练方法及装置、存储介质、计算设备 - Google Patents
模型训练方法及装置、存储介质、计算设备 Download PDFInfo
- Publication number
- CN112784677A CN112784677A CN202011415641.5A CN202011415641A CN112784677A CN 112784677 A CN112784677 A CN 112784677A CN 202011415641 A CN202011415641 A CN 202011415641A CN 112784677 A CN112784677 A CN 112784677A
- Authority
- CN
- China
- Prior art keywords
- model
- category
- reference model
- probability
- error
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V40/00—Recognition of biometric, human-related or animal-related patterns in image or video data
- G06V40/10—Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
- G06V40/103—Static body considered as a whole, e.g. static pedestrian or occupant recognition
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Human Computer Interaction (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
一种模型训练方法及装置、存储介质、计算设备,模型训练方法包括:将训练数据输入至构建好的基准模型和老师模型,基准模型的网络层数小于老师模型的网络层数;获取基准模型针对训练数据的第一输出结果和老师模型针对训练数据的第二输出结果;基于每一类别的第一分类概率生成非该类别的第三分类概率,以及基于每一类别第二分类概率生成非该类别的第四分类概率;利用每一类别下的第一概率分布和第二概率分布计算KL散度,以及计算基准模型自身的误差;利用KL散度以及基准模型自身的误差在基准模型中进行反向传播,以用于调整基准模型的网络参数。本发明技术方案能够提升模型分类效果的准确性和实时性。
Description
技术领域
本发明涉及数据处理技术领域,尤其涉及一种模型训练方法及装置、存储介 质、计算设备。
背景技术
对于数据的特征提取和分类,通常是利用深度网络模型来实现的,尤其 是行人属性数据。行人属性就像每个人随身携带的特性,好的模型可以极大 的提升其应用场景。
目前市面上主流的行人属性模型基本上是通过摄像头采集视频,通过行 人检测模块得到行人框,再通过属性识别模块得到行人属性。
但是,目前的人体属性模型往往在人体人检测模型之后,较依赖于检测 框,而实际场景比较复杂,人体检测模型的效果很难保证,导致在人体部分 缺失或者误检人体的情况下,行人属性预测效果很不理想(在大量现有开源 接口上做过实验得出此结论)。其二,市面上的行人属性模型往往在要求准确 率高的同时牺牲实时性,反之亦然。其三,行人属性模型在跨域场景下泛化 力不强。
发明内容
本发明解决的技术问题是如何通过模型训练提升模型分类效果的准确性 和实时性。
为解决上述技术问题,本发明实施例提供一种模型训练方法,模型训练 方法包括:将训练数据输入至构建好的基准模型和老师模型,所述基准模型 的网络层数小于所述老师模型的网络层数;获取所述基准模型针对所述训练 数据的第一输出结果和老师模型针对所述训练数据的第二输出结果,所述第 一输出结果包括针对每一类别的第一分类概率,所述第二输出结果包括针对 每一类别第二分类概率;基于每一类别的第一分类概率生成非该类别的第三 分类概率,以及基于每一类别第二分类概率生成非该类别的第四分类概率, 以得到每一类别的第一概率分布和第二概率分布,所述第一概率分布包括各 个类别及其第一分类概率、非该类别及其第三分类概率,所述第二概率分布 包括各个类别及其第二分类概率、非该类别及其概率;利用每一类别下的第 一概率分布和第二概率分布计算KL散度,以及计算所述基准模型自身的误差; 利用所述KL散度以及所述基准模型自身的误差在所述基准模型中进行反向 传播,以用于调整所述基准模型的网络参数。
可选的,所述利用所述KL散度以及所述基准模型自身的误差在所述基准 模型中进行反向传播包括:计算所述KL散度与第一权重的乘积以及所述基准 模型自身的误差与第二权重的乘积之和,以作为响应误差;利用所述响应误 差在所述基准模型中进行反向传播。
可选的,所述计算所述基准模型自身的误差包括:采用Focal loss计算所 述基准模型自身的误差。
可选的,所述计算所述基准模型自身的误差包括:获取所述训练数据针 对每一类别的样本比例,所述样本比例为包含该类别的样本数与在该类别下 有效样本总数量的比值;根据所述第一输出结果计算所述基准模型的原始误 差;将所述原始误差与所述样本比例进行加权,以得到所述准模型自身的误 差。
可选的,所述将训练数据输入至构建好的基准模型和老师模型之前还包 括:获取原始样本数据,所述原始样本数据为标注好的行人图像,所述原始 样本数据中包括关键点;根据所述原始样本数据的关键点的坐标,将行人的 上半身图像或下半身图像进行随机擦除,并更改所述行人图像中的属性值, 以得到所述训练数据。
可选的,利用行人重识别模型作为人体模型的预训练模型,所述基准模 型中网络架构的Backbone中的网络参数是直接调用所述行人重识别模型中的 网络参数的。
可选的,所述将训练数据输入至构建好的基准模型和老师模型之前还包 括:获取原始样本数据,所述原始样本数据包括具有多种属性的样本;
将具备第一属性的样本输入至预先训练好的生成式对抗网络,以生成具 备第二属性的样本,所述第一属性与所述第二属性属于同一类别,所述具备 第二属性的样本为数量小于预设门限的样本。
可选的,所述构建好的基准模型在初始化时采用Kaiming算法初始化权 重,所述构建好的基准模型在全连接层采用Normal初始化权重。
可选的,所述基准模型是基于ResNet18构建的,所述老师模型是基于 ResNet101构建的。
为解决上述技术问题,本发明实施例还提供了一种模型训练装置,模型 训练装置包括:输入模块,用于将训练数据输入至构建好的基准模型和老师 模型,所述基准模型的网络层数小于所述老师模型的网络层数;输出结果获 取模块,用于获取所述基准模型针对所述训练数据的第一输出结果和老师模 型针对所述训练数据的第二输出结果,所述第一输出结果包括针对每一类别 的第一分类概率,所述第二输出结果包括针对每一类别第二分类概率;概率 生成模块,用于基于每一类别的第一分类概率生成非该类别的第三分类概率, 以及基于每一类别第二分类概率生成非该类别的第四分类概率,以得到每一 类别的第一概率分布和第二概率分布,所述第一概率分布包括各个类别及其 第一分类概率、非该类别及其第三分类概率,所述第二概率分布包括各个类 别及其第二分类概率、非该类别及其概率;KL散度计算模块,用于利用每一 类别下的第一概率分布和第二概率分布计算KL散度,以及计算所述基准模型 自身的误差;参数调整模块,用于利用所述KL散度以及所述基准模型自身的 误差在所述基准模型中进行反向传播,以用于调整所述基准模型的网络参数。
本发明实施例还提供了一种存储介质,其上存储有计算机程序,所述计 算机程序被处理器运行时执行所述模型训练方法的步骤。
本发明实施例还提供了一种计算设备,包括存储器和处理器,所述存储 器上存储有可在所述处理器上运行的计算机程序,所述处理器运行所述计算 机程序时执行所述模型训练方法的步骤。
与现有技术相比,本发明实施例的技术方案具有以下有益效果:
本发明技术方案中,分别对具有不同网络层数的基准模型和老师模型分 别输入训练数据,并对两个模型的输出结果的概率分布计算KL散度,以用于 基准模型的反向传播,最终得到网络参数优化后的精准模型。由于精准模型 的网络层数较少,因此运行较快,可以保证实时性;并且由于精准模型是利 用网络层数较多的老师模型进行参数调整的,因此可以保证分类准确性,也 即本发明技术方案训练完成的精准模型可以兼顾数据分类的实时性和准确性。
进一步地,获取所述训练数据针对每一类别的样本比例,所述样本比例 为包含该类别的样本数与在该类别下有效样本总数量的比值;根据所述第一 输出结果计算所述基准模型的原始误差;将所述原始误差与所述样本比例进 行加权,以得到所述准模型自身的误差。本发明技术方案在计算反向传播所 使用的误差时,将样本比例加权至原始误差,可以保证对数量较少的样本的 训练效果,进而提升精准模型对所有数据最终的分类准确率。
进一步地,获取原始样本数据,所述原始样本数据为标注好的行人图像; 将所述原始样本数据中行人的上半身图像或下半身图像进行随机擦除,并随 机更改所述行人图像中的属性值,以得到所述训练数据。本发明技术方案通 过对样本数据进行在线扩增,也即随机擦除,实现样本类型的多样化,从而 提升训练效果,提升最终训练完成的精准模型在人体部分缺失或者误检人体 的情况下的分类效果。
进一步地,获取原始样本数据,所述原始样本数据包括具有多种属性的 样本;将具备第一属性的样本输入至预先训练好的生成式对抗网络,以生成 具备第二属性的样本,所述第一属性与所述第二属性属于同一类别,所述具 备第二属性的样本为数量小于预设门限的样本。为了保证训练效果,在样本 数量较少或缺失的情况下,本发明技术方案使用生成式对抗网络实现对上述 样本的补充,保证样本的全面性和多样性,进而保证模型训练效果。
附图说明
图1是本发明实施例一种模型训练方法的流程图;
图2是图1所示步骤S104的一种具体实施方式的流程图;
图3是本发明实施例一种模型训练方法的具体实施方式的部分流程图;
图4是本发明实施例一种模型网络架构的示意图;
图5是本发明实施例一种模型训练装置的结构示意图。
具体实施方式
如背景技术中所述,目前的人体属性模型往往在人体人检测模型之后, 较依赖于检测框,而实际场景比较复杂,人体检测模型的效果很难保证,导 致在人体部分缺失或者误检人体的情况下,行人属性预测效果很不理想(在 大量现有开源接口上做过实验得出此结论)。其二,市面上的行人属性模型往 往在要求准确率高的同时牺牲实时性,反之亦然。其三,行人属性模型在跨 域场景下泛化力不强。
本发明技术方案中,首先提出了一个比较强的基准模型流程,然后在基 准模型基础上采用知识蒸馏技术优化基准模型。
首先,关于基准模型的设计,在设计基准模型时,采用Resnet18作为网 络架构(Backbone),之后接平均池化层,再接全连接层,最后输出层。在训练 阶段,使用Focalloss、Sample ratio、基于关键点的数据在线扩增、基于Reid 模型的预训练、GAN等技术来实现一个很强的基准模型。此基准模型可以解 决人体部分缺失、行人跨域场景下行人识别等目前市面模型解决不了的问题。
其次,在知识蒸馏技术方面,分别对具有网络数少的基准模型和网络数 多的老师模型分别输入训练数据,并对两个模型的输出结果的概率分布计算 KL散度,以用于基准模型的反向传播,最终得到网络参数优化后的精准模型。 由于精准模型的网络层数较少,因此运行较快,可以保证实时性;并且由于 精准模型是利用网络层数较多的老师模型进行参数调整的,因此可以保证分 类准确性,也即本发明技术方案训练完成的精准模型可以兼顾数据分类的实 时性和准确性。
为使本发明的上述目的、特征和优点能够更为明显易懂,下面结合附图 对本发明的具体实施例做详细的说明。
图1是本发明实施例一种模型训练方法的流程图。
本发明技术方案中可以用于计算设备,也即可以由该计算设备执行所述 方法的各个步骤。所述计算设备可以是各种恰当的终端,例如手机、电脑、 物联网设备等,但并不限于此。
具体而言,所述模型训练方法可以包括以下步骤:
步骤S101:将训练数据输入至构建好的基准模型和老师模型,所述基准 模型的网络层数小于所述老师模型的网络层数;
步骤S102:获取所述基准模型针对所述训练数据的第一输出结果和老师 模型针对所述训练数据的第二输出结果,所述第一输出结果包括针对每一类 别的第一分类概率,所述第二输出结果包括针对每一类别第二分类概率;
步骤S103:基于每一类别的第一分类概率生成非该类别的第三分类概率, 以及基于每一类别第二分类概率生成非该类别的第四分类概率,以得到每一 类别的第一概率分布和第二概率分布,所述第一概率分布包括各个类别及其 第一分类概率、非该类别及其第三分类概率,所述第二概率分布包括各个类 别及其第二分类概率、非该类别及其概率;
步骤S104:利用每一类别下的第一概率分布和第二概率分布计算KL散 度,以及计算所述基准模型自身的误差;
步骤S105:利用所述KL散度以及所述基准模型自身的误差在所述基准 模型中进行反向传播,以用于调整所述基准模型的网络参数。
需要指出的是,本实施例中各个步骤的序号并不代表对各个步骤的执行 顺序的限定。
本实施例中,训练数据可以是预先标注好的数据,例如可以是预先标注 好的行人图像。
在步骤S101的具体实施中,可以预先构建好基准模型和老师模型。基准 模型的网络层数小于所述老师模型的网络层数。其中,模型的网络层数越大, 模型的精准性越高,但模型的运行速度也越慢。本发明实施例所要实现的正 是使网络层数较少的模型拥有网络层数较多的模型的分类精准性。
在一个具体的例子中,所述基准模型是基于深度残差网络(Deep residualnetwork,ResNet)18构建的,所述老师模型是基于ResNet101构建的。ResNet18 表示网络层数为18,ResNet101表示网络层数为101。
需要说明的是,在实际应用中,还可以运用其他深度网络构建模型,例 如类Alexnet、类Mobilenet、类Shufflenet、类Hrnet、类Vggnet、类Darknet 等,本发明实施例对此不作限制。
基准模型和老师模型针对训练数据会分别给出相应的输出结果。在步骤 S102的具体实施中,获取基准模型针对所述训练数据的第一输出结果和老师 模型针对所述训练数据的第二输出结果。第一输出结果包括针对每一类别的 第一分类概率,所述第二输出结果包括针对每一类别第二分类概率。在类别 的数量为N时,第一输出结果和第二输出结果则是N维向量,每一个数值表 示对应类别的分类概率。
具体地,对于不同的应用场景所设置的具体类别可以是不同的。例如, 对于行人属性的识别而言,具体的类别可以是性别为男、性别为女、年龄为 儿童、年龄为少年、年龄为青年、年龄为中年、年龄为老年、发型为长发、 发型为短发、上衣颜色为白色、上衣颜色为黑色等等,本发明实施例对此不 一一赘述。
在一个具体例子中,第一输出结果可以是上衣颜色为白色,其概率为0.9; 第二输出结果可以是上衣颜色为白色,其概率为0.99。
由于计算KL散度需要使得输出结果中概率和为1,而输出结果中仅给出 了某一类别的概率,因此在步骤S103的具体实施中,对第一输出结果和第二 输出结果进行处理。也即,基于每一类别的第一分类概率生成非该类别的第 三分类概率,以及基于每一类别第二分类概率生成非该类别的第四分类概率。
具体地,第一分类概率与第三分类概率之和为1,第二分类概率和第四分 类概率之和为1。
在一个具体例子中,第一输出结果可以是上衣颜色为白色,其概率为0.9; 第二输出结果可以是上衣颜色为白色,其概率为0.99。则第一概率分布中上 衣为非白色的概率为0.1,第二概率分布中上衣为非白色的概率为0.01。
进而在步骤S104的具体实施中,针对每一类别计算KL散度 (Kullback-LeiblerDivergence,也称相对熵)。也就是说,本发明实施例是利用 KL散度去监督标签(label)的概率分布。此处的标签(label)也就是基准模型和老 师模型输出结果中的类别。KL散度可以衡量第一概率分布和第二概率分布之 间的相似性。
此外,还可以计算基准模型自身的误差。基准模型自身的误差可以是指 基准模型的输出值与相应的期望值的误差。
需要说明的是,关于计算KL散度以及基准模型自身的误差的具体算法可 以参照现有技术,本发明实施例对此不作限制。
在步骤S105的具体实施中,在所述基准模型中进行反向传播时,使用的 是KL散度以及所述基准模型自身的误差的加权之和,以实现对基准模型的网 络参数调整的优化。
本领域技术人员应当理解的是,训练数据作为输入,输入至基准模型得 到第一输出结果的过程是正向传播的过程。在正向传播过程中,输入信息通 过输入层经隐含层,逐层处理并传向输出层。如果在输出层得不到期望的输 出值,则取输出值与期望的误差的平方和作为目标函数,转入反向传播,逐 层求出目标函数对模型中各神经元权值的偏导数,构成目标函数对权值向量 的梯量,作为修改权值的依据,网络的学习在权值修改过程中完成。误差落 入预定的范围内时,训练过程结束。
本发明实施例中,由于基准模型的网络层数较少,因此运行较快,可以 保证实时性;并且由于基准模型是利用网络层数较多的老师模型进行参数调 整的,因此可以保证分类准确性,也即本发明实施例训练完成的基准模型可 以兼顾数据分类的实时性和准确性。
在一个非限制性的实施例中,图1所示步骤S105可以包括以下步骤:计 算所述KL散度与第一权重的乘积以及所述基准模型自身的误差与第二权重 的乘积之和,以作为响应误差;利用所述响应误差在所述基准模型中进行反 向传播。所述第一权重大于和所述第二权重的比值可以根据实际的应用需求 来确定。
在一个优选实施例中,所述第一权重大于所述第二权重。
本实施例中,为了使基准模型能够更好地学习到老师模型的分类能力, 在利用KL散度以及基准模型自身的误差进行反向传播时,相对于基准模型自 身的误差,可以使KL散度的比重更大。也即在对KL散度与准模型自身的误 差进行加权计算时,设置第一权重大于第二权重。
在一个具体的例子中,第一权重和第二权重可以分别是7和1。
在一个非限制性的实施例中,图1所示步骤S104可以包括以下步骤:采 用Focalloss计算所述基准模型自身的误差。
本发明实施例能够有效避免训练数据正负样本不均衡的问题。其中,在 采用Focal loss计算误差时,所采用的公式可以表示为:其中,y=1表示样本图像为正样本,也即,样本图像中有此属性,y=0表示样 本图像为负样本,也即样本图像中无此属性;p为预测概率。
在一个具体的例子中,参数λ为1.5,参数α为0.5。具体地,参数λ作 为调节因子可以调节正负样本的重要性程度,参数λ的值越大,那些数量越 少的样本越被重视;参数α则是反调节参数λ作的调节因子,以防参数λ调 节程度过大。
在一个非限制性的实施例中,请参照图2,图1所示步骤S104可以包括 以下步骤:
步骤S201:获取所述训练数据针对每一类别的样本比例(sample ratio), 所述样本比例为包含该类别的样本数与在该类别下有效样本总数量的比值;
步骤S202:根据所述第一输出结果计算所述基准模型的原始误差;
步骤S203:将所述原始误差与所述样本比例进行加权,以得到所述准模 型自身的误差。
本实施例中,在计算反向传播所使用的误差时,将样本比例加权至原始 误差,可以保证对数量较少的样本的训练效果,进而提升精准模型对所有数 据最终的分类准确率。
具体实施中,可以计算包含该类别的样本数与在该类别下有效样本总数 量的比值,例如对于上衣颜色为白色这一类别,样本比例为上衣颜色为白色 的图片数与上衣有颜色的图片数的比值,对于上衣没有颜色的图片(如没有 标注的样本)则是无效样本,无效样本不参与样本比例的计算。
需要说明的是,计算原始误差的具体方式可以是任意可实施的误差计算 算法,本发明实施例对此不作限制。
在一个非限制性的实施例中,在图1所示步骤S101之前还可以包括以下 步骤:获取原始样本数据,所述原始样本数据为标注好的行人图像;将所述 原始样本数据中行人的上半身图像或下半身图像进行随机擦除,并随机更改 所述行人图像中的属性值,以得到所述训练数据。
为了保证训练好的基准模型对于图像中人体部分缺失的场景有更高的识 别准确性,本发明实施例对原始样本数据进行预处理,具体可以是随机擦除 上半身图像或下半身图像,以及随机更改属性值(也即基于关键点的数据在 线扩增),从而保证训练数据的多样性和灵活性,提升训练效果。
在一个非限制性的实施例中,利用行人重识别模型作为人体模型的预训 练模型,所述基准模型中网络架构的Backbone中的网络参数是直接调用所述 行人重识别模型中的网络参数的。。
具体实施中,基准模型可以包括Backbone(也称支柱,或核心,用于特 征提取)、池化层和全连接层。通过使用行人再识别模型(Person Re-identification,ReID作为人体属性的预训练模型,也即直接调用行人再识别 模型中Backbone(如ResNet18)的网络参数。由于ReID模型有大规模的数 据集作为训练,以及ReID模型和用于人体属性识别的基准模型在特征提取阶 段有着高维特征的相似性,因此能够有效解决基准模型在跨域场景下精准性 不高的问题。
在一个非限制性的实施例中,请参照图3,在图1所示步骤S101之前还 可以包括以下步骤:
步骤S301:获取原始样本数据,所述原始样本数据包括具有多种属性的 样本;
步骤S302:将具备第一属性的样本输入至预先训练好的生成式对抗网络(Generative Adversarial Networks,GAN),以生成具备第二属性的样本,所述 第一属性与所述第二属性属于同一类别,所述具备第二属性的样本为数量小 于预设门限的样本。
为了保证训练效果,在样本数量较少或缺失的情况下,本发明实施例使 用生成式对抗网络实现对上述样本的补充,保证样本的全面性和多样性,进 而保证模型训练效果。
具体地,生成式对抗网络可以是预先训练好的。生成式对抗网络的输入 和输出是具有相似性的属性,例如生成式对抗网络的输入为上衣为红色,输 出为上衣为绿色;输入为双肩包,输出为单肩包;输入为性别女,输出为性 别男等等。
在一个非限制性的实施例中,所述构建好的基准模型在初始化时采用 Kaiming算法初始化权重,所述构建好的基准模型在全连接层采用Normal初 始化权重。
本发明实施例通过采用Kaiming初始化,可以保证在有relu激活层时, 每层的输出值保持高斯分布,进而解决在训练时的梯度消失的问题。具体地, Kaiming初始化的放缩系数为这样可以保证在输入层和输出层的方差一致, 使前后两层数据都为高斯分布,从而能够在梯度反传时不会由于方差递减引 起的梯度消失的现象。
在一个非限制性的实施例中,所述基准模型自身的误差为焦点损失。
本发明实施例在计算误差时不采用常规的SoftmaxBCELoss,而是采用焦 点损失(Focal loss),可以有效避免训练数据不均衡的问题,极大程度上提高模 型的准确率。
在一个具体应用场景中,请参照图4,预先构建基准模型(baseline)41 和老师模型42。基准模型41是基于ResNet18构建的,所述老师模型42是基 于ResNet101构建的。训练数据可以是一批次的图片。
如图4所示,基准模型41可以包括backbone(也即ResNet18)、池化层(Pooling)、全连接层(FC)、输出层(Outputs)。老师模型42可以包括backbone (也即ResNet101)。
在训练时,将训练数据分别输入至基准模型41和老师模型42。基准模型 41和老师模型42分别输出第一输出结果和第二输出结果。基准模型41计算 出自身的误差Focalloss;老师模型42计算出KL散度KLDivloss。利用误差 Focalloss和KL散度KLDivloss在所述基准模型41中进行反向传播。
本发明实施例是在数据集上训练一个较大的模型ResNet101,采用知识蒸 馏(Knowledge Distillation,KD)的技术原理,然后将ResNet101用做知识蒸 馏的老师,去在线教ResNet18的模型,实现知识迁移。本发明实施例没有在 池化(Pooling)层之前做KD,而是对最终的标签(label)采用KD,而label的 学习往往不能直接用L2loss,而是用KL散度去监督label的概率分布。
请参照图5,本发明实施例还公开了一种模型训练装置50,模型训练装 置50可以包括:
输入模块501,用于将训练数据输入至构建好的基准模型和老师模型,所 述基准模型的网络层数小于所述老师模型的网络层数;
输出结果获取模块502,用于获取所述基准模型针对所述训练数据的第一 输出结果和老师模型针对所述训练数据的第二输出结果,所述第一输出结果 包括针对每一类别的第一分类概率,所述第二输出结果包括针对每一类别第 二分类概率;
概率生成模块503,用于基于每一类别的第一分类概率生成非该类别的第 三分类概率,以及基于每一类别第二分类概率生成非该类别的第四分类概率, 以得到每一类别的第一概率分布和第二概率分布,所述第一概率分布包括各 个类别及其第一分类概率、非该类别及其第三分类概率,所述第二概率分布 包括各个类别及其第二分类概率、非该类别及其概率;
KL散度计算模块504,用于利用每一类别下的第一概率分布和第二概率 分布计算KL散度,以及计算所述基准模型自身的误差;
参数调整模块505,用于利用所述KL散度以及所述基准模型自身的误差 在所述基准模型中进行反向传播,以用于调整所述基准模型的网络参数。
由于基准模型的网络层数较少,因此运行较快,可以保证实时性;并且 由于基准模型是利用网络层数较多的老师模型进行参数调整的,因此可以保 证分类准确性,也即本发明技术方案训练完成的精准模型可以兼顾数据分类 的实时性和准确性。
关于所述模型训练装置50的工作原理、工作方式的更多内容,可以参照 图1至图4中的相关描述,这里不再赘述。
本发明实施例还公开了一种存储介质,所述存储介质为计算机可读存储 介质,其上存储有计算机程序,所述计算机程序运行时可以执行图1—图3中 所示方法的步骤。所述存储介质可以包括ROM、RAM、磁盘或光盘等。所述 存储介质还可以包括非挥发性存储器(non-volatile)或者非瞬态(non-transitory) 存储器等。
本发明实施例还公开了一种计算设备,所述计算设备可以包括存储器和 处理器,所述存储器上存储有可在所述处理器上运行的计算机程序。所述处 理器运行所述计算机程序时可以执行图1—图3中所示方法的步骤。所述计算 设备包括但不限于手机、计算机、平板电脑等终端设备。
应理解,上述的处理器可以是通用处理器、数字信号处理器(digital signalprocessor,DSP)、专用集成电路(application specific integrated circuit,ASIC)、 现成可编程门阵列(field programmable gate array,FPGA)或者其他可编程逻辑 器件、分立门或者晶体管逻辑器件、分立硬件组件,还可以是***芯片(system on chip,SoC),还可以是中央处理器(central processor unit,CPU),还可以是 网络处理器(networkprocessor,NP),还可以是数字信号处理电路(digital signal processor,DSP),还可以是微控制器(micro controller unit,MCU),还可以是 可编程控制器(programmable logicdevice,PLD)或其他集成芯片。可以实现或 者执行本申请实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以 是微处理器或者该处理器也可以是任何常规的处理器等。结合本申请实施例 所公开的方法的步骤可以直接体现为硬件译码处理器执行完成,或者用译码 处理器中的硬件及软件模块组合执行完成。软件模块可以位于随机存储器, 闪存、只读存储器,可编程只读存储器或者电可擦写可编程存储器、寄存器 等本领域成熟的存储介质中。该存储介质位于存储器,处理器读取存储器中 的信息,结合其硬件完成上述方法的步骤。
还应理解,本发明实施例中提及的存储器可以是易失性存储器或非易失 性存储器,或可包括易失性和非易失性存储器两者。其中,非易失性存储器 可以是只读存储器(read-only memory,ROM)、可编程只读存储器 (programmable ROM,PROM)、可擦除可编程只读存储器(erasable PROM, EPROM)、电可擦除可编程只读存储器(electrically EPROM,EEPROM)或闪存。 易失性存储器可以是随机存取存储器(random access memory,RAM),其用作 外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如 静态随机存取存储器(static RAM,SRAM)、动态随机存取存储器(dynamic RAM,DRAM)、同步动态随机存取存储器(synchronous DRAM,SDRAM)、 双倍数据速率同步动态随机存取存储器(doubledata rate SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(enhanced SDRAM,ESDRAM)、 同步连接动态随机存取存储器(synchlink DRAM,SLDRAM)和直接内存总线 随机存取存储器(direct rambus RAM,DR RAM)。应注意,本文描述的***和 方法的存储器旨在包括但不限于这些和任意其它适合类型的存储器。
需要说明的是,当处理器为通用处理器、DSP、ASIC、FPGA或者其他 可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件时,存储器(存 储模块)集成在处理器中。应注意,本文描述的存储器旨在包括但不限于这些 和任意其它适合类型的存储器。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各 示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结 合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特 定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方 法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描 述的***、装置和单元的具体工作过程,可以参考前述方法实施例中的对应 过程,在此不再赘述。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用 时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的 技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可 以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中, 包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网 络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储 介质包括:U盘、移动硬盘、只读存储器(read-only memory,ROM)、随机存 取存储器(random access memory,RAM)、磁碟或者光盘等各种可以存储程序 代码的介质。
虽然本发明披露如上,但本发明并非限定于此。任何本领域技术人员, 在不脱离本发明的精神和范围内,均可作各种更动与修改,因此本发明的保 护范围应当以权利要求所限定的范围为准。
Claims (12)
1.一种模型训练方法,其特征在于,包括:
将训练数据输入至构建好的基准模型和老师模型,所述基准模型的网络层数小于所述老师模型的网络层数;
获取所述基准模型针对所述训练数据的第一输出结果和老师模型针对所述训练数据的第二输出结果,所述第一输出结果包括针对每一类别的第一分类概率,所述第二输出结果包括针对每一类别第二分类概率;
基于每一类别的第一分类概率生成非该类别的第三分类概率,以及基于每一类别第二分类概率生成非该类别的第四分类概率,以得到每一类别的第一概率分布和第二概率分布,所述第一概率分布包括各个类别及其第一分类概率、非该类别及其第三分类概率,所述第二概率分布包括各个类别及其第二分类概率、非该类别及其概率;
利用每一类别下的第一概率分布和第二概率分布计算KL散度,以及计算所述基准模型自身的误差;
利用所述KL散度以及所述基准模型自身的误差在所述基准模型中进行反向传播,以用于调整所述基准模型的网络参数。
2.根据权利要求1所述的模型训练方法,其特征在于,所述利用所述KL散度以及所述基准模型自身的误差在所述基准模型中进行反向传播包括:
计算所述KL散度与第一权重的乘积以及所述基准模型自身的误差与第二权重的乘积之和,以作为响应误差;
利用所述响应误差在所述基准模型中进行反向传播。
3.根据权利要求1所述的模型训练方法,其特征在于,所述计算所述基准模型自身的误差包括:
采用Focalloss计算所述基准模型自身的误差。
4.根据权利要求1所述的模型训练方法,其特征在于,所述计算所述基准模型自身的误差包括:
获取所述训练数据针对每一类别的样本比例,所述样本比例为包含该类别的样本数与在该类别下有效样本总数量的比值;
根据所述第一输出结果计算所述基准模型的原始误差;
将所述原始误差与所述样本比例进行加权,以得到所述准模型自身的误差。
5.根据权利要求1所述的模型训练方法,其特征在于,所述将训练数据输入至构建好的基准模型和老师模型之前还包括:
获取原始样本数据,所述原始样本数据为标注好的行人图像,所述原始样本数据中包括关键点;
根据所述原始样本数据的关键点的坐标,将行人的上半身图像或下半身图像进行随机擦除,并更改所述行人图像中的属性值,以得到所述训练数据。
6.根据权利要求1所述的模型训练方法,其特征在于,利用行人重识别模型作为人体模型的预训练模型,所述基准模型中网络架构的Backbone中的网络参数是直接调用所述行人重识别模型中的网络参数的。
7.根据权利要求1所述的模型训练方法,其特征在于,所述将训练数据输入至构建好的基准模型和老师模型之前还包括:
获取原始样本数据,所述原始样本数据包括具有多种属性的样本;
将具备第一属性的样本输入至预先训练好的生成式对抗网络,以生成具备第二属性的样本,所述第一属性与所述第二属性属于同一类别,所述具备第二属性的样本为数量小于预设门限的样本。
8.根据权利要求1至7任一项所述的模型训练方法,其特征在于,所述构建好的基准模型在初始化时采用Kaiming算法初始化权重,所述构建好的基准模型在全连接层采用Normal初始化权重。
9.根据权利要求1至7任一项所述的模型训练方法,其特征在于,所述基准模型是基于ResNet18构建的,所述老师模型是基于ResNet101构建的。
10.一种模型训练装置,其特征在于,包括:
输入模块,用于将训练数据输入至构建好的基准模型和老师模型,所述基准模型的网络层数小于所述老师模型的网络层数;
输出结果获取模块,用于获取所述基准模型针对所述训练数据的第一输出结果和老师模型针对所述训练数据的第二输出结果,所述第一输出结果包括针对每一类别的第一分类概率,所述第二输出结果包括针对每一类别第二分类概率;
概率生成模块,用于基于每一类别的第一分类概率生成非该类别的第三分类概率,以及基于每一类别第二分类概率生成非该类别的第四分类概率,以得到每一类别的第一概率分布和第二概率分布,所述第一概率分布包括各个类别及其第一分类概率、非该类别及其第三分类概率,所述第二概率分布包括各个类别及其第二分类概率、非该类别及其概率;
KL散度计算模块,用于利用每一类别下的第一概率分布和第二概率分布计算KL散度,以及计算所述基准模型自身的误差;
参数调整模块,用于利用所述KL散度以及所述基准模型自身的误差在所述基准模型中进行反向传播,以用于调整所述基准模型的网络参数。
11.一种存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器运行时执行权利要求1至9中任一项所述模型训练方法的步骤。
12.一种计算设备,包括存储器和处理器,所述存储器上存储有可在所述处理器上运行的计算机程序,其特征在于,所述处理器运行所述计算机程序时执行权利要求1至9中任一项所述模型训练方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011415641.5A CN112784677A (zh) | 2020-12-04 | 2020-12-04 | 模型训练方法及装置、存储介质、计算设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011415641.5A CN112784677A (zh) | 2020-12-04 | 2020-12-04 | 模型训练方法及装置、存储介质、计算设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112784677A true CN112784677A (zh) | 2021-05-11 |
Family
ID=75750750
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011415641.5A Pending CN112784677A (zh) | 2020-12-04 | 2020-12-04 | 模型训练方法及装置、存储介质、计算设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112784677A (zh) |
Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108921092A (zh) * | 2018-07-02 | 2018-11-30 | 浙江工业大学 | 一种基于卷积神经网络模型二次集成的黑色素瘤分类方法 |
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN110021051A (zh) * | 2019-04-01 | 2019-07-16 | 浙江大学 | 一种基于生成对抗网络通过文本指导的人物图像生成方法 |
CN110147456A (zh) * | 2019-04-12 | 2019-08-20 | 中国科学院深圳先进技术研究院 | 一种图像分类方法、装置、可读存储介质及终端设备 |
CN110197212A (zh) * | 2019-05-20 | 2019-09-03 | 北京邮电大学 | 图像分类方法、***及计算机可读存储介质 |
CN110321928A (zh) * | 2019-06-03 | 2019-10-11 | 深圳中兴网信科技有限公司 | 环境检测模型的生成方法、计算机设备及可读存储介质 |
CN110659573A (zh) * | 2019-08-22 | 2020-01-07 | 北京捷通华声科技股份有限公司 | 一种人脸识别方法、装置、电子设备及存储介质 |
CN111008654A (zh) * | 2019-11-26 | 2020-04-14 | 江苏艾佳家居用品有限公司 | 一种户型图中房间的识别方法及*** |
CN111488945A (zh) * | 2020-04-17 | 2020-08-04 | 上海眼控科技股份有限公司 | 图像处理方法、装置、计算机设备和计算机可读存储介质 |
CN111860147A (zh) * | 2020-06-11 | 2020-10-30 | 北京市威富安防科技有限公司 | 行人重识别模型优化处理方法、装置和计算机设备 |
-
2020
- 2020-12-04 CN CN202011415641.5A patent/CN112784677A/zh active Pending
Patent Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108921092A (zh) * | 2018-07-02 | 2018-11-30 | 浙江工业大学 | 一种基于卷积神经网络模型二次集成的黑色素瘤分类方法 |
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN110021051A (zh) * | 2019-04-01 | 2019-07-16 | 浙江大学 | 一种基于生成对抗网络通过文本指导的人物图像生成方法 |
CN110147456A (zh) * | 2019-04-12 | 2019-08-20 | 中国科学院深圳先进技术研究院 | 一种图像分类方法、装置、可读存储介质及终端设备 |
CN110197212A (zh) * | 2019-05-20 | 2019-09-03 | 北京邮电大学 | 图像分类方法、***及计算机可读存储介质 |
CN110321928A (zh) * | 2019-06-03 | 2019-10-11 | 深圳中兴网信科技有限公司 | 环境检测模型的生成方法、计算机设备及可读存储介质 |
CN110659573A (zh) * | 2019-08-22 | 2020-01-07 | 北京捷通华声科技股份有限公司 | 一种人脸识别方法、装置、电子设备及存储介质 |
CN111008654A (zh) * | 2019-11-26 | 2020-04-14 | 江苏艾佳家居用品有限公司 | 一种户型图中房间的识别方法及*** |
CN111488945A (zh) * | 2020-04-17 | 2020-08-04 | 上海眼控科技股份有限公司 | 图像处理方法、装置、计算机设备和计算机可读存储介质 |
CN111860147A (zh) * | 2020-06-11 | 2020-10-30 | 北京市威富安防科技有限公司 | 行人重识别模型优化处理方法、装置和计算机设备 |
Non-Patent Citations (2)
Title |
---|
KAIMING HE ET AL: "Delving Deep into Rectifiers:Surpassing Human-Level Performance on ImageNet Classification", 《ARXIV》 * |
李汉冰等: "基于YOLOV3改进的实时车辆检测方法", 《激光与光电子学进展》 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2019100724A1 (zh) | 训练多标签分类模型的方法和装置 | |
WO2019100723A1 (zh) | 训练多标签分类模型的方法和装置 | |
US20230196117A1 (en) | Training method for semi-supervised learning model, image processing method, and device | |
US20220058426A1 (en) | Object recognition method and apparatus, electronic device, and readable storage medium | |
CN109871781B (zh) | 基于多模态3d卷积神经网络的动态手势识别方法及*** | |
US11417148B2 (en) | Human face image classification method and apparatus, and server | |
Yun et al. | Focal loss in 3d object detection | |
US11755889B2 (en) | Method, system and apparatus for pattern recognition | |
CN111191526B (zh) | 行人属性识别网络训练方法、***、介质及终端 | |
WO2017096753A1 (zh) | 人脸关键点跟踪方法、终端和非易失性计算机可读存储介质 | |
CN111133453B (zh) | 人工神经网络 | |
US20190073553A1 (en) | Region proposal for image regions that include objects of interest using feature maps from multiple layers of a convolutional neural network model | |
US20180144246A1 (en) | Neural Network Classifier | |
CN112257815A (zh) | 模型生成方法、目标检测方法、装置、电子设备及介质 | |
CN111222487B (zh) | 视频目标行为识别方法及电子设备 | |
US11093800B2 (en) | Method and device for identifying object and computer readable storage medium | |
CN111401521B (zh) | 神经网络模型训练方法及装置、图像识别方法及装置 | |
CN110968734A (zh) | 一种基于深度度量学习的行人重识别方法及装置 | |
CN113781164B (zh) | 虚拟试衣模型训练方法、虚拟试衣方法和相关装置 | |
WO2021043023A1 (zh) | 图像处理方法及装置、分类器训练方法以及可读存储介质 | |
Gorijala et al. | Image generation and editing with variational info generative AdversarialNetworks | |
CN114170654A (zh) | 年龄识别模型的训练方法、人脸年龄识别方法及相关装置 | |
CN112749737A (zh) | 图像分类方法及装置、电子设备、存储介质 | |
CN116863194A (zh) | 一种足溃疡图像分类方法、***、设备及介质 | |
CN113723287A (zh) | 基于双向循环神经网络的微表情识别方法、装置及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
WD01 | Invention patent application deemed withdrawn after publication | ||
WD01 | Invention patent application deemed withdrawn after publication |
Application publication date: 20210511 |