CN117523320B - 一种基于关键点的图像分类模型训练方法及终端 - Google Patents

一种基于关键点的图像分类模型训练方法及终端 Download PDF

Info

Publication number
CN117523320B
CN117523320B CN202410004010.6A CN202410004010A CN117523320B CN 117523320 B CN117523320 B CN 117523320B CN 202410004010 A CN202410004010 A CN 202410004010A CN 117523320 B CN117523320 B CN 117523320B
Authority
CN
China
Prior art keywords
key point
image
data set
classification model
mask
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202410004010.6A
Other languages
English (en)
Other versions
CN117523320A (zh
Inventor
梁浩
朱志发
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Santachi Video Technology Shenzhen Co ltd
Original Assignee
Santachi Video Technology Shenzhen Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Santachi Video Technology Shenzhen Co ltd filed Critical Santachi Video Technology Shenzhen Co ltd
Priority to CN202410004010.6A priority Critical patent/CN117523320B/zh
Publication of CN117523320A publication Critical patent/CN117523320A/zh
Application granted granted Critical
Publication of CN117523320B publication Critical patent/CN117523320B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Databases & Information Systems (AREA)
  • Multimedia (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开一种基于关键点的图像分类模型训练方法及终端,获取图像数据集和对应的关键点数据集,基于关键点数据集构建对应的关键点特征矩阵集,通过对关键点特征矩阵集和图像数据集构建关键点掩码和图像掩码,从数据上对两种不同的数据集进行区分;同时,通过关键点特征矩阵集与图像数据集对预设的分类模型进行合并训练,使得分类模型在进行训练的过程中,分类模型能够有效提取关键点附近的特征,从而提高对图像分类的精确度。此外,基于数据集对应的掩码对分类模型进行训练,使得分类模型能够同时包含图像分类结果和关键点检测结果,实现无需舍弃关键点检测结果。

Description

一种基于关键点的图像分类模型训练方法及终端
技术领域
本发明涉及图像分类模型训练技术领域,尤其涉及一种基于关键点的图像分类模型训练方法及终端。
背景技术
在图像处理中,关键点本质上是一种特征,是对一个固定区域或者空间物理关系的抽象描述,描述的是一定领域范围内的组合或上下文关系。它不仅仅是一个点信息,或代表一个位置,更代表着上下文与周围领域的组合关系。因此,可通过关键点检测后的图像数据集作为一组特征关系对图像分类模型进行训练,以提高模型的精确性。
目前,基于关键点检测进行图像分类模型训练的方法主要包括以下两种:
一种是先使用常规分类算法的骨干网络,将池化和全连接部分替换成关键点检测所需要的上采样层,构建成基于heatmap(热力图)方式的关键点检测模型,然后在关键点检测的数据集上进行训练,使分类算法的骨干网络所提取的特征聚焦到关键点上;在完成关键点检测数据集的训练后,将上采样层替换为原有的池化和全连接层,使用一个较小的学习率,在目标图像分类数据集上完成训练;该方案借助关键点检测数据集使模型关注到关键点特征,然后在学习到关键点特征的模型基础上学习目标数据的特征,是通过迁移学习完成目标训练的;但是该方案需要训练两次模型,且最终只能保留图像分类的模型,关键点检测模型作为中间产物被丢弃。
另一种是使用常规的图像分类模型做骨架,但是对最后的全连接层做一定的修改,增加关键点的字段,用于计算目标的关键点损失;该方法通过图像分类的损失和关键点的损失,同时对模型的特征提取进行约束,帮助模型快速定位到关键特征。但是该方法需要基于一个已有的关键点检测算法检测数据集中的关键点作为关键点损失计算的依据,由于这些关键点是检测出来的,会存在一定的误差,对模型的精度产生干扰;同时该方案是基于坐标回归的方式检测关键点,对于数据集的规模要求较高,在小规模数据集上基本没有效果。
发明内容
本发明所要解决的技术问题是:提供一种基于关键点的图像分类模型训练方法及终端,通过关键点检测提高图像分类模型精确性的同时,保留关键点检测模型与图像分类模型。
为了解决上述技术问题,本发明采用的技术方案为:
一种基于关键点的图像分类模型训练方法,包括步骤:
获取图像数据集以及与所述图像数据集相对应的关键点数据集;
根据所述关键点数据集生成对应的关键点热力图,并基于所述关键点热力图构建关键点特征矩阵集;
分别对所述关键点特征矩阵集和图像数据集构建对应的数据集掩码;
根据所述数据集掩码、所述关键点特征矩阵集以及所述图像数据集对预设的分类模型进行训练,得到最优图像分类模型。
为了解决上述技术问题,本发明采用的另一种技术方案为:
一种基于关键点的图像分类模型训练终端,包括存储器、处理器及存储在所述存储器上并在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述的一种基于关键点的图像分类模型训练方法中的各个步骤。
本发明的有益效果在于:获取图像数据集和对应的关键点数据集,基于关键点数据集构建对应的关键点特征矩阵集,通过对关键点特征矩阵集和图像数据集分别构建数据集掩码,从数据上对两种不同的数据集进行区分;同时,通过关键点特征矩阵集与图像数据集对预设的分类模型进行合并训练,使得分类模型在进行训练的过程中,分类模型能够有效提取关键点附近的特征,从而提高对图像分类的精确度。此外,基于数据集对应的掩码对分类模型进行训练,使得分类模型能够同时包含图像分类结果和关键点检测结果,实现无需舍弃关键点检测结果。
附图说明
图1为本发明实施例提供的一种基于关键点的图像分类模型训练方法的步骤流程图;
图2为本发明实施例提供的一种基于关键点的图像分类模型训练终端的结构示意图;
标号说明:
300、一种基于关键点的图像分类模型训练终端;301、存储器;302、处理器。
具体实施方式
为详细说明本发明的技术内容、所实现目的及效果,以下结合实施方式并配合附图予以说明。
请参照图1,本发明实施例提供了一种基于关键点的图像分类模型训练方法,包括步骤:
获取图像数据集以及与所述图像数据集相对应的关键点数据集;
根据所述关键点数据集生成对应的关键点热力图,并基于所述关键点热力图构建关键点特征矩阵集;
分别对所述关键点特征矩阵集和图像数据集构建对应的数据集掩码;
根据所述数据集掩码、所述关键点特征矩阵集以及所述图像数据集对预设的分类模型进行训练,得到最优图像分类模型。
从上述描述可知,本发明的有益效果在于:获取图像数据集和对应的关键点数据集,基于关键点数据集构建对应的关键点特征矩阵集,通过对关键点特征矩阵集和图像数据集分别构建数据集掩码,从数据上对两种不同的数据集进行区分;同时,通过关键点特征矩阵集与图像数据集对预设的分类模型进行合并训练,使得分类模型在进行训练的过程中,分类模型能够有效提取关键点附近的特征,从而提高对图像分类的精确度。此外,基于数据集对应的掩码对分类模型进行训练,使得分类模型能够同时包含图像分类结果和关键点检测结果,实现无需舍弃关键点检测结果。
进一步地,所述数据集掩码的类型包括关键点掩码和图像掩码;
所述分别对所述关键点特征矩阵集和图像数据集构建对应的数据集掩码包括:
对所述关键点特征矩阵集中的每一关键点样本构建关键点掩码和图像掩码,并对所述关键点掩码进行标识;
对所述图像数据集中的每一图像样本构建关键点掩码和图像掩码,并对所述图像掩码进行标识。
由上述描述可知,通过构建关键点掩码和图像掩码以区分关键点数据集和图像数据集中的样本数据,以实现后续对关键点数据集和图像数据集对应的损失函数运算,从而实现对图像分类和关键点检测两个输出结果的按需保留。
进一步地,所述根据所述数据集掩码、所述关键点特征矩阵集以及所述图像数据集对预设的分类模型进行训练之前,还包括:
获取原始分类模型,所述原始分类模型包括卷积层、池化层和全连接层;
将所述池化层和所述全连接层组合为所述原始分类模型的分类头,将所述卷积层组合为所述原始分类模型的骨干层,并在所述骨干层与所述分类头之间创建上采样层,得到预设的分类模型。
由上述描述可知,通过将池化层和全连接层组合得到分类头,以实现对图像分类结果的训练学习;通过将卷积层组合为骨干层,以实现对图像和关键点图像的统一特征提取;通过在模型中创建上采样层,以实现对关键点检测的训练学习;通过上采样层加快原始分类模型对关键特征的学习和提取,提高图像分类的准确率。
进一步地,所述根据所述数据集掩码、所述关键点特征矩阵集以及所述图像数据集对预设的分类模型进行训练,得到最优图像分类模型包括:
将所述关键点特征矩阵集和所述图像数据集输入预设的分类模型进行分类,得到对应的预测结果;
根据所述数据集掩码和所述预测结果分别计算所述关键点特征矩阵集和所述图像数据集对应的损失函数值;
根据所述关键点特征矩阵集和所述图像数据集对应的损失函数值分别对所述分类模型进行迭代训练,得到多个不同参数的初始图像分类模型以及对应的精确度;
选择所述精确度最高的初始图像分类模型作为最优图像分类模型。
由上述描述可知,通过数据集掩码从两个数据集混合的预测结果中,分别计算出不同数据集对应的损失函数值,实现图像分类结果和关键点检测结果的同时保留。并且通过对应的损失函数值对应优化分类模型,在提高关键点检测的精确性的同时,关键点检测同步加快了图像分类对目标关键特征的学习和提取,从而提高了模型的精确度。
进一步地,所述根据所述数据集掩码和所述预测结果分别计算所述关键点特征矩阵集和所述图像数据集对应的损失函数值包括:
根据所述预测结果计算所述关键点特征矩阵集和所述图像数据集中每一样本对应的原始损失函数值;
根据所述数据集掩码将所述原始损失函数值进行分类,得到所述关键点特征矩阵集和所述图像数据集对应的损失函数值。
由上述描述可知,由于关键点特征矩阵集与图像数据集是合并输入预设的分类模型进行迭代训练,则对应输出的预测结果也是两种数据集的混合结果。但,每一预测结果对应的样本数据有对应的数据集掩码,从而通过样本数据的数据集掩码将预测结果进行对应的损失函数计算,区分不同数据集的损失函数值,避免了合并数据集无法区分损失函数值的问题,在一个图像分类模型中实现图像分类和关键点检测的同步优化。
进一步地,所述根据所述关键点数据集生成对应的关键点热力图包括:
将所述关键点数据集中的每一样本对应的标注目标进行裁剪,得到目标图像;
将所述样本中的所有关键点映射至所述目标图像中,得到关键点坐标;
根据所述目标图像和所述关键点坐标生成所述样本对应的关键点热力图。
由上述描述可知,通过对关键点数据集中样本的标注目标进行裁剪,使得关键点更加聚焦于标注目标,同时基于目标图像构建关键点热力图,以此提高同类特征的聚合性。
进一步地,所述基于所述关键点热力图构建关键点特征矩阵集包括:
根据预设的分类模型的输入尺度调整所述关键点热力图的尺度,得到标准热力图;
以所述标准热力图中的每一关键点坐标为中心,分别生成预设维度的高斯核数据;
将所有所述高斯核数据进行合并,得到与所述每一样本对应的关键点特征矩阵;
根据所述关键点数据集中所有样本对应的关键点特征矩阵得到关键点特征矩阵集。
由上述描述可知,通过生成高斯核数据,实现将低维的样本数据转变为高维数据,而高维数据使得同样类型的样本能够更好的凝聚在一起,提高关键点检测的分类精确性。
进一步地,所述骨干层与所述上采样层的学习率相同,所述分类头的学习率为所述骨干层的学习率的十分之一。
由上述描述可知,骨干层与上采样层的学习率相同,保证特征提取的学习速度同步,以实现关键点检测对图像分类目标的关键特征的提取与学习;分类头的学习率小,保证模型寻找最优解的精确性,从而提高图像分类算法的准确率。
进一步地,所述上采样层包括采样操作和与所述采样操作对应的卷积操作。
请参照图2,本发明另一实施例提供了一种基于关键点的图像分类模型训练终端,包括存储器、处理器及存储在所述存储器上并在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述的一种基于关键点的图像分类模型训练方法中的各个步骤。
从上述描述可知,本发明的有益效果在于:获取图像数据集和对应的关键点数据集,基于关键点数据集构建对应的关键点特征矩阵集,通过对关键点特征矩阵集和图像数据集分别构建数据集掩码,从数据上对两种不同的数据集进行区分;同时,通过关键点特征矩阵集与图像数据集对预设的分类模型进行合并训练,使得分类模型在进行训练的过程中,分类模型能够有效提取关键点附加的特征,从而提高对图像分类的精确度。此外,基于数据集对应的掩码对分类模型进行训练,使得分类模型能够同时包含图像分类结果和关键点检测结果,实现无需舍弃关键点检测结果。
本发明实施例提供了一种基于关键点的图像分类模型训练方法及终端,可应用于图像分类模型的训练,通过关键点检测提高图像分类模型精确性的同时,保留关键点检测模型与图像分类模型,以下通过具体实施例来说明:
请参照图1,本发明的实施例一为:
一种基于关键点的图像分类模型训练方法,包括步骤:
S1、获取图像数据集以及与图像数据集相对应的关键点数据集。
需要说明的是,图像数据集的每一样本数据均标注有识别目标对应的特征;关键点数据集的每一样本数据均标注有该识别目标对应的所有关键点。例如,在图像数据集的每一样本数据均标注有行人的特征,则对应的关键点数据集的每一样本数据均标注有行人对应的所有关键点。
在一些实施例中,在图像数据集中,对每个需要识别的目标均标注有多个特征,每个特征的标注格式为二进制表示,其中,1表示该识别目标具备该特征;0表示该识别目标不具备该特征。例如,需要识别的目标为行人,标注的特征A1为该行人是否佩戴安全帽,若该特征A1=1,则表示该行人佩戴有安全帽;若该特征A1=0,则表示该行人未佩戴安全帽。
在一些实施例中,在关键点数据集中,对每个需要识别的目标均标注有多个关键点,每个关键点的标注格式为(x,y,v),其中,x和y为关键点坐标,v表示该关键点是否存在或可见,当v=0表示该关键点不在图像中,当v=1表示该关键点在图像中但是被遮挡,当v=2表示该关键点清晰可见。
在一些实施例中,将关键点数据集和图像数据集分别划分为对应的训练集和测试集,训练集用于完成预设分类模型的训练,测试集用于验证最优分类模型的训练效果。
S2、根据关键点数据集生成对应的关键点热力图,并基于关键点热力图构建关键点特征矩阵集。
具体地,步骤S2包括:
S21、将关键点数据集中的每一样本对应的标注目标进行裁剪,得到目标图像。
S22、将样本中的所有关键点映射至目标图像中,得到关键点坐标。
具体地,将关键点数据集原本样本中的关键点坐标转换为目标图像中的坐标。
S23、根据目标图像和关键点坐标生成关键点样本对应的关键点热力图。
需要说明的是,在生成关键点样本对应的关键点热力图时,对于v不等于2的关键点(x,y,v),其在关键点热力图中的值设置为0。即表示关键点热力图中体现的关键点为在样本图像中清晰可见的。
S24、根据预设的分类模型的输入尺度调整关键点热力图的尺度,得到标准热力图。
在一些实施例中,若预设的分类模型,其对应的输入图像的尺度为W*H(宽度为W,高度为H),则标准热力图的尺度为1/4W*1/4H,即标准热力图的尺度为预设分类模型输入图像尺度的1/16。例如,当输入图像的尺度为256*256,则标准热力图的尺度为64*64。
S25、以标准热力图中的每一关键点坐标为中心,分别生成预设维度的高斯核数据。
需要说明的是,该关键点坐标为v等于2的关键点。其中高斯核数据的中心点为1,其越往***数值越小,逐渐趋于0。
S26、将所有高斯核数据进行合并,得到与每一样本对应的关键点特征矩阵。
S27、根据关键点数据集中所有样本对应的关键点特征矩阵得到关键点特征矩阵集。
S3、分别对关键点特征矩阵集和图像数据集构建对应的数据集掩码。其中数据集掩码的类型包括关键点掩码和图像掩码。
具体地,步骤S3包括:
S31、对关键点特征矩阵集中的每一关键点样本构建关键点掩码和图像掩码,并对关键点掩码进行标识。
S32、对图像数据集中的每一图像样本构建关键点掩码和图像掩码,并对图像掩码进行标识。
具体地,每一关键点样本与每一图像样本均对应有关键点掩码和图像掩码。
在一些实施例中,关键点掩码表示为kpts_mask;图像掩码表示为classify_mask,则对于关键点特征矩阵集中的每一关键点样本,将其关键点掩码kpts_mask标识为1,则其图像掩码classify_mask等于0;对于图像数据集中的每一图像样本将其图像掩码classify_mask标识为1,则其关键点掩码kpts_mask等于0。
S4、根据数据集掩码、关键点特征矩阵集以及图像数据集对预设的分类模型进行训练,得到最优图像分类模型。
具体地,步骤S4包括:
S41、将关键点特征矩阵集和图像数据集输入预设的分类模型进行分类,得到对应的预测结果。
需要说明的是,预设分类模型输入数据与输出数据的数据维度保持不变。例如,对于关键点特征矩阵集,其输入预设分类模型的数据维度为【B,n,w,h】,则预设分类模型输出的预测结果的数据维度也为【B,n,w,h】,而关键点特征矩阵集对应的关键点掩码的数据维度为【B,1,1,1】;其中,n表示一个目标所标注的关键点个数,w和h表示标准热力图的宽度和高度。对于图像数据集,其输入预设分类模型的数据维度为【B,N】,而图像数据集对应的图像掩膜的数据维度为【B,1】;其中,N表示一个目标所标注的特征个数。由此便于后续进行对应数据集的损失函数值计算。
S42、根据数据集掩码和预测结果分别计算关键点特征矩阵集和图像数据集对应的损失函数值。
具体地,步骤S42包括:
S421、根据预测结果计算关键点特征矩阵集和图像数据集中每一样本对应的原始损失函数值;
S422、根据数据集掩码将原始损失函数值进行分类,得到关键点特征矩阵集和图像数据集对应的损失函数值。
在一些实施例中,由于关键点特征矩阵集和图像数据集中每一样本均对应有关键点掩码和图像掩码;将每一样本的原始损失函数值与其关键点掩码相乘,即可得到该样本在关键点特征矩阵集的损失函数值;将每一样本的原始损失函数值与其图像掩码相乘,即可得到该样本在图像数据集的损失函数值。
S43、根据关键点特征矩阵集和图像数据集对应的损失函数值对分类模型进行迭代训练,得到多个不同参数的初始图像分类模型以及对应的精确度。
需要说明的是,将关键点特征矩阵集和图像数据集对应的损失函数值进行相加,即得到模型的损失函数值,根据模型的损失函数值对分类模型进行迭代训练。
在一些实施例中,根据图像数据集的真实标签、预测结果以及图像掩码来确定每一初始图像分类模型的精确度。具体地,先将图像数据集的预测结果preds使用sigmoid激活函数,将预测结果preds的数值转化到(0,1)之间,得到preds_1;设定合理的阈值(一般采用0.5)将preds_1二值化,使得preds_1的数值转换为0或1,得到preds_2;再通过逻辑非运算和逻辑异或运算,计算preds_2和图像数据集的实际标签lables是否相等,得到preds_true;通过逻辑与运算,计算preds_true和图像掩码classify_mask,得到图像样本上的预测正确的样本preds_true_mask。最后计算preds_true_mask和图像掩码classify_mask个数的比值,得到模型的预测精确度accuracy,该数值在0到1之间,越接近1代表模型的精度越高。
S44、选择精确度最高的初始图像分类模型作为最优图像分类模型。
具体地,在步骤S4之前,还包括:
S401、获取原始分类模型,原始分类模型包括卷积层、池化层和全连接层;
S402、将池化层和全连接层组合为原始分类模型的分类头,将卷积层组合为原始分类模型的骨干层,并在骨干层与分类头之间创建上采样层,得到预设的分类模型。
在一些实施例中,得到图像分类模型可以将上采样层删除,仅保留分类头和骨干层,即无需输出关键点特征矩阵集的损失函数值。相较于原始的分类模型,本方案的部署模型的结构和大小与其完全一致,但是模型的精度有明显的提升。
需要说明的是,当输入样本在推理经过骨干层之后,分别输入上采样层和分类头进行对应处理,得到对应的两个输出。
在一种可选的实施方式中,上采样层包括采样操作和与采样操作对应的卷积操作。骨干层与上采样层的学习率相同,分类头的学习率为骨干层的学习率的十分之一。
本发明的实施例二为:
实施例一的一种基于关键点的图像分类模型训练方法应用于实际场景中。包括以下步骤:
S110、获取COCO2017人体关键点数据集A(关键点数据集)和PA100K人体属性分类数据集B(图像数据集)。
其中,A数据集关注人物的关键点,每个人共有17个关键点,按照7:1划分为训练集和测试集,则训练集共有136532个样本,测试集共有19631个样本;B数据集关注人物属性,每个人共有26个类别,按照7:1划分为训练集和测试集,则训练集共有87553个目标,测试集共有12440个样本。
S120、将A数据集和B数据集进行合并训练前的转换。
S1201、对A数据集中每一个样本数据构建对应的掩膜数据,其中掩膜数据用于标识B数据集中所标注的26个类别是否在该样本数据中存在;具体地,B数据集中每个人共有26个类别,则掩膜数据的总数为26(分类标签)+17(关键点标签)=43个,若该类别在其样本数据中存在,则该掩膜数据对应的数值为1,否则该掩膜数据对应的数值为0。由于A数据集的样本数据关注人物的关键点,所以B数据集中的类别在A数据集的样本数据中均不存在,所以A数据集中每一个样本数据对应的掩膜数据中的分类标签为26个0。再将根据A数据集中所标注的关键点生成样本图像对应的关键点热力图,并基于关键点热力图构建关键点特征矩阵集,A数据集中每一个关键点热力图构建对应的图像掩码为0,关键点掩码为1。
S1202、对B数据集每一个样本数据构建对应的掩膜数据,具体地,B数据集中每一个样本数据对应的掩膜数据中的关键点标签为17个0。同时,将B数据集中每一个样本数据构建对应的图像掩码为1,关键点掩码为0。
由此得到A数据集与B数据集合并后的数据集C,数据集C包括以下数据:原始样本图像image、与原始样本图像image对应的分类标签classify_label、图像掩码classify_mask、关键点标签kpts_label以及关键点掩码kpts_mask。其中,上述数据的格式分别为image(3,256,256)、classify_label(26)、classify_mask(1)、kpts_label(17,64,64)以及kpts_mask(1,1,1)。后续需要确保每一批次获取到的样本数据均包括上述数据,若以N表示每一批次获取到的样本数目,则得到数据格式对应为image(N,3,256,256)、classify_label(N,26)、classify_mask(N,1)、kpts_label(N,17,64,64)以及kpts_mask(N,1,1,1)
S130、将C数据集输入预设的分类模型中进行训练,得到多个不同参数的初始图像分类模型以及对应的精确度。
在本实施例中,原始分类模型以resnet18模型为例,该模型由卷积层、池化层以及全连接层构成,在推理过程中对模型进行5次下采样,得到原始输入的宽和高的1/32的输出数据B1,然后进行全局池化和全连接。在本申请中,将池化层和全连接层组合为classify_head(分类头),其他部分组合为模型的backbone(骨干层),然后在骨干层和分类头之间新增一个upsample(上采样层),该层包含三次上采样操作和对应的卷积操作,得到预设的分类模型。分类模型在推理经过backbone(骨干层)后,分别将其送入upsample(上采样层)和classify_head(分类头),得到对应的两个输出,作为模型的最终输出。
S1301、将每一批次获取的样本数据输入预设的分类模型中进行分类,得到B数据集对应的预测结果为classify_pred(N*26),A数据集对应的预测结果为kpts_pred(N*17*64*64)。
S1302、根据A数据集对应的预测结果kpts_pred、对应的关键点标签kpts_label和关键点掩码kpts_mask,计算A数据集对应的损失函数值kpts_loss。根据B数据集对应的预测结果classify_pred、对应的分类标签classify_label和图像掩码classify_mask,计算B数据集对应的损失函数值classify_loss。将A数据集对应的损失函数值kpts_loss与B数据集对应的损失函数值classify_loss相加得到模型的最终损失函数值loss。
S1303、根据模型的最终损失函数值loss对预测的分类模型进行迭代训练,同时确定迭代得到的初始图像分类模型的精确度。具体地,分类算法的全部标签数目为S1,预测准确的数目为S2,正样本的标签数目为T1,预测为正样本的标签数目为T2,预测正确的数目为T3;模型整体的准确率为accuracy=S2/S1,正样本的召回率为Recall=T3/T1,正样本的准确率为Precision=T3/T2,正样本的f1值为:f1=(2*Recall * Precision)/(Recall+Precision)。
S140、当分类模型的训练次数达到预设次数100时,停止训练,选择精确度accuracy最高的初始图像分类模型作为最优图像分类模型。
以下为本实施例所得到的最优图像分类模型与传统的图像分类模型对图像进行分类的比对结果:
数据指标一:在PA100K数据集中共有10万张图像,其中boots属性一共有595个正样本,通过随机划分后训练集包含519个正样本,测试集包含76个正样本。
对于没有使用关键点辅助训练的分类模型,训练集中519个样本识别出来的个数为0,测试集中76个正本识别出来的个数也为0,训练集与测试集对应的f1值均为0。
对于本申请使用关键点辅助训练的分类模型,训练集中519个样本,有381个样本被识别为正样本,其中368个为正确识别,其f1值为0.817779;测试集中76个正样本,25个样本被识别为正样本,其中23个是正确识别,其f1值为0.455447,对小样本的类别精度提升明显。
数据指标二:在PA100K数据集共有10万张图像,每个图像有26个属性,总计260万个属性值,按照7:1随机划分训练集和测试集,训练集中共包含2276378的属性值,其中是正样本的属性值(期望值为1)的数目为493506,测试集共包含323440个属性标签,其中是正样本的属性值(期望值为1)的数目为70034。
对于没有使用关键点辅助训练的分类模型,训练集预测正确的数目为2221260,准确率为0.975787,预测为正样本的数目为460356,正确数目为449372,其f1值为0.975787;测试集预测正确的数目为309948,准确率为0.958286,预测为正样本的数目为65226,正确数目为60884,其f1值为0.900252。
对于本申请使用关键点辅助训练的分类模型,训练集的预测正确数目为2264771,准确率为0.994901,预测为正样本的数目为487239,正确数目为484569,其f1值为0.988166,测试集的预测正确数目为311901,准确率为0.964324,预测为正样本的数目为67495,正确数目为62995,其f1值为0.916099。
综合数据指标二可得本申请的图像分类模型:
1)训练集属性预测正确的数目由2221260提升到2264771,增加43511个;
2)测试集属性预测正确的数目由60884提升到62995,增加2111个
3)测试集的正样本的f1指标由0.900252提升到0.916099,提升0.016,属于较为明显的提升。
请参照图2,本发明的实施例三为:
一种基于关键点的图像分类模型训练终端300,包括存储器301、处理器302及存储在所述存储器301上并在所述处理器302上运行的计算机程序,所述处理器302执行所述计算机程序时实现实施例一和实施例二的一种基于关键点的图像分类模型训练方法中的各个步骤。
综上所述,本发明提供的一种基于关键点的图像分类模型训练方法及终端,获取图像数据集和对应的关键点数据集,基于关键点数据集构建对应的关键点特征矩阵集,通过对关键点特征矩阵集和图像数据集构建关键点掩码和图像掩码,从数据上对两种不同的数据集进行区分;同时,通过关键点特征矩阵集与图像数据集对预设的分类模型进行合并训练,使得分类模型在进行训练的过程中,分类模型能够有效提取关键点附近的特征,从而提高对图像分类的精确度。此外,基于数据集对应的掩码对分类模型进行训练,使得分类模型能够同时包含图像分类结果和关键点检测结果,实现无需舍弃关键点检测结果。
以上所述仅为本发明的实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等同变换,或直接或间接运用在相关的技术领域,均同理包括在本发明的专利保护范围内。

Claims (6)

1.一种基于关键点的图像分类模型训练方法,其特征在于,包括步骤:
获取图像数据集以及与所述图像数据集相对应的关键点数据集;
根据所述关键点数据集生成对应的关键点热力图,并基于所述关键点热力图构建关键点特征矩阵集;
分别对所述关键点特征矩阵集和图像数据集构建对应的数据集掩码;
根据所述数据集掩码、所述关键点特征矩阵集以及所述图像数据集对预设的分类模型进行训练,得到最优图像分类模型;
所述数据集掩码的类型包括关键点掩码和图像掩码;
所述分别对所述关键点特征矩阵集和图像数据集构建对应的数据集掩码包括:
对所述关键点特征矩阵集中的每一关键点样本构建关键点掩码和图像掩码,并对所述关键点掩码进行标识;
对所述图像数据集中的每一图像样本构建关键点掩码和图像掩码,并对所述图像掩码进行标识;
所述根据所述数据集掩码、所述关键点特征矩阵集以及所述图像数据集对预设的分类模型进行训练之前,还包括:
获取原始分类模型,所述原始分类模型包括卷积层、池化层和全连接层;
将所述池化层和所述全连接层组合为所述原始分类模型的分类头,将所述卷积层组合为所述原始分类模型的骨干层,并在所述骨干层与所述分类头之间创建上采样层,得到预设的分类模型;
所述根据所述数据集掩码、所述关键点特征矩阵集以及所述图像数据集对预设的分类模型进行训练,得到最优图像分类模型包括:
将所述关键点特征矩阵集和所述图像数据集输入预设的分类模型进行分类,得到对应的预测结果;
根据所述数据集掩码和所述预测结果分别计算所述关键点特征矩阵集和所述图像数据集对应的损失函数值;
根据所述关键点特征矩阵集和所述图像数据集对应的损失函数值分别对所述分类模型进行迭代训练,得到多个不同参数的初始图像分类模型以及对应的精确度;
选择所述精确度最高的初始图像分类模型作为最优图像分类模型;
所述根据所述数据集掩码和所述预测结果分别计算所述关键点特征矩阵集和所述图像数据集对应的损失函数值包括:
根据所述预测结果计算所述关键点特征矩阵集和所述图像数据集中每一样本对应的原始损失函数值;
根据所述数据集掩码将所述原始损失函数值进行分类,得到所述关键点特征矩阵集和所述图像数据集对应的损失函数值;
将每一样本的原始损失函数值与其关键点掩码相乘,得到所述样本在关键点特征矩阵集的损失函数值;将每一样本的原始损失函数值与其图像掩码相乘,得到所述样本在图像数据集的损失函数值。
2.根据权利要求1所述的一种基于关键点的图像分类模型训练方法,其特征在于,所述根据所述关键点数据集生成对应的关键点热力图包括:
将所述关键点数据集中的每一样本对应的标注目标进行裁剪,得到目标图像;
将所述样本中的所有关键点映射至所述目标图像中,得到关键点坐标;
根据所述目标图像和所述关键点坐标生成所述样本对应的关键点热力图。
3.根据权利要求2所述的一种基于关键点的图像分类模型训练方法,其特征在于,所述基于所述关键点热力图构建关键点特征矩阵集包括:
根据预设的分类模型的输入尺度调整所述关键点热力图的尺度,得到标准热力图;
以所述标准热力图中的每一关键点坐标为中心,分别生成预设维度的高斯核数据;
将所有所述高斯核数据进行合并,得到与所述每一样本对应的关键点特征矩阵;
根据所述关键点数据集中所有样本对应的关键点特征矩阵得到关键点特征矩阵集。
4.根据权利要求1所述的一种基于关键点的图像分类模型训练方法,其特征在于,所述骨干层与所述上采样层的学习率相同,所述分类头的学习率为所述骨干层的学习率的十分之一。
5.根据权利要求1所述的一种基于关键点的图像分类模型训练方法,其特征在于,所述上采样层包括采样操作和与所述采样操作对应的卷积操作。
6.一种基于关键点的图像分类模型训练终端,包括存储器、处理器及存储在所述存储器上并在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1-5任意一项所述的一种基于关键点的图像分类模型训练方法中的各个步骤。
CN202410004010.6A 2024-01-03 2024-01-03 一种基于关键点的图像分类模型训练方法及终端 Active CN117523320B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202410004010.6A CN117523320B (zh) 2024-01-03 2024-01-03 一种基于关键点的图像分类模型训练方法及终端

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202410004010.6A CN117523320B (zh) 2024-01-03 2024-01-03 一种基于关键点的图像分类模型训练方法及终端

Publications (2)

Publication Number Publication Date
CN117523320A CN117523320A (zh) 2024-02-06
CN117523320B true CN117523320B (zh) 2024-05-24

Family

ID=89762994

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202410004010.6A Active CN117523320B (zh) 2024-01-03 2024-01-03 一种基于关键点的图像分类模型训练方法及终端

Country Status (1)

Country Link
CN (1) CN117523320B (zh)

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110503097A (zh) * 2019-08-27 2019-11-26 腾讯科技(深圳)有限公司 图像处理模型的训练方法、装置及存储介质
CN114495089A (zh) * 2021-12-21 2022-05-13 西安电子科技大学 基于多尺度异源特征自适应融合的三维目标检测方法
CN115170870A (zh) * 2022-06-22 2022-10-11 苏州体素信息科技有限公司 基于深度学习的婴儿行为特征分类方法和***

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20230086880A1 (en) * 2021-09-20 2023-03-23 Revery.ai, Inc. Controllable image-based virtual try-on system

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110503097A (zh) * 2019-08-27 2019-11-26 腾讯科技(深圳)有限公司 图像处理模型的训练方法、装置及存储介质
CN114495089A (zh) * 2021-12-21 2022-05-13 西安电子科技大学 基于多尺度异源特征自适应融合的三维目标检测方法
CN115170870A (zh) * 2022-06-22 2022-10-11 苏州体素信息科技有限公司 基于深度学习的婴儿行为特征分类方法和***

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
Fast face-swap using convolutional neural networks;Korshunova I. et al;IEEE International Conference on Computer Vision;20171231;第3677-3685页 *
融合全局时序和局部空间特征的伪造人脸视频检测方法;陈鹏 等;信息安全学报;20200315(第2期);第78-88页 *

Also Published As

Publication number Publication date
CN117523320A (zh) 2024-02-06

Similar Documents

Publication Publication Date Title
US10410353B2 (en) Multi-label semantic boundary detection system
WO2021073417A1 (zh) 表情生成方法、装置、设备及存储介质
CN110837836B (zh) 基于最大化置信度的半监督语义分割方法
Nakajima et al. Full-body person recognition system
CN103455542B (zh) 多类识别器以及多类识别方法
EP2291722B1 (en) Method, apparatus and computer program product for providing gesture analysis
EP3869385B1 (en) Method for extracting structural data from image, apparatus and device
JP2018200685A (ja) 完全教師あり学習用のデータセットの形成
Kadam et al. Detection and localization of multiple image splicing using MobileNet V1
CN106447625A (zh) 基于人脸图像序列的属性识别方法及装置
CN111177507B (zh) 多标记业务处理的方法及装置
WO2021057148A1 (zh) 基于神经网络的脑组织分层方法、装置、计算机设备
Li et al. Deep representation of facial geometric and photometric attributes for automatic 3d facial expression recognition
CN110956167A (zh) 一种基于定位字符的分类判别强化分离的方法
Naqvi et al. Feature quality-based dynamic feature selection for improving salient object detection
Boutell et al. Multi-label Semantic Scene Classfication
Chakraborty et al. Application of daisy descriptor for language identification in the wild
CN112016592B (zh) 基于交叉领域类别感知的领域适应语义分割方法及装置
CN111797704B (zh) 一种基于相关物体感知的动作识别方法
CN117523320B (zh) 一种基于关键点的图像分类模型训练方法及终端
Hisham et al. A Systematic Literature Review of the Mobile Application for Object Recognition for Visually Impaired People
Hiremani et al. Human and Machine Vision Based Indian Race Classification Using Modified-Convolutional Neural Network.
Boroujerdi et al. Deep interactive region segmentation and captioning
CN115705688A (zh) 基于人工智能的古代及近现代艺术品鉴定方法和***
CN112598056A (zh) 一种基于屏幕监控的软件识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant