CN113887698A - 基于图神经网络的整体知识蒸馏方法和*** - Google Patents
基于图神经网络的整体知识蒸馏方法和*** Download PDFInfo
- Publication number
- CN113887698A CN113887698A CN202110982472.1A CN202110982472A CN113887698A CN 113887698 A CN113887698 A CN 113887698A CN 202110982472 A CN202110982472 A CN 202110982472A CN 113887698 A CN113887698 A CN 113887698A
- Authority
- CN
- China
- Prior art keywords
- graph
- knowledge
- model
- teacher
- student
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 83
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 31
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 27
- 238000012549 training Methods 0.000 claims abstract description 24
- 230000002776 aggregation Effects 0.000 claims abstract 2
- 238000004220 aggregation Methods 0.000 claims abstract 2
- 230000006870 function Effects 0.000 claims description 20
- 230000008569 process Effects 0.000 claims description 15
- 238000010276 construction Methods 0.000 claims description 13
- 238000004821 distillation Methods 0.000 claims description 13
- 238000012512 characterization method Methods 0.000 claims description 12
- 238000005516 engineering process Methods 0.000 claims description 10
- 239000011159 matrix material Substances 0.000 claims description 8
- 238000011524 similarity measure Methods 0.000 claims description 7
- 238000005259 measurement Methods 0.000 claims description 6
- 230000003044 adaptive effect Effects 0.000 claims description 5
- 238000004364 calculation method Methods 0.000 claims description 5
- 238000011160 research Methods 0.000 claims description 5
- 238000013507 mapping Methods 0.000 claims description 4
- 238000000354 decomposition reaction Methods 0.000 claims description 3
- 230000004044 response Effects 0.000 claims description 3
- 230000005055 memory storage Effects 0.000 abstract description 2
- 238000002474 experimental method Methods 0.000 description 4
- 230000004931 aggregating effect Effects 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 238000000605 extraction Methods 0.000 description 3
- 238000007670 refining Methods 0.000 description 3
- 238000011156 evaluation Methods 0.000 description 2
- 239000000284 extract Substances 0.000 description 2
- 238000000691 measurement method Methods 0.000 description 2
- 235000009499 Vanilla fragrans Nutrition 0.000 description 1
- 244000263375 Vanilla tahitensis Species 0.000 description 1
- 235000012036 Vanilla tahitensis Nutrition 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 230000000052 comparative effect Effects 0.000 description 1
- 238000012733 comparative method Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000002349 favourable effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/02—Knowledge representation; Symbolic representation
- G06N5/022—Knowledge engineering; Knowledge acquisition
- G06N5/025—Extracting rules from data
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Biophysics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
本发明的目的是提供一种基于图神经网络的整体知识蒸馏方法,包括:给定老师和学生网络学习到的特征表示和分类预测的结果,以每个样本为节点,网络学习到的特征为节点的属性,分类预测结果的K近邻(KNN)关系为边,为每个网络构建一个属性图;使用拓扑结构自适应的图卷积神经网络聚合属性图中邻域样本的节点属性以及拓扑信息来提取整体性知识,表示为统一的基于图的嵌入向量;使用infoNCE估计最大化学生网络与老师网络的图嵌入表示的互信息,并使用特征记忆存储技术加速训练效率。该方法:可以同时整合老师网络中个体上的知识和关系上的知识,使学生网络学习到整体性的知识,从而提升学生网络的性能。
Description
技术领域
本发明涉及深度学习与计算机视觉领域,尤其涉及一种基于图神经网络的整体知识蒸馏方法和***。
背景技术
深度神经网络(DNN)在各种应用中取得了巨大成功。然而,它们的成功在很大程度上依赖于大量的计算和存储资源,而这些资源在嵌入式设备和移动设备中通常是不可用的。为了降低成本,同时保持令人满意的效果,一些模型压缩的技术开始成为热门的研究话题。而知识蒸馏技术正是其中一种方法,该技术可以将知识从统计较大的训练有素的老师网络迁移到体积较小的待学习的学生网络,从而提升学生网络的效果,同时保持学生网络体积小、运算快的特点。
从老师网络中提取的知识在知识蒸馏中起着核心作用。在现有的知识蒸馏方法中,按照提取知识的类型可以分为两类,即个体上的知识蒸馏和关系上的知识蒸馏。个体上的知识蒸馏是指利用老师网络独立地从每个数据实例中提取个体知识,并提供比离散标签更有利的监督,包括概率表示(logits)、特征表示和特征映射等。关系上的知识蒸馏是从成对样本的关系中提取的,通过对学生网络的训练,使这些关系在不同网络架构的学生网络与老师网络之间得到保持。
尽管上述两种知识蒸馏的方法都取得了成功,现有方法均独立地使用两类技术,忽略了其内在的相关性。尤其是当老师网络的能力有限时,独立提取一种类型的知识不足以用于学生网络的学习。直观地说,个体上的知识和关系上的知识可以看作是同一老师网络的自然相关的两个视图。两个相似的实例往往具有相似的个体特征以相似的关系模式,发掘这些知识对训练更具辨别力的学生网络学习至关重要。同时整合个体上的知识和关系上的相关知识,同时保留两者的内在统一性,对于知识蒸馏至关重要。
发明内容
本发明要克服现有技术的上述缺点,提供一种基于图神经网络的整体知识蒸馏方法(the Holistic Knowledge Distillation(HKD)method)和***,可以同时整合个体上的知识和关系上的知识。
本发明的目的是通过以下技术方案实现的:
一种基于图神经网络的整体知识蒸馏方法,包括:
步骤1、分别构建老师模型与学生模型的属性图。将图像输入老师模型和学生模型得到特征表示ft,fs(其中 dt和ds分别是老师模型和学生模型输出的特征表示的维度)以及分类预测的结果pt,ps,然后为老师模型及学生模型分别构建一个属性图Gt={At,Ft}、Gs={As,Fs},其中每个节点表示一个实例,节点属性表示学习到的特征表示,其中At,As是基于pt,ps构建的属性性图的邻接矩阵,并基于公式构造,其中是基于K近邻(Knearestneighbors,KNN)的图构造函数。。在整个训练过程中,Gt是固定不动的,Gs的属性及结构都是在动态变化的。
上述定义的属性图具有以下特性:首先,与现有关系知识分解方法构建的实例间完全连通图相比,KNN图将过滤掉最不相关的样本对。这一点尤其重要,因为在随机抽样的batch中只有少数样本是真正相关的,并且为节点表示的学习提供了足够的信息,本领域的研究人员对batch的概念应该是熟知的,这里不再赘述。其次,由于边是基于预测的概率构建的,因此该图能够对类间和类内信息进行建模。最后,可以利用图神经网络非常高效地从属性图中联合提取个体上的知识和关系上的知识。
步骤2、使用拓扑结构自适应的图卷积神经网络(Topology Adaptive GraphConvolution Network,TAGCN)聚合属性图中邻域样本的节点属性以及拓扑信息来提取整体知识,由此得到基于图的嵌入向量。
其中Θs 1和Θt 1是可学习的参数,gt,gs是上述图表示的维度。Dt,Ds是属性图度的对角阵,即公式(3)所示
步骤3、使用InfoNCE估计最大化学生模型与老师模型的图嵌入表示的互信息,并使用特征记忆化存储技术加速训练效率。
为了让学生模型尽可能地学习到老师模型整体性的知识,需要最大化Ht 与Hs的相似程度,现存的很多基于向量的(vector-wise)相似度度量的方法(如余弦相似度、欧氏距离)等是不适合整体性知识的蒸馏的,他们往往受限于老师模型与学生模型模型结构差异而导致的表征能力的差异,并且直接对齐Ht与Hs可能会导致学习的知识过于精炼。为了克服上述限制,使用互信息来度量学生模型从老师模型蒸馏信息的相似程度,即最大化Ht与Hs的互信息,如下述公式(4) 所示。
I(·)表示两个随机变量的互信息度量,受近期一些互信息估计方法方面的研究工作的启发,采用infoNCE来估计互信息,infoNCE与互信息的关系如公式(5) 如下:
其中f(·)是一个vector-wise的相似度度量函数,ht i,hs i是样本i在老师模型及学生模型中学习出来基于图的表征。
学生模型在学习老师模型的知识以外,还需要学习知识本身(如数据中的标签信息),交叉熵是常见分类任务中的损失函数,最终的损失函数如下公式(6)。
其中β是线性组合的权重。
因为infoNCE需要使用数据集中所有的样本作为负样本,对于规模较大的数据集来说,计算整体性的蒸馏损失的计算代价过于昂贵。为了避免在训练过程计算样本的表征,特征记忆化存储技术已经被广泛地应用。在实施例所述方法中Gt and Gs是在随机选择的mini-Batch中分别构建的,基于图的表征Ht和Hs反应的整体性知识则在不同的属性图中呈现,他们是不能使用特征记忆化存储技术来存储的,因此,只分别使用特征记忆化存储技术来存储Ft,Fs。最终逼近的整体性蒸馏损失被定义为公式(7)所示:
步骤3中最大化学生模型与老师模型的图嵌入表示的互信息的方法,使用infoNCE来最大化互信息,infoNCE与互信息的关系如公式(8)所示:
其中f(·)是一个vector-wise的相似度度量函数,ht i,hs i是样本i在老师模型及学生模型中学习出来基于图的表征。
实施本发明的一种基于图神经网络的整体知识蒸馏方法的***,包括依次连接的老师模型与学生模型的属性图构建模型、整体知识提取模块、学生模型与老师模型的图嵌入表示的互信息的最大化模块。
本发明的工作原理是:本发明的知识蒸馏方法,在老师模型的个体上的知识以及关系上的知识的基础上,进一步使用图神经网络提取来自老师模型的整体性知识,学生模型通过学习来自老师模型的整体性知识,模型性能提升较其他知识蒸馏方法更加明显。
本发明的优点是:学生模型不仅可以学习到来自老师模型的个体上的知识,以及关系上的知识,而且可以学习更加复杂的整体性的知识,从而使得学生模型性能提升较其他知识蒸馏方法更加明显。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1为本发明实施例提供的知识蒸馏方法分类的示意图;
图2为本发明实施例提供的一种基于图神经网络的整体知识蒸馏方法的示意图;
图3为本发明实施例提供的一种基于图神经网络的整体知识蒸馏方法的简要流程示意图。
具体实施方式
下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。
本发明实施例提供一种基于图神经网络的整体知识蒸馏方法,该技术可以同时蒸馏个体上的知识和关系上的知识。
本发明实施例提出知识蒸馏框架中,给定老师模型和学生模型学习到的特征表示和分类预测的结果,首先为每个模型构建一个属性图,其中每个节点表示一个实例,节点属性表示学习到的特征表示,实例之间的边,使用分类预测结果的 K近邻(KNN)关系进行构造。
使用拓扑结构自适应的图卷积神经网络(Topology Adaptive GraphConvolution Network(TAGCN))聚合属性图中邻域样本的节点属性来提取整体知识,表示为统一的基于图的嵌入向量。使用InfoNCE估计最大化老师模型与学生模型的图嵌入表示的互信息,并使用特征记忆存储技术加速训练效率。通过老师模型属性图中的复杂知识,学生模型的性能可以得到相比于目前的知识蒸馏方法更加明显地提升。
具体的,如图3所示,为本发明实施例提供的一种基于图神经网络的整体知识蒸馏方法的示意图,其主要包括如下步骤:
步骤1、分别构建老师模型与学生模型的属性图。将图像输入老师模型和学生模型得到特征表示ft,fs(其中 dt和ds分别是老师模型和学生模型输出的特征表示的维度)以及分类预测的结果pt,ps,然后为老师模型及学生模型分别构建一个属性图Gt={At,Ft}、Gs={As,Fs},其中每个节点表示一个实例,节点属性表示学习到的特征表示,其中 At,As是基于pt,ps构建的属性性图的邻接矩阵,并基于公式构造,其中是基于KNN 的图构造函数。在整个训练过程中,Gt是固定不动的,Gs的属性及结构都是在动态变化的。
上述定义的属性图具有以下特性:首先,与现有关系知识分解方法构建的实例间完全连通图相比,KNN图将过滤掉最不相关的样本对。这一点尤其重要,因为在随机抽样的batch中只有少数样本是真正相关的,并且为节点表示的学习提供了足够的信息,本领域的研究人员对batch的概念应该是熟知的,这里不再赘述。其次,由于边是基于预测的概率构建的,因此该图能够对类间和类内信息进行建模。最后,可以利用图神经网络非常高效地从属性图中联合提取个体上的知识和关系上的知识。
步骤2、使用拓扑结构自适应的图卷积神经网络(Topology Adaptive GraphConvolution Network,TAGCN)聚合属性图中邻域样本的节点属性以及拓扑信息来提取整体知识,由此得到基于图的嵌入向量。
其中Θs l和Θt l是可学习的参数,gt,gs是上述图表示的维度。Dt,Ds是属性图度的对角阵,即公式(3)所示
步骤3、使用InfoNCE估计最大化学生模型与老师模型的图嵌入表示的互信息,并使用特征记忆化存储技术加速训练效率。
为了让学生模型尽可能地学习到老师模型整体性的知识,需要最大化Ht 与Hs的相似程度,现存的很多基于向量的(vector-wise)相似度度量的方法(如余弦相似度、欧氏距离)等是不适合整体性知识的蒸馏的,他们往往受限于老师模型与学生模型模型结构差异而导致的表征能力的差异,并且直接对齐Ht与Hs可能会导致学习的知识过于精炼。为了克服上述限制,使用互信息来度量学生模型从老师模型蒸馏信息的相似程度,即最大化Ht与Hs的互信息,如下述公式(4) 所示。
I(·)表示两个随机变量的互信息度量,受近期一些互信息估计方法方面的研究工作的启发,采用infoNCE来估计互信息,infoNCE与互信息的关系如公式(5) 如下:
其中f(·)是一个vector-wise的相似度度量函数,ht i,hs i是样本i在老师模型及学生模型中学习出来基于图的表征。
学生模型在学习老师模型的知识以外,还需要学习知识本身(如数据中的标签信息),交叉熵是常见分类任务中的损失函数,最终的损失函数如下公式(6)。
其中β是线性组合的权重。
因为infoNCE需要使用数据集中所有的样本作为负样本,对于规模较大的数据集来说,计算整体性的蒸馏损失的计算代价过于昂贵。为了避免在训练过程计算样本的表征,特征记忆化存储技术已经被广泛地应用。在实施例所述方法中Gt and Gs是在随机选择的mini-Batch中分别构建的,基于图的表征Ht和Hs反应的整体性知识则在不同的属性图中呈现,他们是不能使用特征记忆化存储技术来存储的,因此,只分别使用特征记忆化存储技术来存储Ft,Fs。最终逼近的整体性蒸馏损失被定义为公式(7)所示:
步骤3中最大化学生模型与老师模型的图嵌入表示的互信息的方法,使用infoNCE来最大化互信息,infoNCE与互信息的关系如公式(8)所示:
其中f(·)是一个vector-wise的相似度度量函数,ht i,hs i是样本i在老师模型及学生模型中学习出来基于图的表征。
本发明实施例提供的上述方案,本知识蒸馏方法可以使得学生模型更加有效地学习到老师模型的整体性知识,较其他知识蒸馏方法而言,学生模型性能提升会更加明显。
实施本发明的一种基于图神经网络的整体知识蒸馏方法的***,包括依次连接的老师模型与学生模型的属性图构建模型、整体知识提取模块、学生模型与老师模型的图嵌入表示的互信息的最大化模块。老师模型与学生模型的属性图构建模型、整体知识提取模块、学生模型与老师模型的图嵌入表示的互信息的最大化模块分别对应上述步骤1、2、3的内容。
为了说明本发明实施例上述方案的效果,结合实验进行说明,实验在几个图像分类领域经典的数据集上展开。
一、实验数据集
实验涉及两个基准数据集,其相关描述如下表:
数据集 | 类别数量 | 训练集大小 | 测试集大小 | 图像尺寸 |
Tiny-ImageNet | 200 | 100000 | 10000 | 224*224*3 |
Cifar-100 | 20 | 50000 | 10000 | 32*32*3 |
二、模型结构
老师模型及学生模型共使用了四种结构,分别为ResNet,VGG,ShuffleNet,MobileNet,这些结构均为该领域研究人员熟悉的网络结构,这里不再赘述。
三、基线方法
为了体现该方法的优越性,本发明实施例对比了近期的知识蒸馏方法,这些方法可归纳为两类,其区别见本发明实施例提供的知识蒸馏分类的示意图。
具体而言,
第一类是个体上的知识蒸馏方法,包括学习logits的vanilla KD,学习attentionMap的AT,学习特征表示的PKT,CRD和SSKD。
第二类是关系上的知识蒸馏方法,这类方法学习成对的关系类型的知识,包括RKD,CCKD。
上述对比方法,均使用作者开源的代码进行复现,为了保持各种方法之间训练样本的一致性,我们去除了SSKD代码中的数据增强。
四、实验结论
采用不同的老师模型及学生模型训练该发明中实施例及各种对比方法,在CIFAR100
数据集上的表现如下(评估指标为准确率)。
在TinyImageNet数据集上的表现如下(评估指标为准确率)。
从在两个数据集上的对比实验可以发现,本方法(HKD)性能在不同的老师模型学生模型组合下均明显好于对比方法。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例可以通过软件实现,也可以借助软件加必要的通用硬件平台的方式来实现。基于这样的理解,上述实施例的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等) 中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
本发明可应用于图像分类。深度神经网络(DNN)在各种应用中取得了巨大成功。然而,它们的成功在很大程度上依赖于大量的计算和存储资源,而这些资源在嵌入式设备和移动设备中通常是不可用的,即在有限的计算资源条件下,一些对计算资源要求较高的大模型是无法进行部署的。为了在有限的计算资源条件下,使得可用的小模型保持令人满意的效果,一些知识蒸馏的方法(让学生模型同时学习老师模型的学到的知识以及知识本身的学习方法)被广泛研究,
下面介绍一种应用本发明方法的一种图像分类方法:
1.首先准备图像分类模型所需的数据集,该数据集的制作是该领域技术人员所熟识的过程,这里不做赘述。
2.使用交叉熵作为损失函数,训练一个模型体积较大,性能出色的老师模型。
老师模型训练完成后,将不会在后续的训练过程更新参数。下面展开描述学生模型训练方法,即本方法的核心内容。
3学生模型的训练过程;
3.1、分别构建老师模型与学生模型的属性图。将图像输入老师模型和学生模型得到特征表示ft,fs(其中 dt和ds分别是老师模型和学生模型输出的特征表示的维度)以及分类预测的结果pt,ps,然后为老师模型及学生模型分别构建一个属性图Gt={At,Ft}、Gs={As,Fs},其中每个节点表示一个实例,节点属性表示学习到的特征表示,其中 At,As是基于pt,ps构建的属性性图的邻接矩阵,并基于公式构造,其中是基于KNN的图构造函数。在整个训练过程中,Gt是固定不动的,Gs的属性及结构都是在动态变化的。
3.2、使用拓扑结构自适应的图卷积神经网络(Topology Adaptive GraphConvolution Network,TAGCN)聚合属性图中邻域样本的节点属性以及拓扑信息来提取整体知识,由此得到基于图的嵌入向量。
其中Θs l和Θt l是可学习的参数,gt,gs是上述图表示的维度。Dt,Ds是属性图度的对角阵,即公式(3)所示
3.3、使用InfoNCE估计最大化学生模型与老师模型的图嵌入表示的互信息,并使用特征记忆化存储技术加速训练效率。
为了让学生模型尽可能地学习到老师模型整体性的知识,需要最大化Ht与Hs的相似程度,现存的很多基于向量的(vector-wise)相似度度量的方法(如余弦相似度、欧氏距离)等是不适合整体性知识的蒸馏的,他们往往受限于老师模型与学生模型模型结构差异而导致的表征能力的差异,并且直接对齐Ht与Hs可能会导致学习的知识过于精炼。为了克服上述限制,使用互信息来度量学生模型从老师模型蒸馏信息的相似程度,即最大化Ht与Hs的互信息,如下述公式(4)所示。
I(·)表示两个随机变量的互信息度量,受近期一些互信息估计方法方面的研究工作的启发,采用infoNCE来估计互信息,infoNCE与互信息的关系如公式(5) 如下:
其中f(·)是一个vector-wise的相似度度量函数,ht i,hs i是样本i在老师模型及学生模型中学习出来基于图的表征。
学生模型在学习老师模型的知识以外,还需要学习知识本身(如数据中的标签信息),交叉熵是常见分类任务中的损失函数,最终的损失函数如下公式(6)。
其中β是线性组合的权重。
因为infoNCE需要使用数据集中所有的样本作为负样本,对于规模较大的数据集来说,计算整体性的蒸馏损失的计算代价过于昂贵。为了避免在训练过程计算样本的表征,特征记忆化存储技术已经被广泛地应用。在实施例所述方法中Gt and Gs是在随机选择的mini-Batch中分别构建的,基于图的表征Ht和Hs反应的整体性知识则在不同的属性图中呈现,他们是不能使用特征记忆化存储技术来存储的,因此,只分别使用特征记忆化存储技术来存储Ft,Fs。最终逼近的整体性蒸馏损失被定义为公式(7)所示:
4.部署模型,将训练好的学生模型在所需场景下进行部署。即可实现在有限的计算条件下实现,使用性能更好的学生模型进行图像分类。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。
Claims (2)
1.一种基于图神经网络的整体知识蒸馏方法,包括:
步骤1、分别构建老师模型与学生模型的属性图。将图像输入老师模型和学生模型得到特征表示ft,fs(其中dt和ds分别是老师模型和学生模型输出的特征表示的维度)以及分类预测的结果pt,ps,然后为老师模型及学生模型分别构建一个属性图Gt={At,Ft}、Gs={As,Fs},其中每个节点表示一个实例,节点属性表示学习到的特征表示,其中At,As是基于pt,ps构建的属性性图的邻接矩阵,并基于公式构造,其中是基于K近邻(Knearest neighbors,KNN)的图构造函数。
在整个训练过程中,Gt是固定不动的,Gs的属性及结构都是在动态变化的。
上述定义的属性图具有以下特性:首先,与现有关系知识分解方法构建的实例间完全连通图相比,KNN图将过滤掉最不相关的样本对。这一点尤其重要,因为在随机抽样的batch中只有少数样本是真正相关的,并且为节点表示的学习提供了足够的信息,本领域的研究人员对batch的概念应该是熟知的,这里不再赘述。其次,由于边是基于预测的概率构建的,因此该图能够对类间和类内信息进行建模。最后,可以利用图神经网络非常高效地从属性图中联合提取个体上的知识和关系上的知识。
步骤2、使用图卷积神经网络聚合属性图中邻域样本的节点属性以及拓扑信息来提取整体知识,由此得到基于图的嵌入向量。
如下述公式(1)、(2)所示,使用拓扑结构自适应的图卷积神经网络(TopologyAdaptive Graph Convolution Network,TAGCN)同时学习对节点的属性信息及拓扑信息进行学习,来提取整体性的知识,得到老师模型以及学生模型基于图的表示和
其中Θs l和Θt l是可学习的参数,gt,gs是上述图表示的维度。Dt,Ds是属性图度的对角阵,即公式(3)所示
步骤3、最大化学生模型与老师模型的图嵌入表示的互信息,并使用特征记忆化存储技术加速训练效率。
为了让学生模型尽可能地学习到老师模型整体性的知识,需要最大化Ht与Hs的相似程度,现存的很多基于向量的(vector-wise)相似度度量的方法(如余弦相似度、欧氏距离)等是不适合整体性知识的蒸馏的,他们往往受限于老师模型与学生模型模型结构差异而导致的表征能力的差异,并且直接对齐Ht与Hs可能会导致学习的知识过于精炼。为了克服上述限制,使用互信息来度量学生模型从老师模型蒸馏信息的相似程度,即最大化Ht与Hs的互信息,如下述公式(4)所示。
I(·)表示两个随机变量的互信息度量,受近期一些互信息估计方法方面的研究工作的启发,采用infoNCE来估计互信息,infoNCE与互信息的关系如公式(5)所示:
其中f(·)是一个vector-wise的相似度度量函数,ht i,hs i是样本i在老师模型及学生模型中学习出来基于图的表征。
学生模型在学习老师模型的知识以外,还需要学习知识本身(如数据中的标签信息),交叉熵是常见分类任务中的损失函数,最终的损失函数如下公式(6)。
其中β是线性组合的权重。
因为infoNCE需要使用数据集中所有的样本作为负样本,对于规模较大的数据集来说,计算整体性的蒸馏损失的计算代价过于昂贵。为了避免在训练过程计算样本的表征,特征记忆化存储技术已经被广泛地应用。在实施例所述方法中Gtand Gs是在随机选择的mini-Batch中分别构建的,基于图的表征Ht和Hs反应的整体性知识则在不同的属性图中呈现,他们是不能使用特征记忆化存储技术来存储的,因此,只分别使用特征记忆化存储技术来存储Ft,Fs。最终逼近的整体性蒸馏损失被定义为公式(7)所示:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110982472.1A CN113887698B (zh) | 2021-08-25 | 2021-08-25 | 基于图神经网络的整体知识蒸馏方法和*** |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110982472.1A CN113887698B (zh) | 2021-08-25 | 2021-08-25 | 基于图神经网络的整体知识蒸馏方法和*** |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113887698A true CN113887698A (zh) | 2022-01-04 |
CN113887698B CN113887698B (zh) | 2024-06-14 |
Family
ID=79011512
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110982472.1A Active CN113887698B (zh) | 2021-08-25 | 2021-08-25 | 基于图神经网络的整体知识蒸馏方法和*** |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113887698B (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115101119A (zh) * | 2022-06-27 | 2022-09-23 | 山东大学 | 基于网络嵌入的isoform功能预测*** |
CN117058437A (zh) * | 2023-06-16 | 2023-11-14 | 江苏大学 | 一种基于知识蒸馏的花卉分类方法、***、设备及介质 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200302295A1 (en) * | 2019-03-22 | 2020-09-24 | Royal Bank Of Canada | System and method for knowledge distillation between neural networks |
CN112116030A (zh) * | 2020-10-13 | 2020-12-22 | 浙江大学 | 一种基于向量标准化和知识蒸馏的图像分类方法 |
WO2021023202A1 (zh) * | 2019-08-07 | 2021-02-11 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法 |
CN112861936A (zh) * | 2021-01-26 | 2021-05-28 | 北京邮电大学 | 一种基于图神经网络知识蒸馏的图节点分类方法及装置 |
CN113095480A (zh) * | 2021-03-24 | 2021-07-09 | 重庆邮电大学 | 一种基于知识蒸馏的可解释图神经网络表示方法 |
-
2021
- 2021-08-25 CN CN202110982472.1A patent/CN113887698B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200302295A1 (en) * | 2019-03-22 | 2020-09-24 | Royal Bank Of Canada | System and method for knowledge distillation between neural networks |
WO2021023202A1 (zh) * | 2019-08-07 | 2021-02-11 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法 |
CN112116030A (zh) * | 2020-10-13 | 2020-12-22 | 浙江大学 | 一种基于向量标准化和知识蒸馏的图像分类方法 |
CN112861936A (zh) * | 2021-01-26 | 2021-05-28 | 北京邮电大学 | 一种基于图神经网络知识蒸馏的图节点分类方法及装置 |
CN113095480A (zh) * | 2021-03-24 | 2021-07-09 | 重庆邮电大学 | 一种基于知识蒸馏的可解释图神经网络表示方法 |
Non-Patent Citations (1)
Title |
---|
葛仕明;赵胜伟;***;李晨钰;: "基于深度特征蒸馏的人脸识别", 北京交通大学学报, no. 06, 15 December 2017 (2017-12-15) * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115101119A (zh) * | 2022-06-27 | 2022-09-23 | 山东大学 | 基于网络嵌入的isoform功能预测*** |
CN115101119B (zh) * | 2022-06-27 | 2024-05-17 | 山东大学 | 基于网络嵌入的isoform功能预测*** |
CN117058437A (zh) * | 2023-06-16 | 2023-11-14 | 江苏大学 | 一种基于知识蒸馏的花卉分类方法、***、设备及介质 |
CN117058437B (zh) * | 2023-06-16 | 2024-03-08 | 江苏大学 | 一种基于知识蒸馏的花卉分类方法、***、设备及介质 |
Also Published As
Publication number | Publication date |
---|---|
CN113887698B (zh) | 2024-06-14 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109271522B (zh) | 基于深度混合模型迁移学习的评论情感分类方法及*** | |
CN108132968B (zh) | 网络文本与图像中关联语义基元的弱监督学习方法 | |
CN111950594B (zh) | 基于子图采样的大规模属性图上的无监督图表示学习方法和装置 | |
CN110929161B (zh) | 一种面向大规模用户的个性化教学资源推荐方法 | |
CN113065974B (zh) | 一种基于动态网络表示学习的链路预测方法 | |
CN113887698A (zh) | 基于图神经网络的整体知识蒸馏方法和*** | |
CN110993037A (zh) | 一种基于多视图分类模型的蛋白质活性预测装置 | |
CN116304367B (zh) | 基于图自编码器自监督训练用于获得社区的算法及装置 | |
CN114299362A (zh) | 一种基于k-means聚类的小样本图像分类方法 | |
CN112115971B (zh) | 一种基于异质学术网络进行学者画像的方法及*** | |
CN111598252B (zh) | 基于深度学习的大学计算机基础知识解题方法 | |
CN113554100A (zh) | 异构图注意力网络增强的Web服务分类方法 | |
Chen et al. | RRGCCAN: Re-ranking via graph convolution channel attention network for person re-identification | |
CN112131261A (zh) | 基于社区网络的社区查询方法、装置和计算机设备 | |
CN115577283A (zh) | 一种实体分类方法、装置、电子设备及存储介质 | |
Shu et al. | Correntropy-based dual graph regularized nonnegative matrix factorization with L p smoothness for data representation | |
CN112446739A (zh) | 一种基于分解机和图神经网络的点击率预测方法及*** | |
CN111339258A (zh) | 基于知识图谱的大学计算机基础习题推荐方法 | |
CN116109834A (zh) | 一种基于局部正交特征注意力融合的小样本图像分类方法 | |
CN115828988A (zh) | 一种基于自监督的异构图表示学习方法 | |
CN113222018B (zh) | 一种图像分类方法 | |
CN111782964B (zh) | 一种社区帖子的推荐方法 | |
CN116484067A (zh) | 目标对象匹配方法、装置及计算机设备 | |
Longjiang | Test of English vocabulary recognition based on natural language processing and corpus system | |
CN114332469A (zh) | 模型训练方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant |