CN114492849B - 一种基于联邦学习的模型更新方法及装置 - Google Patents
一种基于联邦学习的模型更新方法及装置 Download PDFInfo
- Publication number
- CN114492849B CN114492849B CN202210080990.9A CN202210080990A CN114492849B CN 114492849 B CN114492849 B CN 114492849B CN 202210080990 A CN202210080990 A CN 202210080990A CN 114492849 B CN114492849 B CN 114492849B
- Authority
- CN
- China
- Prior art keywords
- model
- local
- clients
- parameters
- updated
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Bioethics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明提供了一种基于联邦学习的模型更新方法及装置,其中,该方法包括:利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,该多个客户端的本地模型分别是该多个客户端基于服务端下载的总模型进行本地训练得到的;根据该部分客户端的本地模型的参数对该全局模型的参数进行更新,得到更新后的总模型,可以解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,不需要节点进行数据共享,保证数据的安全,利用注意力机制选取部分客户端的本地模型的参数更新总模型的参数,可以更好地捕获客户端的异构型。
Description
技术领域
本发明涉及数据处理领域,具体而言,涉及一种基于联邦学习的模型更新方法及装置。
背景技术
联邦学习作为分布式的机器学习范式,可以支持多个数据拥有方的数据在不出域的情况下进行建模,在联邦机制下,利用隐私安全计算技术,各参与方的数据不发生转移,因此不会泄漏用户隐私或者影响数据规范,是一种在保护数据隐私、满足合法合规的要求下解决数据孤岛问题的有效措施,联邦学习技术的引入为分布式数据共享治理提供了有力的技术支持。
随着物联网设备的大规模使用,协同数以万计的设备及其数据进行联邦学习训练需求增大。然而,由于这些设备和数据归属于不同的用户、企业、场景等,因此其数据分布往往是差异极大的,导致数据具有非独立同分布特性。利用倾斜的Non-IID数据进行训练,往往由于weight-divergence原因,导致模型精度大幅下降。
针对相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,尚未提出解决方案。
发明内容
本发明实施例提供了一种基于联邦学习的模型更新方法及装置,以至少解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题。
根据本发明的一个实施例,提供了一种基于联邦学习的模型更新方法,应用于服务端,包括:
利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型包括:
通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
可选地,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数包括:
利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,利用注意力机制确定所述多个客户端的本地模型的参数的概率分布包括:
获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型包括:
根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,在根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新之后,所述方法还包括:
确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
可选地,所述方法还包括:
通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
其中,α∈[0,1),α表示注意力分数的衰减率,为客户端i的更新后的本地模型的注意力分数,/>为客户端i的本地模型的注意力分数,/>为客户端k的本地模型的注意力分数,/>为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
根据本发明的另一个实施例,还提供了一种基于联邦学习的模型更新装置,应用于服务端,包括:
选取模块,用于利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
第一更新模块,用于根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,所述第一更新模块包括:
第一确定子模块,用于通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
第二确定模块,用于通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
可选地,所述选取模块包括:
第三确定子模块,用于利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
选取子模块,用于根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述第三确定子模块包括:
获取单元,用于获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
选取单元,用于根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述选取单元,还用于:
根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述装置还包括:
确定模块,用于确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
第二更新模块,用于分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
可选地,所述第二更新模块,还用于通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
其中,α∈[0,1),α表示注意力分数的衰减率,为客户端i的更新后的本地模型的注意力分数,/>为客户端i的本地模型的注意力分数,/>为客户端k的本地模型的注意力分数,/>为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
根据本发明的又一个实施例,还提供了一种计算机可读的存储介质,所述存储介质中存储有计算机程序,其中,所述计算机程序被设置为运行时执行上述任一项方法实施例中的步骤。
根据本发明的又一个实施例,还提供了一种电子装置,包括存储器和处理器,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行上述任一项方法实施例中的步骤。
通过本发明,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型,可以解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,不需要节点进行数据共享,保证数据的安全,利用注意力机制选取部分客户端的本地模型的参数更新总模型的参数,可以更好地捕获客户端的异构型。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1是本发明实施例的基于联邦学习的模型更新方法的移动终端的硬件结构框图;
图2是根据本发明实施例的基于联邦学习的模型更新方法的流程图;
图3是根据本发明实施例的全局模型更新的示意图;
图4是根据本发明实施例的基于联邦学习的模型更新装置的框图。
具体实施方式
下文中将参考附图并结合实施例来详细说明本发明。需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
本申请实施例一所提供的方法实施例可以在移动终端、计算机终端或者类似的运算装置中执行。以运行在移动终端上为例,图1是本发明实施例的基于联邦学习的模型更新方法的移动终端的硬件结构框图,如图1所示,移动终端可以包括一个或多个(图1中仅示出一个)处理器102(处理器102可以包括但不限于微处理器MCU或可编程逻辑器件FPGA等的处理装置)和用于存储数据的存储器104,可选地,上述移动终端还可以包括用于通信功能的传输设备106以及输入输出设备108。本领域普通技术人员可以理解,图1所示的结构仅为示意,其并不对上述移动终端的结构造成限定。例如,移动终端还可包括比图1中所示更多或者更少的组件,或者具有与图1所示不同的配置。
存储器104可用于存储计算机程序,例如,应用软件的软件程序以及模块,如本发明实施例中的基于联邦学习的模型更新方法对应的计算机程序,处理器102通过运行存储在存储器104内的计算机程序,从而执行各种功能应用以及数据处理,即实现上述的方法。存储器104可包括高速随机存储器,还可包括非易失性存储器,如一个或者多个磁性存储装置、闪存、或者其他非易失性固态存储器。在一些实例中,存储器104可进一步包括相对于处理器102远程设置的存储器,这些远程存储器可以通过网络连接至移动终端。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
传输装置106用于经由一个网络接收或者发送数据。上述的网络具体实例可包括移动终端的通信供应商提供的无线网络。在一个实例中,传输装置106包括一个网络适配器(Network Interface Controller,简称为NIC),其可通过基站与其他网络设备相连从而可与互联网进行通讯。在一个实例中,传输装置106可以为射频(Radio Frequency,简称为RF)模块,其用于通过无线方式与互联网进行通讯。
在本实施例中提供了一种运行于上述移动终端或网络架构的基于联邦学习的模型更新方法,图2是根据本发明实施例的基于联邦学习的模型更新方法的流程图,如图2所示,应用于服务端,该流程包括如下步骤:
步骤S202,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
步骤S204,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
本发明实施例中,上述步骤S204具体可以通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
通过上述步骤S202至S204,可以解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,不需要节点进行数据共享,保证数据的安全,利用注意力机制选取部分客户端的本地模型的参数更新总模型的参数,可以更好地捕获客户端的异构型。
本发明实施例中,上述步骤S202具体可以包括:
S2021,利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
具体的上述步骤S2021具体可以包括:获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型,具体的,根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
S2022,根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
在一可选的实施例中,在上述步骤S204之后,确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变,进一步的,可以通过以下方式更新所述部分客户端的本地模型的注意力分数:
其中,α∈[0,1),α表示注意力分数的衰减率,为客户端i的更新后的本地模型的注意力分数,/>为客户端i的本地模型的注意力分数,/>为客户端k的本地模型的注意力分数,/>为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
本发明实施例基于注意力机制的横向联邦模型更新策略,可以有效解决基于非独立同分布数据进行模型训练时的模型迭代问题,为模型有效安全聚合提供了技术支持。
在联邦平均FedAvg算法中,每一次交流迭代中,每个用户端从server端获取同样的模型,然后基于此模型进行本地模型的训练,之后将有一小组的用户端的模型会被选中来更新全局模型(对应上述的总模型)。全局模型参数更新的主要思想是通过加权平均在服务器上聚合来自客户端的梯度更新。每一轮t,server更新模型通过以下方程进行模型的更新:
wt+1←wt-η▽f(wt),其中,η是学习率,gk是第k个用户(即客户端)的本地更新参数,St是第t轮选中的用户集合,/>
在实现中,无法***不同客户的相对训练性能。不同的客户对模型聚合可能具有不同的相对重要性,且这在不同的通信回合中可能会有所不同。本实施例提出利用注意力机制来衡量不同回合中不同用户端本地模型的相对重要性。
本实施例使用欧几里得距离来衡量每个局部模型相对于全局模型的模型差异。差异化大小用来给每个用户局部模型打分,这个分数将决定每一轮不同用户局部模型被选中更新全局模型的概率。对于第t轮不同用户端模型的得分和概率分别表示为:p=[p1,p2,...,pM]。
在训练的开始阶段,被选取的用户端下载server端的全局模型,然后进行局部模型训练;server端根据概率分布选取K个用户,选取的K个局部模型表示为其中ij表示从集合St中选取的第j个用户,每个/>是神经网络各层的权重矩阵的集合。通过聚合选取的所有权重得到下一回合的全局权重W(t+1),图3是根据本发明实施例的全局模型更新的示意图,如图3所示,对于在第t回合选取的用户i∈St,全局参数和局部参数的欧几里得距离可以表示为
另外,为了减少连续几轮的注意力得分波动,将当前注意力得分纳入更新标准:/>
其中,α∈[0,1)表示之前注意力得分贡献的衰减率,对于任意一个未被选中的用户j,默认下一个轮的得分与上一个回合一致,即因此,某用户端在第(t+1)回合的概率分布为p=a(t+1)。由于不涉及通常的联邦优化,因此注意机制只更新客户端选择概率分布,而不会更改聚合权重。
基于注意力用户打分机制的联邦学习,输入:M,T,α,W(1),n,具体包括:
步骤1,For t=1到T,do;
步骤2,p←a(t),K←M;
步骤3,Server端选择用户子集合St,其中St的大小为K;
步骤4,对于选中的用户k∈St,用户k下载全局模型W(t),用户k计算局部模型
步骤7,Server计算新的全局模型:
步骤8,对于选取的用户i∈St,Server端更新以及
步骤10,对于为选取的用户
本发明实施例利用注意力机制计算用户被选择用于模型聚合的概率,既提升了模型性能,又不需要任何全局数据共享操作,保证了数据的安全。为非独立同分布横向联邦学习建模提供了有力的技术支撑,可以解决联邦学习中由Non-IID数据导致的模型性能大幅下降问题。
根据本发明的另一个实施例,还提供了一种基于联邦学习的模型更新装置,应用于服务端,图4是根据本发明实施例的基于联邦学习的模型更新装置的框图,如图4所示,包括:
选取模块42,用于利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
第一更新模块44,用于根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,所述第一更新模块44包括:
第一确定子模块,用于通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
第二确定模块,用于通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
可选地,所述选取模块42包括:
第三确定子模块,用于利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
选取子模块,用于根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述第三确定子模块包括:
获取单元,用于获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
选取单元,用于根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述选取单元,还用于:
根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述装置还包括:
确定模块,用于确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
第二更新模块,用于分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
可选地,所述第二更新模块,还用于通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
其中,α∈[0,1),α表示注意力分数的衰减率,为客户端i的更新后的本地模型的注意力分数,/>为客户端i的本地模型的注意力分数,/>为客户端k的本地模型的注意力分数,/>为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
需要说明的是,上述各个模块是可以通过软件或硬件来实现的,对于后者,可以通过以下方式实现,但不限于此:上述模块均位于同一处理器中;或者,上述各个模块以任意组合的形式分别位于不同的处理器中。
本发明的实施例还提供了一种计算机可读的存储介质,该存储介质中存储有计算机程序,其中,该计算机程序被设置为运行时执行上述任一项方法实施例中的步骤。
可选地,在本实施例中,上述存储介质可以被设置为存储用于执行以下步骤的计算机程序:
S1,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
S2,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,在本实施例中,上述存储介质可以包括但不限于:U盘、只读存储器(Read-Only Memory,简称为ROM)、随机存取存储器(Random Access Memory,简称为RAM)、移动硬盘、磁碟或者光盘等各种可以存储计算机程序的介质。
本发明的实施例还提供了一种电子装置,包括存储器和处理器,该存储器中存储有计算机程序,该处理器被设置为运行计算机程序以执行上述任一项方法实施例中的步骤。
可选地,上述电子装置还可以包括传输设备以及输入输出设备,其中,该传输设备和上述处理器连接,该输入输出设备和上述处理器连接。
可选地,在本实施例中,上述处理器可以被设置为通过计算机程序执行以下步骤:
S1,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
S2,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,本实施例中的具体示例可以参考上述实施例及可选实施方式中所描述的示例,本实施例在此不再赘述。
显然,本领域的技术人员应该明白,上述的本发明的各模块或各步骤可以用通用的计算装置来实现,它们可以集中在单个的计算装置上,或者分布在多个计算装置所组成的网络上,可选地,它们可以用计算装置可执行的程序代码来实现,从而,可以将它们存储在存储装置中由计算装置来执行,并且在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。这样,本发明不限制于任何特定的硬件和软件结合。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (7)
1.一种基于联邦学习的模型更新方法,应用于服务端,其特征在于,包括:
利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型;
其中,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,包括:获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
2.根据权利要求1所述的方法,其特征在于,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型包括:
通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,为所述总模型的更新参数,nk为客户端K提供的数据维度,nSt为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
3.根据权利要求1所述的方法,其特征在于,在根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新之后,所述方法还包括:
确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
4.根据权利要求3所述的方法,其特征在于,所述方法还包括:
通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
其中,α∈[0,1),α表示注意力分数的衰减率,为客户端i的更新后的本地模型的注意力分数,/>为客户端i的本地模型的注意力分数,/>为客户端k的本地模型的注意力分数,/>为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,/>为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
5.一种基于联邦学习的模型更新装置,应用于服务端,其特征在于,包括:
选取模块,用于利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
第一更新模块,用于根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型;
其中,所述选取模块包括:
获取单元,用于获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
选取单元,用于根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
6.一种计算机可读的存储介质,其特征在于,所述存储介质中存储有计算机程序,其中,所述计算机程序被设置为运行时执行所述权利要求1至4任一项中所述的方法。
7.一种电子装置,包括存储器和处理器,其特征在于,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行所述权利要求1至4中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210080990.9A CN114492849B (zh) | 2022-01-24 | 2022-01-24 | 一种基于联邦学习的模型更新方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210080990.9A CN114492849B (zh) | 2022-01-24 | 2022-01-24 | 一种基于联邦学习的模型更新方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114492849A CN114492849A (zh) | 2022-05-13 |
CN114492849B true CN114492849B (zh) | 2023-09-08 |
Family
ID=81474358
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210080990.9A Active CN114492849B (zh) | 2022-01-24 | 2022-01-24 | 一种基于联邦学习的模型更新方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114492849B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114782758B (zh) * | 2022-06-21 | 2022-09-02 | 平安科技(深圳)有限公司 | 图像处理模型训练方法、***、计算机设备及存储介质 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112580821A (zh) * | 2020-12-10 | 2021-03-30 | 深圳前海微众银行股份有限公司 | 一种联邦学习方法、装置、设备及存储介质 |
CN113011599A (zh) * | 2021-03-23 | 2021-06-22 | 上海嗨普智能信息科技股份有限公司 | 基于异构数据的联邦学习*** |
CN113221470A (zh) * | 2021-06-10 | 2021-08-06 | 南方电网科学研究院有限责任公司 | 一种用于电网边缘计算***的联邦学习方法及其相关装置 |
CN113378243A (zh) * | 2021-07-14 | 2021-09-10 | 南京信息工程大学 | 一种基于多头注意力机制的个性化联邦学习方法 |
WO2021179720A1 (zh) * | 2020-10-12 | 2021-09-16 | 平安科技(深圳)有限公司 | 基于联邦学习的用户数据分类方法、装置、设备及介质 |
WO2021184836A1 (zh) * | 2020-03-20 | 2021-09-23 | 深圳前海微众银行股份有限公司 | 识别模型的训练方法、装置、设备及可读存储介质 |
CN113537509A (zh) * | 2021-06-28 | 2021-10-22 | 南方科技大学 | 协作式的模型训练方法及装置 |
CN113705610A (zh) * | 2021-07-26 | 2021-11-26 | 广州大学 | 一种基于联邦学习的异构模型聚合方法和*** |
CN113781397A (zh) * | 2021-08-11 | 2021-12-10 | 中国科学院信息工程研究所 | 基于联邦学习的医疗影像病灶检测建模方法、装置及*** |
-
2022
- 2022-01-24 CN CN202210080990.9A patent/CN114492849B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021184836A1 (zh) * | 2020-03-20 | 2021-09-23 | 深圳前海微众银行股份有限公司 | 识别模型的训练方法、装置、设备及可读存储介质 |
WO2021179720A1 (zh) * | 2020-10-12 | 2021-09-16 | 平安科技(深圳)有限公司 | 基于联邦学习的用户数据分类方法、装置、设备及介质 |
CN112580821A (zh) * | 2020-12-10 | 2021-03-30 | 深圳前海微众银行股份有限公司 | 一种联邦学习方法、装置、设备及存储介质 |
CN113011599A (zh) * | 2021-03-23 | 2021-06-22 | 上海嗨普智能信息科技股份有限公司 | 基于异构数据的联邦学习*** |
CN113221470A (zh) * | 2021-06-10 | 2021-08-06 | 南方电网科学研究院有限责任公司 | 一种用于电网边缘计算***的联邦学习方法及其相关装置 |
CN113537509A (zh) * | 2021-06-28 | 2021-10-22 | 南方科技大学 | 协作式的模型训练方法及装置 |
CN113378243A (zh) * | 2021-07-14 | 2021-09-10 | 南京信息工程大学 | 一种基于多头注意力机制的个性化联邦学习方法 |
CN113705610A (zh) * | 2021-07-26 | 2021-11-26 | 广州大学 | 一种基于联邦学习的异构模型聚合方法和*** |
CN113781397A (zh) * | 2021-08-11 | 2021-12-10 | 中国科学院信息工程研究所 | 基于联邦学习的医疗影像病灶检测建模方法、装置及*** |
Non-Patent Citations (1)
Title |
---|
TiFL: A Tier-based Federated Learning System;Zheng Chai et al.;《arXiv》;全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN114492849A (zh) | 2022-05-13 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108985309B (zh) | 一种数据处理方法以及装置 | |
CN108052639A (zh) | 基于运营商数据的行业用户推荐方法及装置 | |
CN105719221B (zh) | 针对多任务的路径协同规划方法和装置 | |
CN105608179A (zh) | 确定用户标识的关联性的方法和装置 | |
CN110166344B (zh) | 一种身份标识识别方法、装置以及相关设备 | |
Hassani et al. | Context-aware recruitment scheme for opportunistic mobile crowdsensing | |
CN107404541A (zh) | 一种对等网络传输邻居节点选择的方法及*** | |
CN113094181A (zh) | 面向边缘设备的多任务联邦学习方法及装置 | |
CN106294778A (zh) | 信息推送方法和装置 | |
CN112669084B (zh) | 策略确定方法、设备及计算机可读存储介质 | |
WO2021008675A1 (en) | Dynamic network configuration | |
CN114492849B (zh) | 一种基于联邦学习的模型更新方法及装置 | |
Lee et al. | Accurate and fast federated learning via IID and communication-aware grouping | |
CN116989819B (zh) | 一种基于模型解的路径确定方法及装置 | |
CN108289115B (zh) | 一种信息处理方法及*** | |
CN110457387B (zh) | 一种应用于网络中用户标签确定的方法及相关装置 | |
CN117119535A (zh) | 一种移动端集群热点共享的数据分流方法和*** | |
CN116843016A (zh) | 一种移动边缘计算网络下基于强化学习的联邦学习方法、***及介质 | |
CN110505186A (zh) | 一种安全规则冲突的识别方法、识别设备及存储介质 | |
US9536199B1 (en) | Recommendations based on device usage | |
CN114022731A (zh) | 基于drl的联邦学习节点选择方法 | |
CN109039907B (zh) | 确定网络数据流量最优路径方法、装置、设备及存储介质 | |
US8392434B1 (en) | Random sampling from distributed streams | |
Cao et al. | FedQMIX: Communication-efficient federated learning via multi-agent reinforcement learning | |
CN115861662B (zh) | 基于组合神经网络模型的预测方法、装置、设备及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |