CN116306969A - 基于自监督学习的联邦学习方法和*** - Google Patents
基于自监督学习的联邦学习方法和*** Download PDFInfo
- Publication number
- CN116306969A CN116306969A CN202310189525.3A CN202310189525A CN116306969A CN 116306969 A CN116306969 A CN 116306969A CN 202310189525 A CN202310189525 A CN 202310189525A CN 116306969 A CN116306969 A CN 116306969A
- Authority
- CN
- China
- Prior art keywords
- domain
- model
- data set
- global model
- self
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 59
- 238000012549 training Methods 0.000 claims abstract description 69
- 230000008569 process Effects 0.000 claims abstract description 17
- 238000013140 knowledge distillation Methods 0.000 claims description 17
- 238000010606 normalization Methods 0.000 claims description 17
- 238000009826 distribution Methods 0.000 claims description 15
- 239000000284 extract Substances 0.000 claims description 9
- 238000010586 diagram Methods 0.000 claims description 8
- 238000011176 pooling Methods 0.000 claims description 6
- 230000002829 reductive effect Effects 0.000 abstract description 2
- 238000004422 calculation algorithm Methods 0.000 description 36
- 230000006870 function Effects 0.000 description 18
- 238000004891 communication Methods 0.000 description 12
- 230000002776 aggregation Effects 0.000 description 8
- 238000004220 aggregation Methods 0.000 description 8
- 238000001514 detection method Methods 0.000 description 7
- 238000004821 distillation Methods 0.000 description 7
- 238000013528 artificial neural network Methods 0.000 description 5
- 238000004364 calculation method Methods 0.000 description 5
- 238000000605 extraction Methods 0.000 description 5
- 238000001914 filtration Methods 0.000 description 5
- 238000012546 transfer Methods 0.000 description 5
- 238000004590 computer program Methods 0.000 description 4
- 238000012935 Averaging Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 230000004913 activation Effects 0.000 description 2
- 230000004931 aggregating effect Effects 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 238000013145 classification model Methods 0.000 description 2
- 238000013136 deep learning model Methods 0.000 description 2
- 238000013508 migration Methods 0.000 description 2
- 230000005012 migration Effects 0.000 description 2
- 210000002569 neuron Anatomy 0.000 description 2
- 230000001360 synchronised effect Effects 0.000 description 2
- 230000003313 weakening effect Effects 0.000 description 2
- 230000004075 alteration Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 239000000969 carrier Substances 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000006317 isomerization reaction Methods 0.000 description 1
- 230000000670 limiting effect Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000036961 partial effect Effects 0.000 description 1
- 238000005192 partition Methods 0.000 description 1
- 238000006116 polymerization reaction Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Image Analysis (AREA)
Abstract
本申请涉及一种基于自监督学习的联邦学习方法和***,联邦学习方法实施在多个参与方和中心节点之间,包括:各参与方利用私有数据集训练本地模型,且在训练过程中对域内数据集进行预测,获得预测值;所述中心节点利用域内数据集、以及对应所述域内数据集的预测值,训练全局模型;利用所述全局模型训练域分类器,所述域分类器从开放数据集中提取所述域内数据集。本申请并非通过简单地线性组合得到全局模型,使得全局模型具有更好的全局性能。此外,训练全局模型使用域内数据集,而非传统意义上的开放数据集,因此弱化了参与方对开放数据集的依赖,减小了开放数据集中的噪声对全局模型的负面影响。
Description
技术领域
本申请涉及深度学习领域,特别是涉及一种基于自监督学习的联邦学习方法和***。
背景技术
训练高精度、泛化性能强的深度神经网络模型通常需要大规模且多样化的数据集,但是当数据涉及到用户隐私和个人信息时这一要求变得难以满足。随着个人隐私保护意识的增强,用户会倾向于选择将自己的私人数据保存在本地而拒绝互联网公司收集数据的请求。在另一些场景如需要使用跨企业或跨部门的数据对模型进行训练时,法律会要求企业清晰的列出数据保护的责任方以及数据的使用范围,这些场景都为人工智能在现实生活中的发展提出了挑战。
为了克服这个问题,联邦学习(FL)为上述数据孤岛问题提供了一种解决方案。它要求所有参与方使用其私有数据集在本地训练深度模型,并通过特定的中心节点来对本地模型进行聚合从而得到一个目标一致的全局模型。虽然联邦学习被有效的应用在大规模私有数据集联合训练的场景下,但仍然存在一定的限性,以下两方面问题是传统的联邦学习亟待解决的问题。
(1):参与方数据非独立同分布问题:联邦学习假设每个参与方的私有数据是独立同分布(IID)的。此要求在小规模联邦学习上是较为满足的,多个参与方以同样的方式从相似的场景中收集数据。然而,当问题的范围扩展到多个地理位置或多种应用场景时,参与方的私有数据集往往是非独立同分布的(non-IID)。在这种情况下,各参与方所训练的本地模型在特征提取的能力上有着一定的参差,仅仅通过线性组合的方式得到的全局模型会有着较弱的全局性能。
(2)模型异构问题:传统的联邦学习要求每个参与方训练一个相同架构的本地模型。在参与方均配备相同硬件和软件的场景这个要求是适用的,但是当参与方涉及的跨度较大(从智能穿戴设备,到移动终端、再到数据中心的服务器)时,联邦学习只能做出模型性能和训练耗时之间的妥协,且由于内存的限制,往往会在训练时出现木桶效应,只能按照硬件条件最弱的参与方设置模型的大小。
现有的解决上述问题的研究思路是通过迁移学习将多个本地模型的知识聚合在全局模型中,以处理数据非独立同分布问题。具体为局部模型所学习到的知识被开放数据集进行统一的量化,随后在中心节点对这部分知识进行聚合从而将所有参与方的知识进行聚合。
这种基于知识蒸馏解决非独立同分步的方法需要一个共享的开放数据集作为知识传递的媒介,这对开放数据集的数据特征分布提出了很高要求,开放数据集和私有数据集在特征分布上的不一致会导致该参与方传递的知识具有误导性,从而对全局模型的泛化性能造成影响。
发明内容
基于此,有必要针对上述技术问题,提供一种基于自监督学习的联邦学习方法。
本申请基于自监督学习的联邦学习方法,实施在多个参与方和中心节点之间,包括:
各参与方利用私有数据集训练本地模型,且在训练过程中对域内数据集进行预测,获得预测值;
所述中心节点利用域内数据集、以及对应所述域内数据集的预测值,训练全局模型;
利用所述全局模型训练域分类器,所述域分类器从开放数据集中提取所述域内数据集。
可选的,所述全局模型、所述域分类器、以及各所述本地模型在训练过程中均迭代更新。
可选的,各所述本地模型在本轮对域内数据集进行预测时,使用上轮所述域分类器提取的域内数据集。
可选的,在首轮对域内数据集进行预测时,所述域内数据集随机提取于所述开放数据集。
可选的,各所述本地模型作为老师模型,所述全局模型作为学生模型,所述全局模型利用知识蒸馏的方式进行迭代更新;
各所述本地模型获得的预测值的均值,用于训练所述全局模型。
可选的,各所述本地模型为相同的结构类别,各所述本地模型的迭代,通过所述全局模型分发至各参与方的方式完成。
可选的,利用所述全局模型训练域分类器,包括:
所述全局模型产生输入样本的输出层信息;
所述域分类器获得所述输入样本、以及所述输入样本的输出层信息;
所述域分类器根据所述输出层信息得到评分,根据所述评分将符合预期的输入样本置入所述域内数据集。
可选的,根据所述评分将符合预期的输入样本置入所述域内数据集,具体包括:对所述评分排序后,按次序提取绝对数量或占比数量的对应的输入样本,置入所述域内数据集。
可选的,利用所述全局模型训练域分类器,包括:
利用所述全局模型的中间层信息,自监督地训练所述域分类器,所述中间层信息来源于所述全局模型中间层内在每个批量归一化层之前的特征图。
可选的,所述域分类器包括基底模型和多层感知机,所述基底模型为每轮迭代的所述全局模型,所述多层感知机作为检测头。
可选的,所述域分类器包括多层感知机,训练过程包括:
对输入样本进行数据增强,获得对比样本,所述输入样本和所述对比样本一一对应;
基于所述输入样本获得第一层次特征,基于所述对比样本获得第二层次特征,所述第一层次特征和所述第二层次特征一一对应;
利用所述第一层次特征、所述第二层次特征、以及二者的对应关系训练所述域分类器。
可选的,所述域分类器从开放数据集中提取所述域内数据集,包括:
所述域分类器接收所述第一层次特征、所述批量归一化层中的特征平均值串联,并输出二者的相对距离,所述相对距离用于将所述开放数据集划分为所述域内数据集和域外数据集。
可选的,所述域分类器从开放数据集中提取所述域内数据集,包括:
所述域分类器接收所述第一层次特征、所述批量归一化层中的特征平均值串联,并二者投影到嵌入空间中,所述相对距离为二者在投影在嵌入空间中的余弦距离,根据所述相对距离保留选择符合预期的、与所述第一层次特征相对应的输入样本,进而置入域内数据集。
可选的,基于所述对比样本获得第二层次特征,按照基于所述输入样本获得第一层次特征的方式进行;
基于所述输入样本获得第一层次特征,利用下式进行:
式中,x为输入样本;
v(x)为输入样本的第一层次特征;
fi表示全局模型对于输入样本x在第i个批量归一化层之前的特征图;
GAP表示将一张二维的特征图进行全局平均池化得到一个标量值;
本申请还提供一种基于自监督学习的联邦学习***,包括多个参与方和中心节点,实施有如本申请所述的基于自监督学习的联邦学习方法。
可选的,各参与方本地模型的结构类别至少包括两种,所述联邦学习方法包括:
所述全局模型的结构类别与其中一个参与方本地模型的结构类别相同,所述全局模型、所述域分类器、以及各所述本地模型在训练过程中均迭代更新,在每轮迭代更新的联邦聚合阶段中:
对于与所述全局模型的结构类别相同的本地模型,利用所述全局模型替换迭代所述本地模型;
对于与所述全局模型的结构类别不同的本地模型,利用所述域内数据集、以及对应所述域内数据集的预测值,以知识蒸馏的方式,更新迭代所述本地模型。
可选的,以知识蒸馏的方式,更新迭代所述本地模型,采用以下二者中的任意一种方式进行:
所述全局模型预测所述域内数据集,并根据预测值训练与所述全局模型结构类别不同的本地模型;
各所述本地模型获得的预测值的均值,用于训练与所述全局模型结构类别不同的本地模型。
本申请基于自监督学习的联邦学习方法和***至少具有以下效果:
本申请并非通过简单地线性组合得到全局模型,使得全局模型具有更好的全局性能。此外,训练全局模型使用域内数据集,而非传统意义上的开放数据集,因此弱化了参与方对开放数据集的依赖,减小了开放数据集中的噪声对全局模型的负面影响。
本申请基于利用域分类器减少对开放数据集的基础上,既适用于各本地模型为相同的结构类别、还适用于各参与方本地模型的结构类别至少包括两种,这两种情况,不仅全局模型具有良好的全局性能,也能够解决本地模型异构的问题。
本申请域分类器采用自监督学习的方式进行训练,包括基底模型,基底模型为每轮迭代的全局模型,大大提高了域分类器的训练速度,加快了联邦学习方法。
附图说明
图1为本申请一实施例中基于自监督学习的联邦学习方法的流程示意图;
图2为本申请一实施例中基于自监督学习的联邦学习***中参与方或中心节点使用的计算机设备的内部结构图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅用以解释本申请,并不用于限定本申请。
本申请一实施例中提供一种基于自监督学习的联邦学习方法,实施在多个参与方和中心节点之间。
例如有N个参与方分别表示为{P1,P2,…,PN},其私有本地数据集表示为{D1,D2,…,DN},其中每个数据集Di由三个部分组成:[Xi,Yi,Ii],分别代表特征空间、标签空间和样本编号空间。上述的N个参与方的私有数据为非独立同分布数据,本申请各实施例旨在参与方数据非独立同分布前提下训练一个全局模型Mfed使得其在所有参与方聚合的数据集Dagg={D1,D2,…,DN}上有着尽可能高的测试精度,在训练过程中,各参与方Pi需要保证其本地数据集Pi的隐私。基于自监督学习的联邦学习方法包括步骤S100~步骤S300。
步骤S100,各参与方利用私有数据集训练本地模型,且在训练过程中对域内数据集进行预测,获得预测值;
步骤S200,中心节点利用域内数据集、以及对应域内数据集的预测值,训练全局模型;
步骤S300,利用全局模型训练域分类器,域分类器从开放数据集中提取域内数据集。
步骤S100~步骤S200是联邦学习方法的训练循环过程,其中步骤S100中所使用的域内数据集源于步骤S300,步骤S300所使用的域分类器由步骤S200中的全局模型训练获得,步骤S200使用的预测值源于步骤S100。
本实施例步骤S200并非通过简单地线性组合得到全局模型,使得全局模型具有更好的全局性能。此外,步骤S100和步骤S200均使用域内数据集,而非传统意义上的开放数据集,因此弱化了参与方对开放数据集的依赖,减小了开放数据集中的噪声对全局模型的负面影响。步骤S300中域分类器的训练基于开放数据集(无标签数据集),以全局模型为基底模型,采用自监督学习的方式进行训练。
弱化参与方对开放数据集的依赖的实现原因,可通过域内数据检测来理解。域内数据检测,即从无标签的开放数据集中数据中检测域内数据和域外(Out Of Domain,简称为OOD)样本,例如图像分类模型在面对域外数据时,理想情况下应该输出各分类概率相同的高熵分布(低预测确定性),因为来自域外数据的输入图片不属于目标类别中的任何一类;但是通常情况下OOD类别可能存在某些特征与目标类别相似,从而导致这一类别的预测概率较高而从带来误判。一个直观的区分域内域外数据的方法则是对输出概率的熵进行判断,如利用验证集计算模型有关域内和域外样本有关熵的阈值,将所有预测概率高于该阈值的样本认为是域外样本。但是这一方法在开放世界中往往难以使用,一些未知类别的数据甚至是噪声数据也会使得模型输出有把握(高预测确定性)的低熵预测。
域分类器可以将无标签的开放数据集中的域内数据进行过滤,这部分域内数据DID将作为媒介对来自多个参与方的知识进行聚合。FedAvg算法在对模型聚合时直接采用线性平均的方式,但是在数据异构的场景下这一方法不再合适,现有的解决方案大多采用知识蒸馏算法将局部模型的知识蒸馏到全局模型中。
本实施例采用集成蒸馏的方式进行模型聚合,各参与方将各自的模型上传至中心节点,由中心节点计算模型在DU过滤出的域内数据集DID上的预测输出并直接进行知识蒸馏,随后各参与方将本地模型替换为全局模型从而进行下一轮的训练,整个过程类似于FedAvg算法,除了模型聚合采用了知识蒸馏的方式进行。通过这种方式,联邦学习方法能构建一个在全局数据集Dagg(所有私有数据集的集合)上精度较好的模型,。
在步骤S200中,全局模型、域分类器、以及各本地模型在训练过程中均迭代更新。在步骤S300中,各本地模型在本轮对域内数据集进行预测时,使用上轮域分类器提取的域内数据集。在首轮对域内数据集进行预测时,域内数据集随机提取于开放数据集。
域分类器包括基底模型和多层感知机,基底模型为每轮迭代的全局模型,多层感知机作为检测头。基底模型需要有较强的特征提取能力并将所有私有数据集都视为域内数据。将每轮迭代的全局模型作为域分类器的基底模型能够提高域分类器的训练速度。
步骤S200中,各本地模型作为老师模型,全局模型作为学生模型,全局模型利用知识蒸馏的方式进行迭代更新,各本地模型获得的预测值的均值,用于训练全局模型。
在选择集成蒸馏的学生模型时,本文选择使用新一轮局部参数平均后的全局模型,将分发全局模型至各参与方(作为本地模型)作为蒸馏的对象,尽管线性组合导致平均后的模型在测试精度上不如上一轮全局模型,但是本文认为参数平均为模型导入了新的特征提取能力,更适合作为新一轮学生模型的初始化参数。
在每轮通信中都会参照式(7)对域分类器的参数进行更新,本文认为随着模型在样本上的训练,其提取特征的能力逐渐增强,而前述引入的参数平均同样也会改变BN层所记录的统计信息,因此对域分类器进行实时更新是非常有必要的。
在一轮通信结束后对域分类器进行训练,选择中心节点上一轮的全局模型作为基底模型,这样的设定使得在参与方训练本地模型时中心节点可以并行的训练该域分类器,避免模型聚合时各参与方停滞时间过长。
若各本地模型为相同的结构类别,各本地模型的迭代,通过全局模型分发至各参与方的方式完成。
若各本地模型为相同的结构类别,上一实施例减少了基于自监督学习的联邦学习方法中,对于开放数据集的依赖。在减少对开放数据集依赖的基础上,在一个实施例中解决模型异构的问题。对于模型异构的问题,各参与方由于硬件限制所训练的本地模型{M1,M2,…,MN}是架构不一致且深度和宽度不同的异构模型,算法除了期望全局模型Mfed在全局数据集Dagg上有不错的精度之外,同时也关注各本地模型Mi的泛化性能和在数据集Di上的精度。
若各参与方本地模型的结构类别至少包括两种,全局模型的结构类别与其中一个参与方本地模型的结构类别相同,全局模型、域分类器、以及各本地模型在训练过程中均迭代更新,在每轮迭代更新的联邦聚合阶段中:
(1)对于与全局模型的结构类别不同的本地模型,利用域内数据集、以及对应域内数据集的预测值,以知识蒸馏的方式,更新迭代本地模型。即将全局模型作为老师模型,将对应的本地模型作为学生模型,以完成各个本地模型的迭代,数据集选用域内数据集。
(2)对于与全局模型的结构类别相同的本地模型,利用全局模型替换迭代本地模型。
对于与全局模型的结构类别不同的本地模型,利用全局模型替换迭代本地模型,采用以下二者中的任意一种方式进行:
方式一:全局模型预测域内数据集,并根据预测值训练与全局模型结构类别不同的本地模型;
方式二:各本地模型获得的预测值的均值,用于训练与全局模型结构类别不同的本地模型。
步骤S300包括步骤S310利用全局模型训练域分类器;步骤S320,域分类器从开放数据集中提取域内数据集。
步骤S310,利用全局模型训练域分类器,包括利用全局模型的输出层信息和/或全局模型的中间层信息。
在利用全局模型的输出层信息时,步骤S310包括步骤S311~步骤S313。步骤S311,全局模型产生输入样本的输出层信息;步骤S312,域分类器获得输入样本、以及输入样本的输出层信息;步骤S313,域分类器根据输出层信息得到评分(评分函数l(x)的结果),根据评分将符合预期的输入样本置入域内数据集。具体包括:对评分排序后,按次序提取绝对数量或占比数量的对应的输入样本,置入域内数据集。
域内数据检测的问题可以形式化地表示为一个分类问题,具体地,从无标签的开放数据集DU过滤出域内数据集DID,可以定义一个以样本x作为输入样本,域内置信度作为输出层信息的评分函数l(x),这样可以建立一个基于预定义阈值的简单域内数据分类器:
DID={x|l(x)>γ,x∈DU}#(1)
其中γ则是需要预定义的置信度。式(1)仅提供了评分函数l(x)的基本思路,对于具体的函数内容不再赘述。本实施例按照式(1)给出的基本思路基于输出预测熵值,在图像分类任务中深度学习模型在输出层给出的logits对应着模型对每个分类的评分,随后的Softmax函数作为一种归一化方式将评分转换为类别的预测概率。基于熵的算法可以理解为以one-hot编码作为先验分布对输出分布进行匹配。这种匹配仅仅涉及到模型的最后一层,而对中间层丰富的维度特征不作约束。
在利用全局模型的中间层信息,训练域分类器时,中间层信息来源于全局模型中间层内在每个批量归一化层之前的特征图。
对于中间层特征的利用可以帮助模型更好地判断域内数据和域外噪声。深度学习模型如ResNet等架构往往是由多个Block组成,这些结构相似的Block的堆叠做到了图像从低维到高维的特征匹配。从模型的结构上来讲,神经网络低维卷积数量小且维度低,特征图尺寸大分辨率高,导致整个神经元的感受野较小从而使得低维特征更多的是与图形图像有关的局部特征。而神经网络的高维卷积数量多且维度大,相应的特征图尺寸小分辨率低,整个神经元的感受野较大,能提取出更全局更与任务相关的特征。
对此本实施例在利用全局模型的中间层信息时,步骤S310的训练过程包括步骤S314~步骤S316。也就是说步骤S310包括步骤S311~步骤S313(利用全局模型的输出层信息)、或S314~步骤S316(利用全局模型的中间层信息)。
步骤S314,对输入样本进行数据增强,获得对比样本,输入样本和对比样本一一对应;步骤S315,基于输入样本获得第一层次特征,基于对比样本获得第二层次特征,第一层次特征和第二层次特征一一对应;步骤S316,利用第一层次特征、第二层次特征、以及二者的对应关系训练域分类器。
对比样本通过数据增强构造,例如随机裁剪、水平翻转以及图像属性(亮度、对比度、饱和度和色调)的随机变更。具体地,对于给定的无标签的开放数据集DU以及模型f(x;θ),首先对每一个样本xi随机地进行数据增强从而得到一对正样本{xi,x′i},即输入样本和对比样本一一对应。随后将数量为2N的批次输入到域分类器中,并使用式(2)来计算每一个样本的第二层次特征{v1,v′1,…,vN,v′N},即可以知晓第一层次特征和第二层次特征一一对应的对应关系。
在步骤S315中,基于对比样本获得第二层次特征,按照基于输入样本获得第一层次特征的方式进行;基于输入样本获得第一层次特征,利用下式进行:
式中,x为输入样本;
v(x)为输入样本的第一层次特征,用于探测域内数据;
fi表示全局模型对于输入样本x在第i个批量归一化层之前的特征图;
GAP表示将一张二维的特征图进行全局平均池化得到一个标量值;
GAP算子指的是全局平均池化(Global Average Pooling),fi代表模型对于输入样本x在第i个批量归一化层之前的特征图,是简单的串联操作。GAP算子将一张二维的特征图进行全局平均得到一个标量值,这个算子最初是被设计用来取代模型输出层的全连接层,GAP层能加强类别得分和对应的卷积之间的联系,同时也能拥有更好的空间信息。本实施例对卷积层的输出做了全局平均池化,目的是将每个通道的特征进行聚合,同时也让语义信息和BN层的统计信息进行对齐。
可以理解,在训练过程中,需要基于输入样本获得第一层次特征、基于对比样本获得第二层次特征,以完成对域分类器的训练。但是在应用过程中(步骤S320),不再需要基于对比样本获得第二层次特征。
在神经网络中,不同通道之间的特征通常具有很强的相互依赖性,朴素的距离度量会更加关注那些幅度变化较大的通道,但无法捕获多个通道之间的结构化信息。本实施例提供的域分类器使用的评分函数解决了这一问题。本实施例基于输入样本、和对比样本,进行自监督学习。
本实施例利用批量归一化层(Batch Normalization Layer,BN层)作为先验分布的域内数据检测算法,以便于中间层信息的利用。几乎所有的现代神经网络都具有BN层,这是因为随着神经网络朝着更深更宽的目标发展,上层有关特征的微小扰动会对下层造成极大的影响,因此为了消除每一批次数据不同导致的特征具有偏差的问题,BN层被添加在卷积层之后对特征图作归一化处理。训练过程中特征图的均值和方差会被存储在BN层中,我们可以利用这些统计数据构架一个用于检测域内数据的先验分布,这部分额外的信息使得我们的算法相对仅使用分类概率的算法有着更好的鲁棒性。
当在利用全局模型的中间层信息时,步骤S320,域分类器从开放数据集中提取域内数据集,包括:域分类器接收第一层次特征、批量归一化层中的特征平均值串联,并输出二者的相对距离即d(v(x),v*),相对距离用于将开放数据集划分为域内数据集和域外数据集。
现在可以得出本文所提出的域内分类器的形式化定义,设存储在BN层中的特征平均值串联为v*,则可以我们通过如下方式从无标签数据集中提取出域内数据:
DID={x|d(v(x),v*)>∈,x∈DU}#(3)
式中,v(x)见公式(2),d(v(x),v*)是衡量层次特征与先验分布之间的置信度可以有多种实现方式,如简单的距离函数‖v(x)-v*‖等。但是普通的距离函数可能并不合适,因为我们不知道层次特征不同维度之间的重要性,同时不同类别的噪声样本也具有不同的差异性供模型进行判断。
提取的域内数据提取后置入域内数据集DID内。中心节点会对无标签数据集DU进行遍历,按照式(3)选择这一轮通信中的域内数据集DID,本文按照各分类的样本数目对DU进行过滤,而非显式地设定阈值∈。
相对距离用于将开放数据集划分为域内数据集和域外数据集包括:按照排列次序提取开放数据集中绝对数量或占比数量输入样本,置入域内数据集。这有利于生成各类别数目均衡的数据集DID。
步骤S320,具体包括:域分类器接收第一层次特征、批量归一化层中的特征平均值串联,并二者投影到嵌入空间中,相对距离为二者在投影在嵌入空间中的余弦距离,根据相对距离保留选择符合预期的、与第一层次特征相对应的输入样本,进而置入域内数据集。
多层感知机作为域分类器的检测头,域分类器接收第一层次特征v(x)作为输入并将其投影到嵌入空间中,接收批量归一化层中的特征平均值串联并将其投影到嵌入空间中,并通过余弦距离衡量投影在嵌入空间/>中的距离。
多层感知机包括全连接层和作为非线性的ReLU函数的激活函数,在投影过程能在投影的过程中将特征的重要程度和通道之间的相关性反映在高维嵌入空间中。另一方面,余弦距离会省略特征的绝对大小,达到类似将输入投影到超球面的效果,因此我们可以使用简单的线性距离来将相似的样本聚集在一起。
使用式(2)来计算每一个样本的第二层次特征{v1,v′1,…,vN,v′N}后,按照下式计算自监督对比损失:
其中h(*)代表本节阐述的多层感知机,起着将层次特征投影在超球面上的作用。
上述过程利用对比学习能够做到自监督地训练一个投影头,但是还需要多层感知机具有部分和任务强相关的能力,即如式(3)中的评分函数d(v(x),v*)计算某个层次特征和BN层统计信息的距离。值得注意的是,v(x)计算的是某个样本的层次特征,而v*是从BN层的统计信息串联得到的,这两者从概念上有一定的不匹配;另外如之前所说,多层感知机将层次特征投影在超球面中,而v*在嵌入空间中只以一个点的形式存在,因此在这种情况下直接计算sim(v(x),v*)是不合适的。
为了解决上述问题,本文提出了另一种建立对比学习多视角的方法:受到BN层聚合一个批次中所有样本特征的启发,我们随机将特征进行分组,并将分组结果作为自监督标签进行训练。
具体地,本文将层次特征{v1,v2,…,vN}进行随机分组平均得到{g1,g2,…,gM},每个层次特征属于且只属于其中一个分组。随后将这些输入通过多层感知机投影到超球面中,对于每一层次特征vi尝试拉近其与所属分组gi的距离,而退远与其它分区的距离。目标函数可以形式化地描述为:
综上,自监督对比学习多层感知机的目标函数如下:
其中β为平衡两者损失的系数,训练时只有投影头的参数是可训练的,模型f(x;θ)的参数固定。
本申请各实施例选择在通信中传递整个模型参数并在中心节点进行知识蒸馏的方案,蒸馏损失可以形式化地表示为:
其中DkL为KL散度,用来衡量两个分布之间的距离,τ是温度系数。
接下来对本申请各实施例提供的基于自监督学习的联邦学习方法进一步详细解释。
联邦学习算法往往由多轮通信组成,每轮通信中算法可以粗略地被分为两个阶段:本地训练阶段和联邦聚合阶段。算法的输入为各参与方的私有数据集以及无标签数据集DU,在初始化阶段,中心节点需要对联邦学习全局模型ffed(x;θfed)以及域分类器d(v,g;θd)进行模型参数的初始化,除此之外各参与方也需要和中心节点进行同步,交换任务元信息如使用的算法、模型的架构、训练超参数以及通信接口和密钥等,这一步骤未在算法中详细给出。接下来将详细介绍算法的细节。
本地训练阶段是局部模型从私有数据集中学习知识的主要阶段,这一阶段各参与方会并行地使用本地数据对模型的参数进行更新;同时从算法时序上来讲,中心节点在这一阶段中会对域分类器做并行地自监督训练,并利用训练后的域分类器对域内数据进行过滤。
在每一轮通信开始时,中心节点会将上一轮得到的全局模型或初始化全局模型的参数分发至每个参与方。随后各参与方利用全局模型ffed替换局部模型fi,随后的参数更新将基于全局模型的参数。在每一轮通信中由于分布式的特殊性,少数参与方会出现掉线或错过训练的情况,在算法中本文利用激活率C来模拟现实状况。处于激活状态的参与方Pi会将本地数据集随机按照固定的批次大小进行划分,对于每一批数据(x,y)都会利用交叉熵损失函数计算损失。
这一阶段的训练过程发生在中心节点,因此可以在各参与方训练本地模型时并行地训练域分类器。首先由无标签数据集DU经过随机的数据增强生成两对样本x和x′,其中的对应样本会被认为是正样例通过式(4)自监督地计算对比损失。样本经由上一轮的全局模型ffed进行特征提取,并将自底向上的多层次特征拼接为一个全局的特征表达v,随后为了模仿BN层的计算模式随机对该特征表达进行分组并取均值得到g,最终由分组信息作为自监督标签计算对抗损失。域分类器在每一轮通信中都会进行参数的更新,以适应局部模型从私有数据集中学习到的新的特征提取能力。
过滤域内数据阶段同样是并行于参与方训练发生。在域分类器训练完成后,利用式(2)将BN层的统计信息提取成为判断域内样本的先验分布v*,随后遍历无标签数据集DU中的每一个样本并进行过滤。具体地,在过滤过程中,本文将模型的输出标签作为样本的类别标签,同时记录其经由域分类器输出的置信度;随后对类别内部的置信度进行排序,选择置信度排名靠前的样本作为域内样本添加到DID中。本文并不显示地指定置信度阈值,而是通过各类别内部的置信度排名进行过滤以生成类别数目平衡的域内数据集。
联邦聚合阶段中,算法将来自多个non-IID数据集的局部模型进行聚合,期望得到一个在全局数据集Dagg上有着良好性能的全局模型。对于这一轮通信处于激活状态的参与方,中心节点会等待这些参与方完成本地训练,并收集其模型参数用以进行知识蒸馏。中心节点将过滤得到的域内数据集DID按固定的批次大小进行随机划分,随后将一批次的样本输入到此轮参与训练的局部模型中,以平均的logits作为教师将知识蒸馏至全局模型ffed中,集成蒸馏的形式化表示见式(8)。至此一轮通信的流程已经完成,在下一轮通信中ffed会被当做新的全局模型分发给各参与方,在联邦学习算法结束时,参与方从中心节点处拿到最终的模型并根据下游任务进行测试和部署。
由于联邦学习算法会涉及到多个参与方,这些参与方往往有着不同的硬件设施以及软件资源,难以做到算力层面的统一。FedAvg算法基于模型参数之间的线性组合,多个不同架构的模型已经超出了该算法的适用范围。一些基于自监督学习的算法可以通过聚合模型的预测来做到模型无关的联邦学习***,但是这些算法同样依赖于开放的无标签数据集作为知识转移的载体。因此接下来会对之前的算法进行了修改,使其适应模型异构场景。
算法的中心思想是利用知识蒸馏将来自多个参与方的异构模型集成至一个全局模型中,随后利用该模型对域分类器的参数进行更新,最终由各参与方利用域内数据集对本地模型进行更新。
对于全局模型的选择问题,本文期望全局模型能够在当前状态中达到最好的性能,且同时能满足各参与方的本地模型能学习到全局数据中的知识。因此有:
其中MAX_Count,最大数量;Arch(θi)指模型的架构。具体地,在每一轮通信中,算法选择当前激活参与方中,模型架构相似最多的模型作为全局模型。
适用于模型异构场景的联邦学习算法可以分为三个阶段,本地训练、联邦聚合以及知识迁移。本地训练指各参与方利用本地数据集训练局部个性化模型,并将模型参数发送给中心节点,这一步骤和之前提到的算法的本地训练阶段相似,在此不再赘述。
联邦聚合阶段中,算法利用局部模型进行集成蒸馏,得到一个性能较好的全局模型。具体地,算法首先进行模型架构的选择,随后算法从属于该架构的局部模型中随机选择一个模型作为初始化的全局模型。下一步则是利用域内数据集进行知识的转移,值得注意的是此处域分类器的训练和知识转移的顺序有了一个,对于第一轮模型来讲,需要中心节点对域内数据集进行随机初始化,即随机从无标签开放数据集中挑选样本加入到DID中。在全局模型的参数更新完成后,算法利用全局模型BN层的统计信息作为先验分布训练域分类器,特别地,由于模型异构会导致层次特征具有不同的大小,因此域分类器d(v,g;θd)需要在每一轮中从初始化状态开始训练。最终算法利用训练完成的域分类器对无标签数据集DU进行过滤,以方便下一轮通信利用域内数据集作知识蒸馏。
知识迁移阶段中,被激活的参与方会从中心节点处接受新一轮的全局模型ffed,并利用这部分全局知识对本地个性化模型做微调。具体地,如果全局模型和本地模型的架构相同,则直接进行相应的替换,这种特殊情况可以理解为模型同构的联邦学习算法;否则需要利用全局模型中的全局视野对本地模型进行知识蒸馏。
中心节点在共享模型参数时会附上域内数据集DID的标签,本地参与方通过标签进行数据集还原,并利用蒸馏损失式(8)进行知识蒸馏。值得一提的是在进行知识蒸馏时,教师模型即ffed的参数大小不会阻碍算力较弱的参与放的进度,这是由于此时教师模型处于固定参数的推理状态,在计算时不会记录计算图以及各参数的梯度。
本申请各实施例提供的基于自监督学习的联邦学习方法,致力于扩大联邦学习的适用场景,至少减少了基于自监督学习的算法过于依赖于开放的无标签数据集进行知识转移这一现象,自监督地训练了域分类器将样本的层次特征投影至嵌入空间中,方便算法更好地利用线性距离判断样本和模型BN层之间的相似性。随后针对模型异构的联邦学习场景,本文提出了适用于该场景的基于自监督学习的模型异构联邦学习算法。
应该理解的是,虽然图1的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些子步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
本申请一实施例中还提供一种基于自监督学习的联邦学习***,包括多个参与方和中心节点,实施有如本申请各实施例中提供的基于自监督学习的联邦学习方法。
对于任何一个参与方和中心节点,例如可以采用计算机设备,该计算机设备可以是终端、智能穿戴设备、移动终端、服务器,其内部结构图可以如图2所示。该计算机设备包括通过***总线连接的处理器、存储器、网络接口、显示屏和输入装置。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作***和计算机程序。该内存储器为非易失性存储介质中的操作***和计算机程序的运行提供环境。该计算机设备的网络接口用于与外部的终端通过网络连接通信。不同计算机设备的计算机程序被处理器执行时以实现一种基于自监督学习的联邦学习的方法。计算机设备的显示屏可以是液晶显示屏或者电子墨水显示屏,该计算机设备的输入装置可以是显示屏上覆盖的触摸层,也可以是计算机设备外壳上设置的按键、轨迹球或触控板,还可以是外接的键盘、触控板或鼠标等。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。不同实施例中的技术特征体现在同一附图中时,可视为该附图也同时披露了所涉及的各个实施例的组合例。
以上所述实施例仅表达了本申请的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。因此,本申请专利的保护范围应以所附权利要求为准。
Claims (10)
1.基于自监督学习的联邦学习方法,实施在多个参与方和中心节点之间,其特征在于,包括:
各参与方利用私有数据集训练本地模型,且在训练过程中对域内数据集进行预测,获得预测值;
所述中心节点利用域内数据集、以及对应所述域内数据集的预测值,训练全局模型;
利用所述全局模型训练域分类器,所述域分类器从开放数据集中提取所述域内数据集。
2.根据权利要求1所述的基于自监督学习的联邦学习方法,其特征在于,所述全局模型、所述域分类器、以及各所述本地模型在训练过程中均迭代更新;
各所述本地模型在本轮对域内数据集进行预测时,使用上轮所述域分类器提取的域内数据集;
各所述本地模型作为老师模型,所述全局模型作为学生模型,所述全局模型利用知识蒸馏的方式进行迭代更新;
各所述本地模型获得的预测值的均值,用于训练所述全局模型。
3.根据权利要求2所述的基于自监督学习的联邦学习方法,其特征在于,各所述本地模型为相同的结构类别,各所述本地模型的迭代,通过所述全局模型分发至各参与方的方式完成。
4.根据权利要求1所述的基于自监督学习的联邦学习方法,其特征在于,利用所述全局模型训练域分类器,包括:
所述全局模型产生输入样本的输出层信息;
所述域分类器获得所述输入样本、以及所述输入样本的输出层信息;
所述域分类器根据所述输出层信息得到评分,根据所述评分将符合预期的输入样本置入所述域内数据集。
5.根据权利要求1所述的基于自监督学习的联邦学习方法,其特征在于,利用所述全局模型训练域分类器,包括:
利用所述全局模型的中间层信息,自监督地训练所述域分类器,所述中间层信息来源于所述全局模型中间层内在每个批量归一化层之前的特征图。
6.根据权利要求5所述的基于自监督学习的联邦学习方法,其特征在于,所述域分类器包括多层感知机,训练过程包括:
对输入样本进行数据增强,获得对比样本,所述输入样本和所述对比样本一一对应;
基于所述输入样本获得第一层次特征,基于所述对比样本获得第二层次特征,所述第一层次特征和所述第二层次特征一一对应;
利用所述第一层次特征、所述第二层次特征、以及二者的对应关系训练所述域分类器。
7.根据权利要求6所述的基于自监督学习的联邦学习方法,其特征在于,所述域分类器从开放数据集中提取所述域内数据集,包括:
所述域分类器接收所述第一层次特征、所述批量归一化层中的特征平均值串联,并输出二者的相对距离,所述相对距离用于将所述开放数据集划分为所述域内数据集和域外数据集。
8.根据权利要求7所述的基于自监督学习的联邦学习方法,其特征在于,所述域分类器从开放数据集中提取所述域内数据集,包括:
所述域分类器接收所述第一层次特征、所述批量归一化层中的特征平均值串联,并二者投影到嵌入空间中,所述相对距离为二者在投影在嵌入空间中的余弦距离,根据所述相对距离保留选择符合预期的、与所述第一层次特征相对应的输入样本,进而置入域内数据集。
10.基于自监督学习的联邦学习***,包括多个参与方和中心节点,其特征在于,实施有如权利要求1~9任一项所述的基于自监督学习的联邦学习方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310189525.3A CN116306969A (zh) | 2023-02-21 | 2023-02-21 | 基于自监督学习的联邦学习方法和*** |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310189525.3A CN116306969A (zh) | 2023-02-21 | 2023-02-21 | 基于自监督学习的联邦学习方法和*** |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116306969A true CN116306969A (zh) | 2023-06-23 |
Family
ID=86782754
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310189525.3A Pending CN116306969A (zh) | 2023-02-21 | 2023-02-21 | 基于自监督学习的联邦学习方法和*** |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116306969A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117171628A (zh) * | 2023-11-01 | 2023-12-05 | 之江实验室 | 异构联邦环境中的图结构数据节点分类方法和装置 |
-
2023
- 2023-02-21 CN CN202310189525.3A patent/CN116306969A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117171628A (zh) * | 2023-11-01 | 2023-12-05 | 之江实验室 | 异构联邦环境中的图结构数据节点分类方法和装置 |
CN117171628B (zh) * | 2023-11-01 | 2024-03-26 | 之江实验室 | 异构联邦环境中的图结构数据节点分类方法和装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2023065545A1 (zh) | 风险预测方法、装置、设备及存储介质 | |
CN110866530A (zh) | 一种字符图像识别方法、装置及电子设备 | |
CN116227624A (zh) | 面向异构模型的联邦知识蒸馏方法和*** | |
Wang et al. | Deep streaming label learning | |
Zhou et al. | Improved cross-label suppression dictionary learning for face recognition | |
CN115659966A (zh) | 基于动态异构图和多级注意力的谣言检测方法及*** | |
CN113065409A (zh) | 一种基于摄像分头布差异对齐约束的无监督行人重识别方法 | |
Zhou et al. | Expanding the prediction capacity in long sequence time-series forecasting | |
Gao et al. | Adversarial mobility learning for human trajectory classification | |
CN112258250A (zh) | 基于网络热点的目标用户识别方法、装置和计算机设备 | |
CN116306969A (zh) | 基于自监督学习的联邦学习方法和*** | |
Shehu et al. | Lateralized approach for robustness against attacks in emotion categorization from images | |
CN115310589A (zh) | 一种基于深度图自监督学习的群体识别方法及*** | |
CN115114484A (zh) | 异常事件检测方法、装置、计算机设备和存储介质 | |
CN114254738A (zh) | 双层演化的动态图卷积神经网络模型构建方法及应用 | |
CN112598089B (zh) | 图像样本的筛选方法、装置、设备及介质 | |
Zhai et al. | Population-based evolutionary gaming for unsupervised person re-identification | |
CN115705706A (zh) | 视频处理方法、装置、计算机设备和存储介质 | |
CN116502705A (zh) | 兼用域内外数据集的知识蒸馏方法和计算机设备 | |
CN114937166A (zh) | 图像分类模型构建方法、图像分类方法及装置、电子设备 | |
Gorokhovatskiy et al. | Intellectual Data Processing and Self-Organization of Structural Features at Recognition of Visual Objects | |
Jiang et al. | A classification algorithm based on weighted ML-kNN for multi-label data | |
Jia et al. | An unsupervised person re‐identification approach based on cross‐view distribution alignment | |
Feng et al. | Learning from noisy correspondence with tri-partition for cross-modal matching | |
US20240070466A1 (en) | Unsupervised Labeling for Enhancing Neural Network Operations |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |