CN113361627A - 一种面向图神经网络的标签感知协同训练方法 - Google Patents

一种面向图神经网络的标签感知协同训练方法 Download PDF

Info

Publication number
CN113361627A
CN113361627A CN202110697015.8A CN202110697015A CN113361627A CN 113361627 A CN113361627 A CN 113361627A CN 202110697015 A CN202110697015 A CN 202110697015A CN 113361627 A CN113361627 A CN 113361627A
Authority
CN
China
Prior art keywords
node
neural network
label
graph
nodes
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110697015.8A
Other languages
English (en)
Inventor
王杰
贺华瑞
张占秋
陈佳俊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Science and Technology of China USTC
Original Assignee
University of Science and Technology of China USTC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Science and Technology of China USTC filed Critical University of Science and Technology of China USTC
Priority to CN202110697015.8A priority Critical patent/CN113361627A/zh
Publication of CN113361627A publication Critical patent/CN113361627A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种面向图神经网络的标签感知协同训练方法,包括:步1,用图神经网络将已知标签沿连边传播估计出各结点初步标签取值分布,衡量各结点的初步标签取值分布属于某类的初步预测置信度;步2,为初步预测置信度最高的预定个数结点标注伪标签形成增广的有标注图结构数据集;步3,利用图神经网络学习各结点的向量表示并做出全局预测,得出未标注结点的最终预测值;步4,将最终预测值作为未标注结点的初始标签,重复步1至步4,得出各未标注结点的最终分类结果。通过充分利用少量已知标签建模全局标签依赖性,结合标签传播和图神经网络特征传播优势,将全局标签依赖信息隐式融入图神经网络训练过程,得到更客观结点表示,提升分类性能。

Description

一种面向图神经网络的标签感知协同训练方法
技术领域
本发明涉及图神经网络的半监督目标分类领域,尤其涉及一种面向图神经网络的标签感知协同训练方法。
背景技术
图神经网络以图结构数据作为输入,根据“相邻的结点具有相似性”这一假设,对图中结点采用聚合其周围结点信息作为自身的表征的方式,将图中结点映射为连续向量空间中的向量,使得结构上相近的结点在嵌入空间中有相似的向量表示。近年来,面向半监督目标分类的图神经网络技术在诸如网页分类、语音识别、图像检测和蛋白质性质预测等领域取得了巨大的成就。
现有的图神经网络(包括图卷积网络-GCN、图注意力网络-GAT等)训练方式主要关注局部的结点特征信息,而对少量的已知标签利用不足,没有考虑全局的标签依赖信息。例如,图卷积网络(GCN)通过堆叠一阶Chebyshev过滤层来简化图信号的频域卷积操作,堆叠层数即考虑的邻居结点跳数,层数越少则考虑的局部范围越小;图注意力网络(GAT)通过为不同邻居结点赋值大小不一的权重来区分亲疏远近的邻居结点。上述两种经典的图神经网络都没有充分利用已知标签的信息,仅将已知标签作为监督信号,独立地对每个结点做误差反向传播。
这样的训练过程隐式地假设了给定邻域特征后,各结点的标签是条件独立的。该假设违背了传统的统计关系学习方法所认为的各结点的标签之间是存在依赖和影响关系的的基本观点。这也就限制了图神经网络的性能表现,造成现有图神经网络训练方式的缺陷与不足。
发明内容
针对现有技术所存在的问题,本发明的目的是提供一种面向图神经网络的标签感知协同训练方法,能解决现有训练方法并未有效利用各结点的标签之间存在的依赖和影响关系,存在限制图神经网络的性能表现,造成现有图神经网络训练方式存在缺陷与不足的问题。
本发明的目的是通过以下技术方案实现的:
本发明实施方式提供一种面向图神经网络的标签感知协同训练方法,用于通过图神经网络为图结构数据集中的未标注结点进行分类,所述图结构数据集由已标注结点集合和未标注结点集合组成,所述已标注结点集合占所述图结构数据集的总结点数量小于等于1%,包括:
步骤1,设定图神经网络的网络层数L、协同训练回合数K、标签传播迭代轮数T、标签传播权重α和协同训练每回合标注伪标签个数m,将所述图结构数据集对应的图邻接矩阵A、归一化邻接矩阵
Figure BDA0003128892970000021
结点特征矩阵X和已标注结点的已知标签YL输入所述图神经网络进行处理;
步骤2,所述图神经网络通过将已标注结点的已知标签沿着连边传播估计出每个结点初步的标签取值分布F(T),用基尼系数衡量每个结点初步的标签取值分布属于某类的初步预测置信度;
所述每个结点初步的标签取值分布F(T)的迭代计算公式为:F(t)=αSF(t-1)+(1-α)Y,t取值为1至T,
Figure BDA0003128892970000022
Figure BDA0003128892970000023
表示未标注结点的初始标签,是所述图神经网络对未标注结点的预测值,第一回合协同训练中
Figure BDA0003128892970000024
为空,每一回合训练中所述公式F(t)=αSF(t-1)+(1-α)Y的迭代执行次数为T次;
步骤3,为初步预测置信度最高的预定m个结点标注伪标签形成增广的有标注图结构数据集Ltrain
步骤4,利用所述图神经网络基于所述增广的有标注图结构数据集Ltrain,学习各结点的向量表示并做出全局预测,得出未标注结点的最终预测值
Figure BDA0003128892970000025
步骤5,将所述图神经网络对未标注结点的最终预测值
Figure BDA0003128892970000026
作为未标注结点的初始标签,按步骤1设定的协同训练回合数K重复进行所述步骤2至步骤5,得出各未标注结点的最终分类结果。
由上述本发明提供的技术方案可以看出,本发明实施例提供的面向图神经网络的标签感知协同训练方法,其有益效果为:
通过充分利用少量标签信息建模全局标签依赖性,结合了标签传播方式和主流图神经网络的特征传播方式的优势,捕捉全局标签依赖信息并将其隐式融入到图神经网络训练过程中,以得到更客观的结点表示,提升模型的分类性能。经实验验证,在没有引入额外的参数和计算量的情况下,本发明的方法在主流半监督结点分类数据集上达到了优越性能,超越了之前方法的性能表现。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
图1为本发明实施例提供的面向图神经网络的标签感知协同训练方法的流程图;
图2为本发明实施例提供的面向图神经网络的标签感知协同训练方法整体框架示意图。
具体实施方式
下面结合本发明的具体内容,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。
参见图1,本发明实施例提供一种面向图神经网络的标签感知协同训练方法,用于通过图神经网络为图结构数据集中的未标注结点进行分类,所述图结构数据集由已标注结点集合和未标注结点集合组成,所述已标注结点集合占所述图结构数据集的总结点数量小于等于1%,包括:
步骤1,设定图神经网络的网络层数L、协同训练回合数K、标签传播迭代轮数T、标签传播权重α和协同训练每回合标注伪标签个数m,将所述图结构数据集对应的图邻接矩阵A、归一化邻接矩阵
Figure BDA0003128892970000031
结点特征矩阵X和已标注结点的已知标签YL输入所述图神经网络进行处理;
步骤2,所述图神经网络通过将已标注结点的已知标签沿着连边传播估计出每个结点初步的标签取值分布F(T),用基尼系数衡量每个结点初步的标签取值分布属于某类的初步预测置信度,作为估计出的全局标签依赖性;
所述每个结点初步的标签取值分布F(T)的迭代计算公式为:F(t)=αSF(t-1)+(1-α)Yt取值为1至T,
Figure BDA0003128892970000032
Figure BDA0003128892970000041
表示未标注结点的初始标签,是所述图神经网络对未标注结点的预测值,第一回合协同训练中
Figure BDA0003128892970000042
为空,每一回合训练中所述公式F(t)=αSF(t-1)+(1-α)Y的迭代执行次数为T次;
步骤3,根据估计出的全局标签依赖性,为初步预测置信度最高的预定m个结点标注伪标签形成增广的有标注图结构数据集Ltrain
步骤4,利用所述图神经网络基于所述增广的有标注图结构数据集Ltrain,学习各结点的向量表示并做出全局预测,得出未标注结点的最终预测值
Figure BDA0003128892970000043
步骤5,将所述图神经网络对未标注结点的最终预测值
Figure BDA0003128892970000044
作为未标注结点的初始标签,按步骤1设定的协同训练回合数K重复进行所述步骤2至步骤5,得出各未标注结点的最终分类结果。
上述方法的步骤3中,若图神经网络为图卷积网络,则图神经网络基于所述增广的有标注图结构数据集执行以下的更新公式,得到结点最终向量表示并做出全局预测,更新公式为:
Figure BDA0003128892970000045
Figure BDA0003128892970000046
Figure BDA0003128892970000047
其中,
Figure BDA0003128892970000048
是结点i在第l层图神经网络后的向量表示;X是结点特征矩阵;αij是图邻接矩阵A的第i行、第j列元素,表示结点i和结点j之间是否有连边;di是图结构数据集中结点i的度;dj是图结构数据集中结点j的度;H(0)是全体结点的初始向量表示矩阵;H(L)是全体结点在第L层图神经网络后的向量表示矩阵;
Figure BDA0003128892970000049
是根据H(L)得到的对各结点的概率预测向量;
Figure BDA00031288929700000410
是经伪标签增广后的结点标签;
Figure BDA00031288929700000411
是计算出的交叉熵损失值。
上述方法的方法步骤3中,若图神经网络为图注意力网络,则图神经网络基于所述增广的有标注图结构数据集执行以下的更新公式,得到结点最终向量表示并做出全局预测,更新公式为:
Figure BDA0003128892970000051
Figure BDA0003128892970000052
Figure BDA0003128892970000053
Figure BDA0003128892970000054
Figure BDA0003128892970000055
其中,αij表示结点j对结点i的重要程度,即注意力;LeakyReLU是激活函数名;||表示向量拼接;σ(·)是非线性激活函数;K是注意力机制的头数;
Figure BDA0003128892970000056
和W是可学习的参数矩阵;H(0)是全体结点的初始向量表示矩阵;H(L)是全体结点在第L层图神经网络后的向量表示矩阵;
Figure BDA0003128892970000057
是根据H(L)得到的对各结点的概率预测向量;
Figure BDA0003128892970000058
是经伪标签增广后的结点标签;
Figure BDA0003128892970000059
是计算出的交叉熵损失值。
本发明的方法,由于充分利用了少量标签信息对全局标签依赖性建模,结合了标签传播算法和主流图神经网络特征方法的优势,捕捉全局标签依赖信息并将其隐式融入到图神经网络训练过程中,以得到更客观的结点表示,提升模型的分类性能。
下面对本发明实施例具体作进一步地详细描述。
本发明实施例为基于图神经网络为图上的少量标签结点分类任务提供一种高性能的协同训练方法,其所涉及的技术术语包括:
半监督结点分类:给定图结构数据集,用
Figure BDA00031288929700000510
表示,其中
Figure BDA00031288929700000511
是结点的集合,假设有Nn个结点,即
Figure BDA00031288929700000512
Figure BDA00031288929700000513
是连边的集合,假设有Ne条边,即|ε|=Ne;图
Figure BDA00031288929700000514
的邻接矩阵为
Figure BDA00031288929700000515
其中aij=1表示存在点i到点j的连边,即(vi,vj)∈ε,否则aij=0;每个结点vi都有d维特征向量
Figure BDA0003128892970000061
和标签向量yi∈{0,1}C,yi是C维独热编码向量,C是结点类别数;半监督结点分类任务中,图结构数据集合S由已标注结点集合L和未标注结点集合U两部分组成,有标签结点数ml远小于无标签结点数mu;目标是根据图结构数据集
Figure BDA0003128892970000062
所有结点特征信息X和已知标签
Figure BDA0003128892970000063
学习函数
Figure BDA0003128892970000064
来预测未标注结点的标签YU
图神经网络:是当前性能最优的结点分类方法,它以少量的已知标签YL、所有结点特征X和图邻接矩阵A为输入;对于每个结点,图神经网络综合该结点及其邻居结点的特征信息,得到该结点的最终向量表示;而后每个结点的向量表示被独立地用于预测其标签,并通过将已知标签作为监督信号,执行误差反向传播以纠正结点的向量表示。
参见图1、2,本发明的整体框架示意如图2所示,该方法用于通过图神经网络为图结构数据集中的未标注结点进行分类,所述图结构数据集由已标注结点集合和未标注结点集合组成,所述已标注结点集合占所述图结构数据集的总结点数量小于等于1%,具体包括以下步骤:
步骤1,设定图神经网络的网络层数L、协同训练回合数K、标签传播迭代轮数T、标签传播权重α和协同训练每回合标注伪标签个数m,将所述图结构数据集对应的图邻接矩阵A、归一化邻接矩阵
Figure BDA0003128892970000065
结点特征矩阵X和已标注结点的已知标签YL输入所述图神经网络进行处理;
步骤2,所述图神经网络通过将已标注结点的已知标签沿着连边传播估计出每个结点初步的标签取值分布F(T),用基尼系数衡量每个结点初步的标签取值分布属于某类的初步预测置信度,作为估计出的全局标签依赖性;
所述每个结点初步的标签取值分布F(T)的迭代计算公式为:F(t)=αSF(t-1)+(1-α)Y,
Figure BDA0003128892970000066
Figure BDA0003128892970000067
表示未标注结点的初始标签,是所述图神经网络对未标注结点的预测值,第一回合协同训练中
Figure BDA0003128892970000068
为空,每一回合训练中所述公式F(t)=αSF(t-1)+(1-α)Y的迭代执行次数为T次;
步骤3,根据估计出的全局标签依赖性,为初步预测置信度最高的预定m个结点标注伪标签形成增广的有标注图结构数据集Ltrain
步骤4,利用所述图神经网络基于所述增广的有标注图结构数据集Ltrain,学习各结点的向量表示并做出全局预测,得出未标注结点的最终预测值
Figure BDA0003128892970000071
步骤5,将所述图神经网络对未标注结点的最终预测值
Figure BDA0003128892970000072
作为未标注结点的初始标签,按步骤1设定的协同训练回合数K重复进行所述步骤2至步骤5,得出各未标注结点的最终分类结果。
经过预定次数的多次迭代执行,该标签感知协同训练方法能在考虑全局标签信息依赖性同时建模局部特征信息相关性。可以看出,这种协同训练方法并不会在原有图神经网络的基础上引入额外的参数,在保障训练效率的同时达到了有效建模全局标签依赖性的目标。
下面将从标签传播和特征传播两部分对本发明的方法进行具体说明。
本发明的方法吸取了经典的标签传播算法的优点,但在标签传播过程不显式考虑结点特征信息,给定图邻接矩阵A、部分已知标签YL和上一轮图神经网络对未标注结点(即无标签数据)的预测值
Figure BDA0003128892970000073
(若是第一轮则该预测值为空),本发明方法的标签传播采用以下更新公式:
Figure BDA0003128892970000074
F(0)=Y (2)
F(t)=αSF(t-1)+(1-α)Y (3);
其中,α∈[0,1]是预设的超参数,以平衡对迭代结果相对初始标签值的重要程度;S是对A的归一化邻接矩阵,本发明采用表现最好的随机游走归一化方式,即
Figure BDA0003128892970000075
更新公式F(t)=αSF(t-1)+(1-α)Y被迭代执行T次,最终矩阵F(T)即体现了各结点的标签依赖情况;
通过计算各结点标签依赖的基尼系数,得到各结点的预测置信度;
实验表明,标签传播过程以及置信度计算可以找出边界点,即连接了不同类标签的结点,这对于结点分类、社区发现等任务是有帮助的。
特征传播:
图神经网络的本质是局部特征传播,L层的图神经网络将结点的特征信息在它的L跳邻居内传播。在特征传播步骤,本发明方法并不改动图神经网络模型,使得该训练方法(即LACING)可以广泛地适用于现存的大多数图神经网络模型,特征传播阶段以图邻接矩阵A、结点特征X和已知标签YL为输入,采用以下更新公式进行更新:
Figure BDA0003128892970000081
Figure BDA0003128892970000082
Figure BDA0003128892970000083
其中,
Figure BDA0003128892970000084
是对称归一化的邻接矩阵,即
Figure BDA0003128892970000085
Figure BDA0003128892970000086
Θ是训练过程中的可学习参数矩阵;σ(·)是非线性激活函数,此处采用ReLU。以上是一般的特征传播过程,在实际中可替换为GCN、GAT、SGC等知名的图神经网络模型。
本发明的协同训练方法(即LACING)能充分利用少量标签信息来建模全局标签依赖性,结合了标签传播算法和主流图神经网络的特征传播算法的优势,捕捉全局标签依赖信息并将其隐式融入到图神经网络训练过程中,以得到更客观的结点表示,提升模型的分类性能。经实验验证有效。在没有引入额外的参数和计算量的情况下,该方法在主流半监督结点分类数据集上达到了优越性能,超越了之前方法的性能表现,实验结果见图1所示。根据前文的分析和说明,实验性能的提升来源于对全局标签依赖信息的考虑。
表1是本发明方法与现有方法在3个结点分类数据集上的性能对比
Figure BDA0003128892970000091
上述表1中结果是运行十次取平均的精度(%)结果。
Figure BDA0003128892970000101
以上是本发明方法的伪代码流程,根据伪代码可以复现前文的实验。
本发明的训练方法可以应用于现有的多种图神经网络模型,下面以应用于不同类型的图神经网络的实施例对该方法进行具体说明。
实施例
本实施例1是面向图神经网络的标签感知协同训练方法,是一种以图卷积网络(即GCN)为特征传播的主干模型(可称为LGCN)进行的标签感知协同训练方法,参考上述伪代码流程,该方法包括:设定协同训练回合数K、GCN的网络层数L、标签传播迭代轮数T、标签传播权重α、协同训练每回合打伪标签个数m,输入图邻接矩阵A、结点特征矩阵X和部分已知标签YL,在每回合协同训练中,首先,本实施例的方法将图神经网络的预测值作为未标注结点的初始标签(第一回合为空),并按照以下三步公式进行更新(其中第三步公式迭代执行T次):
Figure BDA0003128892970000111
F(0)=Y
F(t)=αSF(t-1)+(1-α)Y;
由此实现将少量标签信息沿着连边传播来估计出全局标签依赖性,即得到每个结点初步的标签取值分布F(T);再通过基尼系数衡量每个结点初步的标签取值分布F(T)属于某类的预测置信度,为预测置信度最高的m个结点打上伪标签,进而得到增广的有标注数据集Ltrain;最后,图卷积网络基于该增广的有标注图结构数据集Ltrain,执行以下的更新公式:
H(0)=X
Figure BDA0003128892970000112
Figure BDA0003128892970000113
Figure BDA0003128892970000114
得到结点最终向量表示并做出全局预测;
按上述各步骤重复K个回合的协同训练后,即得出最终的分类结果。
实施例2
本实施例1是面向图神经网络的标签感知协同训练方法,是一种以图注意力网络(即GAT)为特征传播的主干模型(可称为LGAT)进行的标签感知协同训练方法,参考上述伪代码流程,该方法包括:设定协同训练轮数K、GAT的网络层数L、标签传播迭代轮数T、标签传播权重α、协同训练每回合打伪标签个数m,输入图邻接矩阵A、结点特征矩阵X和部分已知标签YL;在每回合协同训练中,首先,本实施例的方法将图神经网络的预测值作为未标注结点的初始标签(第一回合为空),并按照以下三步公式进行更新(其中第三步公式迭代执行T次):
Figure BDA0003128892970000121
F(0)=Y
F(t)=αSF(t-1)+(1-α)Y;
由此实现利用少量标签信息沿着连边传播来估计出全局标签依赖性,即得到每个结点初步的标签取值分布F(T),再用基尼系数衡量每个结点初步的标签取值分布F(T)属于某类的预测置信度,为预测置信度最高的m个结点打上伪标签,从而得到增广的有标注数据集Ltrain;最后,图注意力网络基于该增广的有标注图结构数据集Ltrain,执行以下的更新公式:
H(0)=X
Figure BDA0003128892970000122
Figure BDA0003128892970000123
Figure BDA0003128892970000124
Figure BDA0003128892970000125
Figure BDA0003128892970000126
得到结点最终向量表示并做出全局预测;
按上述各步骤重复K个回合的协同训练后,即得出最终的分类结果。
本领域普通技术人员可以理解:实现上述实施例方法中的全部或部分流程是可以通过程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,所述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)或随机存储记忆体(Random Access Memory,RAM)等。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。

Claims (3)

1.一种面向图神经网络的标签感知协同训练方法,其特征在于,用于通过图神经网络为图结构数据集中的未标注结点进行分类,所述图结构数据集由已标注结点集合和未标注结点集合组成,所述已标注结点集合占所述图结构数据集的总结点数量小于等于1%,包括:
步骤1,设定图神经网络的网络层数L、协同训练回合数K、标签传播迭代轮数T、标签传播权重α和协同训练每回合标注伪标签个数m,将所述图结构数据集对应的图邻接矩阵A、归一化邻接矩阵
Figure FDA0003128892960000011
结点特征矩阵X和已标注结点的已知标签YL输入所述图神经网络进行处理;
步骤2,所述图神经网络通过将已标注结点的已知标签沿着连边传播估计出每个结点初步的标签取值分布F(T),用基尼系数衡量每个结点初步的标签取值分布属于某类的初步预测置信度;
所述每个结点初步的标签取值分布F(T)的迭代计算公式为:F(t)=αSF(t-1)+(1-α)Y,t取值为1至T,F(0)=Y,
Figure FDA0003128892960000012
Figure FDA0003128892960000013
表示未标注结点的初始标签,是所述图神经网络对未标注结点的预测值,第一回合协同训练中
Figure FDA0003128892960000014
为空,每一回合训练中所述公式F(t)=αSF(t -1)+(1-α)Y的迭代执行次数为T次;
步骤3,为初步预测置信度最高的预定m个结点标注伪标签形成增广的有标注图结构数据集Ltrain
步骤4,利用所述图神经网络基于所述增广的有标注图结构数据集Ltrain,学习各结点的向量表示并做出全局预测,得出未标注结点的最终预测值
Figure FDA0003128892960000015
步骤5,将所述图神经网络对未标注结点的最终预测值
Figure FDA0003128892960000016
作为未标注结点的初始标签,按所述步骤1设定的协同训练回合数K重复进行所述步骤2至步骤5,得出各未标注结点的最终分类结果。
2.根据权利要求1所述的面向图神经网络的标签感知协同训练方法,其特征在于,所述方法的步骤3中,若图神经网络为图卷积网络,则图神经网络基于所述增广的有标注图结构数据集执行以下的更新公式,得到结点最终向量表示并做出全局预测,更新公式为:
Figure FDA0003128892960000021
Figure FDA0003128892960000022
Figure FDA0003128892960000023
其中,
Figure FDA0003128892960000024
是结点i在第l层图神经网络后的向量表示;X是结点特征矩阵;αij是图邻接矩阵A的第i行、第j列元素,表示结点i和结点j之间是否有连边;di是图结构数据集中结点i的度;dj是图结构数据集中结点j的度;H(0)是全体结点的初始向量表示矩阵;H(L)是全体结点在第L层图神经网络后的向量表示矩阵;
Figure FDA0003128892960000025
是根据H(L)得到的对各结点的概率预测向量;
Figure FDA0003128892960000026
是经伪标签增广后的结点标签;
Figure FDA0003128892960000027
是计算出的交叉熵损失值。
3.根据权利要求1所述的面向图神经网络的标签感知协同训练方法,其特征在于,所述方法步骤3中,若图神经网络为图注意力网络,则图神经网络基于所述增广的有标注图结构数据集执行以下的更新公式,得到结点最终向量表示并做出全局预测,更新公式为:
H(0)=X
Figure FDA0003128892960000028
Figure FDA0003128892960000029
Figure FDA00031288929600000210
Figure FDA00031288929600000211
Figure FDA00031288929600000212
其中,αij表示结点j对结点i的重要程度,即注意力;LeakyReLU是激活函数名;||表示向量拼接;σ(·)是非线性激活函数;K是注意力机制的头数;
Figure FDA0003128892960000031
和W是可学习的参数矩阵;H(0)是全体结点的初始向量表示矩阵;H(L)是全体结点在第L层图神经网络后的向量表示矩阵;
Figure FDA0003128892960000032
是根据H(L)得到的对各结点的概率预测向量;
Figure FDA0003128892960000033
是经伪标签增广后的结点标签;
Figure FDA0003128892960000034
是计算出的交叉熵损失值。
CN202110697015.8A 2021-06-23 2021-06-23 一种面向图神经网络的标签感知协同训练方法 Pending CN113361627A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110697015.8A CN113361627A (zh) 2021-06-23 2021-06-23 一种面向图神经网络的标签感知协同训练方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110697015.8A CN113361627A (zh) 2021-06-23 2021-06-23 一种面向图神经网络的标签感知协同训练方法

Publications (1)

Publication Number Publication Date
CN113361627A true CN113361627A (zh) 2021-09-07

Family

ID=77535986

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110697015.8A Pending CN113361627A (zh) 2021-06-23 2021-06-23 一种面向图神经网络的标签感知协同训练方法

Country Status (1)

Country Link
CN (1) CN113361627A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113780584A (zh) * 2021-09-28 2021-12-10 京东科技信息技术有限公司 标签预测方法、设备、存储介质及程序产品
CN113807247A (zh) * 2021-09-16 2021-12-17 清华大学 基于图卷积网络的行人重识别高效标注方法及装置
CN116032665A (zh) * 2023-03-28 2023-04-28 北京芯盾时代科技有限公司 一种网络群体的发现方法、装置、设备及存储介质
CN116127386A (zh) * 2023-04-19 2023-05-16 浪潮电子信息产业股份有限公司 一种样本分类方法、装置、设备和计算机可读存储介质
WO2023221592A1 (zh) * 2022-05-20 2023-11-23 腾讯科技(深圳)有限公司 模型协同训练方法及相关装置

Cited By (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113807247A (zh) * 2021-09-16 2021-12-17 清华大学 基于图卷积网络的行人重识别高效标注方法及装置
CN113807247B (zh) * 2021-09-16 2024-04-26 清华大学 基于图卷积网络的行人重识别高效标注方法及装置
CN113780584A (zh) * 2021-09-28 2021-12-10 京东科技信息技术有限公司 标签预测方法、设备、存储介质及程序产品
CN113780584B (zh) * 2021-09-28 2024-03-05 京东科技信息技术有限公司 标签预测方法、设备、存储介质
WO2023221592A1 (zh) * 2022-05-20 2023-11-23 腾讯科技(深圳)有限公司 模型协同训练方法及相关装置
CN116032665A (zh) * 2023-03-28 2023-04-28 北京芯盾时代科技有限公司 一种网络群体的发现方法、装置、设备及存储介质
CN116032665B (zh) * 2023-03-28 2023-06-30 北京芯盾时代科技有限公司 一种网络群体的发现方法、装置、设备及存储介质
CN116127386A (zh) * 2023-04-19 2023-05-16 浪潮电子信息产业股份有限公司 一种样本分类方法、装置、设备和计算机可读存储介质
CN116127386B (zh) * 2023-04-19 2023-08-08 浪潮电子信息产业股份有限公司 一种样本分类方法、装置、设备和计算机可读存储介质

Similar Documents

Publication Publication Date Title
CN109919108B (zh) 基于深度哈希辅助网络的遥感图像快速目标检测方法
CN113361627A (zh) 一种面向图神经网络的标签感知协同训练方法
CN110717526B (zh) 一种基于图卷积网络的无监督迁移学习方法
CN109800692B (zh) 一种基于预训练卷积神经网络的视觉slam回环检测方法
CN113361334B (zh) 基于关键点优化和多跳注意图卷积行人重识别方法及***
CN110569901A (zh) 一种基于通道选择的对抗消除弱监督目标检测方法
CN108537264B (zh) 基于深度学习的异源图像匹配方法
CN113326731A (zh) 一种基于动量网络指导的跨域行人重识别算法
Ou et al. Multi-label zero-shot learning with graph convolutional networks
CN112364747B (zh) 一种有限样本下的目标检测方法
CN113128667B (zh) 一种跨域自适应的图卷积平衡迁移学习方法与***
CN113255366B (zh) 一种基于异构图神经网络的方面级文本情感分析方法
CN111898665A (zh) 基于邻居样本信息引导的跨域行人再识别方法
CN114881125A (zh) 基于图一致性和半监督模型的标签含噪图像分类方法
CN116010813A (zh) 基于图神经网络融合标签节点影响度的社区检测方法
Das et al. Group incremental adaptive clustering based on neural network and rough set theory for crime report categorization
Li et al. Transductive distribution calibration for few-shot learning
CN117131348B (zh) 基于差分卷积特征的数据质量分析方法及***
Christilda et al. Enhanced hyperspectral image segmentation and classification using K-means clustering with connectedness theorem and swarm intelligent-BiLSTM
CN110717402B (zh) 一种基于层级优化度量学习的行人再识别方法
CN116912576A (zh) 基于脑网络高阶结构的自适应图卷积脑疾病分类方法
CN115457345A (zh) 一种利用基于Graphormer的上下文推理网络进行图片预测分类的方法
CN114882279A (zh) 基于直推式半监督深度学习的多标签图像分类方法
CN112307914B (zh) 一种基于文本信息指导的开放域图像内容识别方法
CN113779520A (zh) 基于多层属性分析的跨空间目标虚拟身份关联方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20210907

RJ01 Rejection of invention patent application after publication