CN115115886A - 基于teacher-student模型的半监督目标检测方法 - Google Patents

基于teacher-student模型的半监督目标检测方法 Download PDF

Info

Publication number
CN115115886A
CN115115886A CN202210811820.3A CN202210811820A CN115115886A CN 115115886 A CN115115886 A CN 115115886A CN 202210811820 A CN202210811820 A CN 202210811820A CN 115115886 A CN115115886 A CN 115115886A
Authority
CN
China
Prior art keywords
model
data
teacher
image
label
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210811820.3A
Other languages
English (en)
Inventor
王磊
王瑞生
白洋
张依漪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beihang University
Beijing Jinghang Computing Communication Research Institute
Original Assignee
Beihang University
Beijing Jinghang Computing Communication Research Institute
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beihang University, Beijing Jinghang Computing Communication Research Institute filed Critical Beihang University
Priority to CN202210811820.3A priority Critical patent/CN115115886A/zh
Publication of CN115115886A publication Critical patent/CN115115886A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • G06V20/10Terrestrial scenes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/80Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
    • G06V10/806Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V2201/00Indexing scheme relating to image or video recognition or understanding
    • G06V2201/07Target detection

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • Multimedia (AREA)
  • Artificial Intelligence (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

提供了基于teacher‑student模型的半监督目标检测方法。包括以下步骤:获取半监督目标检测数据集D;在有标注数据集DL上按照全监督目标检测方法,用模型预测样本
Figure DDA0003740662500000011
的标注得到teacher模型;利用的teacher模型,对无标签数据XU做预测,生成伪标签,并用伪标签更新标注集合YU;对无标注样本集合XU进行数据增强第五步,训练student模型;将student模型的权重参数θs以指数滑动平均的方式更新至teacher模型的权重θt中;进行数轮迭代,以最终的teacher模型作为所述基于teacher‑student模型的半监督目标检测方法训练的目标模型。

Description

基于teacher-student模型的半监督目标检测方法
技术领域
本发明涉及图像处理和计算机视觉领域,特别是涉及半监督学习下的目标检测方法。
背景技术
目标检测(Object Detection)是计算机视觉和数字图像处理的一个热门方向,广泛应用于机器人导航、智能视频监控、工业检测、航空航天等诸多领域,通过计算机视觉减少对人力资本的消耗,具有重要的现实意义。因此,目标检测也就成为了近年来理论和应用的研究热点,它是图像处理和计算机视觉学科的重要分支,也是智能监控***的核心部分,同时目标检测也是泛身份识别领域的一个基础性的算法,对后续的人脸识别、步态识别、人群计数、实例分割等任务起着至关重要的作用。与图片分类任务相比,目标检测多出一个回归任务,即不仅要用算法判断图片中是否存在对象,还要在图片中标记出它的位置,对图像上标注的标定框进行回归预测。
近年来,由于深度学习的广泛运用,目标检测算法得到了较为快速的发展。然而通常目标检测模型是基于大量人工精确标注的数据集训练的,这些方法要求每一张训练的图像都有精确充分的高质量标注,即全监督方法。而往往一张图像中有多个物体,各自可能属于不同类别,这些都需要人工一一进行标注;有的物体更是由于本身物体较小、环境影响或图像失真导致肉眼难以辨认,更进一步地增加了标注数据集所需的时间与精力。因此,为了降低标注带来的大量人力消耗,如何充分使用小样本的标注数据集成为研究的一大热点。
基于此出发,半监督学习试图在小样本的标注数据基础上引入更多的无标注数据来增强模型性能。目前的大多数半监督学习方面研究针对的是图像分类任务,对于标注成本更大的目标检测却没有充分的探索。因此开展半监督目标检测方法的进一步研究具有重要的意义。
发明内容
现有的半监督目标检测方法通常采用自训练或伪标签生成的方式。一方面,自训练包括数据增强、一致正则化等要点;其核心思想在于,一个表现好的模型对于加入少量干扰后的图像,输出的预测结果应与加入干扰前保持一致;在输出置信度较低的无标签数据上应用此方法尤其有效。另一方面,伪标签生成是一种简单而高效的方法,其核心思想在于,在少部分有标签数据上训练初始模型,并使用此模型对无标签的数据做预测,将生成的结果视作新的标签,称为伪标签。以上两种方式试图充分利用无标签数据提升模型预测能力。经实践证明这两种方式单独其中一种都不足以实现较为理想的效果,两者相互结合、相互补充,可以大大改善模型预测质量。本方法即是这样两种思想的有机结合。
同时,目标检测中通常会遇到类别不平衡的问题,该问题导致模型倾向于预测主要的类别,却忽视了数量较少的类别,从而导致在部分类别物体上检测精度较低。该情况在半监督目标检测使用伪标签方法时尤其突出,仅在采样的小部分有标签数据集上训练的teacher模型非常易于过拟合,使预测结果偏向于某一类或数类,而对其余类别预测则更加不准确;且常见的伪标签方法简单迭代中,一轮迭代中的teacher模型通常为上一轮的student模型,模型参数变化较大,生成的预测结果波动也较大,使得student模型没有一个稳定的teacher模型为其提供学***均准确度)指标比不使用EMA的高出4个百分点,证明了本方法的有效性。
本发明的具体内容如下:
一种基于teacher-student模型的半监督目标检测框架包括以下步骤。
第一步,获取半监督目标检测数据集D,其中,半监督目标检测数据集D的元素来源于有标注数据集DL和无标注数据集DU
获取有标注数据集
Figure BDA0003740662480000021
N为有标注数据集中样本的数目,其样本的集合记为
Figure BDA0003740662480000022
Figure BDA0003740662480000023
分别为第1,2,…,N个样本;相应的标注的集合记为
Figure BDA0003740662480000024
Figure BDA0003740662480000025
分别为
Figure BDA0003740662480000026
对应的标注信息,对于DL中每一个元素
Figure BDA0003740662480000027
可记为由样本
Figure BDA0003740662480000028
和标注
Figure BDA0003740662480000029
形成的二元组
Figure BDA00037406624800000210
i为正整数;
获取无标注数据集
Figure BDA00037406624800000211
M为无标注数据集中样本的数目对于Du中每一个元素
Figure BDA00037406624800000212
可记为由样本
Figure BDA00037406624800000213
和标注
Figure BDA00037406624800000214
形成的二元组
Figure BDA00037406624800000215
Figure BDA00037406624800000216
来自于无标注样本集合
Figure BDA00037406624800000217
Figure BDA00037406624800000218
来自于无标注样本的标注集合
Figure BDA00037406624800000219
其中YU
Figure BDA00037406624800000220
其中的元素
Figure BDA00037406624800000221
也为空,在接下来生成伪标签的步骤中为元素
Figure BDA00037406624800000222
赋值;
第二步,在有标注数据集DL上按照全监督目标检测方法,对于每个输入的样本
Figure BDA00037406624800000223
用模型预测样本的标注,使之和真实标注
Figure BDA00037406624800000224
尽量保持一致,通过优化损失函数训练得到teacher模型。
第三步,利用teacher模型,对无标签数据XU做预测,将得到的预测结果进行置信度阈值过滤后生成伪标签(pesudo label),并用伪标签更新标注集合YU
第四步,对无标注样本集合XU进行数据增强,包括遮挡、平移、旋转、翻转、颜色变换、色彩抖动和/或高斯模糊,得到的集合记为记为
Figure BDA00037406624800000225
若使用了平移、旋转、翻转的增强方式,则对标注集合YU的元素即标签也做相应的几何变换,记为
Figure BDA00037406624800000226
得到扩充无标签数据集
Figure BDA00037406624800000227
第五步,训练student模型;训练起始时将teacher模型的全部参数赋值给student模型初始化,并让student模型在扩充后的全部数据集
Figure BDA00037406624800000228
上再次训练;训练时,在每个批次中将DL
Figure BDA00037406624800000229
两部分数据进行一定比例的混合,作为一个批次(batch)的数据送入student模型,并以标注集合YL和前述步骤中生成的集合
Figure BDA00037406624800000230
作为监督信息进行训练,计算并优化损失函数,得到student模型;还利用一致正则化对扩充无标签数据集
Figure BDA00037406624800000231
在训练中产生的损失函数作约束;迭代训练数个轮次,直至模型损失函数收敛至稳定结果;
在专利号为CN202110286708的中国专利申请中详细介绍了一致正则化(consistency regularization)。
第六步,将student模型的权重参数θs以指数滑动平均(EMA)的方式更新至teacher模型的权重θt中;该步骤要求:对于student模型中全部可以训练的参数θs,以
Figure BDA00037406624800000232
Figure BDA00037406624800000233
的公式将teacher模型权重θt迭代更新;其中j为当前迭代次数,α为参数更新权重,α值越小,单次的参数更新程度越大。
第七步,重复第三步至第六步的操作,进行数轮迭代,每一轮迭代中将teacher模型输出的集合
Figure BDA00037406624800000234
作为监督信息训练student模型,并将student模型参数更新至teacher模型中,在下一轮迭代中使用更新的teacher模型对XU重新预测,将生成的结果作为新的伪标签信息监督student模型的训练;如此循环往复,直至student模型损失函数在迭代中收敛稳定;以最终的teacher模型作为所述基于teacher-student模型的半监督目标检测方法训练的目标模型。
所述第一步中的数据集,数据文件为常见RGB图像,其标注信息为图像内所有标定框的位置、大小以及各标定框内物体所属类别。对于本方法的半监督领域来说,有一定比例(通常为50%~99%)的图像数据其标注信息是无法获取的。
所述第二步中的目标检测学习方法包括:
一、使用特征提取网络(backbone)提取图像特征,其中特征提取网络包括VGG(https://arxiv.org/pdf/1409.1556,Very Deep Convolutional Networks,2014),ResNet(httρs://arxiv.org/pdf/1512.03385.pdf,Deep Residual Learning for ImageRecognition,,2015),Mobilenet(https://arxiv.org/pdf/1704.04861,MobileNets:Efficient Convolutional Neural Networks for Mobile Vision Applications,2017),RetinaNet(https://arxiv.org/pdf/1708.02002,Focal Loss for Dense ObjectDetection,2018)或者EfficientNet(https://arxiv.org/pdf/1905.11946,EfficientNet:Rethinking Model Scaling for Convolutional Neural Networks,2019)的一种或多种;
二、使用特征加工网络(neck)对图像特征做进一步提取与优化,其中特征加工网络包括(C)BAM(https://arxiv.org/pdf/1807.06514,BAM:Bottleneck AttentionModule,2018),SPP(https://arxiv.org/pdf/1406.4729,Spatial Pyramid Pooling inDeep Convolutional Networks for Visual Recognition,2014),FPN(https://arxiv.org/pdf/1612.03144,Feature Pyramid Networks for Object Detection,2016)和/或NAS-FPN(https://arxiv.org/pdf/1904.07392,NAS-FPN:Learning ScalableFeature Pyramid Architecture for Object Detection,2019);
三、使用目标检测头(Head)预测目标的种类与位置,其中目标检测头包括SSD(https://arxiv.org/pdf/1512.02325,SSD:Single Shot MultiBox Detector,2015),YOL0(https://arxiv.org/pdf/1506.02640,You 0nly Look Once:Unified,Real-Time0bject Detection,2015),或者Faster RCNN(https://arxiv.org/pdf/1506.01497,Faster R-CNN:Towards Real-Time Object Detection with Region ProposalNetworks,2015)的一种或多种;
以上的各类网络皆由卷积神经网络(CNN)实现。在Mobilenet网络中卷积的实现形式略有不同,但网络仍为卷积结构。
四、优化损失函数将输出的预测结果与标注信息保持一致,损失函数定义为
Figure BDA0003740662480000031
Figure BDA0003740662480000032
此处数据来源于DL,即xi为样本集合XL的元素,yi为标注集合YL的元素;θ为teacher模型内的可学习参数;损失函数L(yi|xi,θ)的形式包括L1,最小均方误差(MSE),交叉熵损失或者Focal loss等。其中L1与MSE常用于回归任务,交叉熵与Focal loss常用于多类别分类任务。L1的具体形式为
Figure BDA0003740662480000033
MSE的具体形式为
Figure BDA0003740662480000034
交叉熵损失为
Figure BDA0003740662480000035
Focal loss为
Figure BDA0003740662480000036
其中α为一超参数,用于调节难易样本权重系数。
所述第三步中,使用所述teacher模型,将XU作为待预测数据输入所述teacher模型,得到的结果即为该部分数据的预测标签,称其为伪标签,并用伪标签更新标注集合YU。至此全部图像数据都具有了其相应标签,后续步骤中网络一次抽取数个数据标签对进行训练。
所述第四步中,使用的数据增强方法包括但不限于图像的遮挡、平移、旋转、翻转、颜色变换、色彩抖动和/或高斯模糊;其中,遮挡会将图片上随机选取的一定大小方框填充为纯色色块,以起到遮挡屏蔽的作用;平移指将图像整体向各个方向平移一定距离;旋转指图像以图像中心作顺时针或逆时针一定角度的旋转;翻转指以图像的中心垂直轴或水平轴作翻转;颜色变换指将数据集的RGB图像转换为HSV或者灰度图像等;色彩抖动通过对构成图像的色相产生位移,造成临近点状差异的色彩交叉效果;高斯模糊指图像与正态分布做卷积,使图像模糊化,可视作一低通滤波器。各种数据增强方法并无固定参数,参数设置视应用场景以及数据集特性而定。其中在图像数据应用平移、旋转、翻转等几何变换影响到相应标定框时,标定框亦做相应几何变化,以使图像数据与标定框标注保持一致。增强后的图像标注对补充至相应数据集DL
Figure BDA0003740662480000041
所述第五步中,训练过程中数据的混合方法如下:一个批次的训练中,分别从DL
Figure BDA0003740662480000042
中随机各取出部分数据,两者之间的比例定为一固定值δ,该值的选取视具体项目而定。
可选地,所述第五步中,所述Student模型的损失函数形式与所述teacher模型保持一致,即
Figure BDA0003740662480000043
仅数据来源变更为全部数据集
Figure BDA0003740662480000044
一致正则损失的实现则在专利号为CN202110286708的中国专利申请中有详细说明。
可选地,所述第五步中,teacher模型与student模型具有完全一致的可学习参数,且训练前student模型参数数值也与teacher模型完全相同,故称student模型为teacher模型的一份完全拷贝。
所述第六步中,α的取值小于1,通常大于0.9,以保证模型参数的稳步更新,避免剧烈的波动。EMA方法可以使参数的更新更为平滑,减少了因为异常点导致的偏移和抖动,使模型在平滑的流形上逐步优化,一步步稳定地趋向极值点,最终能够带来更好的整体效果。
所述第七步中,迭代终止的标准须以student模型的最终损失在最后数个迭代中趋于平稳而无明显下降为准,而非在单轮迭代中损失平稳。
与现有技术相比,本发明具有如下创新点:
1.其他使用伪标签的技术通常为使用训练不够充分的初始模型一次性生成伪标签,并在接下来的训练中反复使用该伪标签而并未更新,导致伪标签质量低下,使得错误不断地传递累加。本发明使用teacher-student框架协同更新模型,不断地重新生成新的伪标签,保证伪标签的质量随着迭代次数逐步上升,更有利于模型的输出向着真实的数据分布优化。
2.其他技术没有解决半监督目标检测中突出的类别不平衡问题。本发明针对该问题使用EMA指数滑动平均控制模型更新的幅度,既减轻了在小样本数据集上训练的初始模型容易将类别不平衡的偏差在迭代中放大的情况,又能够不断更新模型,充分吸收一个时间序列上模型训练时接收到的信息。同时,不准确的伪标签在训练时还会引入噪声,处于EMA控制下的更新过程还能减轻此噪声带来的影响。
3.经实验,在控制其他因素不变的情况下,不采用EMA的模型预测的平均准确度mAP为16.9%,而采用EMA的方法达到了21%,证明了方法的有效性(该模型仅用于EMA方法有效性测试,其数值结果不代表模型实际使用效果)。
根据本申请实施例提供的方法执行流程,运行在例如个人计算机、服务器、嵌入式计算设备、云计算平台等设备中。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请中记载的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。
图1为根据本发明的基于teacher-student模型的半监督目标检测方法的流程图。
图2为本发明一种实施例使用的DOTA数据集标注范例。
图3为本发明一种实施例采用ResNet特征提取网络,FPN特征加工网络,FasterRCNN特征检测头以及MSE与Focal loss损失函数时的处理流程示意图。
具体实施方式
下面结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
根据本申请的实施例,所使用数据来源于DOTA遥感影像数据集(https://captain-whu.github.io/DOTA/dataset.html)。遥感影像相比于生活中的一般图像有着小目标、物体排列密集、方向任意的特点,带来了更多挑战。DOTA数据集本身为全监督数据集,为构建半监督数据集以进行试验,抛弃其中部分数据的标签,使得有标签数据在全部训练数据中占比为ρ,该值可取为1%~50%不等。实验中取为10%。
1.获取DOTA遥感影像数据集(参看图1)。选取15000张图像作为训练数据,其中1500张图像保留其标注信息作为有标注数据集,13500张图像作为无标注数据集,作为图3中的图像集输入。
2.如图3所示搭建模型,并利用步骤1选取的包含1500张图像的有标注数据集训练teacher模型。模型结构如图3上部的Teacher模型的流程所示,训练方法为标准的全监督目标检测方法,具体为:ResNet作为特征提取网络(backbone)提取图像特征,并将提取的特征送入特征金字塔(FPN)网络做多尺度融合检测以提升小物体检测精度;将FPN输出的特征图送入Faster RCNN检测头做标定框的预测及分类,输出的预测结果输入使用最小均方误差(MSE)的标定框回归分支和使用Focal loss的标定框分类分支。Focal loss可以一定程度上缓解模型倾向简单易分样本的程度,更好地学习较难样本。
3.将训练得到的模型作为初始teacher模型(设置j=1,j为当前迭代次数),将步骤1中选取的13500张无标注图像送入teacher模型,将teacher模型预测图像产生的输出作为对应图像数据的伪标签标注(如图3中部的Teacher模型的流程所示)。
4.从teacher模型拷贝参数和结构生成student模型,在student模型上用全部15000张图像及其标注(1500张图像的有标注数据集,以及13500张图像的伪标签标注的数据集)再次训练,直至损失函数收敛(如图3的下部的Student模型的流程所示)。
5.按照
Figure BDA0003740662480000051
的EMA参数更新方式将student模型参数更新至teacher模型中,见图3中标注的“EMA参数更新”。其中,θt为teacher模型中全部可以训练的权重参数,θs为student模型的参数,j为当前迭代次数,α为参数更新权重,α值越小,单次的参数更新程度越大。迭代次数j递增(j=j+1)。
重复3、4、5步骤,直至数个迭代中student模型的损失平稳值无明显下降。以此时的teacher模型作为训练好的最终模型。
尽管已描述了本申请的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本申请范围的所有变更和修改。显然,本领域的技术人员可以对本申请进行各种改动和变型而不脱离本申请的精神和范围。这样,倘若本申请的这些修改和变型属于本申请权利要求及其等同技术的范围之内,则本申请也意图包含这些改动和变型在内。

Claims (9)

1.一种基于teacher-student模型的半监督目标检测方法,其特征在于,包括以下步骤:
第一步,获取半监督目标检测数据集D,其中,半监督目标检测数据集D的元素来源于有标注数据集DL和无标注数据集DU
获取有标注数据集
Figure FDA0003740662470000011
N为有标注数据集中样本的数目,其样本的集合记为
Figure FDA0003740662470000012
Figure FDA0003740662470000013
分别为第1,2,…,N个样本;相应的标注的集合记为
Figure FDA0003740662470000014
Figure FDA0003740662470000015
分别为
Figure FDA0003740662470000016
对应的标注信息,对于DL中每一个元素
Figure FDA0003740662470000017
Figure FDA0003740662470000018
可记为由样本
Figure FDA0003740662470000019
和标注
Figure FDA00037406624700000110
形成的二元组
Figure FDA00037406624700000111
i为正整数;
获取无标注数据集
Figure FDA00037406624700000112
M为无标注数据集中样本的数目对于DU中每一个元素
Figure FDA00037406624700000113
Figure FDA00037406624700000114
可记为由样本
Figure FDA00037406624700000115
和标注
Figure FDA00037406624700000116
形成的二元组
Figure FDA00037406624700000117
Figure FDA00037406624700000118
来自于无标注样本集合
Figure FDA00037406624700000119
Figure FDA00037406624700000120
来自于无标注样本的标注集合
Figure FDA00037406624700000121
其中YU
Figure FDA00037406624700000122
其中的元素
Figure FDA00037406624700000123
也为空,在接下来生成伪标签的步骤中为元素
Figure FDA00037406624700000124
赋值;
第二步,在有标注数据集DL上按照全监督目标检测方法,对于每个输入的样本
Figure FDA00037406624700000125
用模型预测样本
Figure FDA00037406624700000126
的标注,使之和真实标注
Figure FDA00037406624700000127
尽量保持一致,通过优化损失函数训练得到teacher模型;-
第三步,利用的teacher模型,对无标签数据XU做预测,将得到的预测结果进行置信度阈值过滤后生成伪标签,并用伪标签更新标注集合YU
第四步,对无标注样本集合XU进行数据增强,包括遮挡、平移、旋转、翻转、颜色变换、色彩抖动和/或高斯模糊,得到的集合记为记为
Figure FDA00037406624700000128
对第三步中更新后的标注集合YU也做相应调整,将调整后的标注集合记为
Figure FDA00037406624700000129
得到扩充无标签数据集
Figure FDA00037406624700000130
第五步,训练student模型;训练起始时将teacher模型的全部参数赋值给student模型初始化,并让student模型在扩充后的全部数据集
Figure FDA00037406624700000131
上再次训练;训练时,在每个批次中将DL
Figure FDA00037406624700000132
两部分数据进行一定比例的混合,作为一个批次的数据送入student模型,并以标注集合YL和前述步骤中生成的集合
Figure FDA00037406624700000133
作为监督信息进行训练,计算并优化损失函数,得到student模型;还利用一致正则化对扩充无标签数据集
Figure FDA00037406624700000134
在训练中产生的损失函数作约束;迭代训练数个轮次,直至模型损失函数收敛至稳定结果;
第六步,将student模型的权重参数θs以指数滑动平均的方式更新至teacher模型的权重θt中;该步骤要求:对于student模型中全部可以训练的参数θs,以
Figure FDA00037406624700000135
的公式将teacher模型权重θt迭代更新,其中j为当前迭代次数,a为参数更新权重;
第七步,按照第三步至第六步的操作进行数轮迭代,每一轮迭代中将teacher模型输出的集合
Figure FDA00037406624700000136
作为监督信息训练student模型,并将student模型参数更新至teacher模型中,在下一轮迭代中使用更新的teacher模型对XU重新预测,将生成的结果作为新的伪标签信息监督student模型的训练;如此循环往复,直至student模型损失函数在迭代中收敛稳定;以最终的teacher模型作为所述基于teacher-student模型的半监督目标检测方法训练的目标模型。
2.根据权利要求1所述的方法,其中所述第一步包括收集图像数据,建立数据集;对其中部分图像数据作完全且精确的标定框标注,该部分图像数据在所接收的图像数据中的比例不超过20%;该部分图像数据集即为XL,其相应的标注记为YL;其余无标注的图像数据集记为XU
3.根据权利要求1所述的方法,其中所述第二步中的目标检测学习方法包括:
一、使用特征提取网络提取图像特征,其中特征提取网络包括VGG,ResNet,Mobilenet,RetinaNet或者EfficientNet;
二、使用特征加工网络对图像特征做进一步提取与优化,其中特征加工网络包括BAM,CBAM,SPP,FPN和/或NAS-FPN;
三、使用目标检测头预测目标的种类与位置,其中目标检测头包括SSD,YOLO,FasterRCNN或者CenterNet;
四、优化损失函数将输出的预测结果与标注信息保持一致,损失函数定义为
Figure FDA0003740662470000021
Figure FDA0003740662470000022
此处数据来源于DL,即xi为样本集合XL的元素,yi为标注集合YL的元素;θ为teacher模型内的可学习参数;损失函数P(xi|yi,θ)的形式包括MSE,L2,IoU loss或者Focalloss。
4.根据权利要求2所述的方法,所述第三步中,使用所述teacher模型,将XU作为待预测数据输入所述teacher模型,得到的结果即为该部分数据的预测标签,称其为伪标签,并用伪标签更新标注集合YU
5.根据权利要求1所述的方法,所述第四步中,使用的数据增强方法包括图像的遮挡、平移、旋转、翻转、颜色变换、色彩抖动和/或高斯模糊;其中,遮挡会将图片上随机选取的一定大小方框填充为纯色色块,以起到遮挡屏蔽的作用;平移指将图像整体向各个方向平移一定距离;旋转指图像以图像中心作顺时针或逆时针的旋转;翻转指以图像的中心垂直轴或水平轴作翻转;颜色变换指将数据集的RGB图像转换为HSV或者灰度图像等;色彩抖动通过对构成图像的色相产生位移,造成临近点状差异的色彩交叉效果;高斯模糊指图像与正态分布做卷积,使图像模糊化,可视作一低通滤波器。
6.根据权利要求1所述的方法,所述第五步中,训练过程中数据的混合方法如下:一个批次的训练中,分别从DL
Figure FDA0003740662470000023
中随机各取出部分数据,两者之间的比例定为一固定值δ,该值的选取视具体项目而定。
7.根据权利要求3所述的方法,所述第五步中,所述Student模型的损失函数形式与所述teacher模型保持一致,即
Figure FDA0003740662470000024
仅数据来源变更为全部数据集
Figure FDA0003740662470000025
8.根据权利要求1所述的方法,所述第五步中,teacher模型与student模型具有完全一致的可学习参数,且训练前student模型参数数值也与teacher模型完全相同,故称student模型为teacher模型的一份完全拷贝。
9.一种信息处理设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的程序,行所述程序时实现,其特征在于,所述计算机程序由处理器执行时实现如权利要求1至8任一项所述的方法。
CN202210811820.3A 2022-07-11 2022-07-11 基于teacher-student模型的半监督目标检测方法 Pending CN115115886A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210811820.3A CN115115886A (zh) 2022-07-11 2022-07-11 基于teacher-student模型的半监督目标检测方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210811820.3A CN115115886A (zh) 2022-07-11 2022-07-11 基于teacher-student模型的半监督目标检测方法

Publications (1)

Publication Number Publication Date
CN115115886A true CN115115886A (zh) 2022-09-27

Family

ID=83332386

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210811820.3A Pending CN115115886A (zh) 2022-07-11 2022-07-11 基于teacher-student模型的半监督目标检测方法

Country Status (1)

Country Link
CN (1) CN115115886A (zh)

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110298415A (zh) * 2019-08-20 2019-10-01 视睿(杭州)信息科技有限公司 一种半监督学习的训练方法、***和计算机可读存储介质
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN112926673A (zh) * 2021-03-17 2021-06-08 清华大学深圳国际研究生院 一种基于一致性约束的半监督目标检测方法
CN113688665A (zh) * 2021-07-08 2021-11-23 华中科技大学 一种基于半监督迭代学习的遥感影像目标检测方法及***
US20220172456A1 (en) * 2019-03-08 2022-06-02 Google Llc Noise Tolerant Ensemble RCNN for Semi-Supervised Object Detection

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20220172456A1 (en) * 2019-03-08 2022-06-02 Google Llc Noise Tolerant Ensemble RCNN for Semi-Supervised Object Detection
CN110298415A (zh) * 2019-08-20 2019-10-01 视睿(杭州)信息科技有限公司 一种半监督学习的训练方法、***和计算机可读存储介质
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN112926673A (zh) * 2021-03-17 2021-06-08 清华大学深圳国际研究生院 一种基于一致性约束的半监督目标检测方法
CN113688665A (zh) * 2021-07-08 2021-11-23 华中科技大学 一种基于半监督迭代学习的遥感影像目标检测方法及***

Similar Documents

Publication Publication Date Title
TWI742382B (zh) 透過電腦執行的、用於車輛零件識別的神經網路系統、透過神經網路系統進行車輛零件識別的方法、進行車輛零件識別的裝置和計算設備
JP6504590B2 (ja) 画像のセマンティックセグメンテーションのためのシステム及びコンピューター実施方法、並びに非一時的コンピューター可読媒体
US20210398294A1 (en) Video target tracking method and apparatus, computer device, and storage medium
CN108960086B (zh) 基于生成对抗网络正样本增强的多姿态人体目标跟踪方法
US20190295227A1 (en) Deep patch feature prediction for image inpainting
Yang et al. An improving faster-RCNN with multi-attention ResNet for small target detection in intelligent autonomous transport with 6G
CN110414616B (zh) 一种利用空间关系的遥感图像字典学习分类方法
CN111667027B (zh) 多模态图像的分割模型训练方法、图像处理方法及装置
US20230281974A1 (en) Method and system for adaptation of a trained object detection model to account for domain shift
CN110717953A (zh) 基于cnn-lstm组合模型的黑白图片的着色方法和***
Vallet et al. A multi-label convolutional neural network for automatic image annotation
CN113657387A (zh) 基于神经网络的半监督三维点云语义分割方法
CN111127360A (zh) 一种基于自动编码器的灰度图像迁移学习方法
TW202226077A (zh) 資訊處理裝置及資訊處理方法
CN111444923A (zh) 自然场景下图像语义分割方法和装置
US20220335572A1 (en) Semantically accurate super-resolution generative adversarial networks
CN116597136A (zh) 一种半监督遥感图像语义分割方法与***
CN117152606A (zh) 一种基于置信度动态学习的遥感图像跨域小样本分类方法
Wu et al. STR transformer: a cross-domain transformer for scene text recognition
Lenczner et al. Weakly-supervised continual learning for class-incremental segmentation
CN111062406B (zh) 一种面向异构领域适应的半监督最优传输方法
CN116935125A (zh) 通过弱监督实现的噪声数据集目标检测方法
CN112750128A (zh) 图像语义分割方法、装置、终端及可读存储介质
CN115115886A (zh) 基于teacher-student模型的半监督目标检测方法
CN114022509A (zh) 基于多个动物的监控视频的目标跟踪方法及相关设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination