CN109344897B - 一种基于图片蒸馏的通用物体检测***及其实现方法 - Google Patents
一种基于图片蒸馏的通用物体检测***及其实现方法 Download PDFInfo
- Publication number
- CN109344897B CN109344897B CN201811150901.3A CN201811150901A CN109344897B CN 109344897 B CN109344897 B CN 109344897B CN 201811150901 A CN201811150901 A CN 201811150901A CN 109344897 B CN109344897 B CN 109344897B
- Authority
- CN
- China
- Prior art keywords
- fast rcnn
- model
- rcnn
- frequency
- detection
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于图片蒸馏的通用物体检测***及其实现方法,该***包括:Faster RCNN模型,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;Wae Faster RCNN检测模型,将输入图像分解成两个分辨率只有原图一半的子图,构建并利用Wae Faster RCNN网络结构分别对低频和高频子图进行物体检测,将两个子图的检测结果进行融合得到最终检测结果;训练指导单元,对Wae Faster RCNN检测模型进行训练,并在训练时引入知识蒸馏机制,利用已训练好的Faster RCNN模型的输出作为软目标来指导Wae Faster RCNN模型的训练。
Description
技术领域
本发明涉及计算机视觉技术领域,特别是涉及一种基于图片蒸馏的通用物体检测***及其实现方法。
背景技术
通用物体检测是计算机视觉领域最基础的研究方向,它的具体任务是对给定图像,输出该图像包含的物体的边界框和类别。近年来,随着卷积神经网络的发展,通用物体检测已取得重大进展。目前基于CNN的通用物体检测方法主要分为两种:以RCNN,FastRCNN,Faster RCNN,Mask RCNN为代表的基于分类的通用物体检测方法和以YOLO系列、SSD为代表的基于回归的物体检测方法。基于分类的通用物体检测方法一般检测精度较高于基于回归的通用物体检测方法,应用较为广泛,但其检测速度相对较慢。
具体地说,RCNN提出应用候选框策略来解决检测问题,即先用传统方法对图片预测一系列可能含有物体的候选框,再对候选框进行分类和位置微调。RCNN需要提前保存图像的候选框且每个候选框要单独经过网络提取特征,占用内存大且检测时间长;Fast RCNN采用ROI Pooling对此进行改进,使得每张图片仅需经过网络一次,速度有所提高,但仍然偏慢,Faster RCNN在Fast RCNN的基础上,提出了RPN(Region Proposal Network)来提取候选框,速度较传统方法有明显提高,但仍远远不够,Mask RCNN进一步改进Faster RCNN,添加了一个分支使用现有的检测对目标进行并行预测,提高了对小物体的检测精度,而且Mask RCNN的检测速度在5fps,已经是速度比较快的基于分类的通用物体检测框架了,但这个速度离实时检测还有些遥远。
发明内容
为克服上述现有技术存在的不足,本发明之目的在于提供一种基于图片蒸馏的通用物体检测***及其实现方法,以提高基于分类的通用物体检测技术的检测速度。
为达上述及其它目的,本发明提出一种基于图片蒸馏的通用物体检测***,包括:
Faster RCNN模型,用于构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
Wae Faster RCNN检测模型,用于将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
训练指导单元,用于对所述Wae Faster RCNN检测模型进行训练,并在所述WaeFaster RCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练。
优选地,所述Wae Faster RCNN检测模型包括:
图像分解单元,用于利用训练好的Anto-Encoder模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
检测单元,用于构建所述Wae Faster RCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测;
融合处理单元,用于对低频子图与高频子图的检测结果进行融合,得到融合后的检测结果。
优选地,所述图像分解单元采用类小波自动编码器WAE进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息。
优选地,对于低频子图与高频子图,所述检测单元分别构建所述Wae Faster RCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和Fast RCNN,该高频子网络的RPN和Fast RCNN,采用轻量版FasterRCNN的RPN和Fast RCNN。
优选地,所述轻量版Faster RCNN的部分卷积层通道数为所述完整版Faster RCNN的四分之一。
优选地,所述融合处理单元将低频子图的检测结果和高频子图的检测结果进行融合,作为最终的检测结果。
优选地,所述训练指导单元利用训练好的Faster RCNN模型的输出作为软目标对所述Wae Faster RCNN检测模型的Fast RCNN部分的训练进行指导。
为达到上述目的,本发明还提供一种基于图片蒸馏的通用物体检测***的实现方法,包括如下步骤:
步骤S1,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
步骤S2,将输入图像分解成两个分辨率只有原图一半的子图,构建Wae FasterRCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
步骤S3,对所述Wae Faster RCNN检测模型进行训练,并在Wae Faster RCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练。
优选地,步骤S2进一步包括;
步骤S201,利用训练好的分类模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
步骤S202,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,对于低频子图与高频子图,分别构建所述WaeFaster RCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和Fast RCNN,该高频子网络的RPN和Fast RCNN,采用轻量版Faster RCNN的RPN和Fast RCNN;
步骤S203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。
优选地,于步骤S3中,利用所述Faster RCNN模型的Fast RCNN得到的候选框得分指导所述Wae Faster RCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到所述Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据所述Faster RCNN模型得到的软目标Soft target与所述Wae Faster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据所述Wae Faster RCNN检测模型得到的硬输出Hard output和真实标签Hard target计算硬损失Hard loss,得到总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
与现有技术相比,本发明一种基于图片蒸馏的通用物体检测***及其实现方法通过采用类小波自动编码器将输入图像分解成两个分辨率只有原图一半的子图,然后对两个子图进行后续检测步骤,最后将两个子图的检测结果进行平均得到最终检测结果,本发明由于仅采用分辨率只有原图一半的子图进行检测使得检测速度提高了两倍,但不可避免地会导致精度的下降,因此在训练时引入知识蒸馏的机制,用复杂的但是检测精度高的Faster RCNN模型的输出作为软目标来指导检测模型的训练,从而保证检测精度。
附图说明
图1为本发明一种基于图片蒸馏的通用物体检测***的结构示意图;
图2为本发明具体实施例中基于图片蒸馏的通用物体检测***的架构示意图;
图3为本发明具体实施例中Faster RCNN模型得到软目标的过程示意图;
图4为本发明具体实施例中Wae Faster RCNN检测模型的训练过程示意图;
图5为本发明一种基于图片蒸馏的通用物体检测***的实现方法的步骤流程图。
具体实施方式
以下通过特定的具体实例并结合附图说明本发明的实施方式,本领域技术人员可由本说明书所揭示的内容轻易地了解本发明的其它优点与功效。本发明亦可通过其它不同的具体实例加以施行或应用,本说明书中的各项细节亦可基于不同观点与应用,在不背离本发明的精神下进行各种修饰与变更。
图1为本发明一种基于图片蒸馏的通用物体检测***的结构示意图。如图1所示,本发明一种基于图片蒸馏的通用物体检测***,包括:
Faster RCNN模型10,用于构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型。由于这里Faster RCNN模型的构建与训练采用的是现有技术,在此不予赘述。
Wae Faster RCNN检测模型20,用于将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果。
训练指导单元30,用于对Wae Faster RCNN检测模型进行训练,并在Wae FasterRCNN检测模型训练时引入知识蒸馏机制,用复杂的但是检测精度高的训练好的FasterRCNN模型的输出作为软目标(soft target)来指导Wae Faster RCNN检测模型的训练。
具体地,Wae Faster RCNN检测模型20进一步包括:
图像分解单元201,用于利用训练好的Auto-Encoder(自编码器)模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图。在本发明具体实施例中,图像分解单元201应用了类小波自动编码器(Wavelet-like Auto-Encoder,简称WAE)进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息。在本发明具体实施例中,图像分解的网络结构如表1所示:
表1
其中,含有“conv”的表示卷积层,括号内为卷积层参数,分别为卷积核个数、填充0个数、卷积核大小,步长,“relu”表示激活层,含“CA”的表示该层输出为低频子图,含“CH”的表示该层输出为高频子图,粗体表示该层的输出即为网络输出。
检测单元202,用于构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测。在本发明具体实施例中,对于低频子图与高频子图,分别构建低频子网络和高频子网络。Wae Faster RCNN网络的RPN(RegionProposal Network)部分,低频子网络对低频子图应用完整版Faster RCNN的RPN,高频子网络对高频子图应用轻量版Faster RCNN的RPN,其中轻量版Faster RCNN的RPN部分卷积层通道数是完整版的四分之一。在本发明具体实施例中,Wae Faster RCNN网络的低频子网络和高频子网络的RPN部分结构如下表2所示:
表2
其中,包含“conv”的表示卷积层,括号内为卷积层参数,分别为卷积核个数、填充0个数、卷积核大小,步长。“relu”表示激活层,“batchnorm”表示批量归一化层,“maxpool”表示最大池化层,括号内为最大池化层参数,分别为卷积核大小和下采样步长,“eltwise”开头的表示eltwise层,括号内为eltwise层参数,表示对每对元素的操作,非斜体部分表示RPN与Fast RCNN共享的网络结构,即主干网络,斜体部分表示RPN特有的网络结构,含“CA”的为低频子网络的部分,含“CH”的为高频子网络的部分,粗体表示该层的输出即为网络输出,表中断开的部分无特殊操作,只是为了方便表示,对断开部分上一行重新进行了排列。
Wae Faster RCNN网络的Fast RCNN部分,对低频子图应用完整版Faster RCNN的Fast RCNN,对高频子图应用轻量版Faster RCNN的Fast RCNN,其中轻量版Faster Rcnn网络的RPN部分卷积层通道数是完整版的四分之一,这里使用的Fast RCNN不完全和FasterRCNN中的一致,主要是对全卷积层的神经元个数做了修改。Wae Faster RCNN网络结构的Fast RCNN部分的具体网络结构如表3所示:
表3
其中,包含“conv”的表示卷积层,括号内为卷积层参数,分别为卷积核个数、填充0个数、卷积核大小,步长,表中“relu”表示激活层。“maxpool”表示最大池化层,括号内为最大池化层参数,分别为卷积核大小和下采样步长。“fc”开头的表示全连接层,括号内为全连接参数,为神经元个数。“ROIPooling”表示感兴趣区域池化层,括号内为感兴趣区域池化层的参数,分别为卷积核宽度、卷积核长度,空间缩放尺度(该层与输入图像相比缩小的倍数),“dropout”表示dropout层,括号内为dropout层参数,表示丢失率。“batchnorm”开头的表示批量归一化层。“concat”开头的表示连接层,括号内为连接层参数,表示按某一维度连接,“eltwise”开头的表示eltwise层,括号内为eltwise层参数,表示对每对元素的操作。非斜体部分表示RPN与Fast Rcnn共享的网络结构,即主干网络,斜体部分表示Fast RCNN特有的网络结构。含“CA”的为低频子网络的部分,含“CH”或“fusion”的为高频子网络的部分,粗体表示该层的输出即为网络输出。
融合处理单元203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。在本发明具体实施例中,融合处理单元203将低频子图与高频子图的检测结果进行平均,得到最终检测结果。
在本发明中,训练指导单元30采用Faster RCNN来指导Wae Faster RCNN检测模型的训练。经过实验发现,Wae Faster RCNN的RPN阶段生成的候选框与Faster RCNN的质量相当,差别只在于Fast RCNN部分。因此,训练指导单元30只对Fast RCNN部分的训练进行指导。具体的,训练指导单元30用Faster RCNN的Fast RCNN得到的候选框得分指导WaeFaster RCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据Faster RCNN模型得到的软目标Soft target与Wae Faster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据Wae Faster RCNN检测模型得到的硬输出Hard output和真实标签Hard target计算硬损失Hard loss,这样总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
图2为本发明具体实施例中基于图片蒸馏的通用物体检测***的架构示意图。如图2所示,左边的Teacher model为复杂模型,即Faster RCNN模型,右边的Student model为Wae Faster RCNN检测模型,其参数需要训练,它以Image I作为输入,经过Wae encodinglayer(即图像分解单元)将Image I分解成两个子图(左边是低频子图,右边是高频子图)。对于低频子图,应用复杂的模型(本发明采用如Teacher model的Faster RCNN模型,由于输入图片的分辨率减半,速度会比对原图应用teacher model快),得到检测结果(Studentmodel的左分支)。对于高频子图,应用简化的复杂模型(本发明将如teacher model的Faster RCNN模型的通道数变为原来的四分之一),得到检测结果(Student model的右分支)。将两个分支的结果进行融合得到最终结果。
虽然Student model将输入图片变为原来的一半会加快检测速度,但无疑会带来精度的下降,所以在训练的时候要引入知识蒸馏来保证精度,知识蒸馏就是用训练好的复杂模型(即左边的Teacher model)的输出来指导简单模型(右边的Student model)的训练。
训练时,将相同的图片输入Teacher model和Student model,将Teacher model得到的软目标Soft target与Student model得到的软输出Soft output计算软损失Softloss(这个过程就是知识蒸馏),同时将Student model得到的硬输出Hard output和真实标签Hard target计算硬损失Hard loss,总的分类部分的损失函数classify loss=Hardloss+λSoft loss,λ是权重。
图3为本发明具体实施例中Faster RCNN模型得到软目标的过程示意图。具体地,输入图像,经过CNN,RoI Pooling,NN得到分类结果teacher_cls和边界框位置teacher_bbox(到目前为止是Faster Rcnn模型的Fast Rcnn检测物体的过程),对于分类结果teacher_cls,先除以一个温度系数T,再进过Softmax变换,即得到软化的概率分布Softtarget(软目标)St。
以下将配合图4来具体说明本发明具体实施例中Wae Faster RCNN检测模型的训练过程,在本发明具体实施例中,Wae Faster RCNN检测模型的训练过程包括如下四个阶段
第一阶段:训练Wae Faster RCNN检测模型的RPN部分。用训练好的WAE分类网络进行Wae Faster RCNN模型的初始化。固定两个conv3_1之前的权值,只微调conv3_1之后的权值。RPN的低频子网络,高频子网络,两者输出的平均都有各自的损失函数,其损失函数类比原Faster RCNN的RPN损失函数得到。
第二阶段:训练Wae Faster RCNN检测模型的Fast RCNN部分。用训练好的WAE分类网络进行初始化,固定两个conv3_1之前的权值,只微调conv3_1之后的权值。在每次迭代时,先将当前处理的图片及对应的候选框输入到Faster RCNN,进行前向传播,得到原Faster RCNN的候选框类别得分teacher_cls,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标,图3中的St。将同样的图片及候选框输入到Wae FasterRCNN的Fast RCNN部分,进行前向传播,该过程如图4所示。低频子网络输出候选框分数CA_cls和候选框位置CA_bbox,高频子网络输出候选框分数CH_cls和候选框位置CH_bbox。将CA_cls与CH_cls进行平均得到Avg_cls,对CA_bbox和CH_bbox进行平均得到Avg_bbox,对CA_cls进行两种操作:除以温度参数T并做softmax变换得到CA_cls_soft和直接做softmax变换得到CA_cls_hard。对CH_cls和Avg_cls类似。对于低频子网络,分类损失有两部分组成:CA_cls_hard与真实值cls的交叉熵损失和CA_cls_soft与St的交叉熵损失,赋予第一个损失较小权重,定位损失为CA_bbox与真实值bbox的Smooth L1损失。高频子网络和两个子网络平均之后计算的损失类似。
第三阶段:用第二阶段得到的权值初始化Wae Faster RCNN的RPN网络,固定conv5_1以及之前的层,只微调RPN特有的层。
第四阶段:用第三阶段得到的权值初始化Wae Faster RCNN的Fast RCNN网络,固定conv5_1以及之前的层,只微调Fast RCNN特有的层。
图5为本发明一种基于图片蒸馏的通用物体检测***的实现方法的步骤流程图。如图5所示,本发明一种基于图片蒸馏的通用物体检测***的实现方法,包括如下步骤:
步骤S1,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型。由于这里Faster RCNN模型的构建与训练采用的是现有技术,在此不予赘述。
步骤S2,将输入图像分解成两个分辨率只有原图一半的子图,构建Wae FasterRCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果。
步骤S3,对Wae Faster RCNN检测模型进行训练,并在Wae Faster RCNN检测模型训练时引入知识蒸馏机制,用复杂的但是检测精度高的训练好的Faster RCNN模型的输出作为软目标(soft target)来指导Wae Faster RCNN检测模型的训练。
具体地,步骤S2进一步包括:
步骤S201,利用训练好的Auto-Encoder模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图。在本发明具体实施例中,应用了类小波自动编码器(Wavelet-like Auto-Encoder,简称WAE)进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息。
步骤S202,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测。在本发明具体实施例中,对于低频子图与高频子图,分别构建低频子网络和高频子网络。Wae Faster RCNN网络的RPN(Region ProposalNetwork)部分,低频子网络对低频子图应用完整版Faster RCNN的RPN,高频子网络对高频子图应用轻量版Faster RCNN的RPN,其中轻量版Faster RCNN的RPN部分卷积层通道数是完整版的四分之一。Wae Faster RCNN网络的Fast RCNN部分,对低频子图应用完整版FasterRCNN的Fast RCNN,对高频子图应用轻量版Faster RCNN的Fast RCNN,其中轻量版FasterRcnn网络的RPN部分卷积层通道数是完整版的四分之一,这里使用的Fast RCNN不完全和Faster RCNN模型中的一致,主要是对全卷积层的神经元个数做了修改。
步骤S203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。在本发明具体实施例中,将低频子图与高频子图的检测结果进行平均,得到最终检测结果。
于步骤S3中,采用Faster RCNN模型的输出来指导Wae Faster RCNN检测模型的训练。经过实验发现,Wae Faster RCNN的RPN阶段生成的候选框与Faster RCNN的质量相当,差别只在于Fast RCNN部分。因此,Faster RCNN模型的输出只对Fast RCNN部分的训练进行指导。具体的,于步骤S3中,用Faster RCNN的Fast RCNN得到的候选框得分指导Wae FasterRCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据Faster RCNN模型得到的软目标Soft target与Wae Faster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据Wae Faster RCNN检测模型得到的硬输出Hard output和真实标签Hard target计算硬损失Hard loss,这样总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
综上所述,本发明一种基于图片蒸馏的通用物体检测***及其实现方法通过采用类小波自动编码器将输入图像分解成两个分辨率只有原图一半的子图,然后对两个子图进行后续检测步骤,最后将两个子图的检测结果进行平均得到最终检测结果,本发明由于仅采用分辨率只有原图一半的子图进行检测使得检测速度提高了两倍,但不可避免地会导致精度的下降,因此在训练时引入知识蒸馏的机制,用复杂的但是检测精度高的Faster RCNN模型的输出作为软目标来指导检测模型的训练,从而保证检测精度。
上述实施例仅例示性说明本发明的原理及其功效,而非用于限制本发明。任何本领域技术人员均可在不违背本发明的精神及范畴下,对上述实施例进行修饰与改变。因此,本发明的权利保护范围,应如权利要求书所列。
Claims (6)
1.一种基于图片蒸馏的通用物体检测***,包括:
Faster RCNN模型,用于构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
Wae Faster RCNN检测模型,用于将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
训练指导单元,用于对所述Wae Faster RCNN检测模型进行训练,并在所述Wae FasterRCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练;
所述Wae Faster RCNN检测模型包括:
图像分解单元,用于利用训练好的Anto-Encoder模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
检测单元,用于构建所述Wae Faster RCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测;
融合处理单元,用于对低频子图与高频子图的检测结果进行融合,得到融合后的检测结果;
所述图像分解单元采用类小波自动编码器WAE进行图像分解,以将输入图像分解成分辨率只有原图一半的低频子图和高频子图,两个子图分别包含原图的低频信息和高频信息;
对于低频子图与高频子图,所述检测单元分别构建所述Wae Faster RCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和Fast RCNN,该高频子网络的RPN和Fast RCNN,采用轻量版Faster RCNN的RPN和FastRCNN。
2.如权利要求1所述的一种基于图片蒸馏的通用物体检测***,其特征在于:所述轻量版Faster RCNN的部分卷积层通道数为所述完整版Faster RCNN的四分之一。
3.如权利要求1所述的一种基于图片蒸馏的通用物体检测***,其特征在于:所述融合处理单元将低频子图的检测结果和高频子图的检测结果进行融合,作为最终的检测结果。
4.如权利要求1所述的一种基于图片蒸馏的通用物体检测***,其特征在于:所述训练指导单元利用训练好的Faster RCNN模型的输出作为软目标对所述Wae Faster RCNN检测模型的Fast RCNN部分的训练进行指导。
5.一种基于图片蒸馏的通用物体检测***的实现方法,包括如下步骤:
步骤S1,构建Faster RCNN的网络结构,并进行训练,得到训练好的Faster RCNN模型;
步骤S2,将输入图像分解成两个分辨率只有原图一半的子图,构建Wae Faster RCNN网络结构,利用所述Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,然后将两个子图的检测结果进行融合得到最终检测结果;
步骤S3,对所述Wae Faster RCNN检测模型进行训练,并在Wae Faster RCNN检测模型训练时引入知识蒸馏机制,利用训练好的Faster RCNN模型的输出作为软目标来指导所述Wae Faster RCNN检测模型的训练;
步骤S2进一步包括;
步骤S201,利用训练好的分类模型将输入图像分解成两个分辨率只有原图一半的子图,分别为低频子图和高频子图;
步骤S202,构建Wae Faster RCNN网络结构,利用Wae Faster RCNN网络结构分别对低频子图和高频子图进行物体检测,对于低频子图与高频子图,分别构建所述Wae FasterRCNN网络结构的低频子网络和高频子网络,该低频子网络的RPN和Fast RCNN,采用完整版Faster RCNN的RPN和Fast RCNN,该高频子网络的RPN和Fast RCNN,采用轻量版FasterRCNN的RPN和Fast RCNN;
步骤S203,用于对低频子图与高频子图的检测结果进行融合,得到融合的检测结果。
6.如权利要求5所述的一种基于图片蒸馏的通用物体检测***的实现方法,其特征在于,于步骤S3中,利用所述Faster RCNN模型的Fast RCNN得到的候选框得分指导所述WaeFaster RCNN检测模型的Fast RCNN的候选框得分的训练,即在每次迭代时,先将当前处理的图片及对应的候选框输入到所述Faster RCNN模型,进行前向传播,得到Faster RCNN模型的候选框类别得分,将该得分除以温度参数T,再做softmax变换,得到软化的概率分布,即软目标St,再将同样的图片及候选框输入到Wae Faster RCNN检测模型的Fast RCNN部分,进行前向传播,根据所述Faster RCNN模型得到的软目标Soft target与所述WaeFaster RCNN检测模型得到的软输出Soft output计算软损失Soft loss,并根据所述WaeFaster RCNN检测模型得到的硬输出Hard output和真实标签Hard target计算硬损失Hardloss,得到总的分类部分的损失函数classify loss=Hard loss+λSoft loss,λ是权重。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811150901.3A CN109344897B (zh) | 2018-09-29 | 2018-09-29 | 一种基于图片蒸馏的通用物体检测***及其实现方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811150901.3A CN109344897B (zh) | 2018-09-29 | 2018-09-29 | 一种基于图片蒸馏的通用物体检测***及其实现方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN109344897A CN109344897A (zh) | 2019-02-15 |
CN109344897B true CN109344897B (zh) | 2022-03-25 |
Family
ID=65307678
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201811150901.3A Active CN109344897B (zh) | 2018-09-29 | 2018-09-29 | 一种基于图片蒸馏的通用物体检测***及其实现方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN109344897B (zh) |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110335242A (zh) * | 2019-05-17 | 2019-10-15 | 杭州数据点金科技有限公司 | 一种基于多模型融合的轮胎x光病疵检测方法 |
CN112307976B (zh) * | 2020-10-30 | 2024-05-10 | 北京百度网讯科技有限公司 | 目标检测方法、装置、电子设备以及存储介质 |
CN112101573B (zh) * | 2020-11-16 | 2021-04-30 | 智者四海(北京)技术有限公司 | 一种模型蒸馏学习方法、文本查询方法及装置 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN103390164A (zh) * | 2012-05-10 | 2013-11-13 | 南京理工大学 | 基于深度图像的对象检测方法及其实现装置 |
CN103679677A (zh) * | 2013-12-12 | 2014-03-26 | 杭州电子科技大学 | 一种基于模型互更新的双模图像决策级融合跟踪方法 |
CN107563381A (zh) * | 2017-09-12 | 2018-01-09 | 国家新闻出版广电总局广播科学研究院 | 基于全卷积网络的多特征融合的目标检测方法 |
CN107886117A (zh) * | 2017-10-30 | 2018-04-06 | 国家新闻出版广电总局广播科学研究院 | 基于多特征提取和多任务融合的目标检测算法 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107358258B (zh) * | 2017-07-07 | 2020-07-07 | 西安电子科技大学 | 基于nsct双cnn通道和选择性注意机制的sar图像目标分类 |
CN108470183B (zh) * | 2018-02-05 | 2020-06-16 | 西安电子科技大学 | 基于聚类细化残差模型的极化sar分类方法 |
-
2018
- 2018-09-29 CN CN201811150901.3A patent/CN109344897B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN103390164A (zh) * | 2012-05-10 | 2013-11-13 | 南京理工大学 | 基于深度图像的对象检测方法及其实现装置 |
CN103679677A (zh) * | 2013-12-12 | 2014-03-26 | 杭州电子科技大学 | 一种基于模型互更新的双模图像决策级融合跟踪方法 |
CN107563381A (zh) * | 2017-09-12 | 2018-01-09 | 国家新闻出版广电总局广播科学研究院 | 基于全卷积网络的多特征融合的目标检测方法 |
CN107886117A (zh) * | 2017-10-30 | 2018-04-06 | 国家新闻出版广电总局广播科学研究院 | 基于多特征提取和多任务融合的目标检测算法 |
Non-Patent Citations (2)
Title |
---|
Learning a Wavelet-like Auto-Encoder to Accelerate Deep Neural Networks;Tianshui Chen et al;《arXiv》;20171220;第1-9页 * |
Learning Efficient Object Detection Models with Knowledge Distillation;Guobin Chen et al;《31st Conference on Neural Information Processing Systems (NIPS 2017)》;20171209;第1-10页 * |
Also Published As
Publication number | Publication date |
---|---|
CN109344897A (zh) | 2019-02-15 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Sun et al. | Deep RGB-D saliency detection with depth-sensitive attention and automatic multi-modal fusion | |
CN108320297B (zh) | 一种视频目标实时跟踪方法及*** | |
CN109344897B (zh) | 一种基于图片蒸馏的通用物体检测***及其实现方法 | |
CN110443173B (zh) | 一种基于帧间关系的视频实例分割方法及*** | |
CN111902825A (zh) | 多边形对象标注***和方法以及训练对象标注***的方法 | |
CN113657560B (zh) | 基于节点分类的弱监督图像语义分割方法及*** | |
CN111898432B (zh) | 一种基于改进YOLOv3算法的行人检测***及方法 | |
Batsos et al. | Recresnet: A recurrent residual cnn architecture for disparity map enhancement | |
US20230351618A1 (en) | System and method for detecting moving target based on multi-frame point cloud | |
CN114049381A (zh) | 一种融合多层语义信息的孪生交叉目标跟踪方法 | |
CN111046767B (zh) | 一种基于单目图像的3d目标检测方法 | |
KR102305230B1 (ko) | 객체 경계정보의 정확도 개선방법 및 장치 | |
US11948078B2 (en) | Joint representation learning from images and text | |
Xu et al. | Int: Towards infinite-frames 3d detection with an efficient framework | |
Hong et al. | Unified 3d and 4d panoptic segmentation via dynamic shifting networks | |
CN117649657A (zh) | 基于改进Mask R-CNN的骨髓细胞检测*** | |
CN113361431A (zh) | 一种基于图推理的人脸遮挡检测的网络模型及方法 | |
CN117115911A (zh) | 一种基于注意力机制的超图学习动作识别*** | |
CN115861223A (zh) | 一种太阳能电池板缺陷检测方法及*** | |
Cao et al. | Separable-programming based probabilistic-iteration and restriction-resolving correlation filter for robust real-time visual tracking | |
CN113192186A (zh) | 基于单帧图像的3d人体姿态估计模型建立方法及其应用 | |
CN112348102A (zh) | 一种基于查询的自底向上视频定位方法和*** | |
Gong et al. | ASAFormer: Visual tracking with convolutional vision transformer and asymmetric selective attention | |
Sun et al. | A Metaverse text recognition model based on character-level contrastive learning | |
Yuan et al. | Density-guided Translator Boosts Synthetic-to-Real Unsupervised Domain Adaptive Segmentation of 3D Point Clouds |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |