CN111639744A - 学生模型的训练方法、装置及电子设备 - Google Patents

学生模型的训练方法、装置及电子设备 Download PDF

Info

Publication number
CN111639744A
CN111639744A CN202010297966.1A CN202010297966A CN111639744A CN 111639744 A CN111639744 A CN 111639744A CN 202010297966 A CN202010297966 A CN 202010297966A CN 111639744 A CN111639744 A CN 111639744A
Authority
CN
China
Prior art keywords
model
feature
student model
student
distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010297966.1A
Other languages
English (en)
Other versions
CN111639744B (zh
Inventor
曾凡高
张有才
危夷晨
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Megvii Technology Co Ltd
Original Assignee
Beijing Megvii Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Megvii Technology Co Ltd filed Critical Beijing Megvii Technology Co Ltd
Priority to CN202010297966.1A priority Critical patent/CN111639744B/zh
Publication of CN111639744A publication Critical patent/CN111639744A/zh
Application granted granted Critical
Publication of CN111639744B publication Critical patent/CN111639744B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明提供了一种学生模型的训练方法、装置及电子设备,涉及人工智能领域,学生模型通过知识蒸馏方式向已训练好的教师模型学习,学生模型和教师模型均为物体检测模型,该方法包括:获取训练样本的候选样本区域;分别通过学生模型和教师模型对训练样本的候选样本区域进行特征提取,得到学生模型提取出的第一特征和教师模型提取出的第二特征;获取第一特征的置信度;根据第一特征、第二特征和第一特征的置信度确定学生模型和教师模型之间的蒸馏损失;基于蒸馏损失更新学生模型的参数。本发明可以使学生模型能够针对不同样本进行不同程度的参数更新,使训练好的学生模型具有更优秀的性能,从而提升物体检测效果。

Description

学生模型的训练方法、装置及电子设备
技术领域
本发明涉及人工智能领域,尤其是涉及一种学生模型的训练方法、装置及电子设备。
背景技术
知识蒸馏是一种模型压缩常见方法,在教师-学生框架中,将复杂、学习能力强的教师模型学到的特征表示“知识”蒸馏出来,传递给参数量小、学习能力弱的学生网络。由于物体检测的知识蒸馏中样本数量通常较大,而样本质量却参差不及,诸如样本中可能包括脏样本或过难样本,如果一味要求学生模型在所有样本上进行模仿,会严重影响学生模型的性能,学生模型在训练过程中的蒸馏效果较差,从而导致学生模型在物体检测时的检测效果不佳。
发明内容
有鉴于此,本发明的目的在于提供一种学生模型的训练方法、装置及电子设备,使得学生模型能够针对不同样本进行不同程度的参数更新,使训练好的学生模型具有更优秀的性能,从而提升物体检测效果。
为了实现上述目的,本发明实施例采用的技术方案如下:
第一方面,本发明实施例提供了一种学生模型的训练方法,学生模型通过知识蒸馏方式向已训练好的教师模型学习,学生模型和教师模型均为物体检测模型,方法包括:获取训练样本的候选样本区域;分别通过学生模型和教师模型对训练样本的候选样本区域进行特征提取,得到学生模型提取出的第一特征和教师模型提取出的第二特征;获取第一特征的置信度;根据第一特征、第二特征和第一特征的置信度确定学生模型和教师模型之间的蒸馏损失;基于蒸馏损失更新学生模型的参数。
进一步,获取第一特征的置信度的步骤,包括:将第一特征输入至方差生成网络中,得到方差生成网络输出的第一特征的方差,通过方差表征第一特征的置信度;其中,方差生成网络包括卷积层和/或全连接层,且方差与置信度呈负相关。
进一步,根据第一特征、第二特征和第一特征的置信度确定学生模型和教师模型之间的蒸馏损失的步骤,包括:按照如下公式确定学生模型和教师模型之间的蒸馏损失:
Figure BDA0002451931770000021
其中,d为特征维度,N为样本数量;
Figure BDA0002451931770000022
为第一特征;
Figure BDA0002451931770000023
为第二特征;
Figure BDA0002451931770000024
为方差。
进一步,基于蒸馏损失更新学生模型的参数的步骤,包括:获取学生模型执行物体检测任务的任务损失;根据任务损失和蒸馏损失更新学生模型的参数。
进一步,获取训练样本的候选样本区域的步骤,包括:将训练样本输入至候选区域提取网络,得到候选样本区域。
进一步,获取训练样本的候选样本区域的步骤,包括:根据携带有真值框的标注信息确定训练样本的候选样本区域。
进一步,方法还包括:将待检测图像输入训练后的学生模型,基于训练后的学生模型对待检测图像进行物体检测,得到物体检测结果。
第二方面,本发明实施例还提供一种学生模型的训练装置,学生模型通过知识蒸馏方式向已训练好的教师模型学习,学生模型和教师模型均为物体检测模型,装置包括:获取模块,用于获取训练样本的候选样本区域;特征提取模块,用于分别通过学生模型和教师模型对训练样本的候选样本区域进行特征提取,得到学生模型提取出的第一特征和教师模型提取出的第二特征;置信度获取模块,用于获取第一特征的置信度;蒸馏损失确定模块,用于根据第一特征、第二特征和第一特征的置信度确定学生模型和教师模型之间的蒸馏损失;参数更新模块,用于基于蒸馏损失更新学生模型的参数。
第三方面,本发明实施例提供了一种电子设备,包括:处理器和存储装置;存储装置上存储有计算机程序,计算机程序在被处理器运行时执行如上述实施方式任一项的方法。
第四方面,本发明实施例提供了一种计算机可读存储介质,计算机可读存储介质上存储有计算机程序,计算机程序被处理器运行时执行如上述实施方式任一项的方法的步骤。
本发明实施例提供了一种学生模型的训练方法、装置及电子设备,学生模型通过知识蒸馏方式向已训练好的教师模型学习,学生模型和教师模型均为物体检测模型,该方法首先获取训练样本的候选样本区域,并分别通过学生模型和教师模型对该训练样本的候选样本区域进行特征提取,得到学生模型提取出的第一特征和教师模型提取出的第二特征,然后获取第一特征的置信度,根据第一特征、第二特征和第一特征的置信度确定学生模型和教师模型之间的蒸馏损失,最后基于蒸馏损失更新学生模型的参数。上述方式通过在训练前获取学生模型针对训练样本提取的第一特征的置信度,可以确定样本质量对学生模型的影响,蒸馏损失的确定也与学生模型针对训练样本提取的特征的置信度相关,从而基于蒸馏损失对学生模型进行训练,学生模型提取的不同样本的特征的置信度不同,蒸馏损失不同,相应的模型参数更新效果也不同,也即在学生模型训练过程中针对不同训练样本的知识迁移强度不同,学生模型可针对不同样本进行自适应知识迁移,这种方式使得学生模型能够针对不同样本进行不同程度的参数更新,使训练好的学生模型具有更优秀的性能,从而提升物体检测效果。
本发明实施例的其他特征和优点将在随后的说明书中阐述,或者,部分特征和优点可以从说明书推知或毫无疑义地确定,或者通过实施本发明实施例的上述技术即可得知。
为使本发明的上述目的、特征和优点能更明显易懂,下文特举较佳实施例,并配合所附附图,作详细说明如下。
附图说明
为了更清楚地说明本发明具体实施方式或现有技术中的技术方案,下面将对具体实施方式或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了本发明实施例所提供的一种电子设备的结构图;
图2示出了本发明实施例所提供的一种学生模型的训练方法的流程图;
图3示出了本发明实施例所提供的另一种学生模型的训练方法的流程图;
图4示出了本发明实施例所提供的一种蒸馏框架的结构示意图;
图5示出了本发明实施例所提供的一种学生模型的训练装置的结构框图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合附图对本发明的技术方案进行描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。
目前,在将知识蒸馏应用于物体检测的方案,通常都探究蒸馏的位置(诸如特征图的位置、特征图上的区域、分类的概率等),却对所有蒸馏的样本一视同仁。考虑到物体检测任务的样本质量存在较大差异,低质量样本(诸如模糊图片等)的数量也远比其他视觉任务(诸如人脸识别)要多,一味要求学生模型在所有样本上进行模仿,会严重影响学生模型的性能,进而导致物体检测的蒸馏方法效果较差。为改善此问题,本发明实施例提供了一种学生模型的训练方法、装置及电子设备,以下对本发明实施例进行详细介绍。
实施例一:
首先,参照图1来描述用于实现本发明实施例的一种学生模型的训练方法、装置及电子设备的示例电子设备100。
如图1所示的一种电子设备的结构示意图,电子设备100包括一个或多个处理器102、一个或多个存储装置104、输入装置106、输出装置108以及图像采集装置110,这些组件通过总线***112和/或其它形式的连接机构(未示出)互连。应当注意,图1所示的电子设备100的组件和结构只是示例性的,而非限制性的,根据需要,所述电子设备也可以具有其他组件和结构。
所述处理器102可以采用数字信号处理器(DSP)、现场可编程门阵列(FPGA)、可编程逻辑阵列(PLA)中的至少一种硬件形式来实现,所述处理器102可以是中央处理单元(CPU)、图形处理单元(GPU)或者具有数据处理能力和/或指令执行能力的其它形式的处理单元中的一种或几种的组合,并且可以控制所述电子设备100中的其它组件以执行期望的功能。
所述存储装置104可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理器102可以运行所述程序指令,以实现下文所述的本发明实施例中(由处理器实现)的客户端功能以及/或者其它期望的功能。在所述计算机可读存储介质中还可以存储各种应用程序和各种数据,例如所述应用程序使用和/或产生的各种数据等。
所述输入装置106可以是用户用来输入指令的装置,并且可以包括键盘、鼠标、麦克风和触摸屏等中的一个或多个。
所述输出装置108可以向外部(例如,用户)输出各种信息(例如,图像或声音),并且可以包括显示器、扬声器等中的一个或多个。
所述图像采集装置110可以拍摄用户期望的图像(例如照片、视频等),并且将所拍摄的图像存储在所述存储装置104中以供其它组件使用。
示例性地,用于实现根据本发明实施例的学生模型的训练方法、装置及电子设备的示例电子设备可以被实现为诸如机器人、智能手机、平板电脑、计算机等智能终端。
实施例二:
本实施例提供了一种学生模型的训练方法,学生模型和教师模型均为物体检测模型,物体检测模型可以采用用于执行物体检测任务的单阶段检测器或两阶段检测器,本发明实施例对物体检测模型的结构不作限定。为便于对本实施例提供的物体检测蒸馏中学生模型的训练方法进行理解,参见图2所示的一种学生模型的训练方法的流程图,学生模型通过知识蒸馏方式向已训练好的教师模型学习,该方法主要包括如下步骤S202至步骤S210:
步骤S202,获取训练样本的候选样本区域。
在一种实施方式中,训练样本的候选样本区域可以通过模型生成,诸如通过检测器输出得到训练样本的若干个候选样本区域,具体实施时,可以将训练样本输入至候选区域提取网络,得到候选样本区域。在另一种实施方式中,也可以根据携带有真值框的标注信息确定训练样本的候选样本区域,真值框也即用于显示训练样本中目标图像的真实位置的框,通过携带有真值框的标注信息可以确定训练样本中的待提取特征的位置,从而确定训练样本中所需提取的候选样本区域。每个训练样本的候选样本区域可以为一个,也可以为多个。在知识蒸馏的教师-学生框架中,候选样本区域包括学生模型的待蒸馏特征,该待蒸馏特征也即学生模型要和教师模型学习的特征。
步骤S204,分别通过学生模型和教师模型对训练样本的候选样本区域进行特征提取,得到学生模型提取出的第一特征和教师模型提取出的第二特征。
教师模型可以为复杂模型或者组合模型,与教师模型相比,学生模型大多为相对简单的模型,学生模型和教师模型均为物体检测模型,可以为单阶段(one-stage)检测器,诸如SSD(Single Shot MultiBox Detector,单阶段多框检测器)、YOLO(You Only LookOnce:Unified,Real-Time Object Detection,基于单个神经网络的目标检测***)等,也可以两阶段(two-stage)检测器,诸如CNN(Convolutional Neural Network,卷积神经网络)、Fast-RCNN(Faster Region-based Convolutional Neural Network,超快速神经网络)等。通过学生模型对上述训练样本的候选样本区域进行特征提取,可以得到学生模型提取的第一特征,通过教师模型对上述训练样本的候选样本区域进行特征提取,可以得到教师模型提取的第二特征。
步骤S206,获取第一特征的置信度。
第一特征的置信度可以理解为第一特征的可信程度,诸如,如果训练样本是脏样本或过难样本,则学生模型针对此类训练样本提取出的第一特征的可信程度不高,也即置信度不高。在一种实施方式中,通过求取第一特征对应的方差的方式来获取置信度。诸如,将学生模型提取的第一特征输入至方差生成网络中,得到方差生成网络输出的第一特征的方差,该方差可以表征第一特征的置信度,其中,该方差生成网络可以包括卷积层和/或全连接层,也即通过两个卷积层,或者,两个全连接层,或者,一个卷积层和一个全连接层实现。通过该方差生成网络得到的方差与置信度呈负相关,也即,方差越大,置信度越小。可以理解为:在蒸馏之前,学生模型可以根据有方差生成网络得到的方差结果确定知识迁移的强度,进而可以选择性的进行知识蒸馏,诸如,可以对方差较小(也即置信度较大)的高质量样本增加训练,从而提升模型性能;由于脏样本或过难样本的方差较大(置信度较小),因此可以相应的对该类样本减少训练,降低此类样本对学生模型的不良影响。具体实施时,可以通过设置蒸馏损失函数的方式来增加训练或减少训练。
步骤S208,根据第一特征、第二特征和第一特征的置信度确定学生模型和教师模型之间的蒸馏损失。
在一种实施方式中,可以按照如下公式确定学生模型和教师模型之间的蒸馏损失:
Figure BDA0002451931770000081
其中,d为特征维度,N为样本数量;
Figure BDA0002451931770000082
为第一特征;
Figure BDA0002451931770000083
为第二特征;
Figure BDA0002451931770000084
为方差。上述学生模型提取的第一特征、教师模型提取的第二特征以及第一特征的方差的维度均保持一致。通过上述蒸馏损失的公式可知,样本的方差与损失函数呈反比,也即样本方差越大,则损失函数越小,损失函数小就意味着反向传播调整的学生模型的参数不大。诸如,由于提取出的脏样本和过难样本的特征方差较大,根据方差与置信度呈反比的关系可以确定该特征的置信度较小,损失函数也较小,从而使得脏样本/过难样本对学生模型的参数影响不大。
步骤S210,基于蒸馏损失更新学生模型的参数。
模型蒸馏学习的损失函数可以分为两个部分,学生模型和教师模型之间的蒸馏损失(也可以称为自适应迁移损失),以及学生模型执行物体检测任务的任务损失,参见如下公式:
L=Ltask+Ldistill
其中,L为模型蒸馏学习中的整个损失函数;Ltask为与任务相关的损失函数。
根据任务损失Ltask和蒸馏损失Ldistill确定模型蒸馏过程中的整个损失函数L,并根据L更新学生模型的参数,也即对学生模型进行训练,直至损失函数收敛或者达到预设的停止训练条件,得到训练好的学生模型。
本发明实施例提供的上述学生模型的训练方法,通过在训练前获取学生模型提取的第一特征的置信度,可以确定样本质量对学生模型的参数影响,从而可以准确调整训练时不同训练样本的知识迁移强度,降低了由于脏样本或过难样本导致的训练误差,改善了学生模型与教师模型在物体检测上的蒸馏效果,从而提升了物体检测模型的性能。
在一种实施方式中,通过上述学生模型的训练方法将学生模型进行训练后,可以将待检测图像输入该训练后的学生模型,以便基于训练后的学生模型对待检测图像进行物体检测,得到物体检测结果。需要注意的是,方差生成分支可以只在训练过程中用到,训练结束之后,在实际测试时学生模型不再需要经过方差生成分支,直接使用训练好的模型参数即可。另外,在整个过程中,教师模型的参数无需进行修改或调整。
综上所述,通过在训练前获取学生模型针对训练样本提取的第一特征的置信度,可以确定样本质量对学生模型的影响,蒸馏损失的确定也与学生模型针对训练样本提取的特征的置信度相关,从而基于蒸馏损失对学生模型进行训练的过程中,不同样本的特征的置信度不同,蒸馏损失不同,相应的模型参数更新效果也不同,也即在学生模型训练过程中针对不同训练样本的知识迁移强度不同,学生模型可针对不同样本进行自适应知识迁移,这种方式使得学生模型能够针对不同样本进行不同程度的参数更新,使训练好的学生模型具有更优秀的性能。
实施例三:
在前述实施例的基础上,本实施例提供了一种应用前述学生模型的训练方法的具体示例,参见如图3所示的另一种学生模型的训练方法的流程图,该方法可以应用于如图4所示的知识蒸馏框架中,该方法主要包括如下步骤S302至步骤S308:
步骤S302,获取样本图像的候选样本区域。
在一种实施方式中,样本图像的候选样本区域(图中未示出)可以通过检测器输出,检测器可以采用单阶段检测器,也可以采用两阶段检测器。诸如,在基于两阶段检测器Faster R-CNN的蒸馏框架中,候选样本区域为两阶段检测器第一次粗略检测得到的候选区域,也即将样本图像输入该框架中的RPN(Region Proposal Network,区域生成网络)后得到候选样本区域。
在另一种实施方式中,样本候选区域也可以根据预先进行标注的标注信息获取,诸如,在基于单阶段检测器RetinaNet的蒸馏框架中,可以通过标注信息中的真值框生成若干蒸馏区域(也即候选样本区域),通过将蒸馏区域内的像素的权重设为1,除蒸馏区域以外的区域的像素权重设为0(权重为0表示实际不参与损失函数计算),进行获取样本图像的候选样本区域。
步骤S304,对于每个样本候选区域,分别得到教师模型的特征以及学生模型的特征和学生模型的特征的方差。
对于每个候选样本区域,分别输入教师模型和学生模型,输出通过教师模型提取的特征(也即前述实施例提及的第二特征)以及通过学生模型提取得到的特征(也即前述实施例提及的第一特征)。另外,对于样本候选区域,可以输入预先建立的方差生成网络中,从而通过该方差生成网络输出确定学生模型的方差,并根据该方差确定学生模型的第一特征的置信度,该置信度与方差呈负相关。方差生成网络可以通过两个卷积层和/或全连接层实现,图3中方差生成网络以两个全连接层作示例。
可选的,也可以将方差与置信度的转换过程也归于方差生成网络(也可以称为方差生成分支)中,也即通过方差生成分支生成方差以及将方差转换为置信度,此时该分支也可以称为置信度生成分支。
在一种实施方式中,当蒸馏框架为上述基于两阶段检测器Faster R-CNN的蒸馏框架时,将学生模型(也可以称为学生网络)提取的特征输入该框架中的RPN(RegionProposal Network)后得到样本候选区域,并将该样本候选区域输入上述方差生成分支或置信度生成分支,得到置信度。
在另一种实施方式中,当蒸馏框架为上述基于单阶段检测器RetinaNet的蒸馏框架时,提取上述权重为1的蒸馏区域的特征,并将其输入至方差生成分支或置信度生成分支,作为该蒸馏区域内每个像素的置信度。
步骤S306,根据学生模型的特征、学生模型的特征的方差以及教师模型的特征确定自适应迁移损失。
在一种实施方式中,自适应迁移损失(也即上述实施例的蒸馏损失)可以按照如下公式进行确定:
Figure BDA0002451931770000111
其中,d为特征维度,N为样本数量;
Figure BDA0002451931770000121
为学生模型的特征;
Figure BDA0002451931770000122
为教师模型的特征;
Figure BDA0002451931770000123
为学生模型的特征的方差,
Figure BDA0002451931770000124
Figure BDA0002451931770000125
输出的维度d保持一致。
步骤S308,确定任务损失,并基于自适应迁移损失和任务损失更新学生模型的参数。
任务损失为模型在进行物体检测任务时的损失,在模型进行蒸馏学习时的损失函数通常包括任务损失和自适应迁移损失,也即:
L=Ltask+Ldistill
其中,L为模型蒸馏学习中的整个损失函数;Ltask为任务损失;Ldistill为自适应迁移损失。
根据确定的模型蒸馏学习中的整个损失函数L,更新学生模型的参数,教师模型的参数保持不变,直至收敛。在训练结束后,对待检测图片进行测试时,直接输入待检测图片,无需经过蒸馏分支,直接按照未加蒸馏的检测器进行检测即可。
本实施例提供的学生模型的训练方法,通过在检测蒸馏中样本级的置信度获得方法(通过生成样本的若干候选区域,并在候选区域上分别提取特征并据此特征生成置信度),学生模型提取的不同样本的特征的置信度不同,蒸馏损失不同,相应的模型参数更新效果也不同,也即在学生模型训练过程中针对不同训练样本的知识迁移强度不同,学生模型训练过程中可以根据生成的置信度来调节检测任务中知识迁移的强度,从而获得更好的知识迁移效果。另外,该方法在测试过程完全不增加计算量,因此可以保证在不降低计算效率的情况下提升知识迁移的效果。
实施例四:
对于实施例二中所提供的学生模型的训练方法,本发明实施例提供了一种学生模型的训练装置,参见图5所示的一种学生模型的训练装置的结构框图,该装置包括以下模块:
获取模块502,用于获取训练样本的候选样本区域;
特征提取模块504,用于分别通过学生模型和教师模型对训练样本的候选样本区域进行特征提取,得到学生模型提取出的第一特征和教师模型提取出的第二特征;
置信度获取模块506,用于获取第一特征的置信度;
蒸馏损失确定模块508,用于根据第一特征、第二特征和第一特征的置信度确定学生模型和教师模型之间的蒸馏损失;
参数更新模块510,用于基于蒸馏损失更新学生模型的参数。
本发明实施例提供的上述学生模型的训练装置,通过在训练前获取学生模型针对训练样本提取的第一特征的置信度,可以确定样本质量对学生模型的影响,蒸馏损失的确定也与学生模型针对训练样本提取的特征的置信度相关,从而基于蒸馏损失对学生模型进行训练,学生模型提取的不同样本的特征的置信度不同,蒸馏损失不同,相应的模型参数更新效果也不同,也即在学生模型训练过程中针对不同训练样本的知识迁移强度不同,学生模型可针对不同样本进行自适应知识迁移,这种方式使得学生模型能够针对不同样本进行不同程度的参数更新,使训练好的学生模型具有更优秀的性能。
在一种实施方式中,上述置信度获取模块506,进一步用于将第一特征输入至方差生成网络中,得到方差生成网络输出的第一特征的方差,通过方差表征第一特征的置信度;其中,方差生成网络包括卷积层和/或全连接层,且方差与置信度呈负相关。
在一种实施方式中,上述蒸馏损失确定模块508,进一步用于按照如下公式确定学生模型和教师模型之间的蒸馏损失:
Figure BDA0002451931770000141
其中,d为特征维度,N为样本数量;
Figure BDA0002451931770000142
为第一特征;
Figure BDA0002451931770000143
为第二特征;
Figure BDA0002451931770000144
为方差。
在一种实施方式中,上述参数更新模块510,进一步用于获取学生模型执行物体检测任务的任务损失;根据任务损失和蒸馏损失更新学生模型的参数。
在一种实施方式中,上述获取模块502,进一步用于将训练样本输入至候选区域提取网络,得到候选样本区域。
在一种实施方式中,上述获取模块502,进一步用于根据携带有真值框的标注信息确定训练样本的候选样本区域。
在一种实施方式中,上述装置还包括:检测模块,用于将待检测图像输入训练后的学生模型,基于训练后的学生模型对待检测图像进行物体检测,得到物体检测结果。
本实施例所提供的装置,其实现原理及产生的技术效果和前述实施例相同,为简要描述,装置实施例部分未提及之处,可参考前述方法实施例中相应内容。
综上所述,本发明实施例提供的学生模型的训练方法、装置及电子设备,通过在训练前获取学生模型针对训练样本提取的第一特征的置信度,可以确定样本质量对学生模型的影响,学生模型提取的不同样本的特征的置信度不同,蒸馏损失不同,相应的模型参数更新效果也不同,也即在学生模型训练过程中针对不同训练样本的知识迁移强度不同,学生模型可针对不同样本进行自适应知识迁移,这种方式使得学生模型能够针对不同样本进行不同程度的参数更新,使训练好的学生模型具有更优秀的性能。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的***具体工作过程,可以参考前述实施例中的对应过程,在此不再赘述。
本发明实施例所提供的学生模型的训练方法、装置及电子设备的计算机程序产品,包括存储了程序代码的计算机可读存储介质,所述程序代码包括的指令可用于执行前面方法实施例中所述的方法,具体实现可参见方法实施例,在此不再赘述。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
在本发明的描述中,需要说明的是,术语“中心”、“上”、“下”、“左”、“右”、“竖直”、“水平”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。此外,术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性。
最后应说明的是:以上所述实施例,仅为本发明的具体实施方式,用以说明本发明的技术方案,而非对其限制,本发明的保护范围并不局限于此,尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,其依然可以对前述实施例所记载的技术方案进行修改或可轻易想到变化,或者对其中部分技术特征进行等同替换;而这些修改、变化或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的精神和范围,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以所述权利要求的保护范围为准。

Claims (10)

1.一种学生模型的训练方法,其特征在于,所述学生模型通过知识蒸馏方式向已训练好的教师模型学习,所述学生模型和所述教师模型均为物体检测模型,所述方法包括:
获取训练样本的候选样本区域;
分别通过所述学生模型和所述教师模型对所述训练样本的候选样本区域进行特征提取,得到所述学生模型提取出的第一特征和所述教师模型提取出的第二特征;
获取所述第一特征的置信度;
根据所述第一特征、所述第二特征和所述第一特征的置信度确定所述学生模型和所述教师模型之间的蒸馏损失;
基于所述蒸馏损失更新所述学生模型的参数。
2.根据权利要求1所述的方法,其特征在于,获取所述第一特征的置信度的步骤,包括:
将所述第一特征输入至方差生成网络中,得到所述方差生成网络输出的所述第一特征的方差,通过所述方差表征所述第一特征的置信度;其中,所述方差生成网络包括卷积层和/或全连接层,且所述方差与所述置信度呈负相关。
3.根据权利要求2所述的方法,其特征在于,根据所述第一特征、所述第二特征和所述第一特征的置信度确定所述学生模型和所述教师模型之间的蒸馏损失的步骤,包括:
按照如下公式确定所述学生模型和所述教师模型之间的蒸馏损失:
Figure FDA0002451931760000011
其中,d为特征维度,N为样本数量;
Figure FDA0002451931760000021
为所述第一特征;
Figure FDA0002451931760000022
为所述第二特征;
Figure FDA0002451931760000023
为所述方差。
4.根据权利要求1所述的方法,其特征在于,基于所述蒸馏损失更新所述学生模型的参数的步骤,包括:
获取所述学生模型执行物体检测任务的任务损失;
根据所述任务损失和所述蒸馏损失更新所述学生模型的参数。
5.根据权利要求1所述的方法,其特征在于,所述获取训练样本的候选样本区域的步骤,包括:
将所述训练样本输入至候选区域提取网络,得到候选样本区域。
6.根据权利要求1所述的方法,其特征在于,所述获取训练样本的候选样本区域的步骤,包括:
根据携带有真值框的标注信息确定训练样本的候选样本区域。
7.根据权利要求1至5任一项所述的方法,其特征在于,所述方法还包括:
将待检测图像输入训练后的学生模型,基于所述训练后的学生模型对所述待检测图像进行物体检测,得到物体检测结果。
8.一种学生模型的训练装置,其特征在于,所述学生模型通过知识蒸馏方式向已训练好的教师模型学习,所述学生模型和所述教师模型均为物体检测模型,所述装置包括:
获取模块,用于获取训练样本的候选样本区域;
特征提取模块,用于分别通过所述学生模型和所述教师模型对所述训练样本的候选样本区域进行特征提取,得到所述学生模型提取出的第一特征和所述教师模型提取出的第二特征;
置信度获取模块,用于获取所述第一特征的置信度;
蒸馏损失确定模块,用于根据所述第一特征、所述第二特征和所述第一特征的置信度确定所述学生模型和所述教师模型之间的蒸馏损失;
参数更新模块,用于基于所述蒸馏损失更新所述学生模型的参数。
9.一种电子设备,其特征在于,包括:处理器和存储装置;
所述存储装置上存储有计算机程序,所述计算机程序在被所述处理器运行时执行如权利要求1至7任一项所述的方法。
10.一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,其特征在于,所述计算机程序被处理器运行时执行上述权利要求1至7任一项所述的方法的步骤。
CN202010297966.1A 2020-04-15 2020-04-15 学生模型的训练方法、装置及电子设备 Active CN111639744B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010297966.1A CN111639744B (zh) 2020-04-15 2020-04-15 学生模型的训练方法、装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010297966.1A CN111639744B (zh) 2020-04-15 2020-04-15 学生模型的训练方法、装置及电子设备

Publications (2)

Publication Number Publication Date
CN111639744A true CN111639744A (zh) 2020-09-08
CN111639744B CN111639744B (zh) 2023-09-22

Family

ID=72331330

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010297966.1A Active CN111639744B (zh) 2020-04-15 2020-04-15 学生模型的训练方法、装置及电子设备

Country Status (1)

Country Link
CN (1) CN111639744B (zh)

Cited By (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112965703A (zh) * 2021-03-02 2021-06-15 华南师范大学 克服多头领导的教师主导人工智能教育机器人
CN113361396A (zh) * 2021-06-04 2021-09-07 思必驰科技股份有限公司 多模态的知识蒸馏方法及***
CN113658173A (zh) * 2021-08-31 2021-11-16 北京华文众合科技有限公司 基于知识蒸馏的检测模型的压缩方法、***和计算设备
CN113762051A (zh) * 2021-05-13 2021-12-07 腾讯科技(深圳)有限公司 模型训练方法、图像检测方法、装置、存储介质及设备
CN113822125A (zh) * 2021-06-24 2021-12-21 华南理工大学 唇语识别模型的处理方法、装置、计算机设备和存储介质
CN113850012A (zh) * 2021-06-11 2021-12-28 腾讯科技(深圳)有限公司 数据处理模型生成方法、装置、介质及电子设备
CN114330510A (zh) * 2021-12-06 2022-04-12 北京大学 模型训练方法、装置、电子设备和存储介质
CN114492793A (zh) * 2022-01-27 2022-05-13 北京百度网讯科技有限公司 一种模型训练和样本生成方法、装置、设备及存储介质
CN114677565A (zh) * 2022-04-08 2022-06-28 北京百度网讯科技有限公司 特征提取网络的训练方法和图像处理方法、装置
CN114882324A (zh) * 2022-07-11 2022-08-09 浙江大华技术股份有限公司 目标检测模型训练方法、设备及计算机可读存储介质
CN115082920A (zh) * 2022-08-16 2022-09-20 北京百度网讯科技有限公司 深度学习模型的训练方法、图像处理方法和装置
WO2022257614A1 (zh) * 2021-06-10 2022-12-15 北京百度网讯科技有限公司 物体检测模型的训练方法、图像检测方法及其装置
WO2023030523A1 (zh) * 2021-09-06 2023-03-09 北京字节跳动网络技术有限公司 用于内窥镜的组织腔体定位方法、装置、介质及设备
CN116309245A (zh) * 2022-09-07 2023-06-23 南京江源测绘有限公司 基于深度学习的地下排水管道缺陷智能检测方法与***

Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018169708A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
CN108664893A (zh) * 2018-04-03 2018-10-16 福州海景科技开发有限公司 一种人脸检测方法及存储介质
CN108805185A (zh) * 2018-05-29 2018-11-13 腾讯科技(深圳)有限公司 模型的训练方法、装置、存储介质及计算机设备
CN108830813A (zh) * 2018-06-12 2018-11-16 福建帝视信息科技有限公司 一种基于知识蒸馏的图像超分辨率增强方法
US20190287515A1 (en) * 2018-03-16 2019-09-19 Microsoft Technology Licensing, Llc Adversarial Teacher-Student Learning for Unsupervised Domain Adaptation
US20190385086A1 (en) * 2018-06-13 2019-12-19 Fujitsu Limited Method of knowledge transferring, information processing apparatus and storage medium
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN110738309A (zh) * 2019-09-27 2020-01-31 华中科技大学 Ddnn的训练方法和基于ddnn的多视角目标识别方法和***
CN110837761A (zh) * 2018-08-17 2020-02-25 北京市商汤科技开发有限公司 多模型知识蒸馏方法及装置、电子设备和存储介质
CN110880036A (zh) * 2019-11-20 2020-03-13 腾讯科技(深圳)有限公司 神经网络压缩方法、装置、计算机设备及存储介质
CN110956615A (zh) * 2019-11-15 2020-04-03 北京金山云网络技术有限公司 图像质量评估模型训练方法、装置、电子设备及存储介质

Patent Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2018169708A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US20190287515A1 (en) * 2018-03-16 2019-09-19 Microsoft Technology Licensing, Llc Adversarial Teacher-Student Learning for Unsupervised Domain Adaptation
CN108664893A (zh) * 2018-04-03 2018-10-16 福州海景科技开发有限公司 一种人脸检测方法及存储介质
CN108805185A (zh) * 2018-05-29 2018-11-13 腾讯科技(深圳)有限公司 模型的训练方法、装置、存储介质及计算机设备
CN108830813A (zh) * 2018-06-12 2018-11-16 福建帝视信息科技有限公司 一种基于知识蒸馏的图像超分辨率增强方法
US20190385086A1 (en) * 2018-06-13 2019-12-19 Fujitsu Limited Method of knowledge transferring, information processing apparatus and storage medium
CN110837761A (zh) * 2018-08-17 2020-02-25 北京市商汤科技开发有限公司 多模型知识蒸馏方法及装置、电子设备和存储介质
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN110738309A (zh) * 2019-09-27 2020-01-31 华中科技大学 Ddnn的训练方法和基于ddnn的多视角目标识别方法和***
CN110956615A (zh) * 2019-11-15 2020-04-03 北京金山云网络技术有限公司 图像质量评估模型训练方法、装置、电子设备及存储介质
CN110880036A (zh) * 2019-11-20 2020-03-13 腾讯科技(深圳)有限公司 神经网络压缩方法、装置、计算机设备及存储介质

Cited By (21)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112965703A (zh) * 2021-03-02 2021-06-15 华南师范大学 克服多头领导的教师主导人工智能教育机器人
CN112965703B (zh) * 2021-03-02 2022-04-01 华南师范大学 克服多头领导的教师主导人工智能教育机器人
CN113762051A (zh) * 2021-05-13 2021-12-07 腾讯科技(深圳)有限公司 模型训练方法、图像检测方法、装置、存储介质及设备
CN113762051B (zh) * 2021-05-13 2024-05-28 腾讯科技(深圳)有限公司 模型训练方法、图像检测方法、装置、存储介质及设备
CN113361396A (zh) * 2021-06-04 2021-09-07 思必驰科技股份有限公司 多模态的知识蒸馏方法及***
CN113361396B (zh) * 2021-06-04 2023-12-26 思必驰科技股份有限公司 多模态的知识蒸馏方法及***
WO2022257614A1 (zh) * 2021-06-10 2022-12-15 北京百度网讯科技有限公司 物体检测模型的训练方法、图像检测方法及其装置
CN113850012A (zh) * 2021-06-11 2021-12-28 腾讯科技(深圳)有限公司 数据处理模型生成方法、装置、介质及电子设备
CN113850012B (zh) * 2021-06-11 2024-05-07 腾讯科技(深圳)有限公司 数据处理模型生成方法、装置、介质及电子设备
CN113822125A (zh) * 2021-06-24 2021-12-21 华南理工大学 唇语识别模型的处理方法、装置、计算机设备和存储介质
CN113822125B (zh) * 2021-06-24 2024-04-30 华南理工大学 唇语识别模型的处理方法、装置、计算机设备和存储介质
CN113658173A (zh) * 2021-08-31 2021-11-16 北京华文众合科技有限公司 基于知识蒸馏的检测模型的压缩方法、***和计算设备
WO2023030523A1 (zh) * 2021-09-06 2023-03-09 北京字节跳动网络技术有限公司 用于内窥镜的组织腔体定位方法、装置、介质及设备
CN114330510A (zh) * 2021-12-06 2022-04-12 北京大学 模型训练方法、装置、电子设备和存储介质
CN114492793A (zh) * 2022-01-27 2022-05-13 北京百度网讯科技有限公司 一种模型训练和样本生成方法、装置、设备及存储介质
CN114677565A (zh) * 2022-04-08 2022-06-28 北京百度网讯科技有限公司 特征提取网络的训练方法和图像处理方法、装置
CN114882324A (zh) * 2022-07-11 2022-08-09 浙江大华技术股份有限公司 目标检测模型训练方法、设备及计算机可读存储介质
CN115082920B (zh) * 2022-08-16 2022-11-04 北京百度网讯科技有限公司 深度学习模型的训练方法、图像处理方法和装置
CN115082920A (zh) * 2022-08-16 2022-09-20 北京百度网讯科技有限公司 深度学习模型的训练方法、图像处理方法和装置
CN116309245B (zh) * 2022-09-07 2024-01-19 南京唐壹信息科技有限公司 基于深度学习的地下排水管道缺陷智能检测方法与***
CN116309245A (zh) * 2022-09-07 2023-06-23 南京江源测绘有限公司 基于深度学习的地下排水管道缺陷智能检测方法与***

Also Published As

Publication number Publication date
CN111639744B (zh) 2023-09-22

Similar Documents

Publication Publication Date Title
CN111639744A (zh) 学生模型的训练方法、装置及电子设备
US20220058426A1 (en) Object recognition method and apparatus, electronic device, and readable storage medium
CN111754596B (zh) 编辑模型生成、人脸图像编辑方法、装置、设备及介质
CN108664893B (zh) 一种人脸检测方法及存储介质
CN109583501B (zh) 图片分类、分类识别模型的生成方法、装置、设备及介质
CN111741330B (zh) 一种视频内容评估方法、装置、存储介质及计算机设备
WO2023040510A1 (zh) 图像异常检测模型训练方法、图像异常检测方法和装置
CN111709409A (zh) 人脸活体检测方法、装置、设备及介质
CN109359539B (zh) 注意力评估方法、装置、终端设备及计算机可读存储介质
CN111242222B (zh) 分类模型的训练方法、图像处理方法及装置
CN112950581A (zh) 质量评估方法、装置和电子设备
CN111695421B (zh) 图像识别方法、装置及电子设备
WO2021090771A1 (en) Method, apparatus and system for training a neural network, and storage medium storing instructions
CN114511041B (zh) 模型训练方法、图像处理方法、装置、设备和存储介质
JP2019152964A (ja) 学習方法および学習装置
CN110413551B (zh) 信息处理装置、方法及设备
CN112070040A (zh) 一种用于视频字幕的文本行检测方法
CN110956131A (zh) 单目标追踪方法、装置及***
CN111639667B (zh) 图像识别方法、装置、电子设备及计算机可读存储介质
TWI803243B (zh) 圖像擴增方法、電腦設備及儲存介質
CN116343007A (zh) 目标检测方法、装置、设备和存储介质
US11688175B2 (en) Methods and systems for the automated quality assurance of annotated images
CN116977271A (zh) 缺陷检测方法、模型训练方法、装置及电子设备
CN113947771B (zh) 图像识别方法、装置、设备、存储介质以及程序产品
CN115760908A (zh) 基于胶囊网络感知特征的绝缘子跟踪方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant