CN117556866A - 一种无源域图的数据域适应网络构建方法 - Google Patents
一种无源域图的数据域适应网络构建方法 Download PDFInfo
- Publication number
- CN117556866A CN117556866A CN202410028518.XA CN202410028518A CN117556866A CN 117556866 A CN117556866 A CN 117556866A CN 202410028518 A CN202410028518 A CN 202410028518A CN 117556866 A CN117556866 A CN 117556866A
- Authority
- CN
- China
- Prior art keywords
- domain
- target
- subdomain
- graph
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 230000006978 adaptation Effects 0.000 title claims abstract description 22
- 238000010276 construction Methods 0.000 title abstract description 16
- 238000010586 diagram Methods 0.000 title description 4
- 238000000034 method Methods 0.000 claims abstract description 49
- 238000012549 training Methods 0.000 claims abstract description 40
- 230000003044 adaptive effect Effects 0.000 claims abstract description 25
- 230000004927 fusion Effects 0.000 claims abstract description 23
- 238000013528 artificial neural network Methods 0.000 claims abstract description 18
- 238000009826 distribution Methods 0.000 claims abstract description 16
- 239000000203 mixture Substances 0.000 claims abstract description 9
- 238000005457 optimization Methods 0.000 claims abstract description 7
- 230000006870 function Effects 0.000 claims description 17
- 230000009977 dual effect Effects 0.000 claims description 14
- 238000004821 distillation Methods 0.000 claims description 11
- 238000004364 calculation method Methods 0.000 claims description 6
- 230000008447 perception Effects 0.000 claims description 4
- 238000012795 verification Methods 0.000 claims description 4
- 239000011159 matrix material Substances 0.000 claims description 3
- 230000008569 process Effects 0.000 abstract description 12
- 230000001149 cognitive effect Effects 0.000 abstract description 4
- 238000009825 accumulation Methods 0.000 abstract description 3
- 238000007418 data mining Methods 0.000 abstract description 2
- 238000012360 testing method Methods 0.000 description 4
- 230000005540 biological transmission Effects 0.000 description 3
- 108020001568 subdomains Proteins 0.000 description 3
- 238000004458 analytical method Methods 0.000 description 2
- 230000002860 competitive effect Effects 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 239000000463 material Substances 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 230000002159 abnormal effect Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 230000019771 cognition Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000003631 expected effect Effects 0.000 description 1
- 239000002360 explosive Substances 0.000 description 1
- 238000007499 fusion processing Methods 0.000 description 1
- 238000010438 heat treatment Methods 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000007500 overflow downdraw method Methods 0.000 description 1
- 102000004169 proteins and genes Human genes 0.000 description 1
- 108090000623 proteins and genes Proteins 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/09—Supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及图数据挖掘技术领域,提供一种无源域图的数据域适应网络构建方法。包括:通过图神经网络对目标图进行预测,获得软标签预测结果;通过软标签预测结果对双学生网络进行蒸馏预训练,获得双学生网络模型;通过高斯混合模型拟合双学生网络模型输出的多个节点的损失值,获得拟合值并将双学生网络模型的输出结果划分为源域相似子域及目标特定子域;对源域相似子域及目标特定子域进行拓扑感知数据融合,获得源域相似子域节点的硬标签预测结果;基于软标签预测结果及硬标签预测结果对双学生网络模型进行迭代训练优化,获得数据域适应网络。本发明能够获得代表全图数据分布的高质量训练样本,还降低了模型训练过程中认知偏差的积累。
Description
技术领域
本发明涉及图数据挖掘技术领域,尤其涉及一种无源域图的数据域适应网络构建方法。
背景技术
大数据时代使生产生活中产生和记录的人类数据呈现***式增长,由于图结构对实体和关系有强大的表示能力,将复杂***的数据抽象为图形进行分析已成为一种通用做法,图数据广泛应用在交通***、社交网络、电子商务、蛋白质反应关系分析等领域。节点分类是图上的一个基本任务,给定一个图,通过包括图神经网络在内的技术对图进行建模,预测图上的节点类别。节点分类在现实问题中应用广泛,例如将文本间关系建模成图然后进行文本分类以及在金融网络中实现异常用户的检测等问题就可以抽象成图上的节点分类问题。
当前对节点分类这一任务的基本研究思路为对给定的图针对性地训练一个模型,然后应用此模型进行节点类别的预测。但由于有良好标注信息的图数据匮乏以及图神经网络训练困难等问题的存在,这类范式所需成本开销巨大。一个基本的通用思路为事先预训练好一个图神经网络模型,然后再对下游的未标注图数据直接预测给出节点分类结果。但这类方法忽略了预训练模型使用的源域图数据和下游的目标域未标注图数据之间的分布差异,往往难以得到预期效果。另一个改进思路是为了减小这种分布差异,同时使用完全标注的源域图数据和未标注目标域图数据对模型进行训练,设计相应的域对齐约束,但这种方法忽视了大量源域数据传输带来的传输负担和潜在的隐私风险,无法获得源域图数据是一种更常见的应用场景。
发明内容
本发明旨在至少解决相关技术中存在的技术问题之一。为此,本发明提供一种无源域图的数据域适应网络构建方法。
本发明提供一种无源域图的数据域适应网络构建方法,包括:
S1:通过线上图神经网络对待预测的目标图进行预测,获得软标签预测结果;
S2:通过所述软标签预测结果对双学生网络进行蒸馏预训练,获得双学生网络模型;
S3:通过高斯混合模型拟合所述双学生网络模型输出的多个节点的损失值,获得拟合值,并根据所述拟合值将所述双学生网络模型的输出结果划分为源域相似子域及目标特定子域;
S4:对所述源域相似子域及所述目标特定子域进行拓扑感知数据融合,获得源域相似子域节点的硬标签预测结果;
S5:基于所述软标签预测结果及所述硬标签预测结果对所述双学生网络模型进行迭代训练优化,获得数据域适应网络。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S1中的所述目标图的无源域适应特性包括:
源域图数据及源域模型的模型参数均不可接触;
源域模型的查询接口可接触,且所述查询接口能够对目标图中存在分布漂移的下游目标域上的图数据中的节点进行类别预测。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S1中,由所述线上图神经网络的查询接口,对待预测的目标图进行预测。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S1中的所述软标签预测结果的表达式为:
;
其中,为软标签预测结果,/>为只提供API接口供查询的线上图神经网络,/>为目标图的目标域上的图数据的邻接矩阵,/>为目标图上的节点特征张量。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S3中,计算获得所述损失值的损失函数的表达式为:
;
;
;
;其中,/>为计算获得双学生网络模型输出的多个节点的损失值的损失函数,/>为蒸馏损失,/>为互信息损失,/>为目标图中的节点索引值,/>为目标图中的节点的集合,/>为目标图中节点/>的数学期望,/>为Kullback-Leibler散度,/>为双学生网络模型对节点/>的预测结果,/>节点/>的软标签,/>为转置操作,/>表示信息熵计算。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S3中所述拟合值的表达式为:
;
;
其中,为包括两个单元的高斯混合模型,/>为节点/>的损失值,/>为拟合值,/>为节点/>的硬标签预测结果的独热编码表示,/>为分类类别总数,/>为分类类别个数索引值。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S3中,
所述源域相似子域包括与双学生网络模型中的第一学生网络对应的第一源域相似子域及与双学生网络模型中的第二学生网络对应的第二源域相似子域;
所述目标特定子域包括与双学生网络模型中的第一学生网络对应的第一目标特定子域及与双学生网络模型中的第二学生网络对应的第二目标特定子域。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S4进一步包括:
S41:对所述目标图进行卷积,获得所述目标图中多个节点的第一嵌入表示;
S42:对于每个所述源域相似子域的节点的第一嵌入表示,均选择一个所述目标特定子域的节点的第一嵌入表示,融合初始输入特征并逐层进行拓扑感知数据融合,获得多个所述源域相似子域中的节点的第二嵌入表示;
S43:将所述源域相似子域的伪标签与所述目标特定子域的伪标签进行融合,获得融合结果;
S44:通过所述第二嵌入表示对所述融合结果中的融合后的伪标签进行配合训练,获得硬标签预测结果。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S4中的所述硬标签预测结果的交叉熵损失函数的表达式为:
;
其中,为硬标签预测结果的交叉熵损失函数,/>为源域相似子域的节点索引值,/>为源域相似子域,/>为目标特定子域的节点索引值,/>为目标特定子域,/>为源域相似子域的节点/>的数学期望,/>为目标特定子域的节点/>的数学期望,/>为以/>为超参的贝塔分布中产生的系数,/>为节点/>的交叉熵损失,/>为节点/>的交叉熵损失。
根据本发明提供的一种无源域图的数据域适应网络构建方法,步骤S5还包括:
S51:通过所述目标图中的验证集对迭代训练优化的双学生网络模型进行评估,直至双学生网络模型针对目标图的节点获得的预测指标稳定,停止对双学生网络模型进行迭代训练优化,获得数据域适应网络。
本发明提供的一种无源域图的数据域适应网络构建方法,针对当前预训练好的神经网络模型对目标域的图上节点预测忽略了分布外泛化、现实应用中源域图数据和源域模型参数无法获取的问题,在下游目标域上图数据中测试的时候自适应学习一个新的模型进行预测,摒弃对源域图数据和源域模型参数的需求的同时缓解分布外表现差的问题。
首先,从线上模型查询接口获取目标域图数据上节点类别预测结果的软标签,随机初始化目标域模型,即包含两个图神经网络的双学生模型,使用软标签分别对两个学生网络模型进行蒸馏预热;其次,预热到一定的训练轮次后,两个学生网络分别基于逐节点损失使用高斯混合模型进行拟合,分别将目标域上的节点划分为源域相似和目标域特定的两个子域,两个学生网络互相交换划分结果以减小模型训练过程中模型认知偏差;再次,对划分的两个子域之间进行拓扑感知数据融合,避免只针对源域相似子域进行有监督训练带来的选择偏差;然后,软标签蒸馏和目标域上针对源域相似子域硬标签有监督训练两种训练方式迭代进行,不断优化模型;最后,训练至模型收敛,直接给出双学生网络模型对测试集上的节点类别预测结果。
基于双学生的软标签蒸馏算法,基于本发明提供的方法得到的双学生网络的无源域图数据域适应框架DS-GDA能够快速获取线上源域模型中的知识,从而避免对源域图数据和源域模型参数的需求;基于目标域图上节点划分成子域再进行子域间拓扑融合,DS-GDA能够自适应地针对目标域上的数据进行优化,并获取有干净标签的高质量样本同时避免子域间的分布偏差;为降低模型训练过程的认知偏差,DS-GDA采用了双学生架构,两个图神经网络互相提供给对方子域划分结果。与本领域现有技术相比,本发明能够摆脱对源域图数据和源域模型参数的依赖,仅仅使用源于模型的软标签预测输出结果,更加符合实际应用场景;同时基于子域划分和子域间数据融合,能够获取更能代表全图数据分布的高质量训练样本;此外双学生架构进一步降低了模型训练过程中认知偏差的积累。
本发明的附加方面和优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
为了更清楚地说明本发明或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明提供的一种无源域图的数据域适应网络构建方法流程图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,下面将结合本发明中的附图,对本发明中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。以下实施例用于说明本发明,但不能用来限制本发明的范围。
在本发明实施例的描述中,需要说明的是,术语“中心”、“纵向”、“横向”、“上”、“下”、“前”、“后”、“左”、“右”、“竖直”、“水平”、“顶”、“底”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明实施例和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明实施例的限制。此外,术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性。
在本发明实施例的描述中,需要说明的是,除非另有明确的规定和限定,术语“相连”、“连接”应做广义理解,例如,可以是固定连接,也可以是可拆卸连接,或一体连接;可以是机械连接,也可以是电连接;可以是直接相连,也可以通过中间媒介间接相连。对于本领域的普通技术人员而言,可以具体情况理解上述术语在本发明实施例中的具体含义。
在本发明实施例中,除非另有明确的规定和限定,第一特征在第二特征“上”或“下”可以是第一和第二特征直接接触,或第一和第二特征通过中间媒介间接接触。而且,第一特征在第二特征“之上”、“上方”和“上面”可是第一特征在第二特征正上方或斜上方,或仅仅表示第一特征水平高度高于第二特征。第一特征在第二特征“之下”、“下方”和“下面”可以是第一特征在第二特征正下方或斜下方,或仅仅表示第一特征水平高度小于第二特征。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明实施例的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不必须针对的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任一个或多个实施例或示例中以合适的方式结合。此外,在不相互矛盾的情况下,本领域的技术人员可以将本说明书中描述的不同实施例或示例以及不同实施例或示例的特征进行结合和组合。
下面结合图1描述本发明的实施例。
本发明提供一种无源域图的数据域适应网络构建方法,包括:
S1:通过线上图神经网络对待预测的目标图进行预测,获得软标签预测结果;
其中,步骤S1中的所述目标图的无源域适应特性包括:
源域图数据及源域模型的模型参数均不可接触;
源域模型的查询接口可接触,且所述查询接口能够对目标图中存在分布漂移的下游目标域上的图数据中的节点进行类别预测。
进一步的,无源域图的适应问题,定义为:给定一个从源域中的图形训练的预训练GNN模型和一个未标记的目标域上的图,无源图领域自适应的目标是提高目标域图上的下游任务,下游任务可以是节点分类、链接预测、图分类等等,预训练的GNN模型仅用于响应查询请求并提供结果,模型的参数不可访问,基本假设是源图和目标图相关,但是生成自不同的分布。
其中,步骤S1中,由所述线上图神经网络的查询接口,对待预测的目标图进行预测。
其中,步骤S1中的所述软标签预测结果的表达式为:
;
其中,为软标签预测结果,/>为只提供API接口供查询的线上图神经网络,/>为目标图的目标域上的图数据的邻接矩阵,/>为目标图上的节点特征张量。
S2:通过所述软标签预测结果对双学生网络进行蒸馏预训练,获得双学生网络模型;
进一步的,本阶段的主要目标是输入目标域图数据给线上源域图模型查询接口获取其对目标域图数据上节点分类预测的软标签结果,以及使用这些软标签对目标域模型进行蒸馏预热以快速获取源域模型知识。
S3:通过高斯混合模型拟合所述双学生网络模型输出的多个节点的损失值,获得拟合值,并根据所述拟合值将所述双学生网络模型的输出结果划分为源域相似子域及目标特定子域;
进一步的,本阶段的目标是为了进一步从目标域的图数据上学习到新信息,将目标域上的图中的节点根据学习难易程度进行划分,具体来说,按照目标域上的节点根据其与源域模型给出的预测结果伪标签计算出的损失值对节点进行划分为两个子域,分别为源域相似子域和目标域特定子域,将源域相似子域中的伪标签视作干净的标签。
具体过程为:(1)首先双学生分别给出损失值的计算结果:框架中两个独立(不共享参数)的图神经网络分别给出对目标图上的节点进行预测,分别得到预测结果,并计算节点损失;(2)根据逐节点损失值可以目标图上的所有节点划分为两个子域,基于神经网络先拟合正确样本后拟合困难样本的记忆特性,首先使用高斯混合函数进行拟合。
其中,步骤S3中,计算获得所述损失值的损失函数的表达式为:
;
;
;
;
其中,为计算获得双学生网络模型输出的多个节点的损失值的损失函数,为蒸馏损失,/>为互信息损失,/>为目标图中的节点索引值,/>为目标图中的节点的集合,/>为目标图中节点/>的数学期望,/>为Kullback-Leibler散度,/>为双学生网络模型对节点/>的预测结果,/>节点/>的软标签,/>为转置操作,/>表示信息熵计算。
进一步的,双学生模型中包括两个作为学生的图神经网络,它们将分别进行预热蒸馏,蒸馏过程的核心部分在于损失函数的计算,该过程中使用的损失函数可以表示为上述的计算获得所述损失值的损失函数,其可以用来缓解模型偏向预测类别较多样本的程度。
其中,步骤S3中所述拟合值的表达式为:
;
;
其中,为包括两个单元的高斯混合模型,/>为节点/>的损失值,/>为拟合值,/>为节点/>的硬标签预测结果的独热编码表示,/>为分类类别总数,/>为分类类别个数索引值。
进一步的,上式中的拟合值表示节点划分到均值更小的高斯函数的概率值组成的张量,将概率值进行排名,获取概率值排名占比前百分之的节点作为源域相似子域,其余的所有节点视作目标域特定子域。
其中,步骤S3中,
所述源域相似子域包括与双学生网络模型中的第一学生网络对应的第一源域相似子域及与双学生网络模型中的第二学生网络对应的第二源域相似子域;
所述目标特定子域包括与双学生网络模型中的第一学生网络对应的第一目标特定子域及与双学生网络模型中的第二学生网络对应的第二目标特定子域。
进一步的,两个学生网络交换彼此的划分结果,即第一学生网络处理第二学生网络划分的结果第二源域相似子域和第二目标特定子域,第二学生网络处理第一学生网络划分的结果第一源域相似子域和第一目标特定子域,将源域模型对源域相似子域中的节点的预测结果的硬标签视作干净标签。
S4:对所述源域相似子域及所述目标特定子域进行拓扑感知数据融合,获得源域相似子域节点的硬标签预测结果;
进一步的,本阶段的目标是为了缓解模型只学习目标域图上的划分结果的源域相似节点导致的偏差问题,使用目标域特定的子域中的节点样本和源域相似节点进行数据融合,为了更好地适应图神经网络地消息传递架构,采用的数据融合方法进一步考虑了拓扑信息。
其中,步骤S4进一步包括:
S41:对所述目标图进行卷积,获得所述目标图中多个节点的第一嵌入表示;
S42:对于每个所述源域相似子域的节点的第一嵌入表示,均选择一个所述目标特定子域的节点的第一嵌入表示,融合初始输入特征并逐层进行拓扑感知数据融合,获得多个所述源域相似子域中的节点的第二嵌入表示;
S43:将所述源域相似子域的伪标签与所述目标特定子域的伪标签进行融合,获得融合结果;
S44:通过所述第二嵌入表示对所述融合结果中的融合后的伪标签进行配合训练,获得硬标签预测结果。
进一步的,步骤S4的具体过程为:(1)数据融合的第一个阶段为获取逐层的节点嵌入表示,经过一般的图卷积操作获取目标图上的所有节点的嵌入表示,每一层用来更新并获取节点的嵌入表示;(2)数据融合的第二阶段为进行子域间节点数据融合,对每一个来自于源域相似子域的节点,均选择一个来自于目标域特定子域的节点进行数据融合处理;(3)基于步骤前述步骤共两阶段的图卷积操作,可以得到源域相似子域中节点的最终预测输出,另外对源域相似子域中的节点的输入特征和伪标签也更新为和来自目标特定子域中的节点的融合结果,于是可以对源域相似子域中的节点采用融合后的伪标签进行有监督训练。
其中,步骤S4中的所述硬标签预测结果的交叉熵损失函数的表达式为:
;
其中,为硬标签预测结果的交叉熵损失函数,/>为源域相似子域的节点索引值,/>为源域相似子域,/>为目标特定子域的节点索引值,/>为目标特定子域,/>为源域相似子域的节点/>的数学期望,/>为目标特定子域的节点/>的数学期望,/>为以/>为超参的贝塔分布中产生的系数,/>为节点/>的交叉熵损失,/>为节点/>的交叉熵损失。
S5:基于所述软标签预测结果及所述硬标签预测结果对所述双学生网络模型进行迭代训练优化,获得数据域适应网络。
进一步的,本阶段是每一个训练轮次中先根据进行蒸馏训练,再根据/>进行硬标签训练,不断迭代训练优化双学生网络模型,这种训练方式能够有效抽取线上源域模型中的知识,同时也能从目标域图数据中学习到更多信息。
其中,步骤S5还包括:
S51:通过所述目标图中的验证集对迭代训练优化的双学生网络模型进行评估,直至双学生网络模型针对目标图的节点获得的预测指标稳定,停止对双学生网络模型进行迭代训练优化,获得数据域适应网络。
进一步的,训练直到模型收敛以及验证集上的分类指标持续多个轮次都不再上升,使用双学生网络模型的预测的融合结果作为目标域图上节点分类结果。
下面对本发明提供一种无源域图的数据域适应网络构建方法训练得到的基于双学生网络的无源域图数据域适应框架DS-GDA进行有效性验证,实验结果表明,本发明相比于其他方法,尽管不接触源域数据和源域模型参数,在多个数据集上仍然能取得有竞争力或更好的结果:
本发明在跨域迁移和时间漂移的图数据集上进行了测试,跨域迁移图数据集Twitch-e中包含多个地区的社交网络数据,自然地会存在地区之间的分布差异,任务为对图上节点进行二分类,源域模型在其中的某个图上进行预训练,在其余的图上进行测试;时间漂移的图数据集使用了Elliptic和OGB-Arxiv,对原始数据集按照时间片进行划分,每个数据集都划分出了多个图,任务为对图上节点进行二分类或多分类,源域模型在时间靠前的时间片的图上进行预训练得到,将时间靠后的时间片上的图视作目标域,用来评测结果。实验对比的现有方法包括直接使用源域模型进行目标域图数据预测的方法、在源域模型训练过程提高模型泛化能力的方法、能够显式接触到源域数据的方法、能够显式接触到源域模型参数的方法。
实验结果显示,本发明提出的框架DS-GDA相比于其他的相关方法,均能取得有竞争力的表现或最佳表现。在跨域图节点分类中,通常与不进行域适应的源域模型直接预测的方法相比,在所有数据集上性能约提高4%;在随时间分布漂移的数据集上,与不进行域适应的方法相比,分别能提升10.16%和5.25%,这表明DS-GDA在分布外泛化方面表现良好。同时需要说明的是,DS-GDA在过程中并不能接触到源域图数据和源域模型参数,因此相比于能够显示接触到源域数据或源域模型参数的方法,DS-GDA的限制会更强,但仍然能够取得优异的表现,这说明了DS-GDA在无源域的问题设置下的有效性,为证明模型各个模块的有效性。
综上所述,本发明提供一种无源域图的数据域适应网络构建方法,通过双学生网络蒸馏架构,并将目标域数据进行子域划分同时进行子域数据融合以避免选择偏差,能够充分利用源域模型返回的软标签结果与目标域未标注图数据,很好地处理不接触源域图数据和源域模型参数的场景下,将源域模型有效泛化到目标域图上节点分类任务上。
本发明能够摆脱对源域图数据和源域模型参数的依赖,仅仅使用源于模型的软标签预测输出结果,更加符合实际应用场景;同时基于子域划分和子域间数据融合,能够获取更能代表全图数据分布的高质量训练样本;此外双学生架构进一步降低了模型训练过程中认知偏差的积累。
最后应说明的是:以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (10)
1.一种无源域图的数据域适应网络构建方法,其特征在于,包括:
S1:通过线上图神经网络对待预测的目标图进行预测,获得软标签预测结果;
S2:通过所述软标签预测结果对双学生网络进行蒸馏预训练,获得双学生网络模型;
S3:通过高斯混合模型拟合所述双学生网络模型输出的多个节点的损失值,获得拟合值,并根据所述拟合值将所述双学生网络模型的输出结果划分为源域相似子域及目标特定子域;
S4:对所述源域相似子域及所述目标特定子域进行拓扑感知数据融合,获得源域相似子域节点的硬标签预测结果;
S5:基于所述软标签预测结果及所述硬标签预测结果对所述双学生网络模型进行迭代训练优化,获得数据域适应网络。
2.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S1中的所述目标图的无源域适应特性包括:
源域图数据及源域模型的模型参数均不可接触;
源域模型的查询接口可接触,且所述查询接口能够对目标图中存在分布漂移的下游目标域上的图数据中的节点进行类别预测。
3.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S1中,由所述线上图神经网络的查询接口,对待预测的目标图进行预测。
4.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S1中的所述软标签预测结果的表达式为:
;
其中,为软标签预测结果,/>为只提供API接口供查询的线上图神经网络,/>为目标图的目标域上的图数据的邻接矩阵,/>为目标图上的节点特征张量。
5.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S3中,计算获得所述损失值的损失函数的表达式为:
;
;
;
;
其中,为计算获得双学生网络模型输出的多个节点的损失值的损失函数,/>为蒸馏损失,/>为互信息损失,/>为目标图中的节点索引值,/>为目标图中的节点的集合,/>为目标图中节点/>的数学期望,/>为Kullback-Leibler散度,/>为双学生网络模型对节点的预测结果,/>节点/>的软标签,/>为转置操作,/>表示信息熵计算。
6.根据权利要求5所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S3中所述拟合值的表达式为:
;
;
其中,为包括两个单元的高斯混合模型,/>为节点/>的损失值,/>为拟合值,/>为节点/>的硬标签预测结果的独热编码表示,/>为分类类别总数,/>为分类类别个数索引值。
7.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S3中,
所述源域相似子域包括与双学生网络模型中的第一学生网络对应的第一源域相似子域及与双学生网络模型中的第二学生网络对应的第二源域相似子域;
所述目标特定子域包括与双学生网络模型中的第一学生网络对应的第一目标特定子域及与双学生网络模型中的第二学生网络对应的第二目标特定子域。
8.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S4进一步包括:
S41:对所述目标图进行卷积,获得所述目标图中多个节点的第一嵌入表示;
S42:对于每个所述源域相似子域的节点的第一嵌入表示,均选择一个所述目标特定子域的节点的第一嵌入表示,融合初始输入特征并逐层进行拓扑感知数据融合,获得多个所述源域相似子域中的节点的第二嵌入表示;
S43:将所述源域相似子域的伪标签与所述目标特定子域的伪标签进行融合,获得融合结果;
S44:通过所述第二嵌入表示对所述融合结果中的融合后的伪标签进行配合训练,获得硬标签预测结果。
9.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S4中的所述硬标签预测结果的交叉熵损失函数的表达式为:
;
其中,为硬标签预测结果的交叉熵损失函数,/>为源域相似子域的节点索引值,为源域相似子域,/>为目标特定子域的节点索引值,/>为目标特定子域,/>为源域相似子域的节点/>的数学期望,/>为目标特定子域的节点/>的数学期望,/>为以/>为超参的贝塔分布中产生的系数,/>为节点/>的交叉熵损失,/>为节点/>的交叉熵损失。
10.根据权利要求1所述的一种无源域图的数据域适应网络构建方法,其特征在于,步骤S5还包括:
S51:通过所述目标图中的验证集对迭代训练优化的双学生网络模型进行评估,直至双学生网络模型针对目标图的节点获得的预测指标稳定,停止对双学生网络模型进行迭代训练优化,获得数据域适应网络。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410028518.XA CN117556866B (zh) | 2024-01-09 | 2024-01-09 | 一种无源域图的数据域适应网络构建方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410028518.XA CN117556866B (zh) | 2024-01-09 | 2024-01-09 | 一种无源域图的数据域适应网络构建方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117556866A true CN117556866A (zh) | 2024-02-13 |
CN117556866B CN117556866B (zh) | 2024-03-29 |
Family
ID=89823424
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410028518.XA Active CN117556866B (zh) | 2024-01-09 | 2024-01-09 | 一种无源域图的数据域适应网络构建方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117556866B (zh) |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2022001489A1 (zh) * | 2020-06-28 | 2022-01-06 | 北京交通大学 | 一种无监督域适应的目标重识别方法 |
CN114882521A (zh) * | 2022-03-30 | 2022-08-09 | 河北工业大学 | 基于多分支网络的无监督行人重识别方法及装置 |
CN115641613A (zh) * | 2022-11-03 | 2023-01-24 | 西安电子科技大学 | 一种基于聚类和多尺度学习的无监督跨域行人重识别方法 |
CN117152503A (zh) * | 2023-08-23 | 2023-12-01 | 北京理工大学 | 一种基于伪标签不确定性感知的遥感图像跨域小样本分类方法 |
-
2024
- 2024-01-09 CN CN202410028518.XA patent/CN117556866B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2022001489A1 (zh) * | 2020-06-28 | 2022-01-06 | 北京交通大学 | 一种无监督域适应的目标重识别方法 |
CN114882521A (zh) * | 2022-03-30 | 2022-08-09 | 河北工业大学 | 基于多分支网络的无监督行人重识别方法及装置 |
CN115641613A (zh) * | 2022-11-03 | 2023-01-24 | 西安电子科技大学 | 一种基于聚类和多尺度学习的无监督跨域行人重识别方法 |
CN117152503A (zh) * | 2023-08-23 | 2023-12-01 | 北京理工大学 | 一种基于伪标签不确定性感知的遥感图像跨域小样本分类方法 |
Non-Patent Citations (2)
Title |
---|
TAKASHI FUKUDA ETC.: ""Implicit Transfer of Privileged Acoustic Information in a Generalized Knowledge Distillation Framework"", 《INTERSPEECH 2020》, 29 October 2020 (2020-10-29), pages 41 - 44 * |
苑婧等: ""融合多教师模型的知识蒸馏文本分类"", 《电子技术应用》, no. 11, 30 November 2023 (2023-11-30), pages 42 - 48 * |
Also Published As
Publication number | Publication date |
---|---|
CN117556866B (zh) | 2024-03-29 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Wang et al. | Knowledge graph embedding via graph attenuated attention networks | |
Liu et al. | Trust beyond reputation: A computational trust model based on stereotypes | |
CN111241419B (zh) | 一种基于用户关系嵌入模型的下一个兴趣点推荐方法 | |
CN108563755A (zh) | 一种基于双向循环神经网络的个性化推荐***及方法 | |
CN112949929B (zh) | 一种基于协同嵌入增强题目表示的知识追踪方法及*** | |
Shi et al. | Attentional memory network with correlation-based embedding for time-aware POI recommendation | |
CN115577185B (zh) | 基于混合推理和中智群决策的慕课推荐方法及装置 | |
CN116127190B (zh) | 一种数字地球资源推荐***及方法 | |
Wang et al. | Education Data‐Driven Online Course Optimization Mechanism for College Student | |
Xia et al. | learning behavior interest propagation strategy of MOOCs based on multi entity knowledge graph | |
CN113742586B (zh) | 一种基于知识图谱嵌入的学习资源推荐方法及*** | |
CN113283488B (zh) | 一种基于学习行为的认知诊断方法及*** | |
Lumbantoruan et al. | I-cars: an interactive context-aware recommender system | |
Liang et al. | Graph path fusion and reinforcement reasoning for recommendation in MOOCs | |
CN117556866B (zh) | 一种无源域图的数据域适应网络构建方法 | |
Lin et al. | Incremental event detection via an improved knowledge distillation based model | |
CN117035013A (zh) | 一种采用脉冲神经网络预测动态网络链路的方法 | |
CN116861923A (zh) | 多视图无监督图对比学习模型构建方法、***、计算机、存储介质及应用 | |
Nadimpalli et al. | Towards personalized learning paths in adaptive learning management systems: bayesian modelling of psychological theories | |
Saha et al. | Predicting preference tags to improve item recommendation | |
Liu et al. | Multi-teacher Self-training for Semi-supervised Node Classification with Noisy Labels | |
CN111460318A (zh) | 基于显性和隐性信任的协同过滤推荐方法 | |
Bai et al. | GPR-OPT: A Practical Gaussian optimization criterion for implicit recommender systems | |
Yun et al. | Interpretable educational recommendation: an open framework based on Bayesian principal component analysis | |
CN113987332B (zh) | 一种推荐模型的训练方法、电子设备及计算机存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |