CN114049515A - 图像分类方法、***、电子设备和存储介质 - Google Patents
图像分类方法、***、电子设备和存储介质 Download PDFInfo
- Publication number
- CN114049515A CN114049515A CN202111273347.XA CN202111273347A CN114049515A CN 114049515 A CN114049515 A CN 114049515A CN 202111273347 A CN202111273347 A CN 202111273347A CN 114049515 A CN114049515 A CN 114049515A
- Authority
- CN
- China
- Prior art keywords
- model
- teacher
- student
- result
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2413—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
- G06F18/24133—Distances to prototypes
- G06F18/24137—Distances to cluster centroïds
- G06F18/2414—Smoothing the distance, e.g. radial basis function networks [RBFN]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/254—Fusion techniques of classification results, e.g. of results related to same input data
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种图像分类方法、***、电子设备和存储介质,图像分类方法包括以下步骤:构建数据集;构建至少两个分类模型并训练分类模型得到教师模型,输入训练数据至教师模型分别得到第一教师模型结果;合并第一教师模型结果得到第二教师模型结果;构建学生模型,输入训练数据至学生模型得到学生模型结果;根据第二教师模型结果与学生模型结果对学生模型进行训练得到目标模型;根据目标模型进行图像分类。本发明通过对多个教师模型进行加权,学生模型采取自主学习权重的方式,选取不同教师模型的优点进行学习,既保证了模型的计算速度又保证了模型预测的准确率,通过本发明能够快速精确的对图像进行检测和分类,保证了图片信息的准确性。
Description
技术领域
本发明涉及图像分类技术领域,特别涉及一种图像分类方法、***、电子设备和存储介质。
背景技术
随着人工智能时代的发展,深度学习模型已广泛应用于图像分类技术领域,复杂模型预测精度高但参数量过大,而简单模型的精度又比较低,实际应用中,我们需要参数量少,精度高的模型,而知识蒸馏作为一种重要的模型压缩手段,可以将复杂模型的知识迁移到简单模型中,使简单模型在参数量不变的情况下提高自身的精度,这其中的复杂模型也称为教师模型,简单模型则称为学生模型,这其中,教师模型的精度就比较重要。
现有技术中,图像分类算法主要是CNN(卷积神经网络)算法,以分层方式提取图像局部特征,但难以捕捉全局表示,于是又提出了Transformer(视觉转换器)算法,可以反映复杂的空间变换,从而构成全局表示,但忽略了局部特征。因此出现了很多将CNN和Transformer结合在一起的网络模型,而该网络模型作为教师模型就会变得过大,在知识蒸馏的过程中可能会因为容量差距过大而出现问题,反而使学生模型的性能下降,使得学生模型作为图像分类模型时,无法快速准确的对图像进行检测和分类。
发明内容
本发明要解决的技术问题是为了克服现有技术中图像分类方法存在图像分类模型性能不佳,无法快速准确的对图像进行检测和分类的缺陷,提供一种图像分类方法、***、电子设备和存储介质。
本发明是通过下述技术方案来解决上述技术问题:
根据本发明的第一方面,提供一种图像分类方法,包括以下步骤:
构建数据集,所述数据集包括若干个训练数据;
构建至少两个分类模型并训练所述分类模型以得到教师模型,输入所述训练数据至所述教师模型分别得到第一教师模型结果;
合并所述第一教师模型结果以得到第二教师模型结果;
构建一个学生模型,输入所述训练数据至所述学生模型得到学生模型结果;
根据所述第二教师模型结果与所述学生模型结果对所述学生模型进行训练以得到目标模型;
根据所述目标模型进行图像分类。
较佳地,所述构建数据集的步骤包括:
获取样本图像;
对所述样本图像进行数据增强操作,得到所述训练数据;
根据所述训练数据构建数据集。
较佳地,所述数据增强操作包括旋转、缩放、随机裁剪、平移、高斯噪声、色彩抖动、随机擦除和归一化处理中的至少一种。
较佳地,所述训练分类模型以得到教师模型的步骤包括:
将所述训练数据输入所述分类模型,通过交叉熵损失函数计算得到所述分类模型的第一交叉熵损失;
根据所述第一交叉熵损失训练所述分类模型以得到所述教师模型。
较佳地,所述合并所述第一教师模型结果并得到第二教师模型结果的步骤包括:
对所述教师模型分别设置可学习超参数,其中,所有可学习超参数之和为1;
根据所述可学习超参数计算所述第一教师模型结果的加权和以得到所述第二教师模型结果。
较佳地,所述根据所述第二教师模型结果与所述学生模型结果对所述学生模型进行训练以得到目标模型的步骤包括:
将所述训练数据输入所述学生模型,通过交叉熵损失函数计算得到所述学生模型的第二交叉熵损失;
根据所述第二教师模型结果与所述学生模型结果得到所述教师模型与所述学生模型的相对熵;
合并所述第二交叉熵损失与所述相对熵,得到所述学生模型的蒸馏损失;
根据所述蒸馏损失训练所述学生模型以得到所述目标模型。
较佳地,所述根据所述目标模型进行图像分类的步骤之前还包括对所述目标模型进行测试。
较佳地,所述对所述目标模型进行测试的步骤包括:
获取待测试图像;
输入所述待测试图像至所述目标模型,得到测试结果。
根据所述测试结果调整所述目标模型。
根据本发明的第二方面,提供一种图像分类***,包括数据集构建模块、模型构建模块、教师模型训练模块、模型结果获取模块、模型结果合并模块、学生模型训练模块和分类预测模块:
所述数据集构建模块用于构建数据集,所述数据集包括若干个训练数据;
所述模型构建模块用于构建至少两个分类模型,还用于构建一个学生模型;
所述教师模型训练模块用于训练所述分类模型以得到教师模型;
所述模型结果生成模块用于输入所述训练数据至所述教师模型分别得到第一教师模型结果,还用于输入所述训练数据至所述学生模型得到学生模型结果;
所述模型结果合并模块用于合并所述第一教师模型结果以得到第二教师模型结果;
所述学生模型训练模块用于根据所述第二教师模型结果与所述学生模型结果对所述学生模型进行训练以得到目标模型;
所述分类预测模块用于根据所述目标模型进行图像分类,还用于对所述目标模型进行测试。
根据本发明的第三方面,提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现本发明的图像分类方法。
根据本发明的第四方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现本发明的图像分类方法。
本发明的积极进步效果在于:
本发明通过构建多个教师模型并进行训练,进一步提高了教师模型的性能,并且本发明并不将多个教师模型合并成一个模型,而是对教师模型分别进行训练直至最优,通过设置可学习超参数对教师模型进行加权;构建学生模型,学生模型采取自主学习权重的方式,选取不同教师模型的优点进行学习,通过计算学生模型的蒸馏损失更新学生模型的参数以及可学习超参数,使学生模型能更好的学习到教师模型。本发明在将教师模型蒸馏至学生模型的过程中,保证学生模型参数量不变,并确保了学生模型的性能,既保证了模型的计算速度又保证了模型预测的准确率,通过本发明能够快速精确的对图像进行检测和分类,保证了图片信息的准确性,提升了用户的体验。
附图说明
图1为本发明实施例1的图像分类方法的流程示意图。
图2为本发明实施例1的图像分类方法中的模型训练的流程示意图。
图3为本发明实施例1的图像分类方法中的步骤107的流程示意图。
图4本发明实施例2的图像分类***的结构示意图。
图5发明实施例3的电子设备的结构示意图。
具体实施方式
下面通过实施例的方式进一步说明本发明,但并不因此将本发明限制在所述的实施例范围之中。
实施例1
本实施例提供一种图像分类方法,如图1所示,该图像分类方法包括以下步骤:
S101、构建数据集,数据集包括若干个训练数据。
图像分类可以根据不同的图像特征把图像划归为若干个类别中的某一种,作为一种示例,本申请实施例中的算法模型应用于OTA(在线旅游)。构建数据集中所选取的图片为OTA下酒店各个类型的图片,包括但不限于室内、室外、房型、风景等图片,以便保证数据集的多样性。为了便于说明,设定本申请实施例中的数据集是用于训练图片旋转判断模型,当然本发明并不仅限于判断图片旋转。
作为可选的一种实施方式,通过设备拍摄或者爬取网络图片等方法来获取不同的图片,默认选取的图片均为正向图,然后对选取的图片进行四个方向的旋转,得到旋转图的样本图像,并构建四个类别的数据集。由于部分图片是没有方向的,即旋转后也为正常图片,如一些俯视图、无方向物件等,则将这部分图片挑选出来组成一个类别,由此一共构建出五个类别的数据集。作为可选的一种实施方式,每个类别至少挖掘出五百个样本图像,每个样本图像仅有一个类别。
为了加强图片旋转训练模型的鲁棒性,需要增大数据的训练量,即对数据集中的样本图像进行数据增强。数据增强操作包括旋转、缩放、随机裁剪、平移、高斯噪声、色彩抖动、随机擦除和归一化处理等,属于本领域的常规技术手段,本实施例中的数据增强操作只针对在数据集中用于训练集的样本图像。
作为可选的一种实施方式,将训练集中的所有图像进行短边尺寸相同的缩放,并将样本图像进行统一大小的随机裁剪,对裁切后的图片像素点进行归一化操作,即把像素点对应到0到1之间,从而消除图像特征中单位和尺度差异的影响,以及做一些色彩抖动操作,调整图像的亮度,饱和度和对比度来对图像的颜色方面做增强,以消除图像在不同背景中存在的差异性,另外可以在图像上随机叠加一些适量的噪声,还可以在图片上随机选取一块区域并擦除图像信息,最后将这些处理后的图片作为训练模型的训练数据,当然本实施例并不限于上述数据增强操作。
S102、构建至少两个分类模型并训练分类模型以得到教师模型。
根据任务特点,搭建一些对应的深度卷积神经网络,作为可选的一种实施方式,本实施例中的分类模型可使用CNN模型和Transformer模型,分别构建CNN网络结构的模型和Transformer网络结构的模型,当然本实施例中的CNN模型和Transformer模型也可以替换为其他的模型。
CNN网络用于提取训练图像的局部信息,在CNN网络的基础上加上attention(注意力)结构,即CNN网络在经典ResNet(深度残差网络)网络的基础上加入attention模块,将中间的特征图分为多组,使用不同大小的卷积核进行计算得到不同的attention区域,并将attention区域映射到原特征图,使网络关注到有用区域部分,抑制无用区域信息。其中对于类别预测,使用了label smooth(标签平滑正则化)软化标签信息。
Transformer网络用于提取图像的全局信息,将图片切割为多个相同大小的patch(子图像块),并转为embedding(将离散变量转为连续向量)向量后,加入每个patch的位置信息和类别信息,输入Transformer网络的编码器后进行分类,其中编码器由self-attention(自注意力机制)和卷积层组成。
作为可选的一种实施方式,CNN网络训练初始化和Transformer网络训练初始化均采取基于ImageNet(大规模带标签图像数据集)数据集分别训练的预训练模型,同时Transformer网络的预训练模型蒸馏了ResNet网络的模型结果。CNN模型和Transformer模型通过上述的预训练模型进行迁移学习,在迁移学习的过程中,对模型进行短边尺寸相同的缩放后,设置统一大小的裁切,保证后续输入网络的图片大小相同,另外通过调整图片分辨率,并设置颜色抖动、随机擦除等,加强模型的鲁棒性。
作为可选的一种实施方式,将训练集中的训练数据分别送到CNN网络和Transformer网络的输入层,在网络的最高层得到对类别信息的预测结果,并通过softmax(归一化指数函数)函数将预测结果转化为概率值,其中,Pi是正确类别对应输出节点的概率值,n为分类的类别个数,yi为第i个节点的输出值,经过softmax处理之后可以将多分类的输出值转换为范围在[0,1]和为1的概率分布:
对输出的结果进行交叉熵损失计算,用L表示交叉熵表示交叉熵损失:
使用上述交叉熵损失函数对CNN模型和Transformer模型进行反向传播,从而调整CNN模型和Transformer模型的模型参数,直到模型参数不再变化,即训练若干个周期后模型收敛,在训练集上分别迭代CNN模型和Transformer模型直至最优,同时将CNN模型和Transformer模型作为后续的教师模型。
S103、输入训练数据至教师模型分别得到第一教师模型结果。
第一教师模型结果用于表征在多个教师模型中任意一个教师模型的预测结果,如图2所示,将图片输入教师模型1得到教师模型1结果,将图片输入教师模型2得到教师模型2结果。在本实施例中,教师模型1表示为CNN模型,教师模型2表示为Transformer模型。作为可选的一种实施方式,将图片输入CNN模型得到CNN模型的预测结果,将图片输入Transformer模型得到Transformer模型的预测结果,用Zt1表示CNN模型的预测结果,用Zt2表示Transformer模型的预测结果。
S104、合并第一教师模型结果以得到第二教师模型结果。
第二教师模型结果用于表征多个教师模型结果的加权和。为了合并不同教师模型的预测结果,对不同的教师模型的预测结果进行加权,作为可选的一种实施方式,对每个教师模型分别设置可学习超参数,所有教师模型的可学习超参数之和为1。在本实施例中,设置CNN模型的可学习超参数为∈,那么Transformer模型的可学习超参数就为1-∈,再将两个模型的结果集成输出,为了便于说明,在这里用Ψ表示softmax函数,用Zt表示集成后的教师模型预测结果,计算CNN结果和Transformer结果的加权和,集成公式如下:
Zt=∈*Ψ(Zt1)+(1-∈)Ψ(Zt2)
将CNN模型的预测结果和Transformer模型的预测结果代入上述的公式,即可获得集成后的教师模型预测结果。
S105、构建一个学生模型,输入训练数据至学生模型得到学生模型结果。
构建一个学生模型进行学习,为了更好的学习教师模型,构建的学生模型一般会选择一个网络结构较小并且与教师模型结构一致的网络模型。在本实施例中,教师模型有CNN模型和Transformer模型两种教师模型,作为可选的一种实施方式,选取未训练的CNN模型作为学生模型,当然学生模型并不仅限于CNN模型。
学生模型结果用于表征学生模型的预测结果,参照图2,在每一轮的训练过程中,将图片输入学生模型,获取图片在学生模型中的预测结果,用Zs表示学生模型的预测结果。
S106、根据第二教师模型结果与学生模型结果对学生模型进行训练以得到目标模型。
如图2所示,使用蒸馏损失函数对学生模型进行训练,蒸馏损失函数包括两部分,第一部分为学生模型的交叉熵损失函数,第二部分为学生模型和教师模型的相对熵。作为可选的一种实施方式,将教师模型的预测结果、学生模型的预测结果和图片原标签结合,进行蒸馏损失计算,使得学生模型既能学习到原标签结果,也能学习到教师模型的预测结果。
为了便于说明,用Loss表示蒸馏损失,LCE表示学生模型的交叉熵损失,KL表示学生模型和教师模型的相对熵,其中相对熵也叫KL散度,τ是参数,初始值取1,y表示图片原标签,Zt表示集成后的教师模型预测结果,Zs表示学生模型的预测结果,具体的蒸馏损失Loss计算公式如下:
Loss=0.5LCE(Ψ(Zs),y)+0.5τ2KL(Ψ(Zs/τ),Ψ(Zt/τ))
使用上述蒸馏损失函数对学生模型进行反向传播,训练学生模型直至最优以得到目标模型,目标模型即训练好的学生模型。本实施例的可学习超参数表示为教师模型结果的权重,作为可选的一种实施方式,通过调整可学习超参数,来调整教师模型的分布情况,使学生模型能更好的选取不同教师模型的优点进行学习,在迭代学习的训练中,不断更新可学习超参数直至学生模型的预测结果近似最优,由此得到教师模型结果的权重。在得到教师模型结果的权重后,通过进一步的迭代训练,更新学生模型的网络参数,直到学生模型收敛,得到训练好的学生模型。教师模型可以减少原始训练集的标记错误,并且通过labelsmooth软化后的标签含有更丰富的信息,让学生模型更容易学习。学生模型通过知识蒸馏的方式从多个教师模型中进行学习,来达到近似或比教师模型直接融合更佳的预测效果。
学生模型贴近教师模型的预测准确率,也可以保持自身作为小模型的预测效率。在本实施例中,设定目标模型的任务为判断图片是否为旋转图,并对旋转图进行纠错。当然本实施例中可根据不同图像分类场景建立不同的图像信息挖掘模型,其中的分类处理包括但不限于进行目标分类或检测等。S107、对目标模型进行测试。
考虑到上线部署的便利性,图像分类仅需使用训练好的学生模型(即目标模型)对图片进行前向预测,并输出图像的类别。为了使目标模型得到好的预测效果,在将目标模型部署到真实的场景中之前,还需要对目标模型进行测试,如图3所示,步骤107包括:
S201、获取待测试图像。
其中,待测试图像和用于训练模型的图像为不同的数据,作为可选的一种实施方式,可在构建数据集后对数据集进行划分,通常将数据集的80%样本图像作为训练数据,20%的样本图像作为测试数据。为了保证训练集和测试集数据分布的一致性,在对数据集进行划分时,每个类别的数据集样本都要基本遵循8:2这个比例。作为可选的另一种方式,可在模型构建完成之后,获取新的样本图像作为待测试图像,通常数量不超过构建模型前的数据集的30%。在本实施例中,对线上酒店图片进行挖掘,包括但不限于室内、室外、房型、风景等不同类型,每个类型的图片均衡,并保证与训练集中五个类别的数据分布一致,当然本实施例并不仅限于上述划分比例。
S202、输入待测试图像至目标模型得到测试结果。
在本实施例中,目标模型即图片旋转判断模型,将待测试图像输入图片旋转判断模型,获取图片的测试结果,通过图片旋转判断模型对图片不同角度进行识别,并输出图像的类别,即对应四个方向的旋转以及无方向图。
为了便于说明,设置图像的类别为A、B、C、D、E,假设输入的待测试图像有200个,且测试结果如下:
S203、根据测试结果调整目标模型。
当测试的类别与实际类别一致时,则直接输出图像的类别,如上表所示,预测正确的样本图像为24+26+28+30+32=140,则将这140个待测试图像以及它们的测试结果直接输出。作为可选的一种实施方式,可以计算出在总共200个待测试图像中,模型的准确率(Accuracy)=140/200=70%。
作为可选的一种实施方式,将上表的测试结果按照不同类别合并为二分问题,并计算出不同类别下待测试图像在模型预测时的精确率(Precision)和召回率(Recall),综合Precision与Recall计算调和平均值,得到模型分数(Score):
Score的取值范围为0到1,1代表模型的输出最好,0代表模型的输出结果最差,以上表为例,具体方法如下:
A类:
Precision(A)=24/(24+10)=70.6%,Recall(A)=24/(24+14)=63.2%
Score(A)=(2*70.6%*63.2%)/(70.6%+63.2%)=66.7%
B类:
Precision(B)=26/(26+11)=70.3%,Recall(B)=26/(26+13)=66.7%
Score(B)=(2*70.3%*66.7%)/(70.3%+66.7%)=68.5%
C类:
Precision(C)=28/(28+12)=70%,Recall(C)=28/(28+12)=70%
Score(C)=(2*70%*70%)/(70%+70%)=70%
D类:
Precision(D)=30/(30+13)=69.8%,Recall(D)=30/(30+11)=73.2%
Score(D)=(2*69.8%*73.2%)/(69.8%+73.2%)=71.5%
E类:
Precision(E)=32/(32+14)=69.6%,Recall(D)=32/(32+10)=76.2%
Score(E)=(2*69.6%*76.2%)/(69.6%+76.2%)=72.8%
可以看出E类的测试结果最好,而A类的测试模型更差,作为可选的一种实施方式,针对测试结果,可以总结错误类型,并针对性的完善目标模型,比如在数据集中增加A类样本图像,对目标模型进行训练并调整目标模型的可学习超参数,当然本实施例并不限于上述对测试结果的分析方法。
S108、根据目标模型进行图像分类。
实际使用时,模型预测的结果是一个概率值,比如判定预测结果属于A类的概率值在50%以上,则判定图像属于A类,这里的判定值就是阈值,当然本实施例的阈值并不仅限于50%。为保证线上准确率,需要选择合适的阈值,若预测结果高于阈值则输出旋转图的旋转类别,并将旋转图交给人工进行处理,若低于阈值则过滤,并输出为不旋转。
本实施例通过构建CNN模型和Transfomer模型两个分类模型作为教师模型,CNN模型可提取图像的局部特征细节,Transfomer模型可以反映复杂的空间变换和远距离特征相关性,从而提取图像的全局特征。分别对训练上述两个分类模型并作为教师模型,通过设置可学习超参数对CNN模型和Transfomer模型进行加权;构建一个学生模型,学生模型采取自主学习权重的方式,选取不同教师模型的优点进行学习,通过计算学生模型的蒸馏损失更新学生模型的参数以及可学习超参数,使学生模型能更好的学习到教师模型。在将教师模型蒸馏至学生模型的过程中,保证学生模型参数量不变,并确保了学生模型的性能,既保证了模型的计算速度又保证了模型预测的准确率。在本实施例中,通过最终训练并测试完成的学生模型对线上的旋转图片进行旋转检测,能够快速准确的发现图片问题,将旋转的图片筛选出来,可大幅度节省运营维护成本,保证图片信息的准确性,有效提升OTA场景下用户的体验。
实施例2
本实施例提供一种图像分类***,如图4所示,该图像分类***包括数据集构建模块31、模型构建模块32、教师模型训练模块33、模型结果获取模块34、模型结果合并模块35、学生模型训练模块36、分类预测模块37。
作为一种示例,本申请实施例中的算法模型应用于OTA。构建数据集中所选取的图片为OTA下酒店各个类型的图片,包括但不限于室内、室外、房型、风景等图片,以便保证数据集的多样性。为了便于说明,设定本申请实施例中的数据集是用于训练图片旋转判断模型,当然本发明并不仅限于判断图片旋转。
作为可选的一种实施方式,数据集构建模块31通过设备拍摄或者爬取网络图片等方法来获取不同的图片,默认选取的图片均为正向图,然后数据集构建模块31对选取的图片进行四个方向的旋转,得到旋转图的样本图像,并构建四个类别的数据集。由于部分图片是没有方向的,即旋转后也为正常图片,如一些俯视图、无方向物件等,则数据集构建模块31将这部分图片挑选出来组成一个类别,由此一共构建出五个类别的数据集。作为可选的一种实施方式,每个类别至少挖掘出五百个样本图像,每个样本图像仅有一个类别。
为了加强图片旋转训练模型的鲁棒性,需要增大数据的训练量,数据集构建模块31对数据集中的样本图像进行数据增强。数据增强操作包括旋转、缩放、随机裁剪、平移、高斯噪声、色彩抖动、随机擦除和归一化处理等,属于本领域的常规技术手段。数据增强操作只针对在数据集中用于训练集的样本图像。
作为可选的一种实施方式,数据集构建模块31将训练集中的所有图像进行短边尺寸相同的缩放,并将样本图像进行统一大小的随机裁剪,对裁切后的图片像素点进行归一化操作,即把像素点对应到0到1之间,从而消除图像特征中单位和尺度差异的影响,以及做一些色彩抖动操作,调整图像的亮度,饱和度和对比度来对图像的颜色方面做增强,以消除图像在不同背景中存在的差异性,另外可以在图像上随机叠加一些适量的噪声,还可以在图片上随机选取一块区域并擦除图像信息,最后数据集构建模块31最后将这些处理后的图片作为训练模型的训练数据,当然本实施例并不限于上述数据增强操作。
根据任务的特点,模型构建模块32搭建一些对应的深度卷积神经网络,作为可选的一种实施方式,本实施例中的分类模型可使用CNN模型和Transformer模型,模型构建模块32分别构建CNN网络结构的模型和Transformer网络结构的模型,需要说明的是,本实施例中的CNN模型和Transformer模型可以替换为其他的模型。
CNN网络用于提取训练图像的局部信息,模型构建模块32在CNN网络的基础上加上attention结构,即CNN网络在经典ResNet网络的基础上加入attention模块,将中间的特征图分为多组,使用不同大小的卷积核进行计算得到不同的attention区域,并将attention区域映射到原特征图,使网络关注到有用区域部分,抑制无用区域信息。其中对于类别预测,使用了label smooth软化标签信息。
Transformer网络用于提取图像的全局信息,模型构建模块32将图片切割为多个相同大小的patch,并转为embedding向量后,加入每个patch的位置信息和类别信息,输入Transformer网络的编码器后进行分类,其中编码器由self-attention和卷积层组成。
模型构建模块32还用于构建一个学生模型进行学习,为了更好的学习教师模型,构建的学生模型一般会选择一个网络结构较小并且与教师模型结构一致的网络模型。在本实施例中,教师模型有CNN模型和Transformer模型两种教师模型,作为可选的一种实施方式,模型构建模块32选取未训练的CNN模型作为学生模型,当然学生模型并不仅限于CNN模型。
作为可选的一种实施方式,教师模型训练模块33用于训练教师模型,作为可选的一种实施方式,CNN网络训练初始化和Transformer网络训练初始化均采取基于ImageNet数据集分别训练的预训练模型,同时Transformer网络的预训练模型蒸馏了ResNet网络的模型结果。CNN模型和Transformer模型通过上述的预训练模型进行迁移学习,在迁移学习的过程中,对模型进行短边尺寸相同的缩放后,设置统一大小的裁切,保证后续输入网络的图片大小相同,另外通过调整图片分辨率,并设置颜色抖动、随机擦除等,加强模型的鲁棒性。
作为可选的一种实施方式,教师模型训练模块33将训练集中的训练数据分别送到CNN网络和Transformer网络的输入层,在网络的最高层得到对类别信息的预测结果,并通过softmax函数将预测结果转化为概率值,其中,Pi是正确类别对应输出节点的概率值,n为分类的类别个数,yi为第i个节点的输出值,经过softmax处理之后可以将多分类的输出值转换为范围在[0,1]和为1的概率分布:
教师模型训练模块33对输出的结果进行交叉熵损失计算,用L表示交叉熵表示交叉熵损失:
教师模型训练模块33使用上述交叉熵损失函数对CNN模型和Transformer模型进行反向传播,从而调整CNN模型和Transformer模型的模型参数,直到模型参数不再变化,即训练若干个周期后模型收敛,在训练集上分别迭代CNN模型和Transformer模型直至最优,同时将训练完成的CNN模型和Transformer模型作为后续的教师模型。
模型结果生成模块34用于输入训练数据至教师模型得到教师模型的预测结果,在本实施例中,如图2所示,图中的教师模型1表示为CNN模型,教师模型2表示为Transformer模型。作为可选的一种实施方式,将图片输入CNN模型得到CNN模型的预测结果,将图片输入Transformer模型得到Transformer模型的预测结果,用Zt1表示CNN模型的预测结果,用Zt2表示Transformer模型的预测结果。
模型结果生成模块34还用于输入训练数据至学生模型得到学生模型的预测结果,参照图2,在每一轮的训练过程中,将图片输入学生模型,获取图片在学生模型中的预测结果,用Zs表示学生模型的预测结果。
模型结果合并模块35用于对不同教师模型的预测结果进行加权,并计算这些预测结果的加权和。作为可选的一种实施方式,对每个教师模型分别设置可学习超参数,所有教师模型的可学习超参数之和为1。在本实施例中,模型结果合并模块35设置CNN模型的可学习超参数为∈,那么Transformer模型的可学习超参数就为1-∈,再将两个模型的结果集成输出,为了便于说明,在这里用Ψ表示softmax函数,用Zt表示集成后的教师模型预测结果,计算CNN结果和Transformer结果的加权和,集成公式如下:
Zt=∈*Ψ(Zt1)+(1-∈)Ψ(Zt2)
模型结果合并模块35将CNN模型的预测结果和Transformer模型的预测结果代入上述的公式,即可得到集成后的教师模型预测结果。
学生模型训练模块36用于训练学生模型,如图2所示,使用蒸馏损失函数对学生模型进行训练,蒸馏损失函数包括两部分,第一部分为学生模型的交叉熵损失函数,第二部分为学生模型和教师模型的相对熵。作为可选的一种实施方式,将教师模型的输出结果、学生模型的输出结果和图片原标签结合,进行蒸馏损失计算,学生模型既能学习到原标签结果,也能学习到教师模型的预测结果。
为了便于说明,用Loss表示蒸馏损失,LCE表示学生模型的交叉熵损失,KL表示学生模型和教师模型的相对熵,其中相对熵也叫KL散度,τ是参数,初始值取1,y表示图片原标签,Zt表示集成后的教师模型预测结果,Zs表示学生模型的预测结果,具体的Loss计算公式如下:
Loss=0.5LCE(Ψ(Zs),y)+0.5τ2KL(Ψ(Zs/τ),Ψ(Zt/τ))
学生模型训练模块36使用上述蒸馏损失函数对学生模型进行反向传播,训练学生模型直至最优以得到目标模型,目标模型即训练好的学生模型。本实施例的可学习超参数表示为教师模型结果的权重,作为可选的一种实施方式,学生模型训练模块36通过调整可学习超参数,来调整教师模型的分布情况,使学生模型能更好的选取不同教师模型的优点进行学习,在迭代学习的训练中,不断更新可学习超参数直至学生模型的预测结果近似最优,由此得到教师模型结果的权重。在得到教师模型结果的权重后,学生模型训练模块36通过进一步的迭代训练,更新学生模型的网络参数,直到学生模型收敛,得到训练好的学生模型。教师模型可以减少原始训练集的标记错误,并且通过label smooth软化后的标签含有更丰富的信息,让学生模型更容易学习。学生模型通过知识蒸馏的方式从多个教师模型中进行学习,来达到近似或比教师模型直接融合更佳的预测效果。
分类预测模块37用于使用训练好的学生模型(即目标模型)对图片进行前向预测,并输出图像的类别。学生模型贴近教师模型的预测准确率,也可以保持自身作为小模型的预测效率。在本实施例中,设定目标模型的任务为判断图片是否为旋转图,并对旋转图进行纠错。当然本实施例中可根据不同图像分类场景建立不同的图像信息挖掘模型,其中的分类处理包括但不限于进行目标分类或检测等。
为了使目标模型得到好的预测效果,在将目标模型部署到真实的场景中之前,分类预测模块37还用于对目标模型进行测试。
作为可选的一种实施方式,数据集构建模块31可在构建数据集后对数据集进行划分,通常将数据集的80%样本图像作为训练数据,20%的样本图像作为测试数据。为了保证训练集和测试集数据分布的一致性,数据集构建模块31在对数据集进行划分时,每个类别的数据集样本都要基本遵循8:2这个比例。作为可选的另一种方式,数据集构建模块31可在模型构建完成之后,获取新的样本图像作为待测试图像,通常数量不超过构建模型前的数据集的30%。在本实施例中,对线上酒店图片进行挖掘,包括但不限于室内、室外、房型、风景等不同类型,每个类型的图片均衡,并保证与训练集中的数据分布一致,当然本实施例并不仅限于上述划分比例。
在本实施例中,目标模型即图片旋转判断模型,分类预测模块37将待测试图像输入图片旋转判断模型,获取图片的测试结果,通过图片旋转判断模型对图片不同角度进行识别,并输出图像的类别,即对应四个方向的旋转以及无方向图。
当测试的类别与实际类别一致时,则直接输出图像的类别,作为可选的一种实施方式,可以计算出模型的准确率(Accuracy)。
作为可选的一种实施方式,将上表的测试结果按照不同类别合并为二分问题,并计算出不同类别下待测试图像在模型预测时的精确率(Precision)和召回率(Recall),综合Precision与Recall计算调和平均值,得到模型分数(Score):
Score的取值范围为0到1,1代表模型的输出最好,0代表模型的输出结果最差。
作为可选的一种实施方式,分类预测模块37用于针对测试结果,将错误的错误的图像类别记录下来,总结错误类型,并通知人工针对性的完善目标模型。
实际使用时,分类预测模块37预测的结果是一个概率值,比如判定预测结果属于某一类别的概率值在50%以上,则判定图像属于该类别,这里的判定值就是阈值,当然本实施例的阈值并不仅限于50%。为保证线上准确率,分类预测模块37需要选择合适的阈值,若预测结果高于阈值则输出旋转图的旋转类别,并将旋转图交给人工进行处理,若低于阈值则过滤,并输出为不旋转。
实施例3
图5为本实施例提供的一种电子设备的结构示意图。所述电子设备包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现实施例1的图像分类方法。图5显示的电子设备40仅仅是一个示例,不应对本发明实施例的功能和使用范围带来任何限制。
如图5所示,电子设备40可以以通用计算设备的形式表现,例如其可以为服务器设备。电子设备40的组件可以包括但不限于:上述至少一个处理器41、上述至少一个存储器42、连接不同***组件(包括存储器42和处理器41)的总线43。
总线43包括数据总线、地址总线和控制总线。
存储器42可以包括易失性存储器,例如随机存取存储器(RAM)421和/或高速缓存存储器422,还可以进一步包括只读存储器(ROM)423。
存储器42还可以包括具有一组(至少一个)程序模块424的程序/实用工具425,这样的程序模块424包括但不限于:操作***、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。
处理器41通过运行存储在存储器42中的计算机程序,从而执行各种功能应用以及数据处理,例如本发明实施例1的图像分类方法。
电子设备40也可以与一个或多个外部设备44(例如键盘、指向设备等)通信。这种通信可以通过输入/输出(I/O)接口45进行。并且,模型生成的设备40还可以通过网络适配器46与一个或者多个网络(例如局域网(LAN),广域网(WAN)和/或公共网络,例如因特网)通信。如图所示,网络适配器46通过总线43与模型生成的设备40的其它模块通信。应当明白,尽管图中未示出,可以结合模型生成的设备40使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理器、外部磁盘驱动阵列、RAID(磁盘阵列)***、磁带驱动器以及数据备份存储***等。
应当注意,尽管在上文详细描述中提及了电子设备的若干单元/模块或子单元/模块,但是这种划分仅仅是示例性的并非强制性的。实际上,根据本发明的实施方式,上文描述的两个或更多单元/模块的特征和功能可以在一个单元/模块中具体化。反之,上文描述的一个单元/模块的特征和功能可以进一步划分为由多个单元/模块来具体化。
实施例5
本实施例提供了一种计算机可读存储介质,其上存储有计算机程序,所述程序被处理器执行时实现实施例1的图像分类方法的步骤。
其中,可读存储介质可以采用的更具体可以包括但不限于:便携式盘、硬盘、随机存取存储器、只读存储器、可擦拭可编程只读存储器、光存储器件、磁存储器件或上述的任意合适的组合。
在可能的实施方式中,本发明还可以实现为一种程序产品的形式,其包括程序代码,当所述程序产品在终端设备上运行时,所述程序代码用于使所述终端设备执行实现实施例1的图像分类方法的步骤。
其中,可以以一种或多种程序设计语言的任意组合来编写用于执行本发明的程序代码,所述程序代码可以完全地在用户设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户设备上部分在远程设备上执行或完全在远程设备上执行。
虽然以上描述了本发明的具体实施方式,但是本领域的技术人员应当理解,这仅是举例说明,本发明的保护范围是由所附权利要求书限定的。本领域的技术人员在不背离本发明的原理和实质的前提下,可以对这些实施方式做出多种变更或修改,但这些变更和修改均落入本发明的保护范围。
Claims (11)
1.一种图像分类方法,其特征在于,包括以下步骤:
构建数据集,所述数据集包括若干个训练数据;
构建至少两个分类模型并训练所述分类模型以得到教师模型;
输入所述训练数据至所述教师模型分别得到第一教师模型结果;
合并所述第一教师模型结果以得到第二教师模型结果;
构建一个学生模型,输入所述训练数据至所述学生模型得到学生模型结果;
根据所述第二教师模型结果与所述学生模型结果对所述学生模型进行训练以得到目标模型;
根据所述目标模型进行图像分类。
2.根据权利要求1所述的图像分类方法,其特征在于,所述构建数据集的步骤包括:
获取样本图像;
对所述样本图像进行数据增强操作,得到所述训练数据;
根据所述训练数据构建数据集。
3.根据权利要求2所述的图像分类方法,其特征在于,所述数据增强操作包括旋转、缩放、随机裁剪、平移、高斯噪声、色彩抖动、随机擦除和归一化处理中的至少一种。
4.根据权利要求1所述的图像分类方法,其特征在于,所述训练所述分类模型以得到教师模型的步骤包括:
将所述训练数据输入所述分类模型,通过交叉熵损失函数计算得到所述分类模型的第一交叉熵损失;
根据所述第一交叉熵损失训练所述分类模型以得到所述教师模型。
5.根据权利要求1所述的图像分类方法,其特征在于,所述合并所述第一教师模型结果并得到第二教师模型结果的步骤包括:
对所述教师模型分别设置可学习超参数,其中,所有可学习超参数之和为1;
根据所述可学习超参数计算所述第一教师模型结果的加权和以得到所述第二教师模型结果。
6.根据权利要求1所述的图像分类方法,其特征在于,所述根据所述第二教师模型结果与所述学生模型结果对所述学生模型进行训练以得到目标模型的步骤包括:
将所述训练数据输入所述学生模型,通过交叉熵损失函数计算得到所述学生模型的第二交叉熵损失;
根据所述第二教师模型结果与所述学生模型结果得到所述教师模型与所述学生模型的相对熵;
合并所述第二交叉熵损失与所述相对熵,得到所述学生模型的蒸馏损失;
根据所述蒸馏损失训练所述学生模型以得到所述目标模型。
7.根据权利要求1所述的图像分类方法,其特征在于,所述根据所述目标模型进行图像分类的步骤之前还包括对所述目标模型进行测试。
8.根据权利要求7所述的图像分类方法,其特征在于,所述对所述目标模型进行测试的步骤包括:
获取待测试图像;
输入所述待测试图像至所述目标模型,得到测试结果;
根据所述测试结果调整所述目标模型。
9.一种图像分类***,其特征在于,包括数据集构建模块、模型构建模块、教师模型训练模块、模型结果生成模块、模型结果合并模块、学生模型训练模块和分类预测模块:
所述数据集构建模块用于构建数据集,所述数据集包括若干个训练数据;
所述模型构建模块用于构建至少两个分类模型,还用于构建一个学生模型;
所述教师模型训练模块用于训练所述分类模型以得到教师模型;
所述模型结果生成模块用于输入所述训练数据至所述教师模型分别得到第一教师模型结果,还用于输入所述训练数据至所述学生模型得到学生模型结果;
所述模型结果合并模块用于合并所述第一教师模型结果以得到第二教师模型结果;
所述学生模型训练模块用于根据所述第二教师模型结果与所述学生模型结果对所述学生模型进行训练以得到目标模型;
所述分类预测模块用于根据所述目标模型进行图像分类。
10.一种电子设备,其特征在于,包括存储器以及与所述存储器连接的处理器,所述处理器执行存储在所述存储器上的计算机程序时实现权利要求1-8中任一项所述图像分类方法。
11.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-8中任一项所述的图像分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111273347.XA CN114049515A (zh) | 2021-10-29 | 2021-10-29 | 图像分类方法、***、电子设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111273347.XA CN114049515A (zh) | 2021-10-29 | 2021-10-29 | 图像分类方法、***、电子设备和存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114049515A true CN114049515A (zh) | 2022-02-15 |
Family
ID=80206545
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111273347.XA Pending CN114049515A (zh) | 2021-10-29 | 2021-10-29 | 图像分类方法、***、电子设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114049515A (zh) |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114927190A (zh) * | 2022-06-17 | 2022-08-19 | 吉林大学 | 一种基于半监督-迁移学习的分布式隐私保护方法及*** |
CN114972877A (zh) * | 2022-06-09 | 2022-08-30 | 北京百度网讯科技有限公司 | 一种图像分类模型训练方法、装置及电子设备 |
CN116166889A (zh) * | 2023-02-21 | 2023-05-26 | 深圳市天下房仓科技有限公司 | 酒店产品筛选方法、装置、设备及存储介质 |
CN117437459A (zh) * | 2023-10-08 | 2024-01-23 | 昆山市第一人民医院 | 基于决策网络实现用户膝关节髌骨软化状态分析方法 |
WO2024099032A1 (zh) * | 2022-11-09 | 2024-05-16 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置和计算机设备 |
-
2021
- 2021-10-29 CN CN202111273347.XA patent/CN114049515A/zh active Pending
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114972877A (zh) * | 2022-06-09 | 2022-08-30 | 北京百度网讯科技有限公司 | 一种图像分类模型训练方法、装置及电子设备 |
CN114927190A (zh) * | 2022-06-17 | 2022-08-19 | 吉林大学 | 一种基于半监督-迁移学习的分布式隐私保护方法及*** |
WO2024099032A1 (zh) * | 2022-11-09 | 2024-05-16 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置和计算机设备 |
CN116166889A (zh) * | 2023-02-21 | 2023-05-26 | 深圳市天下房仓科技有限公司 | 酒店产品筛选方法、装置、设备及存储介质 |
CN116166889B (zh) * | 2023-02-21 | 2023-12-12 | 深圳市天下房仓科技有限公司 | 酒店产品筛选方法、装置、设备及存储介质 |
CN117437459A (zh) * | 2023-10-08 | 2024-01-23 | 昆山市第一人民医院 | 基于决策网络实现用户膝关节髌骨软化状态分析方法 |
CN117437459B (zh) * | 2023-10-08 | 2024-03-22 | 昆山市第一人民医院 | 基于决策网络实现用户膝关节髌骨软化状态分析方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11798132B2 (en) | Image inpainting method and apparatus, computer device, and storage medium | |
CN114049515A (zh) | 图像分类方法、***、电子设备和存储介质 | |
US20210358170A1 (en) | Determining camera parameters from a single digital image | |
CN110728295B (zh) | 半监督式的地貌分类模型训练和地貌图构建方法 | |
CN106537379A (zh) | 细粒度图像相似性 | |
Lin et al. | Local and global encoder network for semantic segmentation of Airborne laser scanning point clouds | |
Pham et al. | Road damage detection and classification with YOLOv7 | |
CN111274981B (zh) | 目标检测网络构建方法及装置、目标检测方法 | |
Hu et al. | Boosting lightweight depth estimation via knowledge distillation | |
CN110443279B (zh) | 一种基于轻量级神经网络的无人机图像车辆检测方法 | |
CN111723660A (zh) | 一种用于长形地面目标检测网络的检测方法 | |
CN110765882A (zh) | 一种视频标签确定方法、装置、服务器及存储介质 | |
CN106407978B (zh) | 一种结合似物度的无约束视频中显著物体检测方法 | |
US20230281974A1 (en) | Method and system for adaptation of a trained object detection model to account for domain shift | |
CN113378897A (zh) | 基于神经网络的遥感图像分类方法、计算设备及存储介质 | |
JP2019185787A (ja) | 地理的地域内のコンテナのリモート決定 | |
KR20210026542A (ko) | 기하학적 이미지를 이용한 인공신경망 기반 단백질 결합 화합물의 생물학적 활성 예측 시스템 | |
CN111027551B (zh) | 图像处理方法、设备和介质 | |
CN115953330B (zh) | 虚拟场景图像的纹理优化方法、装置、设备和存储介质 | |
CN117371511A (zh) | 图像分类模型的训练方法、装置、设备及存储介质 | |
US20230401670A1 (en) | Multi-scale autoencoder generation method, electronic device and readable storage medium | |
CN111144422A (zh) | 一种飞机部件的定位识别方法和*** | |
CN115512207A (zh) | 一种基于多路特征融合及高阶损失感知采样的单阶段目标检测方法 | |
CN114202694A (zh) | 基于流形混合插值和对比学习的小样本遥感场景图像分类方法 | |
CN113609957A (zh) | 一种人体行为识别方法及终端 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |