CN112329916A - 模型训练方法、装置、计算机设备及存储介质 - Google Patents
模型训练方法、装置、计算机设备及存储介质 Download PDFInfo
- Publication number
- CN112329916A CN112329916A CN202011162501.1A CN202011162501A CN112329916A CN 112329916 A CN112329916 A CN 112329916A CN 202011162501 A CN202011162501 A CN 202011162501A CN 112329916 A CN112329916 A CN 112329916A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- picture
- sample set
- target
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 186
- 238000000034 method Methods 0.000 title claims abstract description 52
- 238000004821 distillation Methods 0.000 claims abstract description 28
- 238000004590 computer program Methods 0.000 claims description 23
- 239000011159 matrix material Substances 0.000 claims description 9
- 230000006870 function Effects 0.000 claims description 8
- 238000012360 testing method Methods 0.000 description 14
- 230000008569 process Effects 0.000 description 10
- 238000001514 detection method Methods 0.000 description 8
- 230000003190 augmentative effect Effects 0.000 description 6
- 238000003062 neural network model Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- 238000005516 engineering process Methods 0.000 description 4
- 238000013434 data augmentation Methods 0.000 description 3
- 230000004927 fusion Effects 0.000 description 3
- 238000002372 labelling Methods 0.000 description 3
- 230000003321 amplification Effects 0.000 description 2
- 230000003416 augmentation Effects 0.000 description 2
- 238000013480 data collection Methods 0.000 description 2
- 239000000463 material Substances 0.000 description 2
- 238000003199 nucleic acid amplification method Methods 0.000 description 2
- 230000009466 transformation Effects 0.000 description 2
- 230000008859 change Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Image Analysis (AREA)
Abstract
本申请涉及一种模型训练方法、装置、计算机设备及存储介质。该方法包括获取第一训练样本集合,所述第一训练样本集合包括多个第一图片样本;获取多个样本旋转角度,对于各所述样本旋转角度,将所述第一训练样本集合中的各所述第一图片样本旋转至所述样本旋转角度,得到与所述样本旋转角度对应的第二训练样本集合;其中,各所述样本旋转角度大于预设的角度阈值;基于各所述第二训练样本集合分别对初始模型进行训练,得到与各所述第二训练样本集合一一对应的中间模型;对所述多个中间模型进行模型蒸馏,得到目标模型。本申请实施例可以提高目标模型的精度。
Description
技术领域
本申请涉及机器学习技术领域,特别是涉及一种模型训练方法、装置、计算机设备及存储介质。
背景技术
监督学习是指利用标注有标签的样本对神经网络模型进行训练,基于监督学习方法训练神经网络模型的过程中,需要首先获取带有标签的样本。
现有技术中,由于样本标注标签需要耗费大量人力物力成本,因此,带有标签的样本的数量较少,而样本的数量较少会导致神经网络模型学习到的特征较少,从而导致最终的模型的测试精度较差。
发明内容
基于此,有必要针对上述技术问题,提供一种能够提高模型测试精度的模型训练方法、装置、计算机设备和存储介质。
一种模型训练方法,该方法包括:
获取第一训练样本集合,第一训练样本集合包括多个第一图片样本;
获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;
基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;
对多个中间模型进行模型蒸馏,得到目标模型。
在其中一个实施例中,对多个中间模型进行模型蒸馏,得到目标模型,包括:
获取目标图片;
将目标图片分别旋转至与各中间模型对应的角度,得到旋转后的中间图片;
将中间图片分别输入至各中间模型中,并根据各中间模型的输出结果获取目标图片的标签;
基于目标图片和目标图片的标签对初始模型进行训练,得到目标模型。
在其中一个实施例中,将目标图片分别旋转至与各中间模型对应的角度,包括:
利用矩阵旋转函数将目标图片分别旋转至各中间模型对应的角度。
在其中一个实施例中,中间模型的输出结果为特征图,根据各中间模型的输出结果获取目标图片的标签,包括:
将各中间模型输出的特征图进行按位相加,得到融合后的特征图;
将融合后的特征图确定为目标图片的标签。
在其中一个实施例中,各样本旋转角度之间的角度差相等,且各样本旋转角度大于20度。
在其中一个实施例中,样本旋转角度包括0度,45度,90度,135度,180度,225度,270度和315度。
一种模型训练装置,该装置包括:
第一获取模块,用于获取第一训练样本集合,第一训练样本集合包括多个第一图片样本;
第二获取模块,用于获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;
中间模型获取模块,用于基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;
蒸馏模块,用于对多个中间模型进行模型蒸馏,得到目标模型。
在其中一个实施例中,蒸馏模块还用于:
获取目标图片;
将目标图片分别旋转至与各中间模型对应的角度,得到旋转后的中间图片;
将中间图片分别输入至各中间模型中,并根据各中间模型的输出结果获取目标图片的标签;
基于目标图片和目标图片的标签对初始模型进行训练,得到目标模型。
一种计算机设备,包括存储器和处理器,存储器存储有计算机程序,处理器执行计算机程序时实现以下步骤:
获取第一训练样本集合,第一训练样本集合包括多个第一图片样本;
获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;
基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;
对多个中间模型进行模型蒸馏,得到目标模型。
一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现以下步骤:
获取第一训练样本集合,第一训练样本集合包括多个第一图片样本;
获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;
基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;
对多个中间模型进行模型蒸馏,得到目标模型。
上述模型训练方法、装置、计算机设备和存储介质,可以提高模型的测试精度。该模型训练方法通过获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;对多个中间模型进行模型蒸馏,得到目标模型。本申请实施例中,将第一训练样本集合中的第一图片样本旋转至大于预设的角度阈值的样本旋转角度后,得到了多个第二训练样本集合,这样实现了对第一训练样本集合的增广,对于增广后得到的多个第二训练样本集合,分别对初始模型进行训练,得到多个与各第二训练样本集合对应的中间模型,由于各个中间模型是通过不同的第二训练样本集合训练得到的,因此各个中间模型之间互不干扰,在此基础上,用模型蒸馏的方式把多个中间模型的知识迁移到最终的目标模型中,从而提高了最终的目标模型的精度。
附图说明
图1为一个实施例中模型训练方法的流程示意图;
图2为一个实施例中进行模型蒸馏的步骤的流程示意图;
图3为一个实施例中进行特征融合的步骤的流程示意图;
图4为一个实施例中模型训练装置的结构框图;
图5为一个实施例中计算机设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
监督学习是指利用标注有标签的样本对神经网络模型进行训练,基于监督学习方法训练神经网络模型的过程中,需要首先获取带有标签的样本。
传统技术中,由于样本标注标签需要耗费大量人力物力成本,因此,带有标签的样本的数量较少,而样本的数量较少会导致神经网络模型学习到的特征较少,从而导致最终的模型的测试精度较差。
为了解决样本图像的数量较少带来的模型精度较差的问题,现有技术提出了对样本图像进行数据增广的方式,其中,数据增广可以是指将样本图像旋转一定角度,样本图像旋转后,可以得到处于不同旋转角度的样本图像,但样本图像中的图像内容并未发生变化,也就是说,该旋转后的样本对应的标签也不会发生变化,这样就可以实现扩充样本图像数量,而不需要对新增的样本图像标注标签。从而实现了以较小的成本来扩大样本图像数量的目的。
在上述方法中,对样本图像旋转的角度一般限于-20度到+20度之间。这是由于旋转的角度在(-20,+20)之间时,旋转前和旋转后的样本图像的相似性较高,因此可以用于一个模型训练中。若样本图像的旋转角度较大时,例如样本图像的旋转角度为90度或者180度等较大范围时,旋转前的样本图像和旋转后的样本图像对于模型而言,具有极大差异,此时将旋转前和旋转后的样本图像用于同一个模型训练时,会导致模型崩溃。
基于上述内容可知,现有的样本图像的增广方式受到旋转角度的限制,因此,扩展后的样本图像的数量有限不能满足模型训练需求。
本申请实施例中,提供了一种模型训练方法,将第一训练样本集合中的第一图片样本旋转至大于预设的角度阈值的样本旋转角度后,得到了多个第二训练样本集合,其中,预设的角度阈值为大于20度的角度。由于旋转角度较大,因此,可以认为旋转后得到的多个第二训练样本集合属于不同的领域。从而实现了通过领域变换来扩充数据的目的。这样不需要另外人为搜集其他领域的数据,充分发挥第一训练样本集合的潜能。进一步的,基于每个第二训练样本集合,分别对初始模型进行训练,得到与每个第二训练样本集合一一对应的中间模型,并使用模型蒸馏技术,以多个中间模型为老师来训练最终的目标模型,从而提升目标模型的鲁棒性和测试精度。
在一个实施例中,如图1所示,提供了一种模型训练方法,本实施例以该方法应用于计算机设备中进行举例说明,该计算机设备可以是终端也可以是服务器。本实施例中,该方法包括以下步骤:
步骤101,计算机设备获取第一训练样本集合,第一训练样本集合包括多个第一图片样本。
本申请实施例中,第一训练样本集合为原始数据。
步骤102,计算机设备获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合。
其中,样本旋转角度是指对第一训练样本集合中的多个第一图片样本的旋转角度。各样本旋转角度大于预设的角度阈值。可选的,预设的角度阈值可以为20度。可选的,预设的角度阈值可以为45度。
可选的,样本旋转角度的角度值可以是人为任意设定的。
可选的,多个样本旋转角度之间的角度差相等。在角度差相等的情况下,可以实现旋转后的图片样本的差异最大化。
可选的,本申请实施例中,样本旋转角度可以包括0度,45度,90度,135度,180度,225度,270度和315度。
本申请实施例中,计算机设备对第一训练样本集合中的各第一图片样本进行旋转的过程如下:以样本旋转角度为45度为例,计算机设备可以逐一地将第一训练样本集合中的多个第一图片样本旋转45度,以得到旋转后的多个第一图片样本,将旋转为45度的多个第一图片样本组成第二训练样本集合,该第二训练样本集合可以表示为X45°,Y45°。
需要说明的是,在对第一图片样本进行旋转时,旋转前和旋转后的第一图片样本的标签并不发生变化。
可选的,本申请实施例中,默认第一训练样本集合中的多个第一图片样本的角度为0度。可选的,对各第一图片样本进行旋转是指顺时针旋转。
可以参考上述方法,分别将第一训练样本集合中的各第一图片样本旋转至其他的样本旋转角度,从而得到对应不同样本旋转角度的第二训练样本集合。
需要说明的是,本申请实施例中,由于样本旋转角度大于预设的角度阈值,因此,旋转前的第一训练样本集合和旋转后的第二训练样本集合之间属于不同的领域,并且,各个第二训练样本集合也属于不同的领域,属于不同领域的各个第二训练样本集包括的图片样本不可以用于同一个模型中进行训练,否则,会导致模型崩溃。
本申请实施例中,通过将第一训练样本集合包括的多个第一图片样本旋转至大于预设的角度阈值的样本旋转角度,可以得到增广后的对应不同领域的多个第二训练样本集合,从而实现了对原始数据进行领域增广的目的。
在一种可选的实现方式中,对于属于不同领域的每个第二训练样本集合,本申请实施例中,还可以对第二训练样本集合中的多个图片样本再次进行增广,再次进行增广是指,对第二训练样本集合中的多个图片样本进行再次旋转。可选的,再次旋转时,默认以第二训练样本集合中每个图片样本的当前角度为0度。并且,再次旋转的旋转方向可以是逆时针旋转也可以是顺时针旋转。再次旋转的旋转角度小于预设的角度阈值。
本申请实施例中,借鉴于现有技术中的数据增广方式,对第二训练样本集合中的图片样本进行再次旋转后,旋转后的图片样本和旋转前的图片样本一起组成第二训练样本集合。
通过对第一训练样本集合进行领域增广之后,得到对应不同领域的第二训练样本集合,然后对每个第二训练样本集合中的多个图片样本再次进行增广,使得单个第二训练样本集合中的数据量变大。最终得到的多个第二训练样本集合所包含的数据总量相比于第一训练样本集合包含的数据量得到了极大扩充,从而得到满足模型训练所需要的训练样本的数量,可以提高最终的模型的鲁棒性和测试精度。
步骤103,计算机设备基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型。
本申请实施例中,对于增广后的每一个领域的第二训练样本集合,单独对初始模型进行训练,最终产生多个中间模型,每个中间模型负责一个领域。
承接上文举例,对于旋转了0度的第二训练样本集合X0°,Y0°,可以利用第二训练样本集合对初始模型进行训练,得到训练好的模型为0度的中间模型。
相应的,对于旋转了45度的第二训练样本集合X45°,Y45°,可以利用第二训练样本集合对初始模型进行训练,得到训练好的模型为45度的中间模型。
依次类推,得到与各个第二训练样本集合对应的中间模型,各个中间模型的角度为训练该中间模型的第二训练样本集合相对第一训练样本集合的角度。
步骤104,计算机设备对多个中间模型进行模型蒸馏,得到目标模型。
本申请实施例中,模型蒸馏的方法可以采用任意的模型蒸馏的方法,例如特征对齐的方法。通过模型蒸馏可以将多个中间模型的知识迁移到最终的目标模型中。
本申请实施例中,提供了一种模型训练方法,将第一训练样本集合中的第一图片样本旋转至大于预设的角度阈值的样本旋转角度后,得到了多个第二训练样本集合,其中,预设的角度阈值为大于20度的角度。由于旋转角度较大,因此,可以认为旋转后得到的多个第二训练样本集合属于不同的领域。从而实现了通过领域变换来扩充数据的目的。这样不需要另外人为搜集其他领域的数据,充分发挥第一训练样本集合的潜能。进一步的,基于每个第二训练样本集合,分别对初始模型进行训练,得到与每个第二训练样本集合一一对应的中间模型,并使用模型蒸馏技术,以多个中间模型为老师来训练最终的目标模型,从而提升目标模型的鲁棒性和测试精度。
在一个实施例中,如图2所示,计算机设备对多个中间模型进行模型蒸馏,得到目标模型的过程可以包括以下内容:
步骤201,计算机设备获取目标图片。
本申请实施例中,目标图片为用于进行模型蒸馏的样本图片。
可选的,目标图片的数量可以为多个。
可选的,目标图片可以为第一训练样本集合包括的多个第一图片样本。
步骤202,计算机设备将目标图片分别旋转至与各中间模型对应的角度,得到旋转后的中间图片。
本申请实施例中,各中间模型对应的角度为训练该中间模型的第二训练样本集合相对第一训练样本集合的角度。
本申请实施例中,计算机设备对目标图片进行旋转的过程可以包括以下内容:利用矩阵旋转函数将目标图片分别旋转至各中间模型对应的角度。其中,矩阵旋转函数可以例如是torch.rotation(x,90),其中,x表示目标图片,90表示需要旋转的角度。torch.rotation(x,90)表示,将目标图片旋转90度。
承接上文举例可知,各中间模型对应的角度可以是0度,45度,90度,135度,180度,225度,270度和315度。因此,可以利用矩阵旋转函数将目标图片分别旋转至0度,45度,90度,135度,180度,225度,270度和315度,得到0度的中间图片、45度的中间图片、90度的中间图片,……,315度的中间图片。
可选的,当目标图片为第一训练样本集合包括的多个第一图片样本时,旋转后的中间图片可以是指各个第二训练样本集合中的图片样本。
步骤203,计算机设备将中间图片分别输入至各中间模型中,并根据各中间模型的输出结果获取目标图片的标签。
本申请实施例中,计算机设备将中间图片分别输入至各中间模型中是指:将0度的中间图片输入到0度的中间模型中,将45度的中间图片输入到45度的中间模型中,……,将315度的中间图片输入到315度的中间模型中。
对于每个中间图片,每个中间模型可以对输入的中间图片进行测试,例如若中间模型为目标检测模型,其中目标不限于是人脸、人体姿态、车辆或者其他目标,则中间模型可以对中间图片进行目标检测,并输出目标检测结果。
对于同一目标图片的多个中间图片,每个中间模型可以输出对应于该目标图片的目标检测结果。
本申请实施例中,计算机设备根据各中间模型的输出结果获取目标图片的标签的过程可以包括以下内容:计算机设备可以根据每个中间模型的输出结果确定出该目标图片的标签。
例如,中间模型为车辆检测模型,中间模型输出的检测结果可以包括存在车辆或者不存在车辆两种情况。假设中间模型的数量为8个,将某个目标图片基于步骤202处理后可以得到该目标图片对应的8个中间图片,将该8个中间图片分别输入到对应的中间模型后,每个中间模型可以输出检测结果,总共得到8种输出结果,每个中间模型输出的检测结果可以用于指示目标图片存在车辆或者不存在车辆。然后,计算机设备可以将该8种输出结果中占比超过50%的输出结果作为该目标图片的标签。例如8种输出结果中7种为存在车辆,1种为不存在车辆,那么存在车辆这一输出结果的总输出结果中的占比超过了50%,因此,将存在车辆这一输出结果确定为该目标图片的标签。
目标图片的数量为多个,对于多个目标图片可以分别按照上述方式进行处理,可以得到多个目标图片的标签。
在一种可选的实现方式中,中间模型的输出结果为特征图。特征图(英文:FeatureMap)为数据矩阵。
本申请实施例中,计算机设备根据各中间模型的输出结果获取目标图片的标签的过程可以包括以下内容:计算机设备可以对各个中间模型输出的特征图进行特征融合,得到融合后的特征图,并将融合后的特征图作为目标图片的标签。
可选的,对各个中间模型输出的特征图进行特征融合的过程可以包括以下步骤:
步骤301,计算机设备将各中间模型输出的特征图进行按位相加,得到融合后的特征图。
步骤302,计算机设备将融合后的特征图确定为目标图片的标签。
本申请实施例中,可以实现将多个中间模型学习到的特征进行融合的目的。
步骤204,计算机设备基于目标图片和目标图片的标签对初始模型进行训练,得到目标模型。
本申请实施例中,多个目标图片和各个目标图片的标签可以用于对初始模型进行训练,由于各个目标图片的标签是基于各个中间模型对目标图片学习后的输出结果确定的,因此,基于各个目标图片的标签对初始模型进行训练,可以实现将最终得到的目标模型的输出结果与各个中间模型的输出结果对齐的目的,即实现了将多个中间模型的知识迁移到目标模型的目的,从而提升目标模型的鲁棒性和测试精度。
需要说明的是,在对目标模型进行模型测试时,对于一张测试图片,不需要对该测试图片进行任何旋转处理,可以直接将该测试图片输入至训练好的目标模型中,基于该目标模型对测试图片进行测试。
应该理解的是,虽然图1-图3的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1-图3中的至少一部分步骤可以包括多个步骤或者多个阶段,这些步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤中的步骤或者阶段的至少一部分轮流或者交替地执行。
在一个实施例中,如图4所示,提供了一种模型训练装置,包括:第一获取模块401,第二获取模块402、中间模型获取模块403和蒸馏模块404,其中:
第一获取模块401,用于获取第一训练样本集合,第一训练样本集合包括多个第一图片样本;
第二获取模块402,用于获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;
中间模型获取模块403,用于基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;
蒸馏模块404,用于对多个中间模型进行模型蒸馏,得到目标模型。
在本申请的一个实施例中,蒸馏模块404还用于:
获取目标图片;
将目标图片分别旋转至与各中间模型对应的角度,得到旋转后的中间图片;
将中间图片分别输入至各中间模型中,并根据各中间模型的输出结果获取目标图片的标签;
基于目标图片和目标图片的标签对初始模型进行训练,得到目标模型。
在本申请的一个实施例中,蒸馏模块404还用于:
利用矩阵旋转函数将目标图片分别旋转至各中间模型对应的角度。
在本申请的一个实施例中,蒸馏模块404还用于:
将各中间模型输出的特征图进行按位相加,得到融合后的特征图;
将融合后的特征图确定为目标图片的标签。
在本申请的一个实施例中,各样本旋转角度之间的角度差相等,且各样本旋转角度大于20度。
在本申请的一个实施例中,样本旋转角度包括0度,45度,90度,135度,180度,225度,270度和315度。
关于模型训练装置的具体限定可以参见上文中对于模型训练方法的限定,在此不再赘述。上述模型训练装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图5所示。该计算机设备包括通过***总线连接的处理器、存储器和网络接口。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作***、计算机程序和数据库。该内存储器为非易失性存储介质中的操作***和计算机程序的运行提供环境。该计算机设备的数据库用于存储矩阵旋转函数。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种模型训练方法。
本领域技术人员可以理解,图5中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,提供了一种计算机设备,包括存储器和处理器,存储器中存储有计算机程序,该处理器执行计算机程序时实现以下步骤:
获取第一训练样本集合,第一训练样本集合包括多个第一图片样本;
获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;
基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;
对多个中间模型进行模型蒸馏,得到目标模型。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:获取目标图片;
将目标图片分别旋转至与各中间模型对应的角度,得到旋转后的中间图片;
将中间图片分别输入至各中间模型中,并根据各中间模型的输出结果获取目标图片的标签;
基于目标图片和目标图片的标签对初始模型进行训练,得到目标模型。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:利用矩阵旋转函数将目标图片分别旋转至各中间模型对应的角度。
在一个实施例中,中间模型的输出结果为特征图,处理器执行计算机程序时还实现以下步骤:
将各中间模型输出的特征图进行按位相加,得到融合后的特征图;
将融合后的特征图确定为目标图片的标签。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:各样本旋转角度之间的角度差相等,且各样本旋转角度大于20度。
在一个实施例中,处理器执行计算机程序时还实现以下步骤:样本旋转角度包括0度,45度,90度,135度,180度,225度,270度和315度。
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现以下步骤:
获取第一训练样本集合,第一训练样本集合包括多个第一图片样本;
获取多个样本旋转角度,对于各样本旋转角度,将第一训练样本集合中的各第一图片样本旋转至样本旋转角度,得到与样本旋转角度对应的第二训练样本集合;其中,各样本旋转角度大于预设的角度阈值;
基于各第二训练样本集合分别对初始模型进行训练,得到与各第二训练样本集合一一对应的中间模型;
对多个中间模型进行模型蒸馏,得到目标模型。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:获取目标图片;
将目标图片分别旋转至与各中间模型对应的角度,得到旋转后的中间图片;
将中间图片分别输入至各中间模型中,并根据各中间模型的输出结果获取目标图片的标签;
基于目标图片和目标图片的标签对初始模型进行训练,得到目标模型。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:利用矩阵旋转函数将目标图片分别旋转至各中间模型对应的角度。
在一个实施例中,中间模型的输出结果为特征图,计算机程序被处理器执行时还实现以下步骤:
将各中间模型输出的特征图进行按位相加,得到融合后的特征图;
将融合后的特征图确定为目标图片的标签。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:各样本旋转角度之间的角度差相等,且各样本旋转角度大于20度。
在一个实施例中,计算机程序被处理器执行时还实现以下步骤:样本旋转角度包括0度,45度,90度,135度,180度,225度,270度和315度。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和易失性存储器中的至少一种。非易失性存储器可包括只读存储器(Read-Only Memory,ROM)、磁带、软盘、闪存或光存储器等。易失性存储器可包括随机存取存储器(Random Access Memory,RAM)或外部高速缓冲存储器。作为说明而非局限,RAM可以是多种形式,比如静态随机存取存储器(Static Random Access Memory,SRAM)或动态随机存取存储器(Dynamic Random Access Memory,DRAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求为准。
Claims (10)
1.一种模型训练方法,其特征在于,所述方法包括:
获取第一训练样本集合,所述第一训练样本集合包括多个第一图片样本;
获取多个样本旋转角度,对于各所述样本旋转角度,将所述第一训练样本集合中的各所述第一图片样本旋转至所述样本旋转角度,得到与所述样本旋转角度对应的第二训练样本集合;其中,各所述样本旋转角度大于预设的角度阈值;
基于各所述第二训练样本集合分别对初始模型进行训练,得到与各所述第二训练样本集合一一对应的中间模型;
对所述多个中间模型进行模型蒸馏,得到目标模型。
2.根据权利要求1所述的方法,其特征在于,所述对所述多个中间模型进行模型蒸馏,得到目标模型,包括:
获取目标图片;
将所述目标图片分别旋转至与各所述中间模型对应的角度,得到旋转后的中间图片;
将所述中间图片分别输入至各所述中间模型中,并根据各所述中间模型的输出结果获取所述目标图片的标签;
基于所述目标图片和所述目标图片的标签对所述初始模型进行训练,得到所述目标模型。
3.根据权利要求2所述的方法,其特征在于,所述将所述目标图片分别旋转至与各所述中间模型对应的角度,包括:
利用矩阵旋转函数将所述目标图片分别旋转至各所述中间模型对应的角度。
4.根据权利要求2所述的方法,其特征在于,所述中间模型的输出结果为特征图,所述根据各所述中间模型的输出结果获取所述目标图片的标签,包括:
将各所述中间模型输出的特征图进行按位相加,得到融合后的特征图;
将所述融合后的特征图确定为所述目标图片的标签。
5.根据权利要求1所述的方法,其特征在于,各所述样本旋转角度之间的角度差相等,且各所述样本旋转角度大于20度。
6.根据权利要求1所述的方法,其特征在于,所述样本旋转角度包括0度,45度,90度,135度,180度,225度,270度和315度。
7.一种模型训练装置,其特征在于,所述装置包括:
第一获取模块,用于获取第一训练样本集合,所述第一训练样本集合包括多个第一图片样本;
第二获取模块,用于获取多个样本旋转角度,对于各所述样本旋转角度,将所述第一训练样本集合中的各所述第一图片样本旋转至所述样本旋转角度,得到与所述样本旋转角度对应的第二训练样本集合;其中,各所述样本旋转角度大于预设的角度阈值;
中间模型获取模块,用于基于各所述第二训练样本集合分别对初始模型进行训练,得到与各所述第二训练样本集合一一对应的中间模型;
蒸馏模块,用于对所述多个中间模型进行模型蒸馏,得到目标模型。
8.根据权利要求7所述的装置,其特征在于,所述蒸馏模块还用于:
获取目标图片;
将所述目标图片分别旋转至与各所述中间模型对应的角度,得到旋转后的中间图片;
将所述中间图片分别输入至各所述中间模型中,并根据各所述中间模型的输出结果获取所述目标图片的标签;
基于所述目标图片和所述目标图片的标签对所述初始模型进行训练,得到所述目标模型。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至6中任一项所述的方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至6中任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011162501.1A CN112329916A (zh) | 2020-10-27 | 2020-10-27 | 模型训练方法、装置、计算机设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011162501.1A CN112329916A (zh) | 2020-10-27 | 2020-10-27 | 模型训练方法、装置、计算机设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112329916A true CN112329916A (zh) | 2021-02-05 |
Family
ID=74295954
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011162501.1A Pending CN112329916A (zh) | 2020-10-27 | 2020-10-27 | 模型训练方法、装置、计算机设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112329916A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115130539A (zh) * | 2022-04-21 | 2022-09-30 | 腾讯科技(深圳)有限公司 | 分类模型训练、数据分类方法、装置和计算机设备 |
WO2023197409A1 (zh) * | 2022-04-13 | 2023-10-19 | 魔门塔(苏州)科技有限公司 | 速度控制模型的生成方法、车辆控制方法及装置 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110223281A (zh) * | 2019-06-06 | 2019-09-10 | 东北大学 | 一种数据集中含有不确定数据时的肺结节图像分类方法 |
WO2019233421A1 (zh) * | 2018-06-04 | 2019-12-12 | 京东数字科技控股有限公司 | 图像处理方法及装置、电子设备、存储介质 |
CN111242303A (zh) * | 2020-01-14 | 2020-06-05 | 北京市商汤科技开发有限公司 | 网络训练方法及装置、图像处理方法及装置 |
CN111310808A (zh) * | 2020-02-03 | 2020-06-19 | 平安科技(深圳)有限公司 | 图片识别模型的训练方法、装置、计算机***及存储介质 |
CN111598182A (zh) * | 2020-05-22 | 2020-08-28 | 北京市商汤科技开发有限公司 | 训练神经网络及图像识别的方法、装置、设备及介质 |
WO2020194077A1 (en) * | 2019-03-22 | 2020-10-01 | International Business Machines Corporation | Unification of models having respective target classes with distillation |
-
2020
- 2020-10-27 CN CN202011162501.1A patent/CN112329916A/zh active Pending
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019233421A1 (zh) * | 2018-06-04 | 2019-12-12 | 京东数字科技控股有限公司 | 图像处理方法及装置、电子设备、存储介质 |
WO2020194077A1 (en) * | 2019-03-22 | 2020-10-01 | International Business Machines Corporation | Unification of models having respective target classes with distillation |
CN110223281A (zh) * | 2019-06-06 | 2019-09-10 | 东北大学 | 一种数据集中含有不确定数据时的肺结节图像分类方法 |
CN111242303A (zh) * | 2020-01-14 | 2020-06-05 | 北京市商汤科技开发有限公司 | 网络训练方法及装置、图像处理方法及装置 |
CN111310808A (zh) * | 2020-02-03 | 2020-06-19 | 平安科技(深圳)有限公司 | 图片识别模型的训练方法、装置、计算机***及存储介质 |
CN111598182A (zh) * | 2020-05-22 | 2020-08-28 | 北京市商汤科技开发有限公司 | 训练神经网络及图像识别的方法、装置、设备及介质 |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023197409A1 (zh) * | 2022-04-13 | 2023-10-19 | 魔门塔(苏州)科技有限公司 | 速度控制模型的生成方法、车辆控制方法及装置 |
CN115130539A (zh) * | 2022-04-21 | 2022-09-30 | 腾讯科技(深圳)有限公司 | 分类模型训练、数据分类方法、装置和计算机设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US10936911B2 (en) | Logo detection | |
JP6957624B2 (ja) | ターゲット・ドメイン画像へのソース・ドメイン画像の変換 | |
CN112115783A (zh) | 基于深度知识迁移的人脸特征点检测方法、装置及设备 | |
US10810765B2 (en) | Image processing apparatus and image processing method | |
CN113112509B (zh) | 图像分割模型训练方法、装置、计算机设备和存储介质 | |
US20230042187A1 (en) | Behavior recognition method and system, electronic device and computer-readable storage medium | |
CN112329916A (zh) | 模型训练方法、装置、计算机设备及存储介质 | |
CN112434618A (zh) | 基于稀疏前景先验的视频目标检测方法、存储介质及设备 | |
CN116670687A (zh) | 用于调整训练后的物体检测模型以适应域偏移的方法和*** | |
CN110910375A (zh) | 基于半监督学习的检测模型训练方法、装置、设备及介质 | |
CN113361643A (zh) | 基于深度学习的通用标志识别方法、***、设备及存储介质 | |
CN116051873A (zh) | 关键点匹配方法、装置及电子设备 | |
Chowdhury et al. | Automated augmentation with reinforcement learning and GANs for robust identification of traffic signs using front camera images | |
CN116958148B (zh) | 输电线路关键部件缺陷的检测方法、装置、设备、介质 | |
CN113704276A (zh) | 地图更新方法、装置、电子设备及计算机可读存储介质 | |
CN112241705A (zh) | 基于分类回归的目标检测模型训练方法和目标检测方法 | |
CN112418264A (zh) | 检测模型的训练方法、装置、目标检测方法、设备和介质 | |
WO2023066142A1 (zh) | 全景图像的目标检测方法、装置、计算机设备和存储介质 | |
US20230401670A1 (en) | Multi-scale autoencoder generation method, electronic device and readable storage medium | |
CN113743434A (zh) | 一种目标检测网络的训练方法、图像增广方法及装置 | |
CN111507420A (zh) | 轮胎信息获取方法、装置、计算机设备和存储介质 | |
US20220270353A1 (en) | Data augmentation based on attention | |
CN112329915A (zh) | 模型训练方法、装置、计算机设备和存储介质 | |
CN112348060A (zh) | 分类向量生成方法、装置、计算机设备和存储介质 | |
CN113989374A (zh) | 用于非连续观测情况下物体定位的方法、装置和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |