CN113705610A - Heterogeneous model aggregation method and system based on federal learning - Google Patents

Heterogeneous model aggregation method and system based on federal learning Download PDF

Info

Publication number
CN113705610A
CN113705610A CN202110844739.0A CN202110844739A CN113705610A CN 113705610 A CN113705610 A CN 113705610A CN 202110844739 A CN202110844739 A CN 202110844739A CN 113705610 A CN113705610 A CN 113705610A
Authority
CN
China
Prior art keywords
model
client
data
data set
local
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110844739.0A
Other languages
Chinese (zh)
Other versions
CN113705610B (en
Inventor
陈孔阳
张炜斌
陈卓荣
严基杰
黄耀
李进
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Guangzhou University
Original Assignee
Guangzhou University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Guangzhou University filed Critical Guangzhou University
Priority to CN202110844739.0A priority Critical patent/CN113705610B/en
Publication of CN113705610A publication Critical patent/CN113705610A/en
Application granted granted Critical
Publication of CN113705610B publication Critical patent/CN113705610B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

The invention relates to the field of federal learning, in particular to a heterogeneous model polymerization method and a heterogeneous model polymerization system based on federal learning, wherein the method comprises the following steps: initializing a neural network model; each client contributes part of local data to be uploaded to a server to form a shared data set, and a CGAN model is trained; the client side trains the local model by using the local data set and the data set generated by the CGAN model, predicts each data of the shared data set and uploads the predicted score to the server side; the server side calculates the deviation degree of the prediction scores of the clients, takes the reciprocal of the calculation result as weight, calculates the global prediction score, and uses the global prediction score to perform knowledge distillation on the server side model; the client downloads the prediction scores of other client models from the server to perform cooperative training; the model converges after multiple iterations. The invention can solve the problem of heterogeneous client data, and the client model uploads and downloads the predicted scores of the shared data set, thereby reducing the communication quantity between the client and the server.

Description

Heterogeneous model aggregation method and system based on federal learning
Technical Field
The invention relates to the field of federal learning, in particular to a heterogeneous model polymerization method and system based on federal learning.
Background
The field of deep learning is rapidly developed at present, however, deep learning has an obvious disadvantage that a large amount of data is required for training to achieve better performance. In recent years, importance on data privacy and security has become a worldwide trend, and meanwhile, most industrial data show a data islanding phenomenon, so that how to jointly train an excellent model on the premise of meeting user privacy protection, data security and government regulations is a key technology for solving the problem.
Federal learning has developed to date, and many challenges remain. The two most important aspects are the differences in heterogeneous and local data of the client model. Since each client is not necessarily the same and in different space, the communication volume, computing power and owned data of each client are greatly different, and the differences can seriously affect the quality of the model jointly trained by each client.
In recent years, a new method has been proposed to allow different clients to design different network structures according to their computing power, each client downloads the average prediction score of all client models to a shared data set in each turn, and lets the client local model fit the average prediction score using knowledge distillation, so as to learn global consensus. This method still has some disadvantages as follows:
1. the method still does not well solve the problem of client data heterogeneity, and under the condition that local data sets of the clients are not independently and identically distributed, the model performance of the clients is greatly different, and fairness is poor.
2. The average prediction score of the shared data set is calculated by all the client models only by adopting a simple average method without considering the performance difference of all the client models, so that the quality of the average prediction score is seriously influenced if some client models have poor performance.
Another new approach has been proposed in recent years. In the method, a plurality of clients are randomly selected in each round to send the parameters of the aggregation model in the previous round, and the clients update the parameters by using local data and send the updated parameters to the server. And after the server side averages the weighted values of the received model parameters, the polymerization model parameters of the round are obtained by utilizing unlabeled data or data (such as GAN) generated by a generator to carry out integrated distillation. This method still has some disadvantages as follows:
1. the client still needs to upload and download the model parameters to the server, and the problem of communication volume is not effectively solved.
2. In the method, the server side needs to average the model parameters of each client side, so that the client side models are not completely heterogeneous, some client side models are required to be isomorphic, and only the client side models with the same model structure can be aggregated.
Disclosure of Invention
In order to solve the technical problems in the prior art, the invention aims to provide a federal learning method which can better solve the problem of heterogeneous client data, allow the client model to be heterogeneous and send a prediction score to a server to reduce the communication volume.
In order to solve the technical problems, the technical scheme of the invention is as follows:
a heterogeneous model aggregation method and system based on federated learning comprise the following steps:
s1, each client side contains a local data set and initializes a neural network model, and the server side initializes a neural network model;
s2, each client contributes a small part of local data set and uploads the small part of local data set to the server;
s3, forming a batch of shared data sets by the service end, and training a CGAN model by using the shared data sets;
s4, each client downloads the shared data set and the CGAN model from the server to the local;
s5, the server randomly selects a plurality of clients;
s6, the client trains a local client model by using an enhanced iteration method by using local data and CGAN generated data;
s7, the client uses the local client model to predict each data of the shared data set in turn and uploads the predicted score to the server;
s8, the server side calculates the deviation degree of the prediction scores of the clients and other clients by using a JS function, and the reciprocal of the calculated result is used as the weight of the prediction scores;
s9, calculating global prediction scores by using a weighted average method according to the prediction score weights and the prediction scores of the clients calculated in the step S7;
s10, the server model carries out knowledge distillation on the server model by using the global prediction score;
s11, downloading the prediction scores of other client models from the server by the client, and then performing cooperative training on the local model through the prediction scores of the other client models and the prediction scores of the local model;
s12, iterating S5 to S11 for multiple times, and finally converging the server model and the client model;
and S13, downloading the server model to the local client by each client.
Preferably, in step S1, the model structure and the model parameters of each client are different, and the data distribution of the local data set of each client is different.
Preferably, in step S3, the data classes of the shared data set are balanced, and only one class of data can be generated by one CGAN model.
Preferably, in step S5, only k clients are selected from each round of federal learning to perform local training, where k is generally smaller than the total number of clients.
Another object of the present invention is to provide a heterogeneous model aggregation system based on federal learning, which includes:
the enhanced iteration module is used for ensuring that the data types owned by the client model in the batch iteration process are complete and are uniformly distributed, so that the client model can continuously correct the gradient descending direction in the training process, and the gradient descends towards the optimal solution direction;
the cooperation training module is used for solving the problem of insufficient data representation of the client in the training process, and adopts a loss function to add extra items to guide the gradient descending direction of the model so as to enable the client model to achieve the cooperation effect;
the knowledge distillation module is used for solving the reliability problem of the client model under the heterogeneous condition; and each client sends the prediction score to the server in each round, the server calculates the weight of the prediction score of each client model by using a JS function, then calculates the global prediction score by weighted average, and the server model carries out knowledge distillation by the global prediction score.
Compared with the prior art, the invention has the following advantages and beneficial effects:
1. the method firstly trains the CGAN model by using the shared data set, and the client generates missing data by using the CGAN model, so that the local data are independently and identically distributed, and the problem of heterogeneous client data is better solved.
2. The invention allows the client models to be heterogeneous, and each client model uses the enhanced iteration and the cooperation training in sequence, so that each client model can correct the gradient descending direction and learn the knowledge of other client models, thereby further improving the performance of the client models and reducing the difference between the client models. In terms of communication volume, the client model uploads and downloads the prediction scores of the shared data set, and compared with federal learning, the communication volume is greatly reduced.
3. According to the method, the JS divergence of the prediction scores of the client models and the average prediction scores of the other client models is calculated in sequence, the reciprocal of the JS divergence calculation result is used as the weight of the prediction scores of the client models, so that the excellent client models, namely the client models with the smaller JS divergence of the prediction scores and the average prediction scores of the other client models, have larger weight, and conversely have smaller weight, and finally the weight of the prediction scores of the client models is normalized. The quality of the weighted prediction scores obtained using the JS function will be much higher than the average prediction scores of the previous federal learning.
Drawings
FIG. 1 is a flow chart of the training of federated learning-based heterogeneous model aggregation in an embodiment of the present invention;
FIG. 2 is a schematic diagram illustrating the effect of enhanced iteration in an embodiment of the present invention;
FIG. 3 is a graph showing the experimental effect of the accuracy of the model of the server under different values of lambda in the embodiment of the present invention;
fig. 4 is a diagram of an effect of an experiment on client model accuracy under client model heterogeneity in the embodiment of the present invention.
Detailed Description
The technical solutions of the present invention will be described in further detail with reference to the accompanying drawings and examples, and it is obvious that the described examples are some, but not all, examples of the present invention, and the embodiments of the present invention are not limited thereto. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
For convenience of understanding, terms referred to in the embodiments of the present invention are explained below:
a neural network: the method is an arithmetic mathematical model simulating animal neural network behavior characteristics and performing distributed parallel information processing. The network achieves the aim of processing information by adjusting the mutual connection relationship among a large number of nodes in the network depending on the complexity of the system.
Federal learning: federal machine learning is a machine learning framework, and can enable multiple parties to develop efficient machine learning under the condition of meeting the requirements of user privacy protection, data security and government regulations.
Knowledge distillation: the method is a deep learning technology, and by introducing a complex teacher network with excellent reasoning performance, a soft target deduced by the teacher network is used as a part of an optimization target function, so that a simplified student network is guided to train, and knowledge migration is realized.
JS divergence: is based on the variation of the KL divergence, the similarity of the two probability distributions is measured. The JS divergence solves the problem of KL divergence being asymmetric. Generally, the JS divergence is symmetrical, with a value between 0 and 1. JS divergence is defined as follows:
Figure BDA0003180070020000041
where Q and P are two distributions.
Conditional generation-antagonistic network (CGAN): is a deep learning model. The model passes through (at least) two modules in the framework: the mutual game learning of the generative model and the discriminant model produces a fairly good output. After the training is finished, specific condition information is input into the generative model, and the generative model can generate specific data.
In the field of internet of things, a plurality of devices are located at different spatial positions, so that the data acquired by the devices are distributed differently, the network bandwidths of the devices are different, and the performances and the computing capabilities of the devices are different. If the traditional federal learning algorithm is directly used in the above scenario, the performance of the constructed combined model cannot reach the expected index because the differences of the device data distribution, the performance and the network bandwidth are not fully considered. The invention provides a heterogeneous model aggregation method and system based on federal learning. In addition, the invention allows the client models to be heterogeneous, and each client model uses the enhanced iteration and the cooperative training in sequence, so that each client model can correct the gradient descending direction and learn the knowledge of other client models, thereby further improving the performance of the client models and reducing the difference between the client models. In terms of communication volume, the client model uploads and downloads the prediction scores of the shared data set, and compared with federal learning, the communication volume is greatly reduced.
Example 1
As shown in fig. 1, the federated learning-based heterogeneous model aggregation method in this embodiment uses a mnst data set, and includes the following steps:
s1, setting a local data set D at each clientiAnd initializing a neural network model MiWherein i is a client serial number, and a neural network model M is initialized at a server; wherein, the neural network model M of each clientiAre allowed to be different, each client local data set DiThe data distribution is not the same.
In this embodiment, the client is a computer with certain computing power, and for the client with the number i, the local data set D isiFor data of a part of classes in the Mnist dataset, the initialized neural network model MiThe structure of (a) is shown in table 1:
TABLE 1
Figure BDA0003180070020000051
The server is a data center with strong performance and large communication capacity, the model is M, and the structure of the model is shown in Table 1. Because the computer brands are not consistent and are located at different positions, the data sets, the network bandwidths and the computing power of the clients are different.
And S2, each client contributes a small part of local data and uploads the local data to the server.
In this embodiment, each computer wirelessly uploads a small portion of data of the local client to the data center of the server.
S3, forming a batch of shared data set D at the service end by the uploaded local data, and training a CGAN model G by using the shared data set DjWherein j is a data category and the data categories of the shared data set D are balanced; a CGAN model GjCGAN model G, which can only generate one data of type jjThe number is the same as the number of data categories j.
In this embodiment, the data center of the server receives the local data uploaded by all computers, and integrates the uploaded local data into the shared data set D. The data center of the server side trains out a high-performance CGAN model G by utilizing the strong computing power of the data centerjWhere j is the data class.
S4, each client downloads the shared data set D and the CGAN model G from the serverjTo local, where each client needs to download the CGAN model G from the serverjIs equal to the CGAN model G trained in the step S3jThe number m of (2).
This embodiment wirelessly downloads the shared data set D and the CGAN model G from the data center for each computerjTo the home.
S5, the server randomly selects a plurality of clients and sends the selected instructions to the clients; and the server only randomly selects k clients to execute local training in each round of federal learning, wherein k is generally smaller than the total number of the clients.
The embodiment randomly extracts a plurality of computers for the data center and sends the selected instructions to the computers.
S6, client i pair local data set DiAnd CGAN model GjThe generated data set is used for training a local model M by using an enhanced iteration methodiThe effect of the enhancement iteration is to correct the direction of gradient descent, and the schematic effect diagram is shown in fig. 2, and the specific steps of the enhancement iteration method include:
step S61: for each class j of data, the CGAN model G is used in turnjGenerating a data set d with j class labelsjWherein the data set djOf size N, using a data set djFor local client model MiOne round of training was performed according to the following formula:
Figure BDA0003180070020000061
where Cross EntropyLoss () is the cross entropy loss function, θiFor local client model MiThe parameter of (1), eta, is the learning rate,
Figure BDA0003180070020000062
for local client model MiFor data set djPrediction of the nth data, j is the data label. Usage data set djAfter training, calculating local client model M in the training processiFor data set djAverage predicted fraction of total data in
Figure BDA0003180070020000063
Average prediction fraction
Figure BDA0003180070020000064
Obtained according to the following formula:
Figure BDA0003180070020000065
step S62: using local data sets DiFor local client model MiMultiple rounds of training were performed according to the following formula:
Figure BDA0003180070020000071
wherein KLDivLoss () is the relative entropy loss function, Cross EntropyLoss () is the cross entropy loss function, θiFor local client model MiWith η being the learning rate and α being the regularization parameter,
Figure BDA0003180070020000072
for local client model MiFor local data set DiThe prediction of the n-th data,
Figure BDA0003180070020000073
for local data sets DiThe label of the nth data is identified,
Figure BDA0003180070020000074
for local data setsDiAnd average prediction scores corresponding to the label categories of the nth data.
The embodiment utilizes the local original data set D for the computer receiving the selected instructioniAnd CGAN model GjGenerated data training local client model Mi
S7, using local client model M by client iiPredicting each data of the shared data set D in turn and scoring the prediction PiAnd uploading to the server.
In this embodiment, the computer with number i receiving the selected instruction utilizes the local client model MiPredicting the shared data set D and scoring the prediction PiAnd wirelessly uploading the data to a data center of the server.
S8, the server side calculates the deviation degree of the prediction scores of each selected client side and other client sides by using the JS divergence function, and the reciprocal of the calculated result of the JS divergence function is used as the weight W of the prediction score of each selected client sideiSo that the superior client model, i.e., the client model with less JS divergence of the prediction score to the average prediction scores of the remaining client models, has more weight, and conversely has less weight.
In the embodiment, the data center receives the prediction result scores uploaded by the selected computers, then calculates the deviation degree of each selected computer prediction score from the other computer prediction scores by using the JS divergence function, and takes the calculation result as the weight W of the computer prediction scorei. The data center judges the similarity of the client models by using the JS divergence function, so that the weight of each client model is determined.
S9, according to the prediction score PiAnd the weight W of the prediction score of each selected client calculated in step S8iThe global prediction score P is calculated using a weighted average method.
In this embodiment, the data center calculates the weight W of the prediction score of each computer according to step S8iThe global prediction score P for all computers is calculated using a weighted average.
S10, the server side model performs knowledge distillation on the server side model M by using the shared data set D and the global prediction score P, and the specific operation of the knowledge distillation comprises the following steps:
s101, performing multiple rounds of training on the server model M by using the shared data set D and the global prediction score P according to the following formula:
Figure BDA0003180070020000075
KLDivloss () is a relative entropy loss function, Cross EntropyLoss () is a cross entropy loss function, theta is a parameter of a service model M, eta is a learning rate, alpha is a regularization parameter, and M (D)n) Prediction of the nth data of the shared data set D for the server model M, ynFor the tag of the nth data of the shared data set D, PnIs the nth fraction of the global predicted fraction P and T represents the temperature of the knowledge distillation. It should be noted that different values of the regularization parameter α may cause different effects of the final client model, and fig. 3 shows the accuracy effect of the client model when the regularization parameter α has the same value.
S102, carrying out knowledge distillation by the server model through the global prediction fraction, wherein a loss function of the knowledge distillation is set as follows:
L(M)=α·T2·KLDivLoss(M(Dn),Pn)+(1-α)·CrossEntropyLoss(M(Dn),yn)
wherein KLDivloss () is a relative entropy loss function, Cross EntropyLoss () is a cross entropy loss function, L (M) is a loss value of the service model M training, theta is a parameter of the service model M, eta is a learning rate, alpha is a regularization parameter, and M (D)n) Prediction of the nth data of the shared data set D for the server model M, ynFor the tag of the nth data of the shared data set D, PnIs the nth fraction of the global predicted fraction P and T represents the temperature of the knowledge distillation.
In the embodiment, the data center trains the model M of the data center through the knowledge distillation technology by using the shared data set D obtained in the step S3 and the global prediction score P obtained in the step S9
S11, downloading the prediction scores P of other client models from the server by the client ijWherein j ≠ i, and then the prediction score P of other client models and the prediction score P of the local client modeliFor local client model MiPerforming cooperative training; the specific operation of the cooperative training comprises the following steps:
pairing local client models M using shared data sets DiMultiple rounds of training were performed according to the following formula:
Figure BDA0003180070020000081
wherein Cross EntropyLoss () is a cross entropy loss function, η is a learning rate, α is a weight factor of an additional term of the loss function, Mi(Dn) For local client model MiPrediction of the nth data of the shared data set D, ynTo share the tag of the nth data of the data set D,
Figure BDA0003180070020000082
for local client model MiPrediction and client model M for nth data of shared data set DjPredicted score for nth data of shared data set D
Figure BDA0003180070020000083
JS divergence score of (1), where j ≠ i. Considering that while the client i model is trained, the predictions of other client models are not updated, we add λepochAnd the weight factor, wherein 0 < lambda < 1, epoch represents the number of epoch times of the current training, and the influence of the prediction of other client models on the training of the client i model is gradually reduced along with the increase of the training epoch times. The client judges the similarity of the own model and other client models by using the JS divergence function, and is convenient to perform cooperative training by using a plurality of client models.
The cooperative training is based on the idea of ensemble learning, and the feature representation of data is fully represented by integrating the feature representations of the same group of data by different clients. In order for the client model to achieve the cooperative effect, an additional term needs to be added through a loss function to guide the gradient descending direction of the model. The invention sets the loss function as follows:
Figure BDA0003180070020000084
where Cross EntropyLoss () is the cross entropy loss function, L (M)i) For local client model MiThe loss value of the training, eta is the learning rate, alpha is the weight factor of the extra term of the loss function, Mi(Dn) For local client model MiPrediction of the nth data of the shared data set D, ynTo share the tag of the nth data of the data set D,
Figure BDA0003180070020000091
for local client model MiPrediction and client model M for nth data of shared data set DjPredicted score for nth data of shared data set D
Figure BDA0003180070020000092
JS divergence score of (1), where j ≠ i. Considering that while the client i model is trained, the predictions of other client models are not updated, we add λepochThis weighting factor, where 0 < λ < 1, epoch represents the number of epochs currently trained. As the number of times of training epochs increases, the influence of the predictions of other client models on the training of the client i model gradually decreases.
In the embodiment, the computer with the number i receiving the selected instruction downloads the predicted score P of the shared data set from the data center to other computersjWhere j ≠ i, and uses these prediction scores and the shared dataset D to model M locallyiAnd performing cooperative training.
And S12, iterating steps S5 to S11 for multiple times, and finally converging the server model and the client model.
After multiple iterations of steps S5 and S11, the server model and the client model gradually converge, and the server model becomes a high-performance model for learning the knowledge of local data of all computers. Wherein the heterogeneous client model accuracy is shown in fig. 4.
And S13, downloading the server model M to the computer of the client by each client.
And the computer of each client wirelessly downloads the server model M from the data center of the server, and is deployed on a local computer to run.
Example 2
Based on the same inventive concept as that of embodiment 1, this embodiment further provides a heterogeneous model aggregation system based on federal learning, including:
and the enhanced iteration module is used for ensuring that the data types owned by the client model in the batch iteration process are complete and the types are uniformly distributed, so that the client model can continuously correct the gradient descending direction in the training process, and the gradient descends towards the optimal solution direction.
And the knowledge distillation module is used for solving the reliability problem of the client model under the heterogeneous condition. And each client sends the prediction score to the server in each round, the server calculates the weight of the prediction score of each client model by using a JS function, and then calculates the global prediction score by weighted average. The service-side model carries out knowledge distillation through the global prediction score, wherein the loss function of the knowledge distillation is set as follows:
L(M)=α·T2·KLDivLoss(M(Dn),Pn)+(1-α)·CrossEntropyLoss(M(Dn),yn)
wherein KLDivloss () is a relative entropy loss function, Cross EntropyLoss () is a cross entropy loss function, L (M) is a loss value of the service model M training, theta is a parameter of the service model M, eta is a learning rate, alpha is a regularization parameter, and M (D)n) Prediction of the nth data of the shared data set D for the server model M, ynFor the tag of the nth data of the shared data set D, PnIs the n-th of the global prediction score PFractional, T denotes the temperature of the knowledge distillation.
And the cooperation training module is used for solving the problem of insufficient data representation of the client in the training process. In the federal learning system, the local computing resources of the client are scarce, and the shortage of the computing resources causes the data feature representation of the client model to be insufficient. The cooperative training is based on the idea of ensemble learning, and the feature representation of data is fully represented by integrating the feature representations of the same group of data by different clients. In order for the client model to achieve the cooperative effect, an additional term needs to be added through a loss function to guide the gradient descending direction of the model. The invention sets the loss function as follows:
Figure BDA0003180070020000101
where Cross EntropyLoss () is the cross entropy loss function, L (M)i) For local client model MiThe loss value of the training, eta is the learning rate, alpha is the weight factor of the extra term of the loss function, Mi(Dn) is the local client model MiPrediction of the nth data of the shared data set D, ynTo share the tag of the nth data of the data set D,
Figure BDA0003180070020000102
for local client model MiPrediction and client model M for nth data of shared data set DjPredicted score for nth data of shared data set D
Figure BDA0003180070020000103
JS divergence score of (1), where j ≠ i. Considering that while the client i model is trained, the predictions of other client models are not updated, we add λepochThis weighting factor, where 0 < λ < 1, epoch represents the number of epochs currently trained. As the number of times of training epochs increases, the influence of the predictions of other client models on the training of the client i model gradually decreases.
The above embodiments are preferred embodiments of the present invention, but the present invention is not limited to the above embodiments, and any other changes, modifications, substitutions, combinations, and simplifications which do not depart from the spirit and principle of the present invention should be construed as equivalents thereof, and all such changes, modifications, substitutions, combinations, and simplifications are intended to be included in the scope of the present invention.

Claims (10)

1. A heterogeneous model aggregation method based on federated learning is characterized by comprising the following steps:
s1, setting a local data set at each client and initializing a neural network model, and initializing a neural network model at the server;
s2, each client contributes a small part of local data and uploads the local data to the server;
s3, forming a batch of shared data sets at the service end by the uploaded local data, and training a CGAN model by using the shared data sets;
s4, each client downloads the shared data set and the CGAN model from the server to the local;
s5, the server randomly selects a plurality of clients;
s6, the client trains a local client model by using an enhanced iteration method by using the local data set and the data generated by the CGAN model;
s7, the client uses the local client model to predict each data of the shared data set in turn and uploads the predicted score to the server;
s8, the server side calculates the deviation degree of the prediction scores of the clients and other clients by using a JS function, and the reciprocal of the calculated result is used as the weight of the prediction scores;
s9, according to the prediction score PiStep S8, calculating the global prediction score by using a weighted average method according to the prediction score weight and the prediction score of each client calculated in the step S8;
s10, the server model carries out knowledge distillation on the server model by using the shared data set and the global prediction score;
s11, downloading the prediction scores of other client models from the server by the client, and then performing cooperative training on the local client model through the prediction scores of the other client models and the prediction score of the local client model;
s12, iterating steps S5 to S11 for multiple times, and finally converging the server model and the client model;
and S13, downloading the server model to the computer of the client by each client.
2. The heterogeneous model aggregation method according to claim 1, wherein in step S1, the model structure and the model parameters of the neural network model of each client are different, and the data distribution of the local data set of each client is different.
3. The heterogeneous model aggregation method of claim 1, wherein in step S3, the data classes of the shared data set are balanced, and only one class of data can be generated by one CGAN model.
4. The method for aggregating heterogeneous models according to claim 1, wherein in step S5, the number of clients randomly selected by the server is smaller than the total number of clients.
5. The heterogeneous model aggregation method according to claim 1, wherein in step S6, the enhancement iteration method includes the steps of:
s61: for each class j of data, a CGAN model G is usedjGenerating a data set d with j categoriesjWherein the data set djOf size N, using a data set djFor local model MiOne round of training was performed according to the following formula:
Figure FDA0003180070010000021
where Cross EntropyLoss () is the cross entropy loss function, θiFor local client model MiThe parameter of (1), eta, is the learning rate,
Figure FDA0003180070010000022
for local client model MiFor data set djPrediction of the nth data, j is the data label. Usage data set djAfter training, calculating local client model M in the training processiFor data set djAverage predicted fraction of total data in
Figure FDA0003180070010000023
Average prediction fraction
Figure FDA00031800700100000212
Obtained according to the following formula:
Figure FDA0003180070010000024
s62: using local data sets DiFor local model MiMultiple rounds of training were performed according to the following formula:
Figure FDA0003180070010000025
wherein KLDivLoss () is the relative entropy loss function, Cross EntropyLoss () is the cross entropy loss function, θiFor local client model MiWith η being the learning rate and α being the regularization parameter,
Figure FDA0003180070010000026
for local client model MiFor local data set DiThe prediction of the n-th data,
Figure FDA0003180070010000027
for local data sets DiThe label of the nth data is identified,
Figure FDA0003180070010000028
for local data sets DiAnd average prediction scores corresponding to the label categories of the nth data.
6. The heterogeneous model polymerization process of claim 1, wherein in step S10, the operation of knowledge distillation comprises:
using the shared dataset D and the global prediction score P, client M is trained in multiple rounds according to the following formula:
Figure FDA0003180070010000029
KLDivloss () is a relative entropy loss function, Cross EntropyLoss () is a cross entropy loss function, theta is a parameter of a service model M, eta is a learning rate, alpha is a regularization parameter, and M (D)n) Prediction of the nth data of the shared data set D for the server model M, ynFor the tag of the nth data of the shared data set D, PnIs the nth fraction of the global predicted fraction P and T represents the temperature of the knowledge distillation.
7. The heterogeneous model aggregation method according to claim 1, wherein in step S11, the specific operation of the cooperative training includes:
pairing local models M using a shared dataset DiMultiple rounds of training were performed according to the following formula:
Figure FDA00031800700100000210
wherein Cross EntropyLoss () is a cross entropy loss function, η is a learning rate, α is a weight factor of an additional term of the loss function, Mi(Dn) For local client model MiPrediction of the nth data of the shared data set D, ynTo share the tag of the nth data of the data set D,
Figure FDA00031800700100000211
for local client model MiPrediction and client model M for nth data of shared data set DjPredicted score for nth data of shared data set D
Figure FDA0003180070010000031
The JS divergence score of (1), wherein j is not equal to i; lambda [ alpha ]epochIs a weighting factor, where 0 < λ < 1, epoch represents the number of epochs currently trained.
8. A federated learning-based heterogeneous model aggregation system, comprising:
the enhanced iteration module is used for ensuring that the data types owned by the client model in the batch iteration process are complete and are uniformly distributed, so that the client model can continuously correct the gradient descending direction in the training process, and the gradient descends towards the optimal solution direction;
the knowledge distillation module is used for solving the reliability problem of the client model under the heterogeneous condition; each client sends a prediction score to the server in each round, the server calculates the weight of the prediction score of each client model by using a JS function, then calculates a global prediction score by weighted average, and the server model carries out knowledge distillation by the global prediction score;
and the cooperation training module is used for solving the problem of insufficient data representation of the client in the training process, and adding additional items by adopting a loss function to guide the gradient descending direction of the model so as to enable the client model to achieve the cooperation effect.
9. The heterogeneous model aggregation system of claim 8, wherein the cooperative training module employs a loss function of:
Figure FDA0003180070010000032
wherein Cr isossEntropyLoss () is a cross-entropy loss function, L (M)i) For local client model MiThe loss value of the training, eta is the learning rate, alpha is the weight factor of the extra term of the loss function, Mi(Dn) For local client model MiPrediction of the nth data of the shared data set D, ynTo share the tag of the nth data of the data set D,
Figure FDA0003180070010000033
for local client model MiPrediction and client model M for nth data of shared data set DjPredicted score for nth data of shared data set D
Figure FDA0003180070010000034
The JS divergence score of (1), wherein j is not equal to i; lambda [ alpha ]epochIs a weighting factor, where 0 < λ < 1, epoch represents the number of epochs currently trained.
10. The heterogeneous model polymerization system of claim 8, wherein the loss function of the knowledge distillation is:
L(M)=α·T2·KLDivLoss(M(Dn),Pn)+(1-α)·CrossEntropyLoss(M(Dn),yn)
wherein KLDivloss () is a relative entropy loss function, Cross EntropyLoss () is a cross entropy loss function, L (M) is a loss value of the service model M training, theta is a parameter of the service model M, eta is a learning rate, alpha is a regularization parameter, and M (D)n) Prediction of the nth data of the shared data set D for the server model M, ynFor the tag of the nth data of the shared data set D, PnIs the nth fraction of the global predicted fraction P and T represents the temperature of the knowledge distillation.
CN202110844739.0A 2021-07-26 2021-07-26 Heterogeneous model aggregation method and system based on federal learning Active CN113705610B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110844739.0A CN113705610B (en) 2021-07-26 2021-07-26 Heterogeneous model aggregation method and system based on federal learning

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110844739.0A CN113705610B (en) 2021-07-26 2021-07-26 Heterogeneous model aggregation method and system based on federal learning

Publications (2)

Publication Number Publication Date
CN113705610A true CN113705610A (en) 2021-11-26
CN113705610B CN113705610B (en) 2024-05-24

Family

ID=78650475

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110844739.0A Active CN113705610B (en) 2021-07-26 2021-07-26 Heterogeneous model aggregation method and system based on federal learning

Country Status (1)

Country Link
CN (1) CN113705610B (en)

Cited By (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114154647A (en) * 2021-12-07 2022-03-08 天津大学 Multi-granularity federated learning based method
CN114429223A (en) * 2022-01-26 2022-05-03 上海富数科技有限公司 Heterogeneous model establishing method and device
CN114492849A (en) * 2022-01-24 2022-05-13 光大科技有限公司 Model updating method and device based on federal learning
CN114626550A (en) * 2022-03-18 2022-06-14 支付宝(杭州)信息技术有限公司 Distributed model collaborative training method and system
CN114844889A (en) * 2022-04-14 2022-08-02 北京百度网讯科技有限公司 Video processing model updating method and device, electronic equipment and storage medium
CN114863169A (en) * 2022-04-27 2022-08-05 电子科技大学 Image classification method combining parallel ensemble learning and federal learning
CN115145966A (en) * 2022-09-05 2022-10-04 山东省计算中心(国家超级计算济南中心) Comparison federal learning method and system for heterogeneous data
CN115511108A (en) * 2022-09-27 2022-12-23 河南大学 Data set distillation-based federal learning personalized method
CN115775010A (en) * 2022-11-23 2023-03-10 国网江苏省电力有限公司信息通信分公司 Electric power data sharing method based on horizontal federal learning
TWI800304B (en) * 2022-03-16 2023-04-21 英業達股份有限公司 Fedrated learning system using synonym
CN116822647A (en) * 2023-05-25 2023-09-29 大连海事大学 Model interpretation method based on federal learning
CN117390448A (en) * 2023-10-25 2024-01-12 西安交通大学 Client model aggregation method and related system for inter-cloud federal learning

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113052334A (en) * 2021-04-14 2021-06-29 中南大学 Method and system for realizing federated learning, terminal equipment and readable storage medium
CN113112027A (en) * 2021-04-06 2021-07-13 杭州电子科技大学 Federal learning method based on dynamic adjustment model aggregation weight

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113112027A (en) * 2021-04-06 2021-07-13 杭州电子科技大学 Federal learning method based on dynamic adjustment model aggregation weight
CN113052334A (en) * 2021-04-14 2021-06-29 中南大学 Method and system for realizing federated learning, terminal equipment and readable storage medium

Cited By (18)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114154647A (en) * 2021-12-07 2022-03-08 天津大学 Multi-granularity federated learning based method
CN114492849A (en) * 2022-01-24 2022-05-13 光大科技有限公司 Model updating method and device based on federal learning
CN114492849B (en) * 2022-01-24 2023-09-08 光大科技有限公司 Model updating method and device based on federal learning
CN114429223A (en) * 2022-01-26 2022-05-03 上海富数科技有限公司 Heterogeneous model establishing method and device
CN114429223B (en) * 2022-01-26 2023-11-07 上海富数科技有限公司 Heterogeneous model building method and device
TWI800304B (en) * 2022-03-16 2023-04-21 英業達股份有限公司 Fedrated learning system using synonym
CN114626550A (en) * 2022-03-18 2022-06-14 支付宝(杭州)信息技术有限公司 Distributed model collaborative training method and system
CN114844889A (en) * 2022-04-14 2022-08-02 北京百度网讯科技有限公司 Video processing model updating method and device, electronic equipment and storage medium
CN114863169A (en) * 2022-04-27 2022-08-05 电子科技大学 Image classification method combining parallel ensemble learning and federal learning
CN115145966A (en) * 2022-09-05 2022-10-04 山东省计算中心(国家超级计算济南中心) Comparison federal learning method and system for heterogeneous data
CN115511108A (en) * 2022-09-27 2022-12-23 河南大学 Data set distillation-based federal learning personalized method
CN115511108B (en) * 2022-09-27 2024-07-12 河南大学 Federal learning individualization method based on data set distillation
CN115775010A (en) * 2022-11-23 2023-03-10 国网江苏省电力有限公司信息通信分公司 Electric power data sharing method based on horizontal federal learning
CN115775010B (en) * 2022-11-23 2024-03-19 国网江苏省电力有限公司信息通信分公司 Power data sharing method based on transverse federal learning
CN116822647A (en) * 2023-05-25 2023-09-29 大连海事大学 Model interpretation method based on federal learning
CN116822647B (en) * 2023-05-25 2024-01-16 大连海事大学 Model interpretation method based on federal learning
CN117390448A (en) * 2023-10-25 2024-01-12 西安交通大学 Client model aggregation method and related system for inter-cloud federal learning
CN117390448B (en) * 2023-10-25 2024-04-26 西安交通大学 Client model aggregation method and related system for inter-cloud federal learning

Also Published As

Publication number Publication date
CN113705610B (en) 2024-05-24

Similar Documents

Publication Publication Date Title
CN113705610A (en) Heterogeneous model aggregation method and system based on federal learning
Mills et al. Communication-efficient federated learning for wireless edge intelligence in IoT
Wang et al. Fast adaptive task offloading in edge computing based on meta reinforcement learning
Itahara et al. Distillation-based semi-supervised federated learning for communication-efficient collaborative training with non-iid private data
Zhang et al. MR-DRO: A fast and efficient task offloading algorithm in heterogeneous edge/cloud computing environments
US11715044B2 (en) Methods and systems for horizontal federated learning using non-IID data
CN110113190A (en) Time delay optimization method is unloaded in a kind of mobile edge calculations scene
Lu et al. Auction-based cluster federated learning in mobile edge computing systems
CN108873936B (en) Autonomous aircraft formation method based on potential game
Djigal et al. Machine and deep learning for resource allocation in multi-access edge computing: A survey
CN113469325A (en) Layered federated learning method, computer equipment and storage medium for edge aggregation interval adaptive control
CN107103359A (en) The online Reliability Prediction Method of big service system based on convolutional neural networks
CN115374853A (en) Asynchronous federal learning method and system based on T-Step polymerization algorithm
CN114091667A (en) Federal mutual learning model training method oriented to non-independent same distribution data
Long et al. Fedsiam: Towards adaptive federated semi-supervised learning
CN115879542A (en) Federal learning method oriented to non-independent same-distribution heterogeneous data
Liu et al. Enhancing federated learning with intelligent model migration in heterogeneous edge computing
CN116645130A (en) Automobile order demand prediction method based on combination of federal learning and GRU
CN115686846A (en) Container cluster online deployment method for fusing graph neural network and reinforcement learning in edge computing
Huang et al. Active client selection for clustered federated learning
CN112667912B (en) Task amount prediction method of edge server
Hu et al. Communication-efficient federated learning in channel constrained internet of things
Yuan et al. Accuracy rate maximization in edge federated learning with delay and energy constraints
CN114022731A (en) Federal learning node selection method based on DRL
Balevi et al. Synergies between cloud-fag-thing and brain-spinal cord-nerve networks

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant