CN115829027A - 一种基于对比学习的联邦学习稀疏训练方法及*** - Google Patents
一种基于对比学习的联邦学习稀疏训练方法及*** Download PDFInfo
- Publication number
- CN115829027A CN115829027A CN202211349843.3A CN202211349843A CN115829027A CN 115829027 A CN115829027 A CN 115829027A CN 202211349843 A CN202211349843 A CN 202211349843A CN 115829027 A CN115829027 A CN 115829027A
- Authority
- CN
- China
- Prior art keywords
- local
- sparse
- model
- learning
- global model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于对比学习的联邦学习稀疏训练方法及***,涉及联邦学习算法框架、神经网络稀疏训练和对比学习的交叉领域。其中,所述方法包括:服务端向本地客户端发送全局模型和掩码;本地客户端根据接收的全局模型和掩码,生成局部稀疏模型,并使用本地数据集对局部稀疏模型进行训练;本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型,并将更新后的局部稀疏模型上传到服务端;服务端聚合本地客户端更新后的局部稀疏模型,对全局模型进行更新,并将更新后的全局模型发送给本地客户端,开始新一轮的沟通训练直至全局模型收敛。本发明通过在联邦学习中引入稀疏训练和对比学习,显著降低计算通信开销,提高了全局模型的性能。
Description
技术领域
本发明涉及分布式机器学习技术领域,涉及联邦学习算法框架、神经网络稀疏训练和对比学习的交叉领域,更具体地,涉及一种基于对比学习的联邦学习稀疏训练方法及***。
背景技术
由于隐私保护、计算资源等方面的原因造成的数据孤岛,正在阻碍着训练人工智能模型所必须的大数据使用。
作为一种分布式机器学习技术,联邦学习成为一种解决数据孤岛的方法,通过多个客户端共同训练机器学习模型。联邦学习在数据不发送给他人的情况下,通过交换模型来协同训练机器学习模型,从而保护数据隐私,已在医学学习、自然语言处理和欺诈***检测等广泛应用。
但联邦学习目前仍然存在以下问题:
(1)异质性问题:数据的异质性,即非独立同分布的数据会使局部模型偏离全局模型,影响聚合后的全局模型的性能;
(2)计算通信开销问题:由于现实生活中,一些本地客户端是小型设备,如手机或者个人笔记本,这些设备没有足够的算力训练大模型,同时,与服务器的沟通也会受到带宽的限制。
在资源受限时,上述问题的存在使得联邦学习的训练精度大大降低。
发明内容
本发明提供了一种基于对比学习的联邦学习动态稀疏训练方法,旨在降低联邦学习的通信开销的同时保证模型的准确率。
为解决上述技术问题,本发明的技术方案如下:
第一方面,一种基于对比学习的联邦学习稀疏训练方法,包括:
服务端向本地客户端发送全局模型和掩码;其中,所述掩码基于稀疏度生成,用于表示全局模型参数是否被保留下来;
本地客户端根据接收的全局模型和掩码,生成局部稀疏模型,并使用本地数据集对局部稀疏模型进行训练;
在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型,并将更新后的局部稀疏模型上传到服务端;
服务端聚合本地客户端更新后的局部稀疏模型,对全局模型进行更新,并将更新后的全局模型发送给本地客户端,开始新一轮的沟通训练直至全局模型收敛。
本技术方案中,通过在联邦学习的过程直接训练稀疏模型,有效减少了训练过程中的计算量,降低设备的存储成本,加快训练过程,显著降低联邦学习计算通信开销;此外,在联邦学习的过程中还引入了对比学习方法,学习相似实例之间的共同特征,利用对比损失函数使同一目标在不同数据增强下的相似性最大化,使不同目标之间的相似性最小化,解决数据异质性问题,在降低联邦学习计算通信开销的同时提高了模型的准确率。
作为优选方案,所述服务端向本地客户端发送全局模型和掩码,包括:
作为优选方案,所述本地客户端根据接收的全局模型和掩码,生成局部稀疏模型,具体为:
作为优选方案,所述使用本地数据集对局部稀疏模型进行训练,包括:
作为优选方案,所述在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型,包括:
式中,τ为预设的温度超参数;
更新本地损失函数,其表达式为:
作为优选方案,在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型后,在预设的通信轮次进行掩码调整,动态演化更新局部稀疏模型的网络结构,再将动态演化更新后的局部稀疏模型上传到服务端。
本优选方案中,通过在特定轮次调整掩码,对局部稀疏网络进行动态更新,可实现寻找更好的稀疏结构的目的。相较于静态稀疏训练,在高稀疏性下,动态稀疏训练可提高局部稀疏模型的精度,进而提高整个联邦学习模型的准确率。
作为本优选方案的一种可能设计,所述在预设的通信轮次进行掩码调整,动态演化更新局部稀疏模型的网络结构,具体为:
在本地客户端与服务端通信的特定轮次,移除局部稀疏模型部分神经元结点之间的连接,使局部稀疏模型被调整至更高稀疏度S+(1-S)αt;其中,αt是动态调整参数,其表达式为:
式中,α表示预设的第一轮的动态调整参数α1的值,t表示联邦学习轮次,Tend表示最后一轮学习轮次;
根据局部稀疏模型即时的梯度信息增长与移除相同数量的神经元、梯度最大的连接,使模型的稀疏度恢复为原稀疏度S。
作为优选方案,所述服务端聚合本地客户端更新后的局部稀疏模型,对全局模型进行更新,包括:
第二方面,一种基于对比学习的联邦学习稀疏训练***,应用于第一方面任一技术方案提出的一种基于对比学习的联邦学习稀疏训练方法,包括服务端和本地客户端,所述服务端与本地客户端连接;
其中,所述服务端,用于向本地客户端发送全局模型和掩码,还用于聚合本地客户端上传的局部稀疏模型,更新全局模型;所述掩码基于稀疏度生成,用于表示全局模型参数是否被保留下来;
所述本地客户端,用于接收全局模型和掩码生成局部稀疏模型,利用本地数据集对局部稀疏模型进行训练,还用于计算对比损失函数,更新本地损失函数和局部稀疏模型,并向服务端上传更新后的局部稀疏模型。
与现有技术相比,本发明技术方案的有益效果是:
本发明在联邦学习的过程中采用了稀疏训练的方法,显著降低了计算通信开销,同时,引入了对比学习的方法,基于模型表示之间的相似性修正本地模型,训练出偏差更小的全局模型,解决联邦学习中的数据异质性问题,提高了全局模型的性能。
附图说明
图1为联邦学习稀疏训练方法的流程图;
图2为包括掩码调整的联邦学习稀疏训练方法的流程图;
图3为实施例2中联邦学习稀疏训练方法学习过程框架示意图;
图4为实施例2中基于对比学习的联邦学习稀疏训练方法与其他联邦学习方法在MNIST数据集上测试准确率结果的比较图。
具体实施方式
附图仅用于示例性说明,不能理解为对本专利的限制;
为了更好说明本实施例,附图某些部件会有省略、放大或缩小,并不代表实际产品的尺寸;
对于本领域技术人员来说,附图中某些公知结构及其说明可能省略是可以理解的。
下面结合附图和实施例对本发明的技术方案做进一步的说明。
实施例1
本实施例提供了一种基于对比学习的联邦学习稀疏训练方法,参阅图1,包括:
服务端向本地客户端发送全局模型和掩码;其中,所述掩码基于稀疏度生成,用于表示全局模型参数是否被保留下来;
本地客户端根据接收的全局模型和掩码,生成局部稀疏模型,并使用本地数据集对局部稀疏模型进行训练;
在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型,并将更新后的局部稀疏模型上传到服务端;
服务端聚合本地客户端更新后的局部稀疏模型,对全局模型进行更新,并将更新后的全局模型发送给本地客户端,开始新一轮的沟通训练直至全局模型收敛。
本实施例中,在联邦学习的过程中,引入了稀疏训练方法,利用掩码在本地客户端生成局部稀疏模型并直接训练,有效减少了联邦学习过程中的计算量,并降低设备的存储成本,加快训练过程,显著降低了联邦学习的计算通信开销;同时,通过引入对比学习方法,基于模型表示之间的相似性修正本地模型,解决数据异质性问题。通过联邦学习、系数训练和对比学习间的交叉配合,在降低联邦学习计算通信开销的同时提高了全局模型的准确率。
在一优选实施例中,所述服务端向本地客户端发送全局模型和掩码,包括:
在本优选实施例中,稀疏度S为全局模型中被裁剪掉的参数数量与总参数量之比,掩码基于稀疏度生成,其代表了稀疏网络的结构。
在一可选实施例中,所述掩码为二进制形式。
作为非限制性示例,所述掩码基于稀疏度,利用剪枝算法生成。
在一优选实施例中,所述本地客户端根据接收的全局模型和掩码,生成局部稀疏模型,具体为:
在一优选实施例中,所述使用本地数据集对局部稀疏模型进行训练,包括:
在一优选实施例中,所述在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型,包括:
将本地数据集分别输入第t轮局部稀疏模型第t-1轮的局部稀疏模型第t轮的全局模型中,分别得到对应的特征向量z、zlast和zglob;其中,z表示样本的特征经过特征表示网络的投影头(Projection head)结构的输出的向量;
式中,τ为预设的温度超参数;
更新本地损失函数,其表达式为:
在一优选实施例中,在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型后,在预设的通信轮次进行掩码调整,动态演化更新局部稀疏模型的网络结构,再将动态演化更新后的局部稀疏模型上传到服务端。
在一具体实施过程中,在训练初始阶段随机选择一种稀疏网络结构,在随后的稀疏训练过程中,进行掩码调整。由于掩码代表了稀疏网络的结构,通过掩码调整,可不断改变稀疏网络的结构,以实现寻找更好的稀疏结构的目的。
在一可选实施例中,参阅图2,所述在预设的通信轮次进行掩码调整,动态演化更新局部稀疏模型的网络结构,具体为:
在本地客户端与服务端通信的特定轮次,移除局部稀疏模型部分神经元结点之间的连接,使局部稀疏模型被调整至更高稀疏度S+(1-S)αt;其中,αt是动态调整参数,其表达式为:
式中,α表示预设的第一轮的动态调整参数α1的值,t表示联邦学习轮次,Tend表示最后一轮学习轮次;
根据局部稀疏模型即时的梯度信息,增长与移除相同数量的神经元、梯度最大的连接,使模型的稀疏度恢复为原稀疏度S。
在一优选实施例中,所述服务端聚合本地客户端更新后的局部稀疏模型,对全局模型进行更新,包括:
在一具体实施过程中,服务端完成局部稀疏模型后,将新生成的全局模型发送给选中的本地客户端,开始新一轮的沟通训练直至全局模型收敛。
实施例2
本实施例采用公开的MNIST数据集,对实施例1提出的基于对比学习的联邦学习稀疏训练方法进行实验,参阅图1-图4。
MNIST数据集(Mixed National Institute of Standards and Technologydatabase)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60000个示例的训练集以及10000个示例的测试集。
考虑一个典型的联邦学习框架:设定全局模型为包含两个5*5卷积层、两个最大池化层和四个全连接层的卷积神经网络;设计总共有100个本地客户端,每个通信轮次中随机选取20个本地客户端参与训练,每个本地客户端每轮在本地数据集上使用SGD优化器迭代10次,和服务端沟通50次。
本地客户端接收全局模型和掩码后,生成局部稀疏模型在本地数据集上训练局部模型,将本地数据x以32个样本的小批次输入局部稀疏模型中,局部稀疏模型进行预测,计算损失函数预设学习率η=0.01并进行如下操作更新局部稀疏模型:
更新本地损失函数为:
设定每十轮本地客户端执行一次掩码调整,动态更新稀疏网络的结果。设定α=0.01,当本地训练完成,在本地客户端与服务端通信的特定轮数,本地客户端通过移除局部稀疏模型部分神经元结点之间的连接,局部稀疏模型被调整到的更高的稀疏度S+(1-S)αt;随后根据局部稀疏模型即时的梯度信息增长与移除相同数量的神经元、梯度最大的连接,使局部稀疏模型的稀疏度恢复为S。其中αt是动态调整参数,按照余弦衰减更新计划调整稀疏度的变化。
其中,聚合方式如下:
此外,本实施例还选取了与上述全局模型相同结构的卷积神经网络和相同设置,执行MNIST分类预测任务。从100个本地客户端中选择20个本地客户端,在给定稀疏度为S=0.5的条件下,每个本地客户端每轮在本地数据集上使用SGD优化器迭代10次,和服务端沟通50次,进行联邦学习训练后预测,其预测结果的准确率如图4所示。显而易见,相较于FedDST、FedAvg和FedProx,本实施例提出的基于对比学习的联邦学习稀疏训练方法得到的模型性能更好,经需要较少的沟通轮数便可获得较高的准确率。
实施例3
本实施提出一种基于对比学习的联邦学习稀疏训练***,参照图3,应用于实施例1提出的基于对比学习的联邦学习稀疏训练方法,包括服务端和本地客户端,所述服务端与本地客户端连接;
其中,所述服务端,用于向本地客户端发送全局模型和掩码,还用于聚合本地客户端上传的局部稀疏模型,更新全局模型;所述掩码基于稀疏度生成,用于表示全局模型参数是否被保留下来;
所述本地客户端,用于接收全局模型和掩码生成局部稀疏模型,利用本地数据集对局部稀疏模型进行训练,还用于计算对比损失函数,更新本地损失函数和局部稀疏模型,并向服务端上传更新后的局部稀疏模型。
相同或相似的标号对应相同或相似的部件;
附图中描述位置关系的用语仅用于示例性说明,不能理解为对本专利的限制;
显然,本发明的上述实施例仅仅是为清楚地说明本发明所作的举例,而并非是对本发明的实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式的变化或变动。这里无需也无法对所有的实施方式予以穷举。凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明权利要求的保护范围之内。
Claims (10)
1.一种基于对比学习的联邦学习稀疏训练方法,其特征在于,包括:
服务端向本地客户端发送全局模型和掩码;其中,所述掩码基于稀疏度生成,用于表示全局模型参数是否被保留下来;
本地客户端根据接收的全局模型和掩码,生成局部稀疏模型,并使用本地数据集对局部稀疏模型进行训练;
在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型,并将更新后的局部稀疏模型上传到服务端;
服务端聚合本地客户端更新后的局部稀疏模型,对全局模型进行更新,并将更新后的全局模型发送给本地客户端,开始新一轮的沟通训练直至全局模型收敛。
7.根据权利要求1所述的一种基于对比学习的联邦学习稀疏训练方法,其特征在于,所述在每轮训练过程中,本地客户端进行对比损失函数计算,更新本地损失函数和局部稀疏模型后,在预设的通信轮次进行掩码调整,动态演化更新局部稀疏模型的网络结构,再将动态演化更新后的局部稀疏模型上传到服务端。
10.一种基于对比学习的联邦学习稀疏训练***,应用于权利要求1-9任一项所述的一种基于对比学习的联邦学习动态稀疏训练方法,其特征在于,包括服务端和本地客户端,所述服务端与本地客户端连接;
其中,所述服务端,用于向本地客户端发送全局模型和掩码,还用于聚合本地客户端上传的局部稀疏模型,更新全局模型;所述掩码基于稀疏度生成,用于表示全局模型参数是否被保留下来;
所述本地客户端,用于接收全局模型和掩码生成局部稀疏模型,利用本地数据集对局部稀疏模型进行训练,还用于计算对比损失函数,更新本地损失函数和局部稀疏模型,并向服务端上传更新后的局部稀疏模型。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211349843.3A CN115829027A (zh) | 2022-10-31 | 2022-10-31 | 一种基于对比学习的联邦学习稀疏训练方法及*** |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211349843.3A CN115829027A (zh) | 2022-10-31 | 2022-10-31 | 一种基于对比学习的联邦学习稀疏训练方法及*** |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115829027A true CN115829027A (zh) | 2023-03-21 |
Family
ID=85525940
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211349843.3A Pending CN115829027A (zh) | 2022-10-31 | 2022-10-31 | 一种基于对比学习的联邦学习稀疏训练方法及*** |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115829027A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116341689A (zh) * | 2023-03-22 | 2023-06-27 | 深圳大学 | 机器学习模型的训练方法、装置、电子设备及存储介质 |
CN116578674A (zh) * | 2023-07-07 | 2023-08-11 | 北京邮电大学 | 联邦变分自编码主题模型训练方法、主题预测方法及装置 |
CN117196014A (zh) * | 2023-09-18 | 2023-12-08 | 深圳大学 | 基于联邦学习的模型训练方法、装置、计算机设备及介质 |
CN117391187A (zh) * | 2023-10-27 | 2024-01-12 | 广州恒沙数字科技有限公司 | 基于动态层次化掩码的神经网络有损传输优化方法及*** |
-
2022
- 2022-10-31 CN CN202211349843.3A patent/CN115829027A/zh active Pending
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116341689A (zh) * | 2023-03-22 | 2023-06-27 | 深圳大学 | 机器学习模型的训练方法、装置、电子设备及存储介质 |
CN116341689B (zh) * | 2023-03-22 | 2024-02-06 | 深圳大学 | 机器学习模型的训练方法、装置、电子设备及存储介质 |
CN116578674A (zh) * | 2023-07-07 | 2023-08-11 | 北京邮电大学 | 联邦变分自编码主题模型训练方法、主题预测方法及装置 |
CN116578674B (zh) * | 2023-07-07 | 2023-10-31 | 北京邮电大学 | 联邦变分自编码主题模型训练方法、主题预测方法及装置 |
CN117196014A (zh) * | 2023-09-18 | 2023-12-08 | 深圳大学 | 基于联邦学习的模型训练方法、装置、计算机设备及介质 |
CN117196014B (zh) * | 2023-09-18 | 2024-05-10 | 深圳大学 | 基于联邦学习的模型训练方法、装置、计算机设备及介质 |
CN117391187A (zh) * | 2023-10-27 | 2024-01-12 | 广州恒沙数字科技有限公司 | 基于动态层次化掩码的神经网络有损传输优化方法及*** |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN115829027A (zh) | 一种基于对比学习的联邦学习稀疏训练方法及*** | |
Lin et al. | Network pruning using adaptive exemplar filters | |
CN114943345B (zh) | 基于主动学习和模型压缩的联邦学习全局模型训练方法 | |
CN110781912A (zh) | 一种基于通道扩张倒置卷积神经网络的图像分类方法 | |
CN112836822B (zh) | 基于宽度学习的联邦学习策略优化方法和装置 | |
Wehenkel et al. | Diffusion priors in variational autoencoders | |
CN113987236B (zh) | 基于图卷积网络的视觉检索模型的无监督训练方法和装置 | |
CN112115967A (zh) | 一种基于数据保护的图像增量学习方法 | |
CN115331069A (zh) | 一种基于联邦学习的个性化图像分类模型训练方法 | |
CN115081532A (zh) | 基于记忆重放和差分隐私的联邦持续学习训练方法 | |
Gil et al. | Quantization-aware pruning criterion for industrial applications | |
CN111694977A (zh) | 一种基于数据增强的车辆图像检索方法 | |
CN115600686A (zh) | 基于个性化Transformer的联邦学习模型训练方法及联邦学习*** | |
CN115359298A (zh) | 基于稀疏神经网络的联邦元学习图像分类方法 | |
CN116168197A (zh) | 一种基于Transformer分割网络和正则化训练的图像分割方法 | |
CN109948589B (zh) | 基于量子深度信念网络的人脸表情识别方法 | |
CN115278709A (zh) | 一种基于联邦学习的通信优化方法 | |
CN111401193A (zh) | 获取表情识别模型的方法及装置、表情识别方法及装置 | |
Zhang et al. | Stochastic approximation approaches to group distributionally robust optimization | |
Du et al. | CGaP: Continuous growth and pruning for efficient deep learning | |
Huang et al. | Distributed pruning towards tiny neural networks in federated learning | |
Zhang et al. | Federated multi-task learning with non-stationary heterogeneous data | |
Zhao et al. | Exploiting channel similarity for network pruning | |
CN111414937A (zh) | 物联网场景下提升多分支预测单模型鲁棒性的训练方法 | |
CN116010832A (zh) | 联邦聚类方法、装置、中心服务器、***和电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |