CN114997378A - 归纳式图神经网络剪枝方法、***、设备及存储介质 - Google Patents

归纳式图神经网络剪枝方法、***、设备及存储介质 Download PDF

Info

Publication number
CN114997378A
CN114997378A CN202210896971.3A CN202210896971A CN114997378A CN 114997378 A CN114997378 A CN 114997378A CN 202210896971 A CN202210896971 A CN 202210896971A CN 114997378 A CN114997378 A CN 114997378A
Authority
CN
China
Prior art keywords
mask
graph
neural network
graph data
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210896971.3A
Other languages
English (en)
Inventor
何向南
隋勇铎
王翔
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Science and Technology of China USTC
Original Assignee
University of Science and Technology of China USTC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Science and Technology of China USTC filed Critical University of Science and Technology of China USTC
Priority to CN202210896971.3A priority Critical patent/CN114997378A/zh
Publication of CN114997378A publication Critical patent/CN114997378A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种归纳式图神经网络剪枝方法、***、设备及存储介质,通过图数据剪枝器可以通过全局视角准确地衡量图数据中边的重要性,而且图数据剪枝器的可训练参数远少于UGS引入的参数,且数据规模变化时保持恒定不变,此外,图数据剪枝器可以将掩码生成的机制推广到全新的图,而无需重新训练,从而更有效地修剪看不见的和大规模的图数据,在不同的现实世界图学习任务或应用程序中变得更具可扩展性和灵活性。总之,本发明使用的归纳式图神经网络剪枝方案,可以有效地找到图彩票,能够在保持性能的同时对输入图数据和图神经网络进行共同稀疏化。

Description

归纳式图神经网络剪枝方法、***、设备及存储介质
技术领域
本发明涉及模型压缩与图神经网络领域,尤其涉及一种归纳式图神经网络剪枝方法、***、设备及存储介质。
背景技术
目前图神经网络(GNN)已经成为处理图结构数据一种最流行的模型。这种成功归咎于图神经网络的消息传递机制,即中心节点将邻居节点的信息进行聚合,并更新自己的节点表示。这种学习方式可以有效地将图的结构信息融入到节点的表示中。
随着图神经网络的发展,构建更深的图神经网络并将其部署在更大尺度的图结构数据的需求也越来越大。尽管加深图神经网络在大尺度的图数据上展现出非常优秀的潜能,然而由于图神经网络参数规模的扩大以及图数据尺度和规模的扩大,也带来了极其昂贵的计算代价,这限制了在计算资源受限情况下的应用范围。以交易网络中的欺诈检测为例,用户节点的规模很容易达到数百万甚至更大,使得基于图神经网络的检测器模型难以堆叠深层并实时地预测恶意行为。因此,剪枝过度参数化的图神经网络非常需要,这自然地使研究者们考虑以下问题:能否在保持性能的同时对输入图数据和图神经网络进行共同稀疏化。
最近,对于图神经网络的剪枝算法UGS被提出,目的是为了找到图彩票(GLT)。图彩票是网络参数和输入图的一个更小的子集。图彩票的基础是彩票假说(LTH),彩票假说推测了在任何稠密的、随机初始化的神经网络都包含一个稀疏子网络,可以独立训练该子网络以实现与稠密网络相当的性能。具体地,UGS在输入图数据中的每个边和网络参数中的每个权重上使用可训练掩码,以学习和判断它们的重要性。在使用掩码训练图神经网络时,使用迭代基于幅度的剪枝(IMP)策略,在每次迭代中丢弃掩码值最低的边和权重。
尽管UGS是有效的,但它存在以下局限性:(1)UGS为图数据的每个边独立地设置可训练的掩码。也就是说,图数据边上的掩码仅受限于给定的图数据,这使得UGS在归纳学习的设置中是无法应用的,因为边缘掩码很难推广到全新的图数据。(2)单独为图数据的每条边应用掩码只能提供对该边的局部理解,而不是整个图(如在节点分类任务中)或更多个图(如在图分类任务中)的全局视角。此外,创建可训练图数据边的掩码的方式会使可学习的参数量翻倍,这在某种程度上违背了剪枝的目的。(3)不理想的图剪枝会对网络权重的剪枝也会产生负面影响。更糟糕的是,低质量的权重修剪会反过来放大图数据边掩码的误导信号。它们可能会相互影响,形成恶性循环。
发明内容
本发明的目的是提供一种归纳式图神经网络剪枝方法、***、设备及存储介质,可以同时适用于归纳式和直推式图学习的场景,克服了现有工作只能适用于直推式学习的局限性,适用于更加广泛的图学习应用场景。
本发明的目的是通过以下技术方案实现的:
一种归纳式图神经网络剪枝方法,包括:
联合训练:对于输入图数据,通过图数据剪枝器生成输入图数据的节点表示,并基于节点表示预测节点之间的边的重要性得分,获得输入图数据的第一掩码;将第一掩码应用于所述输入图数据上,获得软掩码图数据,并利用所述软掩码图数据联合训练原始的图神经网络与图数据剪枝器,训练时,将原始的图神经网络预测的节点表示作为图数据剪枝器的监督信号;
联合稀疏化:原始的图神经网络与图数据剪枝器联合训练完毕后,对于输入图数据,利用训练后的图数据剪枝器获得第二掩码,并对第二掩码进行剪枝,获得二值图掩码;对于训练后的图神经网络,根据参数的幅度大小对参数进行剪枝,获得二进制模型掩码;将二值图掩码应用于输入图数据,将二进制模型掩码应用于原始的图神经网络,并进行稀疏性检查,若满足稀疏性要求,则完成图神经网络剪枝,若不满足稀疏性要求,则再次执行联合训练与联合稀疏化,直至满足稀疏性要求。
一种归纳式图神经网络剪枝***,包括:
联合训练单元,用于联合训练,所述联合训练包括:对于输入图数据,通过图数据剪枝器生成输入图数据的节点表示,并基于节点表示预测节点之间的边的重要性得分,获得输入图数据的第一掩码;将第一掩码应用于所述输入图数据上,获得软掩码图数据,并利用所述软掩码图数据联合训练原始的图神经网络与图数据剪枝器,训练时,将原始的图神经网络预测的节点表示作为图数据剪枝器的监督信号;
联合稀疏化单元,用于联合稀疏化,所述联合稀疏化包括:原始的图神经网络与图数据剪枝器联合训练完毕后,利用训练后的图数据剪枝器获得第二掩码,并对掩码进行剪枝,获得二值图掩码;对于训练后的图神经网络,根据参数的幅度大小对参数进行剪枝,获得二进制模型掩码;将二值图掩码应用于输入图数据,将二进制模型掩码应用于原始的图神经网络,并进行稀疏性检查,若满足稀疏性要求,则完成图神经网络剪枝,若不满足稀疏性要求,则再次执行联合训练与联合稀疏化,直至满足稀疏性要求。
一种处理设备,包括:一个或多个处理器;存储器,用于存储一个或多个程序;
其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现前述的方法。
一种可读存储介质,存储有计算机程序,当计算机程序被处理器执行时实现前述的方法。
由上述本发明提供的技术方案可以看出,(1)通过图数据剪枝器可以通过全局视角准确地衡量图数据中边的重要性。(2)图数据剪枝器的可训练参数远少于UGS引入的参数,且数据规模变化时保持恒定不变。(3)图数据剪枝器可以将掩码生成的机制推广到全新的图,而无需重新训练,从而更有效地修剪未知的和大规模的图数据,在不同的现实世界图学习任务或应用程序中变得更具可扩展性和灵活性。总之,本发明使用的归纳式图神经网络剪枝方案,可以有效地找到图彩票,能够在保持性能的同时对输入图数据和图神经网络进行共同稀疏化。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1为本发明实施例提供的一种归纳式图神经网络剪枝方法的框架图;
图2为本发明实施例提供的在MNIST数据集上的剪枝性能曲线;
图3为本发明实施例提供的在社交数据集COLLAB上的剪枝性能曲线;
图4为本发明实施例提供的在社交数据集RED-M5K上的剪枝性能曲线;
图5为本发明实施例提供的在生物分子数据集NCI1上的剪枝性能曲线;
图6为本发明实施例提供的子数据集OGBG-PPA上的剪枝性能曲线;
图7为本发明实施例提供的子数据集OGBG-CODE2上的剪枝性能曲线;
图8为本发明实施例提供的数据集CORA上的剪枝性能曲线;
图9为本发明实施例提供的数据集PPI上的剪枝性能曲线;
图10为本发明实施例提供的一种归纳式图神经网络剪枝***的示意图;
图11为本发明实施例提供的一种处理设备的示意图。
具体实施方式
下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。
首先对本文中可能使用的术语进行如下说明:
术语“和/或”是表示两者任一或两者同时均可实现,例如,X和/或Y表示既包括“X”或“Y”的情况也包括“X和Y”的三种情况。
“第一”与“第二”为区分不同阶段的同类特征的标识,不同阶段的同类特征的内容可以相同也可以不同。
术语“包括”、“包含”、“含有”、“具有”或其它类似语义的描述,应被解释为非排它性的包括。例如:包括某技术特征要素(如原料、组分、成分、载体、剂型、材料、尺寸、零件、部件、机构、装置、步骤、工序、方法、反应条件、加工条件、参数、算法、信号、数据、产品或制品等),应被解释为不仅包括明确列出的某技术特征要素,还可以包括未明确列出的本领域公知的其它技术特征要素。
下面对本发明所提供的一种归纳式图神经网络剪枝方法、***、设备及存储介质进行详细描述。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。本发明实施例中未注明具体条件者,按照本领域常规条件或制造商建议的条件进行。
实施例一
如之前背景技术部分介绍的UGS,经研究,UGS所有局限性都归因于其直推学习的特性。因此,研究如何在归纳设置中进行组合修剪对于获得高质量的图彩票至关重要,基于此,本发明实施例提供一种归纳式图神经网络剪枝方法(ICPG),如图1所示,其主要包括:
步骤1、联合训练(左侧虚线框部分)。
本步骤的主要流程如下:
(1)将输入图数据(稠密图数据)表示为
Figure 826364DEST_PATH_IMAGE001
,其中,
Figure 752732DEST_PATH_IMAGE002
表示输入图数据的邻接矩阵,X表示输入图数据中节点的特征矩阵。通过图数据剪枝器生成输入图数据的节点表示,并基于节点表示预测节点之间的边的重要性得分,获得输入图数据的第一掩码
Figure 567104DEST_PATH_IMAGE003
本发明实施例中,图结构数据(简称为图数据)是一种客观存在的、具有确切技术含义的技术性数据。以节点分类任务为例,分类的目标是以节点为单位,判断每一节点的类别,对于社交网络,输入图数据中的节点为每一个用户,通过用户信息可以提取相应的节点特征,进而构成节点的特征矩阵X,用户之间的交互关系即为节点之间的边,进而构成邻接矩阵
Figure 61670DEST_PATH_IMAGE002
。以图分类任务为例,分类的目标是图数据为单元,例如,化学分子数据构成了一个图数据,图分类的任务是判断该化学分子所属类别,图数据中的节点为原子,通过对每一个原子进行特征提取获得相应的节点特征,进而构成节点的特征矩阵X,原子之间的化学键即为节点之间的边,进而构成邻接矩阵
Figure 594283DEST_PATH_IMAGE002
(2)将第一掩码应用于所述输入图数据上,获得软掩码图数据,表示为:
Figure 58762DEST_PATH_IMAGE004
其中,
Figure 868586DEST_PATH_IMAGE005
表示软掩码图数据,
Figure 658688DEST_PATH_IMAGE006
表示逐元素相乘运算。
通过上述(1)~(2)的处理充分体现了图数据剪枝器对每个边重要性的决定,使得不太重要的边具有较低的掩码值。
(3)利用所述软掩码图数据联合训练原始的稠密图神经网络(简称为原始的图神经网络)与图数据剪枝器。训练时,通常使用给定的损失函数,例如图分类任务中常用交叉熵损失函数。本发明基于随机梯度下降算法同时优化GNN和图数据剪枝器的模型参数。由于原始GNN预测的节点表示作为图数据剪枝器的输入,因此作为监督信号可以指导图数据剪枝器实现更准确的剪枝决策。
步骤2、联合稀疏化(右侧虚线框部分)。
原始的图神经网络与图数据剪枝器联合训练的损失函数收敛后(即训练完毕),可以共同稀疏图数据和模型,本步骤的主要流程如下:
(1)对于输入图数据
Figure 52498DEST_PATH_IMAGE001
,利用训练后的图数据剪枝器获得第二掩码,并对第二掩码进行剪枝,获得二值图掩码。
具体的:根据第二掩码中每一个位置的非0的掩码值对输入图数据中相应节点的边的重要性进行排序,其中,若节点之间通过边连接,则掩码值为相应节点之间的边的重要性得分,若节点之间未通过边连接,则掩码值为0;设定第一比率
Figure 55089DEST_PATH_IMAGE007
,按照排序中的最低数值与第一比率
Figure 109632DEST_PATH_IMAGE007
,确定需要进行剪枝的掩码值,将第二掩码中需要剪枝的掩码值置为0,其余掩码值置为1,获得二值图掩码
Figure 946001DEST_PATH_IMAGE008
例如,第二掩码矩阵
Figure 453206DEST_PATH_IMAGE009
,定义的剪枝比例为20%(即
Figure 134854DEST_PATH_IMAGE010
),对第二掩码矩阵中非0的掩码值进行排序得到以下大小顺序:0.8,0.7,0.4,0.2,0.1,一共只有5个非零掩码值,因此,需要剪枝掉最小的(20%*5向下取整)1个掩码值,具体操为:将需要剪枝的掩码值置为0(掩码值0.1置为0),其余非0掩码值置为1(即其余4个掩码值全部置为1),最终得到的二值图掩码
Figure 43904DEST_PATH_IMAGE011
(2)对于训练后的图神经网络,根据参数的幅度大小对参数进行剪枝,获得二进制模型掩码。
具体的:设定第二比率
Figure 910229DEST_PATH_IMAGE012
,根据参数的幅度大小对参数进行排序,按照排序中的最小参数值与第二比率
Figure 45676DEST_PATH_IMAGE012
,确定需要进行剪枝的参数值,将训练后的图神经网络的参数中需要剪枝的参数值为0,其余参数值置为1,获得二进制模型掩码
Figure 390069DEST_PATH_IMAGE013
(3)将二值图掩码应用于输入图数据,将二进制模型掩码应用于原始的图神经网络,并进行稀疏性检查,若满足给定的稀疏性的要求,则完成图神经网络剪枝,若不满足稀疏性要求,需要使用当前获得的二值图掩码
Figure 153626DEST_PATH_IMAGE008
、二进制模型掩码
Figure 564753DEST_PATH_IMAGE013
以及图数据剪枝器和图神经网络模型原始的初始化参数来更新输入图数据和图神经网络,再次执行联合训练与联合稀疏化(即迭代执行步骤1与步骤2),直至满足稀疏性要求。
本发明实施例中,稀疏性检查即为判断当前是否满足终止条件,稀疏性的要求可以理解为预先设定的一个超参数。以稀疏性的要求为需要剪掉输入图数据中40%的边为例,每次迭代会根据设定的第一比率
Figure 312130DEST_PATH_IMAGE007
剪掉一部分边(即将相应掩码值置为0),因此,通过计算
Figure 335580DEST_PATH_IMAGE014
,求解n并取最小的整数,n即为迭代执行步骤1与步骤2的次数。
本发明实施例上述方案,可以同时适用于归纳式和直推式图学习的场景,克服了现有工作只能适用于直推式学习的局限性,适用于更加广泛的图学习应用场景,具体可以应用到基于图神经网络的图数据处理等任务中,极大地降低图神经网络模型对大规模、大尺度图结构数据的推断阶段的运算代价,并可以扩展到各种不同图结构数据处理任务中,例如节点分类任务,图分类任务等。
本领域技术人员可以理解,归纳式和直推式学习为行业术语。(1)直推式学习的定义:在训练过程中不仅需要用到训练集的全部信息,还用到测试集数据(不带标签)的信息。这也就意味着,只要有新的样本进来,模型就得重新训练,因此模型的泛化性和通用性较差。(2)归纳式学习的定义:模型在训练集上训练并在测试集上测试,其中训练集与测试集之间是相斥的,即测试集中的任何信息都是没有在训练集中出现过的。归纳式学习往往需要模型本身具备一定的通用性和泛化能力,才可以在测试阶段适用于全新的未见过的测试数据中。这是其他图剪枝算法,例如UGS无法做到的(无法剪枝未见过的图数据),这也是我们本发明的一个很大的优势和亮点。从上面的定义可以看出,归纳式学习的难度要比直推式学习的难度更大,适用范围更广,因此所有可以实现归纳式学习的算法都可以适用于直推式学习的场景。换言之,本发明也可以看成是从直推式学习到归纳式学习的一种推广。
为了便于理解,下面针对本发明所涉及的图数据剪枝器以及联合稀疏化后的性能判定做详细介绍。
一、图数据剪枝器。
本发明的目的并不是将掩码分配给输入图数据的每个边,而是在给定的输入图数据上,设计一个可训练的模型来来预测图数据边的掩码,该可训练的模型为基于图神经网络结构的预测模型,称之为图数据剪枝器,其参数在训练中所观察到的图数据中共享。
本发明实施例中,将图数据剪枝器表示为图神经网络编码器和后续评分函数的组合,主要包括:
1、图神经网络编码器。
它主要用于根据输入图像数据,生成节点表示,表示为:
Figure 219223DEST_PATH_IMAGE015
其中,H表示节点表示矩阵,包含所有节点表示,
Figure 427350DEST_PATH_IMAGE016
Figure 271809DEST_PATH_IMAGE017
表示实数集,d表示单个节点表示的维度,
Figure 958006DEST_PATH_IMAGE018
表示节点集合
Figure 696155DEST_PATH_IMAGE019
中的节点数目,节点表示矩阵H的第ih i 为节点v i 的节点表示;
Figure 950549DEST_PATH_IMAGE001
表示输入图数据,
Figure 406939DEST_PATH_IMAGE020
表示图神经网络编码器。
2、多层感知机。
它主要用于基于节点表示预测节点之间的边的重要性分数,表示为:
Figure 270727DEST_PATH_IMAGE021
其中,MLP表示多层感知机,h i h j 分别节点v i 与节点v j 的节点表示,[,]表示拼接操作,
Figure 128962DEST_PATH_IMAGE022
表示节点v i 与节点v j 之间的边(i,j)的重要性分数。
3、sigmoid函数层。
它主要用于将边的重要性分数投影至指定区间,具体为(0,1)区间,获得边的重要性得分,表示为:
Figure 678892DEST_PATH_IMAGE023
其中,
Figure 497943DEST_PATH_IMAGE024
表示节点v i 与节点v j 之间的边(i,j)的重要性得分,
Figure 525942DEST_PATH_IMAGE025
表示sigmoid函数。
4、掩码矩阵输出层。
它主要用于根据边的重要性得分组成掩码矩阵,联合训练时,掩码矩阵为第一掩码
Figure 114049DEST_PATH_IMAGE026
,联合稀疏化时,掩码矩阵为第二掩码,也就是说,步骤1与步骤2中使用图数据剪枝器执行的过程是相同的。
掩码矩阵中第i行第j列的掩码值为:节点i与节点j之间的边的重要性得分,或者为0。具体的:以第一掩码
Figure 834881DEST_PATH_IMAGE027
为例,如果边(i,j)存在,则
Figure 265862DEST_PATH_IMAGE028
,否则,
Figure 238497DEST_PATH_IMAGE029
,其中,
Figure 540166DEST_PATH_IMAGE030
表示掩码矩阵
Figure 71379DEST_PATH_IMAGE026
中第i行第j列的掩码值。
以联合训练阶段为例,可以将图数据剪枝器执行的过程总结为:
Figure 724077DEST_PATH_IMAGE031
其中,GraphMasker表示图数据剪枝器,
Figure 359458DEST_PATH_IMAGE032
表示输入图数据,
Figure 656578DEST_PATH_IMAGE033
表示图数据剪枝器的参数,包括图神经网络编码器与多层感知机中的参数。
本发明实施例提供的图数据剪枝器与UGS相比,主要具有如下优点:
(1)全局的视角:尽管UGS生成的图数据的边的掩码可能保持对局部重要性的保真度,但它们无助于描绘整个图数据的总体情况。与 UGS不同的是,本发明实施例提供的图数据剪枝器采用全局视角来审视所有的图数据,能够准确地识别出边的集合。具体来说,由于图的边通常相互协作以进行预测,而不是单独工作,它们形成了一个集合,而类似于分子图的功能团,社交网络的社区。考虑到这种联合效应,图数据剪枝器能够更准确地衡量图中边的重要性。
(2)轻量级边掩码:利用UGS剪枝具有数百万条边或节点的图数据时,在现实世界场景中如此大规模的数据集下,一对一为每条边分配掩码的成本是不现实的。此外,UGS引入了额外的可训练参数,其规模与边数保持相同。并且比被修剪的原始参数大得多。因此,它在某种程度上违反了修剪的目的。在本发明实施例提供的图数据剪枝器中,附加参数仅是等式中的
Figure 984791DEST_PATH_IMAGE033
,并且在数据规模变化时保持恒定不变。
(3)归纳式剪枝:图数据剪枝器可以将掩码生成的机制归纳式地推广到全新的图,而无需重新训练,从而更有效地修剪看不见的和大规模的图。因此,它使剪枝算法在不同的现实世界图学习任务或应用程序中变得更具可扩展性和灵活性。
二、联合稀疏化后的性能判定。
通过本发明上述方案对输入图数据与图神经网络进行联合稀疏化,此时通过性能测试来判断是否构成图彩票。通常来说,学习图彩票目的是使得给定的输入图数据
Figure 859206DEST_PATH_IMAGE001
和网络初始化参数
Figure 439223DEST_PATH_IMAGE034
稀疏化,以减少计算代价,并同时保留性能。如之前联合稀疏化所述,获得二值图掩码
Figure 715484DEST_PATH_IMAGE008
与二进制模型掩码
Figure 89965DEST_PATH_IMAGE013
,并分别应用于输入图数据与图神经网络的参数。将二值图掩码应用于输入图数据后,获得的图数据称为稀疏图数据
Figure 451676DEST_PATH_IMAGE035
Figure 428859DEST_PATH_IMAGE036
;将二进制模型掩码应用于原始初始化的图神经网络参数
Figure 199107DEST_PATH_IMAGE037
上,获得的网络称为子网络
Figure 869122DEST_PATH_IMAGE038
,其中,
Figure 718130DEST_PATH_IMAGE039
表示二进制模型掩码应用于原始初始化的图神经网络参数,
Figure 374370DEST_PATH_IMAGE040
;若稀疏图数据
Figure 625223DEST_PATH_IMAGE035
与参数
Figure 341506DEST_PATH_IMAGE039
满足给定的稀疏性的要求,则利用稀疏图数据
Figure 677810DEST_PATH_IMAGE035
对子网络
Figure 262375DEST_PATH_IMAGE038
进行训练,获得收敛的网络参数
Figure 508679DEST_PATH_IMAGE041
,判断使用收敛的网络参数
Figure 254918DEST_PATH_IMAGE041
的子网络
Figure 464138DEST_PATH_IMAGE042
性能是否满足要求,即是否可以实现与输入图数据
Figure 586815DEST_PATH_IMAGE043
以及原始的图神经网络相当的性能,若性能满足设定要求,则稀疏图数据
Figure 812260DEST_PATH_IMAGE035
与子网络
Figure 604766DEST_PATH_IMAGE038
构成一个图彩票。
为了说明本发明实施例上述方案的有效性,提供如下验证实验。
一、图分类任务中的剪枝性能验证。
首先在图分类任务中去寻找图彩票,结果如图2~图5所示,图2对应数据集MNIST,图3对应社交数据集COLLAB,图4对应社交数据集RED-M5K,图5对应生物分子数据集NCI1,图2~图5中:ICPG代表本发明提供的方案,随机剪枝为参与对比的现有方案,基线表示的是原始GNN在稠密图数据上的性能,星型形状表示的极限稀疏度,表示的是使得性能不下降的最大稀疏度。通过观察结果,有以下发现:
1、图彩票广泛地存在于图分类任务中。利用ICPG,成功在不同类型的图数据上定位出了更稀疏的图彩票,对于生物分子数据集NCI1上,精确地在26.49%的图稀疏度范围内识别出图彩票。对于社交数据集COLLAB、RED-M5K上,在33.66%-40.13%稀疏度下识别出图彩票。对于数据集MNIST,可以实现43.13%的稀疏度。这些结果表明ICPG可以归纳式地定位高质量的图彩票。
2、图数据剪枝器有很好的泛化性。主流的图稀疏化技术无法归纳式地剪枝未见过的新的图数据,然而图数据剪枝器可以灵活地克服这个缺点。与随机剪枝相比,ICPG可以找到更加稀疏的图数据和子网络,并可以与随机剪枝保持一个很大的差距。例如在RED-M5K数据集上,ICPG可以实现40.13%的稀疏度,与随机剪枝相比,提高了25.87%,这实现了极大的提升。这表明图数据剪枝器可以精确地从训练图数据中捕捉更加重要的核心模式,并可以在未知的新的图数据上实现良好的泛化性。
3、图彩票的极限稀疏度依赖于图数据的属性。虽然ICPG在大多数图上实现了比随机剪枝更高的稀疏度,但在小部分图上改进并不明显,例如生物分子数据集NCI1。可以做出以下猜想:首先,这些图数据中的大多数边都是重要的,例如,某个边可能对应于一个关键的化学键,如果剪枝可能会极大地影响分子的化学性质。其次,图数据的大小比较小,只有几十个节点和边,所以对剪枝比较敏感。相反,较大的社交网络数据集,例如RED-M5K,每个图包含数百或数千条边,因此可能有大量冗余边表明对剪枝不敏感。
二、更大规模和尺度图分类数据集和节点分类数据集的剪枝性能的验证。
1、图数据剪枝器可以处理更大尺度和更大规模的图数据。图6与图7展示了数据集OGB的结果,OGBG-PPA与OGBG-CODE2为数据集OGB的两个子数据集,它们均为大规模大尺度图分类数据集。其中,子数据集OGBG-PPA平均每个图包含了2266.1条边以及243.4个节点,子数据集OGBG-CODE2包含了452,741个图数据。实验发现随机剪枝在OGBG-PPA数据集上仅仅能定位5%稀疏度的GLT,甚至在OGBG-CODE2上无法找到更加稀疏的GLT。尽管如此,ICPG可以在OGBG-PPA和OGBG-CODE2上分别找到14.26%和18.55%稀疏度的GLT。这优越的性能更加证实了ICPG的强大泛化性和可扩展性。
2、ICPG可以在节点分类任务中实现非常好的性能。此部分使用节点分类数据集:数据集CORA与数据集PPI。图8展示了数据集CORA上的结果,在数据集CORA上,本发明提出的ICPG剪枝算法可以超过随机剪枝,并保持一个比较大的差距。图9展示了数据集PPI上的结果,对于归纳式的数据集PPI,ICPG仍然可以实现非常好的性能,与随机剪枝相比有22.62%的提升,这进一步表明了ICPG在归纳式GNN剪枝的有效性。
三、剪枝性能以及计算代价验证。
1、性能比较。采用GraphSAGE和DropEdge作为基线方法,二者均为现有方案,且都是基于图稀疏化和有效训练的。为了公平地比较,通过调整GraphSAGE和DropEdge的超参数使得与ICPG保持相似的稀疏度。从表1看出,ICPG一致地优于所有的基线方法。
表1:不同稀疏度下的性能(准确率)比较
Figure 181241DEST_PATH_IMAGE044
2、推断阶段计算代价(MACs)。将实验中的稀疏度转化成推断时的MACs的减少,来衡量计算代价。结果如表2所示,与完整的基线模型GIN(图同构网络)相比,GIN + ICPG可以极大地降低计算代价在大概51.86%-93.71%左右,并同时适用于从小尺度的数据集到大尺度的数据集,而没有降低任何性能;其中,图同构网络属于图神经网络的一种,GIN + ICPG表示图1所示的框架中使用图同构网络。这些结果反应了本发明的实用性。
表2:推断阶段计算代价(MACs)比较
Figure 842030DEST_PATH_IMAGE045
需要说明的是,图2~图9所示8幅的性能曲线图中篇幅有限,因此,横坐标部分仅展示了部分相关位置点及前后若干位置点处的稀疏度值。
实施例二
本发明还提供一种归纳式图神经网络剪枝***,其主要基于前述实施例提供的方法实现,如图10所示,该***主要包括:
联合训练单元,用于联合训练,所述联合训练包括:对于输入图数据,通过图数据剪枝器生成输入图数据的节点表示,并基于节点表示预测节点之间的边的重要性得分,获得输入图数据的第一掩码;将第一掩码应用于所述输入图数据上,获得软掩码图数据,并利用所述软掩码图数据联合训练原始的图神经网络与图数据剪枝器,训练时,将原始的图神经网络预测的节点表示作为图数据剪枝器的监督信号;
联合稀疏化单元,用于联合稀疏化,所述联合稀疏化包括:原始的图神经网络与图数据剪枝器联合训练完毕后,利用训练后的图数据剪枝器获得第二掩码,并对掩码进行剪枝,获得二值图掩码;对于训练后的图神经网络,根据参数的幅度大小对参数进行剪枝,获得二进制模型掩码;将二值图掩码应用于输入图数据,将二进制模型掩码应用于原始的图神经网络,并进行稀疏性检查,若满足稀疏性要求,则完成图神经网络剪枝,若不满足稀疏性要求,则再次执行联合训练与联合稀疏化,直至满足稀疏性要求。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将***的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。
实施例三
本发明还提供一种处理设备,如图11所示,其主要包括:一个或多个处理器;存储器,用于存储一个或多个程序;其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现前述实施例提供的方法。
进一步的,所述处理设备还包括至少一个输入设备与至少一个输出设备;在所述处理设备中,处理器、存储器、输入设备、输出设备之间通过总线连接。
本发明实施例中,所述存储器、输入设备与输出设备的具体类型不做限定;例如:
输入设备可以为触摸屏、图像采集设备、物理按键或者鼠标等;
输出设备可以为显示终端;
存储器可以为随机存取存储器(Random Access Memory,RAM),也可为非不稳定的存储器(non-volatile memory),例如磁盘存储器。
实施例四
本发明还提供一种可读存储介质,存储有计算机程序,当计算机程序被处理器执行时实现前述实施例提供的方法。
本发明实施例中可读存储介质作为计算机可读存储介质,可以设置于前述处理设备中,例如,作为处理设备中的存储器。此外,所述可读存储介质也可以是U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。

Claims (10)

1.一种归纳式图神经网络剪枝方法,其特征在于,包括:
联合训练:对于输入图数据,通过图数据剪枝器生成输入图数据的节点表示,并基于节点表示预测节点之间的边的重要性得分,获得输入图数据的第一掩码;将第一掩码应用于所述输入图数据上,获得软掩码图数据,并利用所述软掩码图数据联合训练原始的图神经网络与图数据剪枝器,训练时,将原始的图神经网络预测的节点表示作为图数据剪枝器的监督信号;
联合稀疏化:原始的图神经网络与图数据剪枝器联合训练完毕后,对于输入图数据,利用训练后的图数据剪枝器获得第二掩码,并对第二掩码进行剪枝,获得二值图掩码;对于训练后的图神经网络,根据参数的幅度大小对参数进行剪枝,获得二进制模型掩码;将二值图掩码应用于输入图数据,将二进制模型掩码应用于原始的图神经网络,并进行稀疏性检查,若满足稀疏性要求,则完成图神经网络剪枝,若不满足稀疏性要求,则再次执行联合训练与联合稀疏化,直至满足稀疏性要求。
2.根据权利要求1所述的一种归纳式图神经网络剪枝方法,其特征在于,所述图数据剪枝器为基于图神经网络结构的预测模型,包括:
图神经网络编码器,用于根据输入图像数据,生成节点表示;
多层感知机,用于基于节点表示预测节点之间的边的重要性分数;
sigmoid函数层,用于将边的重要性分数投影至指定区间,获得边的重要性得分;
掩码矩阵输出层,用于根据边的重要性得分组成掩码矩阵;联合训练时,掩码矩阵为第一掩码,联合稀疏化时,掩码矩阵为第二掩码。
3.根据权利要求2所述的一种归纳式图神经网络剪枝方法,其特征在于,图数据剪枝器所涉及的计算流程表示为:
生成节点表示:
Figure 904840DEST_PATH_IMAGE001
;其中,H表示节点表示矩阵,包含所有节点表示;
Figure 222427DEST_PATH_IMAGE002
表示输入图数据,
Figure 516005DEST_PATH_IMAGE003
表示输入图数据的邻接矩阵,X表示输入图数据中节点的特征矩阵,
Figure 963167DEST_PATH_IMAGE004
表示图神经网络编码器;
预测边的重要性分数:
Figure 559364DEST_PATH_IMAGE005
;其中,MLP表示多层感知机,h i h j 分别节点v i 与节点v j 的节点表示,[,]表示拼接操作,
Figure 990346DEST_PATH_IMAGE006
表示节点v i 与节点v j 之间的边(i,j)的重要性分数;
计算边的重要性得分:
Figure 822035DEST_PATH_IMAGE007
;其中,
Figure 264649DEST_PATH_IMAGE008
表示节点v i 与节点v j 之间的边(i,j)的重要性得分,
Figure 156382DEST_PATH_IMAGE009
表示sigmoid函数;
利用所有边的重要性得分组成掩码矩阵,联合训练时,掩码矩阵为第一掩码,联合稀疏化时,掩码矩阵为第二掩码,如果边(i,j)存在,则掩码矩阵中第i行第j列的掩码值为
Figure 809080DEST_PATH_IMAGE008
,否则为0。
4.根据权利要求1所述的一种归纳式图神经网络剪枝方法,其特征在于,所述将第一掩码应用于所述输入图数据上,获得软掩码图数据包括:
将输入图数据记为
Figure 585406DEST_PATH_IMAGE002
,将第一掩码记为
Figure 741581DEST_PATH_IMAGE010
,软掩码图数据表示为:
Figure 443695DEST_PATH_IMAGE011
其中,
Figure 318111DEST_PATH_IMAGE012
表示软掩码图数据,
Figure 22761DEST_PATH_IMAGE013
表示逐元素相乘运算。
5.根据权利要求1所述的一种归纳式图神经网络剪枝方法,其特征在于,所述利用训练后的图数据剪枝器获得第二掩码,并对第二掩码进行剪枝,获得二值图掩码包括:
根据第二掩码中每一个位置的非0的掩码值对输入图数据中相应节点的边的重要性进行排序,其中,若节点之间通过边连接,则掩码值为相应节点之间的边的重要性得分,若节点之间未通过边连接,则掩码值为0;
设定第一比率
Figure 174388DEST_PATH_IMAGE014
,按照排序中的最低数值与第一比率
Figure 673503DEST_PATH_IMAGE014
,确定需要进行剪枝的掩码值,将第二掩码中需要剪枝的掩码值置为0,其余掩码值置为1,获得二值图掩码
Figure 176159DEST_PATH_IMAGE015
6.根据权利要求1所述的一种归纳式图神经网络剪枝方法,其特征在于,所述对于训练后的图神经网络,根据参数的幅度大小对参数进行剪枝,获得二进制模型掩码包括:
设定第二比率
Figure 153342DEST_PATH_IMAGE016
,按照排序中的最小参数值与第二比率
Figure 549689DEST_PATH_IMAGE016
,确定需要进行剪枝的参数值,将训练后的图神经网络的参数中需要剪枝的参数值为0,其余参数值置为1,获得二进制模型掩码
Figure 829491DEST_PATH_IMAGE017
7.根据权利要求1所述的一种归纳式图神经网络剪枝方法,其特征在于,该方法还包括通过性能测试来判断联合稀疏化后的输入图数据与图神经网络是否构成图彩票,方式包括:
将二值图掩码应用于输入图数据后,获得的图数据称为稀疏图数据
Figure 944078DEST_PATH_IMAGE018
;将二进制模型掩码应用于原始的图神经网络后,获得的网络称为子网络
Figure 98854DEST_PATH_IMAGE019
,其中,
Figure 349706DEST_PATH_IMAGE020
表示二进制模型掩码应用于原始的图神经网络后的参数;
若稀疏图数据
Figure 925044DEST_PATH_IMAGE018
与子网络
Figure 402293DEST_PATH_IMAGE019
满足稀疏性要求,则利用稀疏图数据
Figure 986858DEST_PATH_IMAGE018
对子网络
Figure 92217DEST_PATH_IMAGE021
进行训练,获得收敛的网络参数
Figure 979402DEST_PATH_IMAGE022
;若使用收敛的网络参数
Figure 537422DEST_PATH_IMAGE022
的子网络
Figure 801044DEST_PATH_IMAGE023
性能满足要求,则稀疏图数据
Figure 760910DEST_PATH_IMAGE018
与子网络
Figure 943630DEST_PATH_IMAGE019
构成一个图彩票。
8.一种归纳式图神经网络剪枝***,其特征在于,基于权利要求1~7任一项所述的方法实现,该***包括:
联合训练单元,用于联合训练,所述联合训练包括:对于输入图数据,通过图数据剪枝器生成输入图数据的节点表示,并基于节点表示预测节点之间的边的重要性得分,获得输入图数据的第一掩码;将第一掩码应用于所述输入图数据上,获得软掩码图数据,并利用所述软掩码图数据联合训练原始的图神经网络与图数据剪枝器,训练时,将原始的图神经网络预测的节点表示作为图数据剪枝器的监督信号;
联合稀疏化单元,用于联合稀疏化,所述联合稀疏化包括:原始的图神经网络与图数据剪枝器联合训练完毕后,利用训练后的图数据剪枝器获得第二掩码,并对掩码进行剪枝,获得二值图掩码;对于训练后的图神经网络,根据参数的幅度大小对参数进行剪枝,获得二进制模型掩码;将二值图掩码应用于输入图数据,将二进制模型掩码应用于原始的图神经网络,并进行稀疏性检查,若满足稀疏性要求,则完成图神经网络剪枝,若不满足稀疏性要求,则再次执行联合训练与联合稀疏化,直至满足稀疏性要求。
9.一种处理设备,其特征在于,包括:一个或多个处理器;存储器,用于存储一个或多个程序;
其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现如权利要求1~7任一项所述的方法。
10.一种可读存储介质,存储有计算机程序,其特征在于,当计算机程序被处理器执行时实现如权利要求1~7任一项所述的方法。
CN202210896971.3A 2022-07-28 2022-07-28 归纳式图神经网络剪枝方法、***、设备及存储介质 Pending CN114997378A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210896971.3A CN114997378A (zh) 2022-07-28 2022-07-28 归纳式图神经网络剪枝方法、***、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210896971.3A CN114997378A (zh) 2022-07-28 2022-07-28 归纳式图神经网络剪枝方法、***、设备及存储介质

Publications (1)

Publication Number Publication Date
CN114997378A true CN114997378A (zh) 2022-09-02

Family

ID=83022644

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210896971.3A Pending CN114997378A (zh) 2022-07-28 2022-07-28 归纳式图神经网络剪枝方法、***、设备及存储介质

Country Status (1)

Country Link
CN (1) CN114997378A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116994309A (zh) * 2023-05-06 2023-11-03 浙江大学 一种公平性感知的人脸识别模型剪枝方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116994309A (zh) * 2023-05-06 2023-11-03 浙江大学 一种公平性感知的人脸识别模型剪枝方法
CN116994309B (zh) * 2023-05-06 2024-04-09 浙江大学 一种公平性感知的人脸识别模型剪枝方法

Similar Documents

Publication Publication Date Title
Ghorbani et al. Neuron shapley: Discovering the responsible neurons
CN111860638B (zh) 基于不平衡数据深度信念网络的并行入侵检测方法和***
CN103559504B (zh) 图像目标类别识别方法及装置
CN105224872B (zh) 一种基于神经网络聚类的用户异常行为检测方法
CN107292097B (zh) 基于特征组的中医主症选择方法
CN103927550B (zh) 一种手写体数字识别方法及***
CN111783841A (zh) 基于迁移学习和模型融合的垃圾分类方法、***及介质
Lin et al. Fairgrape: Fairness-aware gradient pruning method for face attribute classification
CN111311702B (zh) 一种基于BlockGAN的图像生成和识别模块及方法
CN114913379A (zh) 基于多任务动态对比学习的遥感图像小样本场景分类方法
CN115114484A (zh) 异常事件检测方法、装置、计算机设备和存储介质
Balakrishnan et al. Meticulous fuzzy convolution C means for optimized big data analytics: adaptation towards deep learning
CN114997378A (zh) 归纳式图神经网络剪枝方法、***、设备及存储介质
Datta et al. Computational intelligence for observation and monitoring: a case study of imbalanced hyperspectral image data classification
CN110992194A (zh) 一种基于含属性的多进程采样图表示学习模型的用户参考指数算法
CN109934352B (zh) 智能模型的自动进化方法
Zhang et al. Multi-weather classification using evolutionary algorithm on efficientnet
Parker et al. Nonlinear time series classification using bispectrum‐based deep convolutional neural networks
CN116467466A (zh) 基于知识图谱的编码推荐方法、装置、设备及介质
CN114265954B (zh) 基于位置与结构信息的图表示学习方法
Chen et al. Gaussian mixture embedding of multiple node roles in networks
Sufikarimi et al. Speed up biological inspired object recognition, HMAX
US20230041338A1 (en) Graph data processing method, device, and computer program product
Amorim et al. Supervised learning using local analysis in an optimal-path forest
CN110517326B (zh) 一种基于权重蜻蜓算法的比色传感器阵列优化方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20220902