CN111882133B - 一种基于预测的联邦学习通信优化方法及*** - Google Patents
一种基于预测的联邦学习通信优化方法及*** Download PDFInfo
- Publication number
- CN111882133B CN111882133B CN202010768983.9A CN202010768983A CN111882133B CN 111882133 B CN111882133 B CN 111882133B CN 202010768983 A CN202010768983 A CN 202010768983A CN 111882133 B CN111882133 B CN 111882133B
- Authority
- CN
- China
- Prior art keywords
- user
- prediction
- update
- submodule
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q10/00—Administration; Management
- G06Q10/04—Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/20—Information retrieval; Database structures therefor; File system structures therefor of structured data, e.g. relational data
- G06F16/23—Updating
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- H—ELECTRICITY
- H04—ELECTRIC COMMUNICATION TECHNIQUE
- H04L—TRANSMISSION OF DIGITAL INFORMATION, e.g. TELEGRAPHIC COMMUNICATION
- H04L67/00—Network arrangements or protocols for supporting network services or applications
- H04L67/01—Protocols
- H04L67/10—Protocols in which an application is distributed across nodes in the network
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Business, Economics & Management (AREA)
- General Engineering & Computer Science (AREA)
- Economics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Strategic Management (AREA)
- Human Resources & Organizations (AREA)
- Life Sciences & Earth Sciences (AREA)
- Mathematical Physics (AREA)
- Development Economics (AREA)
- Computational Linguistics (AREA)
- Software Systems (AREA)
- Biophysics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Computer Networks & Wireless Communication (AREA)
- Signal Processing (AREA)
- Databases & Information Systems (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Game Theory and Decision Science (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Entrepreneurship & Innovation (AREA)
- Marketing (AREA)
- Operations Research (AREA)
- Quality & Reliability (AREA)
- Tourism & Hospitality (AREA)
- General Business, Economics & Management (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及联邦机器学习领域,公开了一种基于预测的联邦学习通信优化方法及***。本发明中,首先,初始化全局模型以及本发明中所需的全局变量,每个用户根据其本地数据进行本地模型训练,得到本地更新。随后,云中心分别根据每个用户的历史模型更新趋势,预测其本地更新。然后,通过计算每个用户采用其预测更新时全局模型的损失函数变化,设置其预测误差阈值,其中包括初始阈值和动态阈值设置两个步骤。最后,根据设置的预测误差阈值设计全局模型更新策略,云中心采用准确的预测更新代替本地更新计算全局模型更新。解决了联邦学习技术中,终端用户与云中心频繁传递更新参数所导致的高通信代价问题。
Description
技术领域
本发明涉及联邦机器学习领域,更具体地,涉及一种基于预测的联邦学习通信优化方法,用于解决联邦学习技术中终端用户/设备与云中心频繁传递更新参数所导致的高通信代价问题。
背景技术
机器学习作为人工智能领域的一个重要分支,被成功且广泛的应用于模式识别、数据挖掘和计算机视觉等各个领域。由于终端设备计算资源受限,目前对于机器学习模型的训练通常采用基于云的方式,在这种方式中,终端设备所收集的数据,如图片、视频,或者个人位置信息,必须上传至云中心集中完成模型的训练。然而,上传用户真实数据会泄露其隐私,出于隐私保护的考虑,终端用户不愿共享其隐私数据。从长远来看,这严重阻碍了机器学习技术的发展和应用。
因此,为了保护终端用户的敏感数据,同时又不影响机器学习模型的训练,联邦学习应运而生。在联邦学习环境中,用户不用上传其敏感数据至云中心,而只需共享其本地更新,云中心通过与终端用户多次交互,迭代计算得到全局模型更新,既保护了用户的敏感数据,又得到了最终的可用模型。
在联邦学习环境中,终端用户与云中心需要多轮交互才能获得目标精度的全局模型。那么,对于复杂的模型训练,如深度学习模型训练,每次模型更新可能包含数百万个参数,模型更新的高维性将耗费大量的通信成本,甚至成为一个模型训练瓶颈。此外,由于终端用户/设备的异构性,每个设备网络状态的不可靠性以及互联网连接速度的不对称性,如下传速度大于上载速度,导致终端用户上传更新参数的延迟,都会使模型训练瓶颈进一步恶化。
目前,为了解决联邦学习的高通信代价问题,国内外研究学者纷纷对其进行了大量研究,并提出了许多有效的通信优化方法。根据其优化的目标,这些解决方法大致可划分为两类:一类是以减少终端用户与云中心通信轮数为目标;另一类是以减少终端用户与云中心通信量为目标。在以减少通信量为目标的方法中,通常对本地更新进行压缩、轻量化、知识蒸馏以及稀疏化等操作,使得上传的模型更加紧致,从而达到通信量减少的目的。然而,由于模型压缩通常会造成模型信息量的丢失,甚至不能保证模型收敛,因此,越来越多的研究学者开始研究以减少通信次数为目标的通信优化方法。
主流的通信次数减少方法可划分为两类,一类是基于模型收敛的方法,另一类是基于重要性的方法。在基于模型收敛的方法中,通常采用增加本地模型训练迭代轮数、减少每轮本地训练batch块的大小或者修改联邦学习算法等方式加快模型学习速度,使得每次通信迭代上传的本地模型更新更有利于全局模型的收敛;另一类则是通过研究本地更新与全局模型更新的相关性或者计算本地更新对全局模型的重要性,选择重要的或者与全局模型收敛趋势相同的本地更新上传至云中心。虽然这两类方法能够从一定程度上提高联邦学习的通信效率,但它们仍然存在以下不足:基于模型收敛的方法,通常是以消耗更多的本地计算资源为代价,然而,在联邦学习环境中,终端通常是资源受限的异构设备,它们没有足够的计算资源来处理复杂模型的训练,因此,将该算法运用于实际场景的联邦通信优化具有一定的挑战性;基于重要性的方法中,本地更新的重要性或者相关性都是通过一个可调的阈值判断,且这个阈值的设置通常是基于最大化通信次数减少为目标,因此,这类算法由于大量本地更新没有被聚集,而导致严重的模型准确率降低。
综合所述,为了弥补基于云训练所造成的用户敏感数据泄露以及模型可用性问题,联邦学习应运而生。然而由于模型训练参数的高维性以及联邦学习环境中网络的不可靠性,使得通信代价问题成为联邦学习中基础且重要的问题。虽然现有研究方法分别从减小通信量和通信轮数两个方面提出了许多有效的通信优化方法,但他们通常伴随着其他方面的不足,如需要消耗更多的本地计算资源或者严重降低训练模型的准确率,因此,为了更好的解决联邦学习的高通信代价问题,需要设计一种既不需要消耗更多本地计算资源,又能极大减少所需的通信轮数同时保证训练模型准确率的方法。
基于上述背景,本发明提出了一种简单易实现的基于预测的联邦学习通信优化方法,为联邦学习高通信代价问题的解决奠定基础。
发明内容
为了更有效的解决联邦学习的高通信代价问题,本发明提出了一种基于预测的联邦学习通信优化方法。首先,初始化全局模型以及本发明中所需的全局变量,每个终端用户根据其本地数据进行本地模型训练,得到本地更新。随后,云中心分别根据每个终端用户的历史模型更新趋势,预测其本地更新。然后,通过计算每个终端用户采用预测更新,全局模型损失函数的变化,设置其预测误差阈值,其中包括初始阈值和动态阈值设置两个步骤。最后,根据设置的预测误差阈值设计全局模型更新策略,云中心采用准确的预测更新代替本地更新计算全局模型更新。
本发明提出的基于预测的联邦学习通信优化方法,包括以下步骤,
步骤S1,云中心初始化,包括搭建训练模型、初始化全局模型以及所需的全局变量,包括以下子步骤:
步骤S1-1,用于搭建训练模型,其包括输入层、隐藏层以及输出层的神经元个数设计;
步骤S1-2,用于初始化全局模型,其包括全局模型参数W0,全局模型更新G0;
步骤S1-3,用于初始化全局变量,其包括由n个终端用户组成的用户集合U={u1,u2,...,uj,...,un},通信轮数R;
步骤S2,本地模型训练,由n个终端用户组成的用户集合U={u1,u2,...,uj,...,un}中的每个用户uj根据其本地数据并行地进行本地模型训练,得到在第t轮迭代的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t},以用户uj为例,包括以下子步骤:
步骤S2-1,用于从云中心获取聚集的全局模型参数Wt;
步骤S2-2,用于根据本地数据进行本地模型训练,得到用户uj在第t轮迭代的本地更新Lj,t;
重复步骤S2,得到用户集合U中所有用户的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t};
步骤S3,本地更新预测,预测用户集合U中每个用户uj在第t轮迭代的本地更新,得到第t轮迭代的预测更新集合Pt,其中,Pt={P1,t,P2,t,...,Pj,t,...,Pn,t},Pj,t表示用户uj的预测更新,k表示更新参数的维度,以用户uj为例,包括以下子步骤:
步骤S3-1,用于从云中心获取用户uj的历史参数更新集合Hj,计算用户uj在第t-1轮迭代的一步预测更新,其中,Hj=<Hj,1,Hj,2,...,Hj,i,...,Hj,t-1>,k表示更新参数的维度,下面以用户uj第d维更新参数为例,假设由用户uj第d维更新参数组成的历史参数更新集合为则用户uj第d维更新参数的一步预测更新值可表示为:
步骤S3-2,用于计算第t-1轮迭代的状态协方差矩阵mt-1,其计算公式如(2)所示:
mt-1=f*mt-2*fT+q (2)
其中,q为预测噪声,fT为状态转移矩阵f的转置;
步骤S3-3,用于计算第t-1轮迭代的卡尔曼增益zt-1,其计算公式如(3)所示:
其中,r表示本地更新协方差,c表示转换矩阵;
步骤S3-5,用于更新第t轮迭代的状态协方差矩阵mt,更新公式如(5)所示:
mt=(1-zt-1*c)*mt-1 (5)
重复步骤S3,并行计算得到用户集合U中所有用户第t轮迭代的预测更新集合Pt,其中,Pt={P1,t,P2,t,...,Pj,t,...,Pn,t},Pj,t表示用户uj的预测更新k表示更新参数的维度;
步骤S4,设置预测误差阈值,并行计算得到用户集合U中每个用户uj在第t-1轮迭代采用其预测更新时全局模型的损失函数变化e,并为每个用户uj设置预测误差阈值,以用户uj为例,包括以下子步骤:
步骤S4-1,用于从云中心获取用户uj第t-1轮迭代的预测更新Pj,t-1以及用户集合U中所有用户的本地更新集合Lt-1;
步骤S4-2,用于检查标记变量Checkj,若Checkj=true,则进入步骤S4-3;反之,若Checkj=false,则进入步骤S4-7;
步骤S4-3,用于计算用户uj第t-1轮迭代采用预测更新Pj,t-1时的全局模型更新Gj,t-1、全局模型Wj,t-1,用户集合U中所有用户采用本地更新时的全局模型更新Gall,t-1、全局模型Wall,t-1以及全局模型的损失函数变化e,具体计算公式如(6)、(7)、(8)、(9)、(10)所示:
其中,L-j,t-1表示非用户uj第t-1轮迭代的本地更新;
用户uj第t-1轮迭代采用预测更新Pj,t-1时的全局模型Wj,t-1以及用户集合U中所有用户采用本地更新时的全局模型Wall,t-1的计算公式分别(8)、(9)所示:
Wj,t-1=Wt-2-Gj,t-1 (8)
Wall,t-1=Wt-2-Gall,t-1 (9)
其中,Wt-2表示第t-2轮迭代的全局模型;
进一步,全局模型的损失函数变化e的计算公式如(10)所示:
其中,f(.)表示损失函数,|.|表示绝对值;
步骤S4-4,用于比较全局模型的损失函数变化e与预先设定阈值δ的大小,若e≤δ,则进入步骤S4-6,同时设置Checkj=false,变量Tj=Tj+1;反之,若e>δ,则进入步骤S4-5;
步骤S4-5,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池,模型训练进入下一轮迭代;
步骤S4-6,用于设置预测误差初始阈值vj,0,其具体计算公式如下:
vj,0=||Pj,t-1-Lj,t-1|| (11)
其中,||.||表示两个向量的内积;
步骤S4-7,用于设置用户uj第t轮迭代的预测误差阈值vj,t,其具体计算公式如下:
其中,参数Tj表示当前迭代轮数与用户uj设置初始预测误差阈值vj,0时轮数的差值;
重复步骤S4,为用户集合U中每个用户uj设置预测误差阈值;
步骤S5,全局模型更新策略,为用户集合U中每个用户uj制定全局模型更新策略,以用户uj为例,包括以下子步骤:
步骤S5-1,用于计算第t轮迭代的预测更新误差Δj,t,其具体公式如下:
Δj,t=||Pj,t-Lj,t|| (13)
步骤S5-2,比较Δj,t与vj,t的大小,若Δj,t≤vj,t,表示预测更新准确,则进入步骤S5-3;反之,若Δj,t>vj,t,表示预测参数不准确,则进入步骤S5-4;
步骤S5-3,云中心采用用户uj的预测更新Pj,t进行全局模型聚集;
步骤S5-4,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池;
重复步骤S5,为用户集合U中每个用户uj制定全局模型更新策略;
步骤S6,云中心全局模型更新,云中心聚集用户集合U中所有用户上传的本地更新或者云中心准确的预测更新,计算得到聚集的全局模型更新和全局模型,模型训练进入下一轮迭代;
重复以上步骤S1~S6,直至全局模型收敛,模型训练结束。
同时,本发明还相应提供了一种基于预测的联邦学习通信优化***,如图4所示,包含:
初始化模块,用户搭建训练模型、初始化全局模型以及所需的全局变量,包含以下子模块,
训练模型构建子模块,用于搭建训练模型,主要包括输入层、隐藏层以及输出层的神经元个数设计;
全局模型初始化子模块,用于初始化全局模型和全局模型更新;
全局变量初始化子模块,用于初始化通信轮数;
本地模型训练模块,用于用户集合U中每个用户根据其本地训练数据并行地进行本地模型训练,得到用户在第t轮迭代的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t},以用户uj为例,包含以下子模块,
全局模型输入子模块,用于从云中心获取用户在第t轮迭代的全局模型参数;
模型训练子模块,用于根据本地数据并行地进行本地模型训练,得到用户uj在第t轮迭代的本地更新Lj,t;
并行训练子模块,用于并行执行全局模型输入子模块和模型训练子模块,得到用户集合U中所有用户的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t};
本地更新预测模块,用于预测用户集合U中每个用户在第t轮迭代的本地更新,得到第t轮迭代的预测更新集合Pt={P1,t,P2,t,...,Pj,t,...,Pn,t},Pj,t表示用户uj的预测更新,k表示更新参数的维度,以用户uj为例,包含以下子模块,历史更新输入子模块,用于从云中心获取用户的历史本地更新集合;
中间变量子模块,用于存储中间步骤所计算得到的中间变量值,这些中间变量值主要包括用户uj第d维更新参数在第t-1轮迭代的一步预测更新值状态协方差矩阵mt-1,卡尔曼增益zt-1,用户uj第d维更新参数在第t轮迭代的预测更新值状态协方差矩阵mt;
并行预测子模块,用于并行执行历史更新输入子模块、中间变量子模块以及预测更新输出子模块,预测得到用户集合U中所有用户在第t轮迭代的预测更新集合Pt,其中,Pt={P1,t,P2,t,...,Pj,t,...,Pn,t};
预测误差阈值设置模块,用于并行计算得到用户集合U中每个用户uj在第t-1轮迭代采用其预测更新时全局模型的损失函数变化e,并为每个用户设置预测误差阈值,以用户uj为例,包含以下子模块;
变量判断子模块,用于判断用户是否已经设置预测误差初始阈值,若标记变量Checkj=true,表示用户uj未设置预测误差初始阈值,则进入全局损失函数变化计算子模块;反之,进入预测误差动态阈值设置子模块;
全局模型的损失函数变化计算子模块,用于计算用户在第t-1轮迭代采用预测更新时全局模型的损失函数变化e;
损失函数判断子模块,用于比较全局模型的损失函数变化e与预先设定阈值δ的大小,若e≤δ,则进入预测误差初始阈值设置子模块;
预测误差初始阈值设置子模块,用于设置用户的预测误差初始阈值vj,0;
预测误差动态阈值设置子模块,用于设置用户在第t轮迭代的预测误差阈值vj,t;
并行设置子模块,用于并行执行变量判断子模块、全局损失函数变化计算子模块、损失函数判断子模块、预测误差初始阈值设置子模块以及预测误差动态阈值设置子模块,得到每个用户的预测误差阈值;
全局模型更新策略模块,用于为用户集合U中每个用户制定全局模型更新策略,以用户uj为例,包含以下子模块,
预测误差阈值输入子模块,用于获取用户在第t轮迭代的预测误差阈值vj,t;
变量判断子模块,用于判断用户是否设置预测误差阈值,若已设置,则进入预测误差计算子模块,反之,进入本地更新上传子模块;
预测误差计算子模块,用于计算用户在第t轮迭代的预测更新误差Δj,t;
预测准确性判断子模块,用于比较用户的预测误差Δj,t与预测误差阈值vj,t的大小,若Δj,t>vj,t,则进入本地更新上传子模块;
本地更新上传子模块,用于上传用户的本地更新Lj,t至云中心及预测资源池;
通信轮数计算及输出子模块,用于计算和输出模型训练的通信轮数;
云中心全局模型更新模块,用于计算全局模型更新和判定训练模型是否收敛,包含以下子模块,
全局模型更新子模块,用于聚集上传的本地更新和云中心中准确的预测更新,计算得到全局模型更新和全局模型,模型训练进入下一轮更新迭代;
终止判定子模块,用于判定训练模型是否收敛,若收敛,则模型训练结束;反之,进入下一轮训练迭代。
本发明根据本地模型的历史更新趋势,预测本地更新,然后通过计算全局模型的损失函数变化设置预测误差阈值,并根据设置的预测误差阈值设计全局模型更新策略,云中心采用准确的预测更新代替本地更新计算全局模型更新,解决了联邦学习技术中,终端用户与云中心频繁传递更新参数所导致的高通信代价问题,与现有技术相比,具有以下有益效果:
(1)本发明所提方法及***不仅可以极大减少终端用户与云中心的通信轮次,而且可以极小降低训练模型的准确率;
(2)由于本发明将本地更新的预测放置在资源丰富的云中心,而终端用户只需进行简单的预测准确性判断,因此,可以消耗极少的本地计算资源;
(3)本发明的本地更新预测部分采用卡尔曼滤波预测,由于卡尔曼滤波能够对数据进行实时处理,具有较好的预测效果,且便于计算机编程实现,因此采用卡尔曼滤波预测不仅可以获得准确的本地更新预测,而且可以进一步降低计算复杂度,便于算法高效实施。
附图说明
图1是本发明实施例提供的总体方法流程图。
图2是本发明实施例提供的具体步骤流程图。
图3是本发明实施例提供的总体原理示意图。
图4是本发明实施例基于预测的联邦学习通信优化***模块设计示意图。
具体实施方式
以下将结合附图及实施例,对本发明的构思、具体结构及产生的技术效果作进一步说明,以充分地了解本发明的目的、特征和效果。
本发明技术方案所提供方法可采用计算机软件技术实现自动运行流程,图1是本发明实施例的总体方法流程图,参见图1,结合图2本发明实施例的具体步骤流程图,本发明基于预测的联邦学习通信优化方法的实施例具体步骤包括:
步骤S1,云中心初始化,包括搭建训练模型、初始化全局模型以及所需的全局变量,包括以下子步骤:
步骤S1-1,用于搭建训练模型,其包括输入层、隐藏层以及输出层的神经元个数设计;
实施例中,模拟输入层和输出层分别为784和1个神经元节点的线性回归模型;
步骤S1-2,用于初始化全局模型,其包括全局模型参数W0,全局模型更新G0;
实施例中,初始化全局模型参数W0,全局模型更新G0;
步骤S1-3,用于初始化全局变量,其包括由n个终端用户组成的用户集合U={u1,u2,...,uj,...,un},通信轮数R;
实施例中,初始化用户集合U={u1,u2,...,uj,...,u100},通信轮数R=0;
步骤S2,本地模型训练,由n个终端用户组成的用户集合U={u1,u2,...,uj,...,un}中的每个用户uj根据其本地数据并行地进行本地模型训练,得到在第t轮迭代的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t}:
步骤S2-1,用于从云中心获取聚集的全局模型参数Wt;
实施例中,假定当前迭代轮次t=4,以用户u100为例,则用户u100从云中心获取聚集的全局模型参数W4;
步骤S2-2,用于根据本地数据进行本地模型训练,得到用户uj在第t轮迭代的本地更新Lj,t;
实施例中,用户u100根据其本地数据进行本地模型训练,得到在第t=4轮迭代本地模型更新L100,4;
重复步骤S2,得到用户集合U中所有用户的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t};
实施例中,重复步骤S2,得到用户集合U中所有用户的本地更新集合L={L1,4,L2,4,...,Lj,4,...,L100,4};
步骤S3,本地更新预测,预测用户集合U中每个用户uj在第t轮迭代的本地更新,得到第t轮迭代的预测更新集合Pt,其中,Pt={P1,t,P2,t,...,Pj,t,...,Pn,t},Pj,t表示用户uj的预测更新,k表示更新参数的维度,以用户uj为例,包括以下子步骤:
步骤S3-1,用于从云中心获取用户uj的历史参数更新集合Hj,计算用户uj在第t-1轮迭代的一步预测更新,其中,Hj=<Hj,1,Hj,2,...,Hj,i,...,Hj,t-1>,k表示更新参数的维度,下面以用户uj第d维更新参数为例,假设由用户uj第d维更新参数组成的历史参数更新集合为则根据公式计算得到用户uj第d维更新参数的一步预测更新值
实施例中,从云中心获取用户u100的历史参数更新集合H100=<H100,1,H100,2,H100,3>,以用户u100第784维更新参数为例,由用户u100第784维参数组成的历史参数更新集合为设置公式中f=1,b=0,从而计算得到用户u100的第784维参数的一步预测更新值
步骤S3-2,用于根据公式mt-1=f*mt-2*fT+q,计算第t-1轮迭代的状态协方差矩阵mt-1;
实施例中,设置q=0.001,根据公式mt-1=f*mt-2*fT+q,计算第t=3轮迭代的状态协方差矩阵m3=m2+q→m3=m2+0.001;
步骤S3-5,用于根据公式mt=(1-zt-1*c)*mt-1,更新第t轮迭代的状态协方差矩阵mt;
实施例中,根据公式mt=(1-zt-1*c)*mt-1,更新第t=4轮迭次的状态协方差矩阵m4=(1-z3)*m3;
重复步骤S3,并行计算得到用户集合U中所有用户第t轮迭代的预测更新集合Pt,其中,Pt={P1,t,P2,t,...,Pj,t,...,Pn,t},Pj,t表示用户uj的预测更新k表示更新参数的维度;
实施例中,重复步骤S3,并行计算得到用户集合U中所有用户第t=4轮迭代的预测更新向量集P4,其中,P4={P1,4,P2,4,...,Pj,4,...,P100,4},P100,4表示用户u100的预测更新,784表示更新向量的维度大小;
步骤S4,设置预测误差阈值,并行计算得到用户集合U中每个用户uj在第t-1轮迭代采用其预测更新时全局模型的损失函数变化e,并为每个用户uj设置预测误差阈值,以用户uj为例,包括以下子步骤:
步骤S4-1,用于从云中心获取用户uj第t-1轮迭代的预测更新Pj,t-1以及集合U中所有用户的本地更新集合Lt-1;
实施例中,从云中心获取用户u100第t=3轮迭代的预测更新P100,3以及U中所有用户的本地更新集合L3;
步骤S4-2,用于检查标记变量Checkj,若Checkj=true,则进入步骤S4-3;反之,若Checkj=false,则进入步骤S4-7;
实施例中,检查用户u100的标记变量Check100,若Check100=true,进入步骤S4-3;反之,若Check100=false,则进入步骤S4-7;
步骤S4-3,用于根据公式以及Wj,t-1=Wt-2-Gj,t-1,计算用户uj第t-1轮迭代采用预测更新Pj,t-1时的全局模型更新Gj,t-1(j=1,2,...n)、全局模型Wj,t-1(j=1,2,...n),根据公式和Wall,t-1=Wt-2-Gall,t-1,计算用户集合U中所有用户采用本地更新时的全局模型更新Gall,t-1、全局模型Wall,t-1以及根据公式计算全局模型的损失函数变化e;
实施例中,根据公式和公式计算第3轮迭代,用户u100采用预测更新P100,3时的全局模型更新G100,3以及用户集合U中所有用户采用本地更新时的全局模型更新Gall,3,并根据公式Wj,t-1=Wt-2-Gj,t-1和Wall,t-1=Wt-2-Gall,t-1分别计算得到用户u100第t=3轮迭代的全局模型W100,3和Wall,3,根据公式计算全局模型的损失函数变化
步骤S4-4,用于比较全局模型的损失函数变化e与预先设定阈值δ的大小,若e≤δ,则进入步骤S4-6,同时设置Checkj=false,变量Tj=Tj+1;反之,若e>δ,则进入步骤S4-5;
实施例中,设置δ=0.01,比较全局模型的损失函数变化e与预先设定阈值δ的大小,若e≤0.01,则进入步骤S4-6,同时设置Check100=false,变量T100=T100+1;反之,若e>0.01,则进入步骤S4-5;
步骤S4-5,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池,模型训练进入下一轮迭代;
实施例中,上传用户u100的本地更新L100,4至云中心及预测资源池,设置通信轮数R=R+1,模型训练进入下一轮迭代;
步骤S4-6,用于根据公式vj,0=||Pj,t-1-Lj,t-1||,设置预测误差初始阈值vj,0;
实施例中,设置用户u100的预测误差初始阈值v100,0=||P100,3-L100,3||;
重复步骤S4,为用户集合U中每个用户uj设置预测误差阈值;
实施例中,重复步骤S4,为用户集合U={u1,u2,...,uj,...,u100}中每个用户设置预测误差阈值;
步骤S5,全局模型更新策略,为用户集合U中每个用户uj制定全局模型更新策略,以用户uj为例,包括以下子步骤:
步骤S5-1,用于根据公式Δj,t=||Pj,t-Lj,t||,计算第t轮迭代的预测更新误差Δj,t;
实施例中,计算用户u100当前迭代轮次t=4的预测误差Δ100,4=||P100,4-L100,4||;
步骤S5-2,比较Δj,t与vj,t的大小,若Δj,t≤vj,t,表示预测更新准确,则进入步骤S5-3;反之,若Δj,t>vj,t,表示预测参数不准确,则进入步骤S5-4;
实施例中,比较用户u100当前迭代轮次t=4的预测误差Δ100,4与设定的预测误差阈值v100,4的大小,若Δ100,4≤v100,4,则进入步骤S5-3;反之,Δ100,4>v100,4,则进入步骤S5-4;
步骤S5-3,云中心采用用户uj的预测更新Pj,t进行全局模型聚集;
实施例中,云中心采用用户u100的预测更新P100,4进行全局模型聚集,;
步骤S5-4,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池;
实施例中,上传用户u100的本地更新L100,4至云中心及预测资源池,设置通信轮数R=R+1;
重复步骤S5,为用户集合U中每个用户uj制定全局模型更新策略;
实施例中,重复步骤S5,为用户集合U={u1,u2,...,uj,...,u100}中每个用户制定全局模型更新策略;
步骤S6,云中心全局模型更新,云中心聚集用户集合U中所有用户上传的本地更新或者云中心准确的预测更新,计算得到聚集的全局模型更新和全局模型,模型训练进入下一轮迭代;
实施例中,云中心聚集用户集合U={u1,u2,...,uj,...,u100}中所有用户上传的本地更新或者云中心准确的预测更新,计算得到全局模型更新Gt和全局模型Wt,模型训练进入下一轮更新迭代;
重复以上步骤S1~S6,直至全局模型收敛,模型训练结束。
实施例中,重复以上步骤S1~S6,直至全局模型收敛,模型训练结束。
本发明提供了本领域技术人员能够实现的技术方案。以上实施例仅供说明本发明之用,而非对本发明的限制,有关技术领域的技术人员,在不脱离本发明的精神和范围的情况下,还可以做出各种变换或变型,因此所有等同的技术方案,都落入本发明的保护范围。
Claims (4)
1.一种基于预测的联邦学习通信优化方法,其特征在于:包括以下步骤,
步骤S1,云中心初始化,包括搭建训练模型、初始化全局模型以及所需的全局变量;
步骤S2,本地模型训练,由n个终端用户组成的用户集合U={u1,u2,...,uj,...,un}中的每个用户uj根据其本地数据并行地进行本地模型训练,得到在第t轮迭代的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t};
步骤S3,本地更新预测,预测用户集合U中每个用户uj在第t轮迭代的本地更新,得到第t轮迭代的预测更新集合Pt,其中,Pt={P1,t,P2,t,...,Pj,t,...,Pn,t},Pj,t表示用户uj的预测更新,k表示更新参数的维度;
步骤S4,设置预测误差阈值,并行计算得到用户集合U中每个用户uj在第t-1轮迭代采用其预测更新时全局模型的损失函数变化e,并为每个用户uj设置预测误差阈值;
步骤S5,全局模型更新策略,为用户集合U中每个用户uj制定全局模型更新策略;
步骤S6,云中心全局模型更新,云中心聚集用户集合U中所有用户上传的本地更新或者云中心准确的预测更新,计算得到聚集的全局模型更新和全局模型,模型训练进入下一轮迭代;
重复以上步骤S1~S6,直至全局模型收敛,模型训练结束。
2.根据权利要求1所述一种基于预测的联邦学习通信优化方法,其特征在于,所述步骤S4包括以下子步骤:
步骤S4-1,用于从云中心获取用户uj第t-1轮迭代的预测更新Pj,t-1以及用户集合U中所有用户的本地更新集合Lt-1;
步骤S4-2,用于检查标记变量Checkj,若Checkj=true,则进入步骤S4-3;反之,若Checkj=false,则进入步骤S4-7;
步骤S4-3,用于计算用户uj第t-1轮迭代采用预测更新Pj,t-1时的全局模型更新Gj,t-1、全局模型Wj,t-1,用户集合U中所有用户采用本地更新时的全局模型更新Gall,t-1、全局模型Wall,t-1以及全局模型的损失函数变化e,具体计算公式如(1)、(2)、(3)、(4)、(5)所示:
其中,L-j,t-1表示非用户uj第t-1轮迭代的本地更新;
用户uj第t-1轮迭代采用预测更新Pj,t-1时的全局模型Wj,t-1以及用户集合U中所有用户采用本地更新时的全局模型Wall,t-1的计算公式分别如(3)、(4)所示:
Wj,t-1=Wt-2-Gj,t-1 (3)
Wall,t-1=Wt-2-Gall,t-1 (4)
其中,Wt-2表示第t-2轮迭代的全局模型;
进一步,全局模型的损失函数变化e的计算公式如(5)所示:
其中,f(.)表示损失函数,|.|表示绝对值;
步骤S4-4,用于比较全局模型的损失函数变化e与预先设定阈值δ的大小,若e≤δ,则进入步骤S4-6,同时设置Checkj=false,变量Tj=Tj+1;反之,若e>δ,则进入步骤S4-5;
步骤S4-5,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池,模型训练进入下一轮迭代;
步骤S4-6,用于设置预测误差初始阈值vj,0,其具体计算公式如下:
vj,0=||Pj,t-1-Lj,t-1|| (6)
其中,||.||表示两个向量的内积;
步骤S4-7,用于设置用户uj第t轮迭代的预测误差阈值vj,t,其具体计算公式如下:
其中,参数Tj表示当前迭代轮数与用户uj设置初始预测误差阈值vj,0时轮数的差值;
重复步骤S4,为用户集合U中每个用户uj设置预测误差阈值。
3.根据权利要求1所述一种基于预测的联邦学习通信优化方法,其特征在于,所述步骤S5包括以下子步骤:
步骤S5-1,用于计算第t轮迭代的预测更新误差Δj,t,其具体公式如下:
Δj,t=||Pj,t-Lj,t|| (8)
步骤S5-2,比较Δj,t与vj,t的大小,若Δj,t≤vj,t,表示预测更新准确,则进入步骤S5-3;反之,若Δj,t>vj,t,表示预测参数不准确,则进入步骤S5-4;
步骤S5-3,云中心采用用户uj的预测更新Pj,t进行全局模型聚集;
步骤S5-4,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池;
重复步骤S5,为用户集合U中每个用户uj制定全局模型更新策略。
4.一种基于预测的联邦学习通信优化***,其特征在于:包括以下模块,
初始化模块,用户搭建训练模型、初始化全局模型以及所需的全局变量,包含以下子模块,
训练模型构建子模块,用于搭建训练模型,主要包括输入层、隐藏层以及输出层的神经元个数设计;
全局模型初始化子模块,用于初始化全局模型和全局模型更新;
全局变量初始化子模块,用于初始化通信轮数;
本地模型训练模块,用于用户集合U中每个用户根据其本地训练数据并行地进行本地模型训练,得到用户在第t轮迭代的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t},以用户uj为例,包含以下子模块,
全局模型输入子模块,用于从云中心获取用户在第t轮迭代的全局模型参数;
模型训练子模块,用于根据本地数据并行地进行本地模型训练,得到用户uj在第t轮迭代的本地更新Lj,t;
并行训练子模块,用于并行执行全局模型输入子模块和模型训练子模块,得到用户集合U中所有用户的本地更新集合L={L1,t,L2,t,...,Lj,t,...,Ln,t};
本地更新预测模块,用于预测用户集合U中每个用户在第t轮迭代的本地更新,得到第t轮迭代的预测更新集合Pt={P1,t,P2,t,...,Pj,t,...,Pn,t},Pj,t表示用户uj的预测更新,k表示更新参数的维度,以用户uj为例,包含以下子模块,历史更新输入子模块,用于从云中心获取用户的历史本地更新集合;
中间变量子模块,用于存储中间步骤所计算得到的中间变量值,这些中间变量值主要包括用户uj第d维更新参数在第t-1轮迭代的一步预测更新值状态协方差矩阵mt-1,卡尔曼增益zt-1,用户uj第d维更新参数在第t轮迭代的预测更新值状态协方差矩阵mt;
并行预测子模块,用于并行执行历史更新输入子模块、中间变量子模块以及预测更新输出子模块,预测得到用户集合U中所有用户在第t轮迭代的预测更新集合Pt,其中,Pt={P1,t,P2,t,...,Pj,t,...,Pn,t};
预测误差阈值设置模块,用于并行计算得到用户集合U中每个用户uj在第t-1轮迭代采用其预测更新时全局模型的损失函数变化e,并为每个用户设置预测误差阈值,以用户uj为例,包含以下子模块;
变量判断子模块,用于判断用户是否已经设置预测误差初始阈值,若标记变量Checkj=true,表示用户uj未设置预测误差初始阈值,则进入全局损失函数变化计算子模块;反之,进入预测误差动态阈值设置子模块;
全局模型的损失函数变化计算子模块,用于计算用户在第t-1轮迭代采用预测更新时全局模型的损失函数变化e;
损失函数判断子模块,用于比较全局模型的损失函数变化e与预先设定阈值δ的大小,若e≤δ,则进入预测误差初始阈值设置子模块;
预测误差初始阈值设置子模块,用于设置用户的预测误差初始阈值vj,0;
预测误差动态阈值设置子模块,用于设置用户在第t轮迭代的预测误差阈值vj,t;
并行设置子模块,用于并行执行变量判断子模块、全局损失函数变化计算子模块、损失函数判断子模块、预测误差初始阈值设置子模块以及预测误差动态阈值设置子模块,得到每个用户的预测误差阈值;
全局模型更新策略模块,用于为用户集合U中每个用户制定全局模型更新策略,以用户uj为例,包含以下子模块,
预测误差阈值输入子模块,用于获取用户在第t轮迭代的预测误差阈值vj,t;
变量判断子模块,用于判断用户是否设置预测误差阈值,若已设置,则进入预测误差计算子模块,反之,进入本地更新上传子模块;
预测误差计算子模块,用于计算用户在第t轮迭代的预测更新误差Δj,t;
预测准确性判断子模块,用于比较用户的预测误差Δj,t与预测误差阈值vj,t的大小,若Δj,t>vj,t,则进入本地更新上传子模块;
本地更新上传子模块,用于上传用户的本地更新Lj,t至云中心及预测资源池;
通信轮数计算及输出子模块,用于计算和输出模型训练的通信轮数;
云中心全局模型更新模块,用于计算全局模型更新和判定训练模型是否收敛,包含以下子模块,
全局模型更新子模块,用于聚集上传的本地更新和云中心中准确的预测更新,计算得到全局模型更新和全局模型,模型训练进入下一轮更新迭代;
终止判定子模块,用于判定训练模型是否收敛,若收敛,则模型训练结束;反之,进入下一轮训练迭代。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010768983.9A CN111882133B (zh) | 2020-08-03 | 2020-08-03 | 一种基于预测的联邦学习通信优化方法及*** |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010768983.9A CN111882133B (zh) | 2020-08-03 | 2020-08-03 | 一种基于预测的联邦学习通信优化方法及*** |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111882133A CN111882133A (zh) | 2020-11-03 |
CN111882133B true CN111882133B (zh) | 2022-02-01 |
Family
ID=73204433
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010768983.9A Active CN111882133B (zh) | 2020-08-03 | 2020-08-03 | 一种基于预测的联邦学习通信优化方法及*** |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111882133B (zh) |
Families Citing this family (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112364913A (zh) * | 2020-11-09 | 2021-02-12 | 重庆大学 | 一种基于核心数据集的联邦学习通信量优化方法及*** |
CN112801815B (zh) * | 2020-12-30 | 2024-03-29 | 国网江苏省电力公司信息通信分公司 | 一种基于联邦学习的电力通信网络故障预警方法 |
CN113158223A (zh) * | 2021-01-27 | 2021-07-23 | 深圳前海微众银行股份有限公司 | 基于状态转移核优化的数据处理方法、装置、设备及介质 |
CN113011603A (zh) * | 2021-03-17 | 2021-06-22 | 深圳前海微众银行股份有限公司 | 模型参数更新方法、装置、设备、存储介质及程序产品 |
CN113222179B (zh) * | 2021-03-18 | 2023-06-20 | 北京邮电大学 | 一种基于模型稀疏化与权重量化的联邦学习模型压缩方法 |
CN113919512B (zh) * | 2021-09-26 | 2022-09-23 | 重庆邮电大学 | 基于计算资源逻辑分层的联邦学习通信优化方法及*** |
CN114301573B (zh) * | 2021-11-24 | 2023-05-23 | 超讯通信股份有限公司 | 联邦学习模型参数传输方法及*** |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109189825A (zh) * | 2018-08-10 | 2019-01-11 | 深圳前海微众银行股份有限公司 | 横向数据切分联邦学习建模方法、服务器及介质 |
CN109871702A (zh) * | 2019-02-18 | 2019-06-11 | 深圳前海微众银行股份有限公司 | 联邦模型训练方法、***、设备及计算机可读存储介质 |
CN111460443A (zh) * | 2020-05-28 | 2020-07-28 | 南京大学 | 一种联邦学习中数据操纵攻击的安全防御方法 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180089587A1 (en) * | 2016-09-26 | 2018-03-29 | Google Inc. | Systems and Methods for Communication Efficient Distributed Mean Estimation |
CN110442457A (zh) * | 2019-08-12 | 2019-11-12 | 北京大学深圳研究生院 | 基于联邦学习的模型训练方法、装置及服务器 |
CN110797124B (zh) * | 2019-10-30 | 2024-04-12 | 腾讯科技(深圳)有限公司 | 一种模型多端协同训练方法、医疗风险预测方法和装置 |
-
2020
- 2020-08-03 CN CN202010768983.9A patent/CN111882133B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109189825A (zh) * | 2018-08-10 | 2019-01-11 | 深圳前海微众银行股份有限公司 | 横向数据切分联邦学习建模方法、服务器及介质 |
CN109871702A (zh) * | 2019-02-18 | 2019-06-11 | 深圳前海微众银行股份有限公司 | 联邦模型训练方法、***、设备及计算机可读存储介质 |
CN111460443A (zh) * | 2020-05-28 | 2020-07-28 | 南京大学 | 一种联邦学习中数据操纵攻击的安全防御方法 |
Non-Patent Citations (1)
Title |
---|
"基于联邦学习和卷积神经网络的入侵检测方法";王蓉等;《信息网络安全》;20200430(第4期);第47-54页 * |
Also Published As
Publication number | Publication date |
---|---|
CN111882133A (zh) | 2020-11-03 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111882133B (zh) | 一种基于预测的联邦学习通信优化方法及*** | |
CN112651509B (zh) | 量子线路的确定方法及装置 | |
JP7273108B2 (ja) | モデルトレーニング方法、装置、電子デバイス、記憶媒体、プログラム | |
CN109257429A (zh) | 一种基于深度强化学习的计算卸载调度方法 | |
CN111696345A (zh) | 一种基于网络社区检测和gcn的耦合大规模数据流宽度学习快速预测智能算法 | |
Liang et al. | Biased ReLU neural networks | |
CN113361680A (zh) | 一种神经网络架构搜索方法、装置、设备及介质 | |
WO2021103675A1 (zh) | 神经网络的训练及人脸检测方法、装置、设备和存储介质 | |
CN110531996B (zh) | 一种多微云环境下基于粒子群优化的计算任务卸载方法 | |
CN113469891A (zh) | 一种神经网络架构搜索方法、训练方法、图像补全方法 | |
CN113537580B (zh) | 一种基于自适应图学习的公共交通客流预测方法及*** | |
KR20220064866A (ko) | 코스-파인 검색, 2-페이즈 블록 증류, 그리고 뉴럴 하드웨어 예측기를 이용한 하드웨어 및 뉴럴 네트워크 아키텍처들의 공동 설계를 위한 방법 | |
CN114707670A (zh) | 一种面向无标签数据的异构联邦学习方法和*** | |
CN113326869A (zh) | 基于最长路融合算法的深度学习计算图优化方法 | |
CN117520956A (zh) | 一种基于强化学习和元学习的两阶段自动化特征工程方法 | |
CN116757260A (zh) | 一种大型预训练模型的训练方法和*** | |
CN117151195A (zh) | 基于求逆归一化的模型优化方法、装置、设备和介质 | |
CN115438588B (zh) | 一种锂电池的温度预测方法、***、设备及存储介质 | |
Ni et al. | Policy iteration for bounded-parameter POMDPs | |
Zhao | Business intelligence application of enhanced learning in big data scenario | |
CN111950691A (zh) | 一种基于潜在动作表示空间的强化学习策略学习方法 | |
Wang et al. | Using parallel algorithm to speedup the rules learning process of a type-2 fuzzy logic system | |
Li et al. | A multi-task service recommendation model considering dynamic and static QoS | |
Chen et al. | Automated Machine Learning | |
Cai et al. | Value Iteration Networks With Gated Summarization Module |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |