CN114492849A - 一种基于联邦学习的模型更新方法及装置 - Google Patents

一种基于联邦学习的模型更新方法及装置 Download PDF

Info

Publication number
CN114492849A
CN114492849A CN202210080990.9A CN202210080990A CN114492849A CN 114492849 A CN114492849 A CN 114492849A CN 202210080990 A CN202210080990 A CN 202210080990A CN 114492849 A CN114492849 A CN 114492849A
Authority
CN
China
Prior art keywords
model
clients
parameters
local
local models
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210080990.9A
Other languages
English (en)
Other versions
CN114492849B (zh
Inventor
王鹏
贾雪丽
李钰
樊昕晔
田江
向小佳
丁永建
李璠
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Everbright Technology Co ltd
Original Assignee
Everbright Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Everbright Technology Co ltd filed Critical Everbright Technology Co ltd
Priority to CN202210080990.9A priority Critical patent/CN114492849B/zh
Publication of CN114492849A publication Critical patent/CN114492849A/zh
Application granted granted Critical
Publication of CN114492849B publication Critical patent/CN114492849B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/62Protecting access to data via a platform, e.g. using keys or access control rules
    • G06F21/6218Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
    • G06F21/6245Protecting personal data, e.g. for financial or medical purposes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Medical Informatics (AREA)
  • Molecular Biology (AREA)
  • Bioethics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Databases & Information Systems (AREA)
  • Computer Hardware Design (AREA)
  • Computer Security & Cryptography (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明提供了一种基于联邦学习的模型更新方法及装置,其中,该方法包括:利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,该多个客户端的本地模型分别是该多个客户端基于服务端下载的总模型进行本地训练得到的;根据该部分客户端的本地模型的参数对该全局模型的参数进行更新,得到更新后的总模型,可以解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,不需要节点进行数据共享,保证数据的安全,利用注意力机制选取部分客户端的本地模型的参数更新总模型的参数,可以更好地捕获客户端的异构型。

Description

一种基于联邦学习的模型更新方法及装置
技术领域
本发明涉及数据处理领域,具体而言,涉及一种基于联邦学习的模型更新方法及装置。
背景技术
联邦学习作为分布式的机器学习范式,可以支持多个数据拥有方的数据在不出域的情况下进行建模,在联邦机制下,利用隐私安全计算技术,各参与方的数据不发生转移,因此不会泄漏用户隐私或者影响数据规范,是一种在保护数据隐私、满足合法合规的要求下解决数据孤岛问题的有效措施,联邦学习技术的引入为分布式数据共享治理提供了有力的技术支持。
随着物联网设备的大规模使用,协同数以万计的设备及其数据进行联邦学习训练需求增大。然而,由于这些设备和数据归属于不同的用户、企业、场景等,因此其数据分布往往是差异极大的,导致数据具有非独立同分布特性。利用倾斜的Non-IID数据进行训练,往往由于weight-divergence原因,导致模型精度大幅下降。
针对相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,尚未提出解决方案。
发明内容
本发明实施例提供了一种基于联邦学习的模型更新方法及装置,以至少解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题。
根据本发明的一个实施例,提供了一种基于联邦学习的模型更新方法,应用于服务端,包括:
利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型包括:
通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
Figure BDA0003485882740000021
通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,
Figure BDA0003485882740000022
为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
可选地,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数包括:
利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,利用注意力机制确定所述多个客户端的本地模型的参数的概率分布包括:
获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型包括:
根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,在根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新之后,所述方法还包括:
确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
可选地,所述方法还包括:
通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
Figure BDA0003485882740000031
其中,α∈[0,1),α表示注意力分数的衰减率,
Figure BDA0003485882740000032
为客户端i的更新后的本地模型的注意力分数,
Figure BDA0003485882740000033
为客户端i的本地模型的注意力分数,
Figure BDA0003485882740000034
为客户端k的本地模型的注意力分数,
Figure BDA0003485882740000035
为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,
Figure BDA0003485882740000036
为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
根据本发明的另一个实施例,还提供了一种基于联邦学习的模型更新装置,应用于服务端,包括:
选取模块,用于利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
第一更新模块,用于根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,所述第一更新模块包括:
第一确定子模块,用于通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
Figure BDA0003485882740000041
第二确定模块,用于通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,
Figure BDA0003485882740000042
为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
可选地,所述选取模块包括:
第三确定子模块,用于利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
选取子模块,用于根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述第三确定子模块包括:
获取单元,用于获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
选取单元,用于根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述选取单元,还用于:
根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述装置还包括:
确定模块,用于确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
第二更新模块,用于分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
可选地,所述第二更新模块,还用于通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
Figure BDA0003485882740000051
其中,α∈[0,1),α表示注意力分数的衰减率,
Figure BDA0003485882740000052
为客户端i的更新后的本地模型的注意力分数,
Figure BDA0003485882740000053
为客户端i的本地模型的注意力分数,
Figure BDA0003485882740000054
为客户端k的本地模型的注意力分数,
Figure BDA0003485882740000055
为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,
Figure BDA0003485882740000061
为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
根据本发明的又一个实施例,还提供了一种计算机可读的存储介质,所述存储介质中存储有计算机程序,其中,所述计算机程序被设置为运行时执行上述任一项方法实施例中的步骤。
根据本发明的又一个实施例,还提供了一种电子装置,包括存储器和处理器,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行上述任一项方法实施例中的步骤。
通过本发明,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型,可以解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,不需要节点进行数据共享,保证数据的安全,利用注意力机制选取部分客户端的本地模型的参数更新总模型的参数,可以更好地捕获客户端的异构型。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1是本发明实施例的基于联邦学习的模型更新方法的移动终端的硬件结构框图;
图2是根据本发明实施例的基于联邦学习的模型更新方法的流程图;
图3是根据本发明实施例的全局模型更新的示意图;
图4是根据本发明实施例的基于联邦学习的模型更新装置的框图。
具体实施方式
下文中将参考附图并结合实施例来详细说明本发明。需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
本申请实施例一所提供的方法实施例可以在移动终端、计算机终端或者类似的运算装置中执行。以运行在移动终端上为例,图1是本发明实施例的基于联邦学习的模型更新方法的移动终端的硬件结构框图,如图1所示,移动终端可以包括一个或多个(图1中仅示出一个)处理器102(处理器102可以包括但不限于微处理器MCU或可编程逻辑器件FPGA等的处理装置)和用于存储数据的存储器104,可选地,上述移动终端还可以包括用于通信功能的传输设备106以及输入输出设备108。本领域普通技术人员可以理解,图1所示的结构仅为示意,其并不对上述移动终端的结构造成限定。例如,移动终端还可包括比图1中所示更多或者更少的组件,或者具有与图1所示不同的配置。
存储器104可用于存储计算机程序,例如,应用软件的软件程序以及模块,如本发明实施例中的基于联邦学习的模型更新方法对应的计算机程序,处理器102通过运行存储在存储器104内的计算机程序,从而执行各种功能应用以及数据处理,即实现上述的方法。存储器104可包括高速随机存储器,还可包括非易失性存储器,如一个或者多个磁性存储装置、闪存、或者其他非易失性固态存储器。在一些实例中,存储器104可进一步包括相对于处理器102远程设置的存储器,这些远程存储器可以通过网络连接至移动终端。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
传输装置106用于经由一个网络接收或者发送数据。上述的网络具体实例可包括移动终端的通信供应商提供的无线网络。在一个实例中,传输装置106包括一个网络适配器(Network Interface Controller,简称为NIC),其可通过基站与其他网络设备相连从而可与互联网进行通讯。在一个实例中,传输装置106可以为射频(Radio Frequency,简称为RF)模块,其用于通过无线方式与互联网进行通讯。
在本实施例中提供了一种运行于上述移动终端或网络架构的基于联邦学习的模型更新方法,图2是根据本发明实施例的基于联邦学习的模型更新方法的流程图,如图2所示,应用于服务端,该流程包括如下步骤:
步骤S202,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
步骤S204,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
本发明实施例中,上述步骤S204具体可以通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
Figure BDA0003485882740000081
通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,
Figure BDA0003485882740000082
为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
通过上述步骤S202至S204,可以解决相关技术中基于联邦学习的模型训练需要各个节点进行少量数据的共享,数据的出域存在数据安全隐患的问题,不需要节点进行数据共享,保证数据的安全,利用注意力机制选取部分客户端的本地模型的参数更新总模型的参数,可以更好地捕获客户端的异构型。
本发明实施例中,上述步骤S202具体可以包括:
S2021,利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
具体的上述步骤S2021具体可以包括:获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型,具体的,根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
S2022,根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
在一可选的实施例中,在上述步骤S204之后,确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变,进一步的,可以通过以下方式更新所述部分客户端的本地模型的注意力分数:
Figure BDA0003485882740000091
其中,α∈[0,1),α表示注意力分数的衰减率,
Figure BDA0003485882740000092
为客户端i的更新后的本地模型的注意力分数,
Figure BDA0003485882740000093
为客户端i的本地模型的注意力分数,
Figure BDA0003485882740000094
为客户端k的本地模型的注意力分数,
Figure BDA0003485882740000095
为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,
Figure BDA0003485882740000096
为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
本发明实施例基于注意力机制的横向联邦模型更新策略,可以有效解决基于非独立同分布数据进行模型训练时的模型迭代问题,为模型有效安全聚合提供了技术支持。
在联邦平均FedAvg算法中,每一次交流迭代中,每个用户端从server端获取同样的模型,然后基于此模型进行本地模型的训练,之后将有一小组的用户端的模型会被选中来更新全局模型(对应上述的总模型)。全局模型参数更新的主要思想是通过加权平均在服务器上聚合来自客户端的梯度更新。每一轮t,server更新模型通过以下方程进行模型的更新:
wt+1←wt-η▽f(wt),其中,
Figure BDA0003485882740000101
η是学习率,gk是第k个用户(即客户端)的本地更新参数,St是第t轮选中的用户集合,
Figure BDA0003485882740000102
在实现中,无法***不同客户的相对训练性能。不同的客户对模型聚合可能具有不同的相对重要性,且这在不同的通信回合中可能会有所不同。本实施例提出利用注意力机制来衡量不同回合中不同用户端本地模型的相对重要性。
本实施例使用欧几里得距离来衡量每个局部模型相对于全局模型的模型差异。差异化大小用来给每个用户局部模型打分,这个分数将决定每一轮不同用户局部模型被选中更新全局模型的概率。对于第t轮不同用户端模型的得分和概率分别表示为:
Figure BDA0003485882740000103
p=[p1,p2,...,pM]。
在训练的开始阶段,被选取的用户端下载server端的全局模型,然后进行局部模型训练;server端根据概率分布选取K个用户,选取的K个局部模型表示为
Figure BDA0003485882740000104
其中ij表示从集合St中选取的第j个用户,每个
Figure BDA0003485882740000105
是神经网络各层的权重矩阵的集合。通过聚合选取的所有权重得到下一回合的全局权重W(t+1),图3是根据本发明实施例的全局模型更新的示意图,如图3所示,对于在第t回合选取的用户i∈St,全局参数和局部参数的欧几里得距离可以表示为
Figure BDA0003485882740000111
另外,为了减少连续几轮的注意力得分波动,将当前注意力得分纳入更新标准:
Figure BDA0003485882740000112
其中,α∈[0,1)表示之前注意力得分贡献的衰减率,对于任意一个未被选中的用户j,默认下一个轮的得分与上一个回合一致,即
Figure BDA0003485882740000113
因此,某用户端在第(t+1)回合的概率分布为p=a(t+1)。由于不涉及通常的联邦优化,因此注意机制只更新客户端选择概率分布,而不会更改聚合权重。
基于注意力用户打分机制的联邦学习,输入:M,T,α,W(1),n,具体包括:
步骤1,For t=1到T,do;
步骤2,p←a(t),K←M;
步骤3,Server端选择用户子集合St,其中St的大小为K;
步骤4,对于选中的用户k∈St,用户k下载全局模型W(t),用户k计算局部模型
Figure BDA0003485882740000114
步骤7,Server计算新的全局模型:
Figure BDA0003485882740000115
步骤8,对于选取的用户i∈St,Server端更新
Figure BDA0003485882740000116
以及
Figure BDA0003485882740000117
步骤10,对于为选取的用户
Figure BDA0003485882740000118
本发明实施例利用注意力机制计算用户被选择用于模型聚合的概率,既提升了模型性能,又不需要任何全局数据共享操作,保证了数据的安全。为非独立同分布横向联邦学习建模提供了有力的技术支撑,可以解决联邦学习中由Non-IID数据导致的模型性能大幅下降问题。
根据本发明的另一个实施例,还提供了一种基于联邦学习的模型更新装置,应用于服务端,图4是根据本发明实施例的基于联邦学习的模型更新装置的框图,如图4所示,包括:
选取模块42,用于利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
第一更新模块44,用于根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,所述第一更新模块44包括:
第一确定子模块,用于通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
Figure BDA0003485882740000121
第二确定模块,用于通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
wt+1=wt-η▽f(wt);
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,▽f(wt)为所述总模型的更新参数,nk为客户端K提供的数据维度,
Figure BDA0003485882740000122
为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
可选地,所述选取模块42包括:
第三确定子模块,用于利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
选取子模块,用于根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述第三确定子模块包括:
获取单元,用于获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
选取单元,用于根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述选取单元,还用于:
根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
可选地,所述装置还包括:
确定模块,用于确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
第二更新模块,用于分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
可选地,所述第二更新模块,还用于通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
Figure BDA0003485882740000141
其中,α∈[0,1),α表示注意力分数的衰减率,
Figure BDA0003485882740000142
为客户端i的更新后的本地模型的注意力分数,
Figure BDA0003485882740000143
为客户端i的本地模型的注意力分数,
Figure BDA0003485882740000144
为客户端k的本地模型的注意力分数,
Figure BDA0003485882740000145
为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,
Figure BDA0003485882740000146
为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
需要说明的是,上述各个模块是可以通过软件或硬件来实现的,对于后者,可以通过以下方式实现,但不限于此:上述模块均位于同一处理器中;或者,上述各个模块以任意组合的形式分别位于不同的处理器中。
本发明的实施例还提供了一种计算机可读的存储介质,该存储介质中存储有计算机程序,其中,该计算机程序被设置为运行时执行上述任一项方法实施例中的步骤。
可选地,在本实施例中,上述存储介质可以被设置为存储用于执行以下步骤的计算机程序:
S1,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
S2,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,在本实施例中,上述存储介质可以包括但不限于:U盘、只读存储器(Read-Only Memory,简称为ROM)、随机存取存储器(Random Access Memory,简称为RAM)、移动硬盘、磁碟或者光盘等各种可以存储计算机程序的介质。
本发明的实施例还提供了一种电子装置,包括存储器和处理器,该存储器中存储有计算机程序,该处理器被设置为运行计算机程序以执行上述任一项方法实施例中的步骤。
可选地,上述电子装置还可以包括传输设备以及输入输出设备,其中,该传输设备和上述处理器连接,该输入输出设备和上述处理器连接。
可选地,在本实施例中,上述处理器可以被设置为通过计算机程序执行以下步骤:
S1,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
S2,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
可选地,本实施例中的具体示例可以参考上述实施例及可选实施方式中所描述的示例,本实施例在此不再赘述。
显然,本领域的技术人员应该明白,上述的本发明的各模块或各步骤可以用通用的计算装置来实现,它们可以集中在单个的计算装置上,或者分布在多个计算装置所组成的网络上,可选地,它们可以用计算装置可执行的程序代码来实现,从而,可以将它们存储在存储装置中由计算装置来执行,并且在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。这样,本发明不限制于任何特定的硬件和软件结合。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (10)

1.一种基于联邦学习的模型更新方法,应用于服务端,其特征在于,包括:
利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
2.根据权利要求1所述的方法,其特征在于,根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型包括:
通过以下方式根据所述部分客户端的本地模型的参数确定所述总模型的更新参数:
Figure FDA0003485882730000011
通过以下方式根据所述总模型的更新参数确定所述更新后的总模型的参数:
Figure FDA0003485882730000012
wt+1为所述更新后的模型的参数,wt为所述总模型的参数,η为学习率,
Figure FDA0003485882730000013
为所述总模型的更新参数,nk为客户端K提供的数据维度,
Figure FDA0003485882730000014
为所述部分客户端提供的数据维度,gk是客户端k的本地模型的更新参数,St是被选中的所述部分客户端的集合。
3.根据权利要求1所述的方法,其特征在于,利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数包括:
利用注意力机制确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
4.根据权利要求3所述的方法,其特征在于,利用注意力机制确定所述多个客户端的本地模型的参数的概率分布包括:
获取上一次更新所述总模型之后确定的所述多个客户端的本地模型的注意力分数,其中,所述注意力分数是根据所述多个客户端的本地模型中每个本地模型的参数与所述总模型的参数的欧几里得距离确定的;
根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
5.根据权利要求4所述的方法,其特征在于,根据所述注意力分数从所述多个客户端的本地模型中选取所述部分客户端的本地模型包括:
根据所述注意力分数确定所述多个客户端的本地模型的参数的概率分布;
根据所述概率分布从所述多个客户端的本地模型中选取所述部分客户端的本地模型。
6.根据权利要求3所述的方法,其特征在于,在根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新之后,所述方法还包括:
确定所述部分客户端的本地模型中每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离;
分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数,其中,除所述部分客户端之外的未被选中的客户端的本地模型的注意力分数不变。
7.根据权利要求6所述的方法,其特征在于,所述方法还包括:
通过以下方式分别根据所述每个本地模型的参数与所述更新后的总模型的参数的欧几里得距离更新所述部分客户端的本地模型的注意力分数:
Figure FDA0003485882730000031
其中,α∈[0,1),α表示注意力分数的衰减率,
Figure FDA0003485882730000032
为客户端i的更新后的本地模型的注意力分数,
Figure FDA0003485882730000033
为客户端i的本地模型的注意力分数,
Figure FDA0003485882730000034
为客户端k的本地模型的注意力分数,
Figure FDA0003485882730000035
为客户端i的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,
Figure FDA0003485882730000036
为客户端k的本地模型的参数与所述更新后的总模型的参数的欧几里得距离,St为所述部分客户端的集合。
8.一种基于联邦学习的模型更新装置,应用于服务端,其特征在于,包括:
选取模块,用于利用注意力机制从多个客户端的本地模型中选取部分客户端的本地模型的参数,其中,所述多个客户端的本地模型分别是所述多个客户端基于服务端下载的总模型进行本地训练得到的;
第一更新模块,用于根据所述部分客户端的本地模型的参数对所述总模型的参数进行更新,得到更新后的总模型。
9.一种计算机可读的存储介质,其特征在于,所述存储介质中存储有计算机程序,其中,所述计算机程序被设置为运行时执行所述权利要求1至7任一项中所述的方法。
10.一种电子装置,包括存储器和处理器,其特征在于,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行所述权利要求1至7中任一项所述的方法。
CN202210080990.9A 2022-01-24 2022-01-24 一种基于联邦学习的模型更新方法及装置 Active CN114492849B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210080990.9A CN114492849B (zh) 2022-01-24 2022-01-24 一种基于联邦学习的模型更新方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210080990.9A CN114492849B (zh) 2022-01-24 2022-01-24 一种基于联邦学习的模型更新方法及装置

Publications (2)

Publication Number Publication Date
CN114492849A true CN114492849A (zh) 2022-05-13
CN114492849B CN114492849B (zh) 2023-09-08

Family

ID=81474358

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210080990.9A Active CN114492849B (zh) 2022-01-24 2022-01-24 一种基于联邦学习的模型更新方法及装置

Country Status (1)

Country Link
CN (1) CN114492849B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114782758A (zh) * 2022-06-21 2022-07-22 平安科技(深圳)有限公司 图像处理模型训练方法、***、计算机设备及存储介质

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112580821A (zh) * 2020-12-10 2021-03-30 深圳前海微众银行股份有限公司 一种联邦学习方法、装置、设备及存储介质
CN113011599A (zh) * 2021-03-23 2021-06-22 上海嗨普智能信息科技股份有限公司 基于异构数据的联邦学习***
CN113221470A (zh) * 2021-06-10 2021-08-06 南方电网科学研究院有限责任公司 一种用于电网边缘计算***的联邦学习方法及其相关装置
CN113378243A (zh) * 2021-07-14 2021-09-10 南京信息工程大学 一种基于多头注意力机制的个性化联邦学习方法
WO2021179720A1 (zh) * 2020-10-12 2021-09-16 平安科技(深圳)有限公司 基于联邦学习的用户数据分类方法、装置、设备及介质
WO2021184836A1 (zh) * 2020-03-20 2021-09-23 深圳前海微众银行股份有限公司 识别模型的训练方法、装置、设备及可读存储介质
CN113537509A (zh) * 2021-06-28 2021-10-22 南方科技大学 协作式的模型训练方法及装置
CN113705610A (zh) * 2021-07-26 2021-11-26 广州大学 一种基于联邦学习的异构模型聚合方法和***
CN113781397A (zh) * 2021-08-11 2021-12-10 中国科学院信息工程研究所 基于联邦学习的医疗影像病灶检测建模方法、装置及***

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2021184836A1 (zh) * 2020-03-20 2021-09-23 深圳前海微众银行股份有限公司 识别模型的训练方法、装置、设备及可读存储介质
WO2021179720A1 (zh) * 2020-10-12 2021-09-16 平安科技(深圳)有限公司 基于联邦学习的用户数据分类方法、装置、设备及介质
CN112580821A (zh) * 2020-12-10 2021-03-30 深圳前海微众银行股份有限公司 一种联邦学习方法、装置、设备及存储介质
CN113011599A (zh) * 2021-03-23 2021-06-22 上海嗨普智能信息科技股份有限公司 基于异构数据的联邦学习***
CN113221470A (zh) * 2021-06-10 2021-08-06 南方电网科学研究院有限责任公司 一种用于电网边缘计算***的联邦学习方法及其相关装置
CN113537509A (zh) * 2021-06-28 2021-10-22 南方科技大学 协作式的模型训练方法及装置
CN113378243A (zh) * 2021-07-14 2021-09-10 南京信息工程大学 一种基于多头注意力机制的个性化联邦学习方法
CN113705610A (zh) * 2021-07-26 2021-11-26 广州大学 一种基于联邦学习的异构模型聚合方法和***
CN113781397A (zh) * 2021-08-11 2021-12-10 中国科学院信息工程研究所 基于联邦学习的医疗影像病灶检测建模方法、装置及***

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
JIYING: "加入联邦学习的客户端设备-随机选择真的好吗?", Retrieved from the Internet <URL:https://m.thepaper.cn/baijiahao_12522801> *
ZHENG CHAI ET AL.: "TiFL: A Tier-based Federated Learning System", 《ARXIV》 *
王健宗 等: "联邦学习算法综述", 《大数据》 *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114782758A (zh) * 2022-06-21 2022-07-22 平安科技(深圳)有限公司 图像处理模型训练方法、***、计算机设备及存储介质
CN114782758B (zh) * 2022-06-21 2022-09-02 平安科技(深圳)有限公司 图像处理模型训练方法、***、计算机设备及存储介质

Also Published As

Publication number Publication date
CN114492849B (zh) 2023-09-08

Similar Documents

Publication Publication Date Title
CN109754105B (zh) 一种预测方法及终端、服务器
RU2497293C2 (ru) Способ и система передачи информации в социальной сети
CN108052639A (zh) 基于运营商数据的行业用户推荐方法及装置
US20230106985A1 (en) Developing machine-learning models
CN106055630A (zh) 日志存储的方法及装置
CN107404541A (zh) 一种对等网络传输邻居节点选择的方法及***
CN113094181A (zh) 面向边缘设备的多任务联邦学习方法及装置
WO2021008675A1 (en) Dynamic network configuration
Lee et al. Accurate and fast federated learning via IID and communication-aware grouping
CN114492849A (zh) 一种基于联邦学习的模型更新方法及装置
CN117499297B (zh) 数据包传输路径的筛选方法及装置
CN116989819B (zh) 一种基于模型解的路径确定方法及装置
CN108810089B (zh) 一种信息推送方法、装置及存储介质
CN114021017A (zh) 信息推送方法、装置及存储介质
CN112766560B (zh) 联盟区块链网络优化方法、装置、***和电子设备
CN110457387B (zh) 一种应用于网络中用户标签确定的方法及相关装置
CN112465371A (zh) 一种资源数据分配方法、装置及设备
CN116843016A (zh) 一种移动边缘计算网络下基于强化学习的联邦学习方法、***及介质
CN111291092A (zh) 一种数据处理方法、装置、服务器及存储介质
CN109635183A (zh) 一种基于社区的合作者推荐方法
CN109299388A (zh) 一种查找高质量社交用户的***及方法
CN110781384B (zh) 一种基于优先级的内容推荐方法、装置、设备及介质
CN107688582A (zh) 资源推荐模型的获取方法及装置
EP3216167B1 (en) Orchestrator and method for virtual network embedding using offline feedback
CN111324444A (zh) 一种云计算任务调度方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant