WO2023082406A1 - Federated learning-based electroencephalogram signal classification model training method and device - Google Patents

Federated learning-based electroencephalogram signal classification model training method and device Download PDF

Info

Publication number
WO2023082406A1
WO2023082406A1 PCT/CN2021/138013 CN2021138013W WO2023082406A1 WO 2023082406 A1 WO2023082406 A1 WO 2023082406A1 CN 2021138013 W CN2021138013 W CN 2021138013W WO 2023082406 A1 WO2023082406 A1 WO 2023082406A1
Authority
WO
WIPO (PCT)
Prior art keywords
client
classification model
signal classification
eeg signal
local
Prior art date
Application number
PCT/CN2021/138013
Other languages
French (fr)
Chinese (zh)
Inventor
郑青青
陈彦锋
王琼
Original Assignee
中国科学院深圳先进技术研究院
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 中国科学院深圳先进技术研究院 filed Critical 中国科学院深圳先进技术研究院
Publication of WO2023082406A1 publication Critical patent/WO2023082406A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F2218/00Aspects of pattern recognition specially adapted for signal processing
    • G06F2218/12Classification; Matching
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F3/00Input arrangements for transferring data to be processed into a form capable of being handled by the computer; Output arrangements for transferring data from processing unit to output unit, e.g. interface arrangements
    • G06F3/01Input arrangements or combined input and output arrangements for interaction between user and computer
    • G06F3/011Arrangements for interaction with the human body, e.g. for user immersion in virtual reality
    • G06F3/015Input arrangements based on nervous system activity detection, e.g. brain waves [EEG] detection, electromyograms [EMG] detection, electrodermal response detection
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F2218/00Aspects of pattern recognition specially adapted for signal processing
    • G06F2218/08Feature extraction

Definitions

  • the present application belongs to the field of biological information technology, and in particular relates to a federated learning-based brain electrical signal classification model training method and device.
  • the brain-computer interface (BCI, Brain Computer Interface) based on emotion recognition can identify the user's true emotional state and intention by collecting the user's EEG signal in the emotional interaction experiment, and extracting and decoding the EEG signal's features.
  • Sentiment analysis based on EEG signals has a wide range of application scenarios, such as auxiliary diagnosis of affective disorders and psychotherapy interventions such as depression.
  • Emotion recognition models based on deep learning are often data-driven and require a large amount of training data.
  • EEG Electroencephalographic
  • EEG data are often scattered among various users in the form of multiple small data sets.
  • existing methods focus on sharing data between different users, using technologies such as knowledge transfer and domain adaptation to effectively use useful information from other users and improve the emotion recognition rate of target users.
  • the EEG signals contain private information such as people's identity characteristics, thoughts and emotions, once they are misused or illegally read and disseminated, personal privacy will be leaked.
  • the EEG signal classification models mainly include: EEG signal classification models based on EEGNet (EEGNet is a general and compact convolutional neural network designed for general EEG recognition tasks), and federated transfer learning (FTL, Federate Transfer Learning) EEG signal classification model.
  • EEGNet uses the original EEG signal as input to train an end-to-end competitive emotion recognition network for each user.
  • FTL Federate Transfer Learning
  • EEGNet uses the original EEG signal as input to train an end-to-end competitive emotion recognition network for each user.
  • FTL Federate Transfer Learning
  • the network based on EEGNet training can only use The local data of each user trains the emotion recognition network separately, ignoring the data of other users and the effective information that can be provided, resulting in the problem of data waste.
  • the FTL-based method utilizes federated learning to effectively utilize other users' data information, it also meets the needs of users' local data not to be shared.
  • this method uses the spatial covariance matrix of the EEG signal as input, and part of the effective information of the original EEG signal is lost.
  • FTL relies on the federated averaging algorithm, which randomly selects the gradients of some local models in the process of joint training, and updates the gradients of the server through indiscriminate simple average aggregation, ignoring the data quality and importance of different users. This will cause the gradient change of the server model to be updated each time to be unstable, which is not conducive to the accuracy of the shared model (that is, the EEG signal classification model), and the convergence speed is often slow, which makes the model training difficult.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training method and device, which can solve the problems of low precision and slow convergence speed of the EEG signal classification model.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training method, which is applied to the server side, and the method includes:
  • the local model gradient is obtained by the client using a local training set to train the EEG classification model
  • the step of determining a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals includes:
  • a preset proportion of user terminals is selected from the K said user terminals as target user terminals.
  • the step of updating the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target clients includes:
  • the network parameters of the EEG signal classification model at the server end are updated according to the updated global gradient.
  • the step of updating the global gradient of the server according to the normalized importance evaluation value and the local model gradients of all target clients includes:
  • the step of normalizing the importance evaluation value of each target client includes:
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • the step of obtaining the importance evaluation value of each client according to the local model gradient of each client includes:
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the method also includes:
  • the EEG signal classification model at the server end When the EEG signal classification model at the server end converges, the EEG signal classification model at the server end is delivered to the K client ends.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training device, which is applied to the server side, and the device includes:
  • a sending module configured to send the EEG signal classification model at the server end to K client ends;
  • a receiving module configured to receive a local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG classification model;
  • An acquisition module configured to acquire the importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals
  • a first determination module configured to determine a plurality of target user terminals from the K user terminals according to the importance evaluation values of the K user terminals
  • An update module configured to update the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target client terminals;
  • the second determination module is used to return to execute the step of sending the server-side EEG signal classification model to K clients if the server-side EEG signal classification model does not converge until the server-side EEG signal classification model Signal classification model converges.
  • the above-mentioned first determination module 304 is specifically configured to select a preset proportion of user terminals from the K user terminals as target user terminals in descending order of importance evaluation values.
  • the above-mentioned update module 305 includes:
  • a processing unit configured to perform normalization processing on the importance evaluation value of each target client
  • the first update unit is configured to update the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
  • the second update unit is configured to update the network parameters of the EEG signal classification model at the server end according to the updated global gradient.
  • the above-mentioned first update unit is specifically used to pass the formula updating the global gradient on the server side;
  • the above processing unit is specifically used to pass the formula Normalize the importance evaluation value of each selected client;
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the above-mentioned EEG signal classification model training device also includes:
  • the sending module is configured to send the EEG signal classification model at the server end to the K client terminals when the EEG signal classification model at the server end converges.
  • an embodiment of the present application provides a server, including a memory, a processor, and a computer program stored in the memory and operable on the processor, and the computer program is implemented when the processor executes the computer program. the above method.
  • an embodiment of the present application provides a computer-readable storage medium, where the computer-readable storage medium stores a computer program, and when the computer program is executed by a processor, the foregoing method is implemented.
  • an embodiment of the present application provides a computer program product, which, when the computer program product is run on a terminal device, causes the terminal device to execute the method described in any one of the foregoing first aspects.
  • joint training and distributed training can be realized to fully utilize all users' effective The effect of information on improving the accuracy of EEG signal classification models.
  • the target client since the target client is not randomly selected, but through the importance evaluation value of each client, the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient of the target client and The importance evaluation value updates the network parameters of the EEG signal classification model on the server side, thereby improving the accuracy and convergence speed of the EEG signal classification model.
  • Fig. 1 is the flow chart of the federated learning-based EEG classification model training method provided by an embodiment of the present application
  • Fig. 2 is a flowchart of step 15 provided by an embodiment of the present application.
  • FIG. 3 is a schematic structural diagram of a federated learning-based EEG classification model training device provided by an embodiment of the present application
  • Fig. 4 is a schematic structural diagram of a server provided by an embodiment of the present application.
  • the term “if” may be construed, depending on the context, as “when” or “once” or “in response to determining” or “in response to detecting “.
  • the phrase “if determined” or “if [the described condition or event] is detected” may be construed, depending on the context, to mean “once determined” or “in response to the determination” or “once detected [the described condition or event] ]” or “in response to detection of [described condition or event]”.
  • references to "one embodiment” or “some embodiments” or the like in the specification of the present application means that a particular feature, structure, or characteristic described in connection with the embodiment is included in one or more embodiments of the present application.
  • appearances of the phrases “in one embodiment,” “in some embodiments,” “in other embodiments,” “in other embodiments,” etc. in various places in this specification are not necessarily All refer to the same embodiment, but mean “one or more but not all embodiments” unless specifically stated otherwise.
  • the terms “including”, “comprising”, “having” and variations thereof mean “including but not limited to”, unless specifically stated otherwise.
  • the EEG signal classification models mainly include the EEGNet-based EEG signal classification model and the FTL-based EEG signal classification model.
  • the accuracy of the EEG classification model based on EEGNet is low, while the convergence speed of the EEG classification model based on FTL is slow and the accuracy is not ideal.
  • the embodiment of the present application is based on the federated learning framework, by sending the server-side EEG signal classification model to K client terminals during distributed training, so that each client terminal can use the local training set to analyze the received EEG signal
  • the classification model is trained, and the local model gradient obtained by training is sent to the server for joint training, so that joint training and its distributed Training has achieved the effect of improving the accuracy of the EEG signal classification model while making full use of the effective information of all users.
  • the target client since the target client is not randomly selected, but through the importance evaluation value of each client, the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient of the target client and The importance evaluation value updates the network parameters of the EEG signal classification model on the server side, thereby improving the accuracy and convergence speed of the EEG signal classification model.
  • the federated learning-based EEG classification model training method provided by the present application will be exemplarily described below in conjunction with specific embodiments.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training method, which is applied to the server side, and the method includes the following steps:
  • Step 11 sending the EEG signal classification model at the server end to K client ends.
  • the above-mentioned K clients are clients that participate in federated learning with the above-mentioned server.
  • the server end can initialize an EEG signal classification model (ie, the above steps EEG classification model in 11).
  • the weight of the model can be initialized to 0, or other common initialization schemes can be used, such as Gaussian and Xavier initialization (Xavier initialization is a neural network initialization method).
  • the above-mentioned EEG signal classification model may be an EEGNet model, of course, it may also be other deep learning networks, such as a convolutional neural network (ConvNet) and other EEG signal classification neural networks.
  • ConvNet convolutional neural network
  • Step 12 receiving the local model gradient sent by each client.
  • the aforementioned local model gradient is obtained by the user terminal using a local training set to train the EEG signal classification model.
  • the server after receiving the EEG signal classification model issued by the server, it will use the local training set of the client to analyze the received EEG signal
  • the classification model is trained, and the local model gradient is obtained when the EEG classification model converges.
  • Step 13 according to the local model gradient of each client, obtain the importance evaluation value of each client.
  • the above-mentioned importance evaluation value is mainly used to characterize the importance degree of the client end, so that subsequently, according to the order of importance from high to low, select from the K End EEG signal classification model) to jointly train the target user end to improve the accuracy and convergence speed of the EEG signal classification model.
  • Step 14 Determine a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals.
  • a preset proportion of user terminals can be selected from the K user terminals as the target user terminals, thereby screening from the K user terminals Target users with high importance.
  • the specific numerical value of the above preset ratio can be set according to the actual situation.
  • the importance of the above-mentioned target client is higher than that of other client in the K client, that is, the target client pairs the shared model (ie, the EEG signal classification model at the server end) The contribution is greater than the contribution of other clients to the shared model. Subsequent joint training using these target clients can improve the accuracy and convergence speed of the EEG signal classification model.
  • Step 15 Update the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the multiple target client ends.
  • the performance of the server-side EEG signal classification model in the joint training, by updating the network parameters of the server-side EEG signal classification model according to the target client's local model gradient and importance evaluation value, the performance of the server-side EEG signal classification model can be improved. accuracy and convergence speed.
  • Step 16 if the EEG signal classification model at the server end does not converge, return to the step of sending the EEG signal classification model at the server end to K clients until the EEG signal classification model at the server end converges .
  • the EEG signal classification model converged in step 16 above is a shared model, which can be used to classify EEG signals of any user.
  • step 15 if the server-side EEG signal classification model does not converge, then return to step 11 to update the network parameters of the server-side EEG signal classification model again until the server-side End EEG signal classification model converges.
  • the updated EEG signal classification model is Share the model, otherwise, send the EEG signal classification model after updating the network parameters to K clients, so that K clients use their own local training sets to train the received EEG signal classification model, and the obtained Local model gradients to update the network parameters of the server-side EEG signal classification model again.
  • the local training set data of the client is not directly used, but the local model gradient of the client is used to jointly train the EEG classification model of the server, so that The privacy and use security of the local data of the client are guaranteed.
  • joint training and distributed training can be realized, and the effective utilization of all users can be achieved. The effect of information on improving the accuracy of EEG signal classification models.
  • the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient and importance of the target client.
  • the network parameters of the EEG signal classification model on the server side are updated, thereby improving the accuracy and convergence speed of the EEG signal classification model.
  • the above method further includes the following steps: when the server-side EEG signal classification model converges, sending the server-side EEG signal classification model to the The above K clients.
  • the client can use its own local training set to train the EEG signal classification model to fine-tune the model parameters of the EEG signal classification model, An EEG signal classification model that is more suitable for the client terminal is obtained, and the client terminal can then use the fine-tuned EEG signal classification model to classify the user data of the client terminal to improve classification accuracy.
  • the local training set of the user terminal may be derived from Shanghai Jiaotong University Emotional EEG Dataset (SEED).
  • SEED Shanghai Jiaotong University Emotional EEG Dataset
  • 15 screened Chinese movie clips were selected as emotional stimuli in the experiment, and the labels included positive, neutral, and negative emotions.
  • a total of 15 Chinese subjects (including 7 boys and 8 girls) were collected in this dataset, and each subject conducted 3 experiments.
  • Each sample in this dataset contains 62 electrode channels, downsampled to 200 Hz, and a bandpass frequency filter of 0–75 Hz is applied.
  • the embodiment of the present application selects 32 channels related to emotion, corresponding to Fp1, AF3, F3, F7, FC5, FC1, C3, T7, CP5, CP1, P3, P7, PO3, O1 , Oz, Pz, Fp2, AF4, Fz, F4, F8, FC6, FC2, Cz, C4, T8, CP6, CP2, P4, P8, PO4, O2.
  • each sample has a size of 32 ⁇ 200.
  • the data of 32 channels of any one of the 15 subjects can be used as a local training set of the client.
  • the user terminal can use all the data in the local training set to train the EEG signal classification model each time. It should be further explained that the local training set corresponding to each client is different.
  • the EEG signal classification model adopts the EEGNet model to extract the feature representation and classification of the EEG signal.
  • the parameters of the feature extractor and classifier model in this application are shown in Table 1.
  • Table 1 the number of convolutional layers, the size of the convolutional kernel, the pooling method, and the activation function can all be set according to the actual situation.
  • the cross entropy (cross entropy) loss function can be used to evaluate the training result, wherein the training loss function of the kth user end is as follows: Among them, n k represents the local sample size contained in the local training set of the kth client, y i is the real label of the training sample (that is, the local sample in the local training set), is the predicted label. It should be noted that the above training loss function is a commonly used loss function, so here, the principle of the training loss function will not be described in detail.
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the importance of the user terminal may also be measured by other similarity measurement learning methods or attention mechanism algorithms.
  • the above step 15 according to the local model gradients and importance evaluation values of the multiple target client terminals, updates the network parameters of the EEG signal classification model at the server end.
  • the specific implementation method includes the following steps:
  • Step 21 performing normalization processing on the importance evaluation value of each target client.
  • the formula Perform normalization processing on the importance evaluation value of each selected client.
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • Step 22 update the global gradient of the server according to the normalized importance evaluation value and the local model gradients of all target clients.
  • Step 23 Update network parameters of the server-side EEG signal classification model according to the updated global gradient.
  • a stochastic gradient descent method based on stochastic gradient descent may be used to solve network parameters.
  • SGD stochastic gradient descent
  • the global gradient of the server side will also be initialized to 0.
  • the federated learning-based EEG signal classification model training method provided by the embodiment of the present application has the following effects:
  • the EEG signal classification model adopts the EEGNet model, which is applied to the classification task of emotional EEG signals. It does not need to manually extract signal features, and can perform end-to-end feature extraction and classification of emotional EEG signals;
  • the federated learning-based EEG classification model training device provided by the present application will be exemplarily described below in conjunction with specific embodiments.
  • the EEG signal classification model training device 300 includes:
  • Sending module 301 for sending the EEG signal classification model of the server end to K client ends;
  • the receiving module 302 is configured to receive the local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG signal classification model;
  • An acquisition module 303 configured to acquire the importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals;
  • the first determination module 304 is configured to determine a plurality of target user terminals from the K user terminals according to the importance evaluation values of the K user terminals;
  • An update module 305 configured to update the network parameters of the EEG classification model on the server side according to the local model gradients and importance evaluation values of the multiple target client terminals;
  • the second determination module 306 is configured to return to execute the step of sending the server-side EEG signal classification model to K client terminals if the server-side EEG signal classification model does not converge until the server-side EEG signal classification model Electrical signal classification model converges.
  • the above-mentioned first determination module 304 is specifically configured to select a preset proportion of user terminals from the K user terminals as target user terminals in descending order of importance evaluation values.
  • the above-mentioned update module 305 includes:
  • a processing unit configured to perform normalization processing on the importance evaluation value of each target client
  • the first update unit is configured to update the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
  • the second update unit is used to update the network parameters of the EEG signal classification model at the server end according to the updated global gradient.
  • the above-mentioned first update unit is specifically used to pass the formula updating the global gradient on the server side;
  • the above processing unit is specifically used to pass the formula Normalize the importance evaluation value of each selected client;
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the above-mentioned EEG signal classification model training device also includes:
  • the sending module is configured to send the EEG signal classification model at the server end to the K client terminals when the EEG signal classification model at the server end converges.
  • an embodiment of the present application provides a server, as shown in Figure 4, the server D10 of this embodiment includes: at least one processor D100 (only one processor is shown in Figure 4), a memory D101 And a computer program D102 stored in the memory D101 and operable on the at least one processor D100, when the processor D100 executes the computer program D102, the steps in any of the above method embodiments are implemented.
  • the so-called processor D100 can be a central processing unit (CPU, Central Processing Unit), and the processor D100 can also be other general processors, digital signal processors (DSP, Digital Signal Processor), application specific integrated circuits (ASIC, Application Specific Integrated Circuit), off-the-shelf programmable gate array (FPGA, Field-Programmable Gate Array) or other programmable logic devices, discrete gate or transistor logic devices, discrete hardware components, etc.
  • DSP digital signal processors
  • ASIC Application Specific Integrated Circuit
  • FPGA Field-Programmable Gate Array
  • a general-purpose processor may be a microprocessor, or the processor may be any conventional processor, or the like.
  • the storage D101 may be an internal storage unit of the server D10 in some embodiments, such as a hard disk or a memory of the server D10.
  • the memory D101 can also be an external storage device of the server D10 in other embodiments, such as a plug-in hard disk equipped on the server D10, a smart memory card (SMC, Smart Media Card), a secure digital (SD , Secure Digital) card, flash memory card (Flash Card), etc.
  • the storage D101 may also include both an internal storage unit of the server D10 and an external storage device.
  • the memory D101 is used to store operating systems, application programs, boot loaders (BootLoader), data and other programs, such as program codes of the computer programs.
  • the memory D101 can also be used to temporarily store data that has been output or will be output.
  • the embodiment of the present application also provides a computer-readable storage medium, the computer-readable storage medium stores a computer program, and when the computer program is executed by a processor, the steps in each of the foregoing method embodiments can be realized.
  • An embodiment of the present application provides a computer program product.
  • the terminal device can implement the steps in the foregoing method embodiments when executed.
  • the integrated unit is realized in the form of a software function unit and sold or used as an independent product, it can be stored in a computer-readable storage medium. Based on this understanding, all or part of the procedures in the method of the above-mentioned embodiments in the present application can be completed by instructing related hardware through a computer program.
  • the computer program can be stored in a computer-readable storage medium.
  • the computer program When executed by a processor, the steps in the above-mentioned various method embodiments can be realized.
  • the computer program includes computer program code, and the computer program code may be in the form of source code, object code, executable file or some intermediate form.
  • the computer-readable medium may at least include: any entity or device, recording medium, computer memory, read-only memory (ROM, Read-Only Memory) capable of carrying computer program codes to the EEG signal classification model training device/terminal equipment , Random Access Memory (RAM, Random Access Memory), electrical carrier signal, telecommunication signal and software distribution medium.
  • ROM read-only memory
  • RAM Random Access Memory
  • electrical carrier signal telecommunication signal and software distribution medium.
  • U disk mobile hard disk, magnetic disk or optical disk, etc.
  • computer readable media may not be electrical carrier signals and telecommunication signals under legislation and patent practice.
  • the disclosed device/network device and method may be implemented in other ways.
  • the device/network device embodiments described above are only illustrative.
  • the division of the modules or units is only a logical function division.
  • the mutual coupling or direct coupling or communication connection shown or discussed may be through some interfaces, and the indirect coupling or communication connection of devices or units may be in electrical, mechanical or other forms.
  • the units described as separate components may or may not be physically separated, and the components displayed as units may or may not be physical units, that is, they may be located in one place, or may be distributed to multiple network units. Part or all of the units can be selected according to actual needs to achieve the purpose of the solution of this embodiment.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Physics & Mathematics (AREA)
  • Health & Medical Sciences (AREA)
  • Signal Processing (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Artificial Intelligence (AREA)
  • Biomedical Technology (AREA)
  • Dermatology (AREA)
  • General Health & Medical Sciences (AREA)
  • Neurology (AREA)
  • Neurosurgery (AREA)
  • Human Computer Interaction (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Measurement And Recording Of Electrical Phenomena And Electrical Characteristics Of The Living Body (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

The present application is suitable for the technical field of biological information, and provides a federated learning-based electroencephalogram signal classification model training method and device. The method comprises: sending an electroencephalogram signal classification model of a server to K clients; receiving a local model gradient sent by each client; obtaining an importance evaluated value of the client according to the local model gradient of the client; determining multiple target clients from the K clients according to the importance evaluated values of the K clients; updating network parameters of the electroencephalogram signal classification model of the server according to the local model gradients and importance evaluated values of the target clients; and if the electroencephalogram signal classification model of the server is not converged, returning to the step of sending the electroencephalogram signal classification model of the server to the K clients until the electroencephalogram signal classification model of the server is converged. The present application can improve the precision and convergence speed of the electroencephalogram signal classification model while fully using effective information of all users.

Description

基于联邦学习的脑电信号分类模型训练方法及装置EEG signal classification model training method and device based on federated learning 技术领域technical field
本申请属于生物信息技术领域,尤其涉及一种基于联邦学习的脑电信号分类模型训练方法及装置。The present application belongs to the field of biological information technology, and in particular relates to a federated learning-based brain electrical signal classification model training method and device.
背景技术Background technique
基于情感识别的脑机接口(BCI,Brain Computer Interface)通过在情感交互实验中采集用户的脑电信号,并对脑电信号进行特征提取和解码,可以识别用户真正的情感状态和意图,从而实现用户和设备间的友好通信及交互。基于脑电信号的情感分析有广泛的应用场景,例如情感障碍疾病的辅助诊断和抑郁症等心理治疗干预等。The brain-computer interface (BCI, Brain Computer Interface) based on emotion recognition can identify the user's true emotional state and intention by collecting the user's EEG signal in the emotional interaction experiment, and extracting and decoding the EEG signal's features. Friendly communication and interaction between users and devices. Sentiment analysis based on EEG signals has a wide range of application scenarios, such as auxiliary diagnosis of affective disorders and psychotherapy interventions such as depression.
基于深度学习的情感识别模型往往是数据驱动,要求有大量的训练数据。然而由于脑电图(EEG,Electroencephalographic)信号的采集过程繁琐及个体间差异性巨大的特点,EEG数据往往以多个小数据集的形式分散存在于各个用户。为了构建高精度的情感识别模型,现有方法致力于通过共享不同用户之间的数据,利用知识迁移和领域自适应等技术来有效利用其他用户的有用信息和提升目标用户的情感识别率。但在数据共享的过程中,如果包含了人的身份特征及思想情感等私密信息的脑电信号,一旦被滥用或者非法阅读传播,将造成个人隐私的泄露。Emotion recognition models based on deep learning are often data-driven and require a large amount of training data. However, due to the cumbersome acquisition process of EEG (Electroencephalographic) signals and the characteristics of huge differences among individuals, EEG data are often scattered among various users in the form of multiple small data sets. In order to build a high-precision emotion recognition model, existing methods focus on sharing data between different users, using technologies such as knowledge transfer and domain adaptation to effectively use useful information from other users and improve the emotion recognition rate of target users. However, in the process of data sharing, if the EEG signals contain private information such as people's identity characteristics, thoughts and emotions, once they are misused or illegally read and disseminated, personal privacy will be leaked.
目前脑电信号分类模型主要有:基于EEGNet(EEGNet是为专门一般的脑电图识别任务而设计的通用紧凑的卷积神经网络)的脑电信号分类模型,和基于联邦迁移学***均算法,该算法在联合训练的过程中,随机选择部分本地模型的梯度,通过无区别的简单平均聚合来更新服务器的梯度,忽略了不同用户的数据质量和重要性,这将导致每次更新服务器模型的梯度变化不稳定,不利于共享模型(即脑电信号分类模型)的精度,而且往往收敛速度慢,给模型训练造成一定难度。At present, the EEG signal classification models mainly include: EEG signal classification models based on EEGNet (EEGNet is a general and compact convolutional neural network designed for general EEG recognition tasks), and federated transfer learning (FTL, Federate Transfer Learning) EEG signal classification model. Among them, EEGNet uses the original EEG signal as input to train an end-to-end competitive emotion recognition network for each user. However, due to the large individual differences of users' EEG signals, directly using all users' data to train a unified network often leads to low accuracy of the shared model (ie EEG signal classification model). Therefore, the network based on EEGNet training can only use The local data of each user trains the emotion recognition network separately, ignoring the data of other users and the effective information that can be provided, resulting in the problem of data waste. Although the FTL-based method utilizes federated learning to effectively utilize other users' data information, it also meets the needs of users' local data not to be shared. However, this method uses the spatial covariance matrix of the EEG signal as input, and part of the effective information of the original EEG signal is lost. In addition, FTL relies on the federated averaging algorithm, which randomly selects the gradients of some local models in the process of joint training, and updates the gradients of the server through indiscriminate simple average aggregation, ignoring the data quality and importance of different users. This will cause the gradient change of the server model to be updated each time to be unstable, which is not conducive to the accuracy of the shared model (that is, the EEG signal classification model), and the convergence speed is often slow, which makes the model training difficult.
发明内容Contents of the invention
本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练方法及装置,可以解决脑电信号分类模型的精度低、且收敛速度慢的问题。The embodiment of the present application provides a federated learning-based EEG signal classification model training method and device, which can solve the problems of low precision and slow convergence speed of the EEG signal classification model.
第一方面,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练方法,应用于服务器端,该方法包括:In the first aspect, the embodiment of the present application provides a federated learning-based EEG signal classification model training method, which is applied to the server side, and the method includes:
将所述服务器端的脑电信号分类模型发送给K个用户端;Send the EEG signal classification model of the server end to K client ends;
接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;receiving a local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG classification model;
根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;Acquiring an importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals;
根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;determining a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals;
根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;Updating the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target client ends;
若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务 器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。If the EEG signal classification model at the server end does not converge, then return to execute the step of sending the EEG signal classification model at the server end to K client ends until the EEG signal classification model at the server end converges.
其中,所述根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端的步骤,包括:Wherein, the step of determining a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals includes:
按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。According to the descending order of the importance evaluation value, a preset proportion of user terminals is selected from the K said user terminals as target user terminals.
其中,所述根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的步骤,包括:Wherein, the step of updating the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target clients includes:
对每个所述目标用户端的重要性评估值进行归一化处理;performing normalization processing on the importance evaluation value of each target client;
根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;Updating the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。The network parameters of the EEG signal classification model at the server end are updated according to the updated global gradient.
其中,所述根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度的步骤,包括:Wherein, the step of updating the global gradient of the server according to the normalized importance evaluation value and the local model gradients of all target clients includes:
通过公式
Figure PCTCN2021138013-appb-000001
更新所述服务器端的全局梯度;
by formula
Figure PCTCN2021138013-appb-000001
updating the global gradient on the server side;
其中,
Figure PCTCN2021138013-appb-000002
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000003
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000004
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
in,
Figure PCTCN2021138013-appb-000002
Indicates the global gradient obtained by the t-th round of update, C indicates the preset ratio, K indicates the number of clients,
Figure PCTCN2021138013-appb-000003
Indicates the importance evaluation value of the kth client after normalization processing,
Figure PCTCN2021138013-appb-000004
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
其中,所述对每个所述目标用户端的重要性评估值进行归一化处理的步骤,包括:Wherein, the step of normalizing the importance evaluation value of each target client includes:
通过公式
Figure PCTCN2021138013-appb-000005
对选择出的每个用户端的重要性评估值进行归一化处理;
by formula
Figure PCTCN2021138013-appb-000005
Normalize the importance evaluation value of each selected client;
其中,
Figure PCTCN2021138013-appb-000006
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
in,
Figure PCTCN2021138013-appb-000006
Indicates the importance evaluation value of the kth client after normalization processing, μ k represents the importance evaluation value of the kth client, C represents a preset ratio, and K represents the number of clients.
其中,所述根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的步骤,包括:Wherein, the step of obtaining the importance evaluation value of each client according to the local model gradient of each client includes:
通过公式μ k=α k×β k,计算第k个用户端的重要性评估值; Calculate the importance evaluation value of the kth user terminal by the formula μ kk ×β k ;
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000007
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000008
Figure PCTCN2021138013-appb-000009
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000010
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
Among them, μ k represents the importance evaluation value of the kth client, α k = n k /n, nk represents the local sample size contained in the local training set of the kth client,
Figure PCTCN2021138013-appb-000007
n represents the sum of the local samples contained in the local training set of K clients, K represents the number of clients,
Figure PCTCN2021138013-appb-000008
Figure PCTCN2021138013-appb-000009
Indicates the global gradient of the server side at the time of the t-1 round update,
Figure PCTCN2021138013-appb-000010
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
其中,所述方法还包括:Wherein, the method also includes:
在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。When the EEG signal classification model at the server end converges, the EEG signal classification model at the server end is delivered to the K client ends.
第二方面,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练装置,应用于服务器端,该装置包括:In the second aspect, the embodiment of the present application provides a federated learning-based EEG signal classification model training device, which is applied to the server side, and the device includes:
发送模块,用于将所述服务器端的脑电信号分类模型发送给K个用户端;A sending module, configured to send the EEG signal classification model at the server end to K client ends;
接收模块,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;A receiving module, configured to receive a local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG classification model;
获取模块,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;An acquisition module, configured to acquire the importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals;
第一确定模块,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;A first determination module, configured to determine a plurality of target user terminals from the K user terminals according to the importance evaluation values of the K user terminals;
更新模块,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;An update module, configured to update the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target client terminals;
第二确定模块,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。The second determination module is used to return to execute the step of sending the server-side EEG signal classification model to K clients if the server-side EEG signal classification model does not converge until the server-side EEG signal classification model Signal classification model converges.
其中,上述第一确定模块304,具体用于按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。Wherein, the above-mentioned first determination module 304 is specifically configured to select a preset proportion of user terminals from the K user terminals as target user terminals in descending order of importance evaluation values.
其中,上述更新模块305包括:Wherein, the above-mentioned update module 305 includes:
处理单元,用于对每个所述目标用户端的重要性评估值进行归一化处理;a processing unit, configured to perform normalization processing on the importance evaluation value of each target client;
第一更新单元,用于根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;The first update unit is configured to update the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
第二更新单元,用于根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。The second update unit is configured to update the network parameters of the EEG signal classification model at the server end according to the updated global gradient.
其中,上述第一更新单元,具体用于通过公式
Figure PCTCN2021138013-appb-000011
更新所述服务器端的全局梯度;
Wherein, the above-mentioned first update unit is specifically used to pass the formula
Figure PCTCN2021138013-appb-000011
updating the global gradient on the server side;
其中,
Figure PCTCN2021138013-appb-000012
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000013
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000014
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
in,
Figure PCTCN2021138013-appb-000012
Indicates the global gradient obtained by the t-th round of update, C indicates the preset ratio, K indicates the number of clients,
Figure PCTCN2021138013-appb-000013
Indicates the importance evaluation value of the kth client after normalization processing,
Figure PCTCN2021138013-appb-000014
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
其中,上述处理单元,具体用于通过公式
Figure PCTCN2021138013-appb-000015
对选择出的每个用户端的重要性评估值进行归一化处理;
Among them, the above processing unit is specifically used to pass the formula
Figure PCTCN2021138013-appb-000015
Normalize the importance evaluation value of each selected client;
其中,
Figure PCTCN2021138013-appb-000016
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
in,
Figure PCTCN2021138013-appb-000016
Indicates the importance evaluation value of the kth client after normalization processing, μ k represents the importance evaluation value of the kth client, C represents a preset ratio, and K represents the number of clients.
其中,上述获取模块303,具体用于通过公式μ k=α k×β k,计算第k个用户端的重要性评估值; Wherein, the above acquisition module 303 is specifically used to calculate the importance evaluation value of the kth user terminal through the formula μ kk ×β k ;
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000017
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000018
Figure PCTCN2021138013-appb-000019
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000020
表示第t轮更新时 第k个用户端的本地模型梯度,t为大于0的整数。
Among them, μ k represents the importance evaluation value of the kth client, α k = n k /n, nk represents the local sample size contained in the local training set of the kth client,
Figure PCTCN2021138013-appb-000017
n represents the sum of the local samples contained in the local training set of K clients, K represents the number of clients,
Figure PCTCN2021138013-appb-000018
Figure PCTCN2021138013-appb-000019
Indicates the global gradient of the server side at the time of the t-1 round update,
Figure PCTCN2021138013-appb-000020
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
其中,上述脑电信号分类模型训练装置还包括:Wherein, the above-mentioned EEG signal classification model training device also includes:
下发模块,用于在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。The sending module is configured to send the EEG signal classification model at the server end to the K client terminals when the EEG signal classification model at the server end converges.
第三方面,本申请实施例提供了一种服务器,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述的方法。In a third aspect, an embodiment of the present application provides a server, including a memory, a processor, and a computer program stored in the memory and operable on the processor, and the computer program is implemented when the processor executes the computer program. the above method.
第四方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述的方法。In a fourth aspect, an embodiment of the present application provides a computer-readable storage medium, where the computer-readable storage medium stores a computer program, and when the computer program is executed by a processor, the foregoing method is implemented.
第五方面,本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得终端设备执行上述第一方面中任一项所述的方法。In a fifth aspect, an embodiment of the present application provides a computer program product, which, when the computer program product is run on a terminal device, causes the terminal device to execute the method described in any one of the foregoing first aspects.
本申请实施例与现有技术相比存在的有益效果是:Compared with the prior art, the embodiments of the present application have the following beneficial effects:
在本申请的实施例中,通过基于联邦学习框架,在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,即可实现联合训练及其分布式训练,达到充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果。同时在联合训练中,由于不是随机挑选目标用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。In the embodiment of this application, based on the federated learning framework, on the premise of satisfying data security and without sharing or exchanging local data of each client, joint training and distributed training can be realized to fully utilize all users' effective The effect of information on improving the accuracy of EEG signal classification models. At the same time, in the joint training, since the target client is not randomly selected, but through the importance evaluation value of each client, the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient of the target client and The importance evaluation value updates the network parameters of the EEG signal classification model on the server side, thereby improving the accuracy and convergence speed of the EEG signal classification model.
附图说明Description of drawings
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。In order to more clearly illustrate the technical solutions in the embodiments of the present application, the accompanying drawings that need to be used in the descriptions of the embodiments or the prior art will be briefly introduced below. Obviously, the accompanying drawings in the following description are only for the present application For some embodiments, those of ordinary skill in the art can also obtain other drawings based on these drawings without any creative effort.
图1是本申请一实施例提供的基于联邦学习的脑电信号分类模型训练方法的流程图;Fig. 1 is the flow chart of the federated learning-based EEG classification model training method provided by an embodiment of the present application;
图2是本申请一实施例提供的步骤15的流程图;Fig. 2 is a flowchart of step 15 provided by an embodiment of the present application;
图3是本申请一实施例提供的基于联邦学习的脑电信号分类模型训练装置的结构示意图;3 is a schematic structural diagram of a federated learning-based EEG classification model training device provided by an embodiment of the present application;
图4是本申请一实施例提供的服务器的结构示意图。Fig. 4 is a schematic structural diagram of a server provided by an embodiment of the present application.
具体实施方式Detailed ways
以下描述中,为了说明而不是为了限定,提出了诸如特定***结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的***、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。In the following description, specific details such as specific system structures and technologies are presented for the purpose of illustration rather than limitation, so as to thoroughly understand the embodiments of the present application. It will be apparent, however, to one skilled in the art that the present application may be practiced in other embodiments without these specific details. In other instances, detailed descriptions of well-known systems, devices, circuits, and methods are omitted so as not to obscure the description of the present application with unnecessary detail.
应当理解,当在本申请说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。It should be understood that when used in this specification and the appended claims, the term "comprising" indicates the presence of described features, integers, steps, operations, elements and/or components, but does not exclude one or more other Presence or addition of features, wholes, steps, operations, elements, components and/or collections thereof.
还应当理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。It should also be understood that the term "and/or" used in the description of the present application and the appended claims refers to any combination and all possible combinations of one or more of the associated listed items, and includes these combinations.
如在本申请说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。As used in this specification and the appended claims, the term "if" may be construed, depending on the context, as "when" or "once" or "in response to determining" or "in response to detecting ". Similarly, the phrase "if determined" or "if [the described condition or event] is detected" may be construed, depending on the context, to mean "once determined" or "in response to the determination" or "once detected [the described condition or event] ]” or “in response to detection of [described condition or event]”.
另外,在本申请说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。In addition, in the description of the specification and the appended claims of the present application, the terms "first", "second", "third" and so on are only used to distinguish descriptions, and should not be understood as indicating or implying relative importance.
在本申请说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本申请的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。Reference to "one embodiment" or "some embodiments" or the like in the specification of the present application means that a particular feature, structure, or characteristic described in connection with the embodiment is included in one or more embodiments of the present application. Thus, appearances of the phrases "in one embodiment," "in some embodiments," "in other embodiments," "in other embodiments," etc. in various places in this specification are not necessarily All refer to the same embodiment, but mean "one or more but not all embodiments" unless specifically stated otherwise. The terms "including", "comprising", "having" and variations thereof mean "including but not limited to", unless specifically stated otherwise.
目前脑电信号分类模型主要有基于EEGNet的脑电信号分类模型和基于FTL的脑电信号分类模型。但基于EEGNet的脑电信号分类模型的精度低,而基于FTL的脑电信号分类模型的收敛速度慢、且精度不理想。At present, the EEG signal classification models mainly include the EEGNet-based EEG signal classification model and the FTL-based EEG signal classification model. However, the accuracy of the EEG classification model based on EEGNet is low, while the convergence speed of the EEG classification model based on FTL is slow and the accuracy is not ideal.
针对上述问题,本申请实施例基于联邦学习框架,通过在分布式训练中,将服务器端的脑电信号分类模型发送给K个用户端,使各用户端利用本地训练集对接收到的脑电信号分类模型进行训练,并将训练得到的本地模型梯度发送给服务器端,以进行联合训练,从而在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,实现了联合训练及其分布式训练,达到了在充分利用所有用户的有效信息的情况下,提升脑电信号分类模型的精度的效果。In view of the above problems, the embodiment of the present application is based on the federated learning framework, by sending the server-side EEG signal classification model to K client terminals during distributed training, so that each client terminal can use the local training set to analyze the received EEG signal The classification model is trained, and the local model gradient obtained by training is sent to the server for joint training, so that joint training and its distributed Training has achieved the effect of improving the accuracy of the EEG signal classification model while making full use of the effective information of all users.
同时在联合训练中,由于不是随机挑选目标用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。At the same time, in the joint training, since the target client is not randomly selected, but through the importance evaluation value of each client, the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient of the target client and The importance evaluation value updates the network parameters of the EEG signal classification model on the server side, thereby improving the accuracy and convergence speed of the EEG signal classification model.
下面结合具体实施例对本申请提供的基于联邦学习的脑电信号分类模型训练方法进行示例性的说明。The federated learning-based EEG classification model training method provided by the present application will be exemplarily described below in conjunction with specific embodiments.
如图1所示,本申请的实施例提供了一种基于联邦学习的脑电信号分类模型训练方法,应用于服务器端,该方法包括如下步骤:As shown in Figure 1, the embodiment of the present application provides a federated learning-based EEG signal classification model training method, which is applied to the server side, and the method includes the following steps:
步骤11,将所述服务器端的脑电信号分类模型发送给K个用户端。 Step 11, sending the EEG signal classification model at the server end to K client ends.
在本申请的一些实施例中,上述K个用户端为与上述服务器端参与联邦学 习的用户端。需要说明的是,为确保最终得到的脑电信号分类模型是基于用户端的有用户的有效信息得到的,在执行上述训练方法的步骤之前,服务器端可以初始化一个脑电信号分类模型(即上述步骤11中的脑电信号分类模型)。具体的,可以将模型权重初始化为0,也可以采用其他常见的初始化方案,例如高斯、Xavier初始化(Xavier初始化是一种神经网络初始化方法)。In some embodiments of the present application, the above-mentioned K clients are clients that participate in federated learning with the above-mentioned server. It should be noted that, in order to ensure that the finally obtained EEG signal classification model is obtained based on the effective information of the user at the client end, before performing the steps of the above training method, the server end can initialize an EEG signal classification model (ie, the above steps EEG classification model in 11). Specifically, the weight of the model can be initialized to 0, or other common initialization schemes can be used, such as Gaussian and Xavier initialization (Xavier initialization is a neural network initialization method).
其中,上述脑电信号分类模型可以为EEGNet模型,当然也可以是其他的深度学习网络,例如卷积神经网络(ConvNet)等脑电信号分类神经网络。Wherein, the above-mentioned EEG signal classification model may be an EEGNet model, of course, it may also be other deep learning networks, such as a convolutional neural network (ConvNet) and other EEG signal classification neural networks.
步骤12,接收每个所述用户端发送的本地模型梯度。 Step 12, receiving the local model gradient sent by each client.
在本申请的一些实施例中,上述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的。In some embodiments of the present application, the aforementioned local model gradient is obtained by the user terminal using a local training set to train the EEG signal classification model.
即,在本申请的一些实施例中,对于参与联邦学习的每个用户端,在接收到服务器端下发的脑电信号分类模型后,会利用用户端的本地训练集对接收到的脑电信号分类模型进行训练,并在该脑电信号分类模型收敛时得到本地模型梯度。That is, in some embodiments of the present application, for each client participating in federated learning, after receiving the EEG signal classification model issued by the server, it will use the local training set of the client to analyze the received EEG signal The classification model is trained, and the local model gradient is obtained when the EEG classification model converges.
步骤13,根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值。 Step 13, according to the local model gradient of each client, obtain the importance evaluation value of each client.
在本申请的一些实施例中,上述重要性评估值主要用于表征用户端的重要性程度,以便后续按照重要性从高至低的顺序,从K个用户端中选择出对共享模型(即服务器端的脑电信号分类模型)贡献大的目标用户端进行联合训练,从提升脑电信号分类模型的精度和收敛速度。In some embodiments of the present application, the above-mentioned importance evaluation value is mainly used to characterize the importance degree of the client end, so that subsequently, according to the order of importance from high to low, select from the K End EEG signal classification model) to jointly train the target user end to improve the accuracy and convergence speed of the EEG signal classification model.
步骤14,根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端。Step 14: Determine a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals.
在本申请的一些实施例中,可按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端,从而从K个用户端中筛选出重要性程度高的目标用户端。其中,上述预设比例的具体数值可根据实际情况进行设定。In some embodiments of the present application, according to the order of the importance evaluation value from large to small, a preset proportion of user terminals can be selected from the K user terminals as the target user terminals, thereby screening from the K user terminals Target users with high importance. Wherein, the specific numerical value of the above preset ratio can be set according to the actual situation.
可见,在本申请的一些实施例中,上述目标用户端的重要性程度高于K个用户端中其他用户端的重要性程度,即,目标用户端对共享模型(即服务器端的脑电信号分类模型)贡献大于其他用户端对共享模型的贡献,后续利用这些目标用户端进行联合训练,能提升脑电信号分类模型的精度和收敛速度。It can be seen that, in some embodiments of the present application, the importance of the above-mentioned target client is higher than that of other client in the K client, that is, the target client pairs the shared model (ie, the EEG signal classification model at the server end) The contribution is greater than the contribution of other clients to the shared model. Subsequent joint training using these target clients can improve the accuracy and convergence speed of the EEG signal classification model.
步骤15,根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数。Step 15: Update the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the multiple target client ends.
在本申请的一些实施例中,在联合训练中,通过根据目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,能提升服务器端的脑电信号分类模型的精度和收敛速度。In some embodiments of the present application, in the joint training, by updating the network parameters of the server-side EEG signal classification model according to the target client's local model gradient and importance evaluation value, the performance of the server-side EEG signal classification model can be improved. accuracy and convergence speed.
步骤16,若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。 Step 16, if the EEG signal classification model at the server end does not converge, return to the step of sending the EEG signal classification model at the server end to K clients until the EEG signal classification model at the server end converges .
在本申请的一些实施例中,上述步骤16中收敛的脑电信号分类模型是共享模型,可用于对任一用户的脑电信号进行分类。In some embodiments of the present application, the EEG signal classification model converged in step 16 above is a shared model, which can be used to classify EEG signals of any user.
在本申请的一些实施例中,在执行完步骤15后,若服务器端的脑电信号分类模型未收敛,则返回步骤11,以再次更新服务器端的脑电信号分类模型的网络参数,直至所述服务器端的脑电信号分类模型收敛。In some embodiments of the present application, after step 15 is performed, if the server-side EEG signal classification model does not converge, then return to step 11 to update the network parameters of the server-side EEG signal classification model again until the server-side End EEG signal classification model converges.
需要说明的是,每次更新服务器端的脑电信号分类模型的网络参数后,都需要判断更新网络参数后的脑电信号分类模型是否收敛,若收敛,则更新后的脑电信号分类模型即为共享模型,否则,将更新网络参数后的脑电信号分类模型下发给K个用户端,使K个用户端分别利用自身的本地训练集对接收到的脑电信号分类模型进行训练,得到的本地模型梯度,以再次更新服务器端的脑电信号分类模型的网络参数。It should be noted that every time the network parameters of the EEG signal classification model on the server side are updated, it is necessary to judge whether the EEG signal classification model after updating the network parameters converges. If it converges, the updated EEG signal classification model is Share the model, otherwise, send the EEG signal classification model after updating the network parameters to K clients, so that K clients use their own local training sets to train the received EEG signal classification model, and the obtained Local model gradients to update the network parameters of the server-side EEG signal classification model again.
值得一提的是,在本申请的一些实施例中,在联合训练中,不直接使用用户端的本地训练集数据,而是使用用户端的本地模型梯度来共同训练服务器端的脑电信号分类模型,从而保障了用户端本地数据的隐私和使用安全性,在满 足数据安全、无需共享或者交换各个用户端本地数据的前提下,即可实现联合训练及其分布式训练,达到了充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果。It is worth mentioning that in some embodiments of the present application, in the joint training, the local training set data of the client is not directly used, but the local model gradient of the client is used to jointly train the EEG classification model of the server, so that The privacy and use security of the local data of the client are guaranteed. Under the premise of satisfying data security and no need to share or exchange the local data of each client, joint training and distributed training can be realized, and the effective utilization of all users can be achieved. The effect of information on improving the accuracy of EEG signal classification models.
同时在联合训练中,由于不是随机挑选用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。At the same time, in the joint training, since the client is not randomly selected, but the importance evaluation value of each client is used, the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient and importance of the target client. The network parameters of the EEG signal classification model on the server side are updated, thereby improving the accuracy and convergence speed of the EEG signal classification model.
在本申请的实施例中,在执行完上述步骤16后,上述方法还包括如下步骤:在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。In the embodiment of the present application, after the above step 16 is performed, the above method further includes the following steps: when the server-side EEG signal classification model converges, sending the server-side EEG signal classification model to the The above K clients.
需要说明的是,用户端在接收到该脑电信号分类模型后,可利用自身的用本地训练集对该脑电信号分类模型进行训练,以对该脑电信号分类模型的模型参数进行微调,得到更适合该用户端的脑电信号分类模型,后续该用户端可利用微调后的脑电信号分类模型对该用户端的用户数据进行分类,提高分类准确性。It should be noted that, after receiving the EEG signal classification model, the client can use its own local training set to train the EEG signal classification model to fine-tune the model parameters of the EEG signal classification model, An EEG signal classification model that is more suitable for the client terminal is obtained, and the client terminal can then use the fine-tuned EEG signal classification model to classify the user data of the client terminal to improve classification accuracy.
接下来,结合具体实施例对用户端利用本地训练集对脑电信号分类模型进行训练的过程的进行示例性的说明。Next, the process of training the EEG signal classification model by the user terminal using the local training set will be exemplarily described in conjunction with specific embodiments.
在本申请的一些实施例中,用户端的本地训练集可来源于上海交大情感脑电数据集(SEED)。在该数据集的实验中,15个筛选过的中国电影片段被选取为实验中的情感刺激源,标签包括正面、中性和负面情绪。该数据集一共采集了15名中国受试者(包括7名男生和8名女生),其中每个受试者分别进行3次实验。该数据集中的每个样本包含62个电极通道,下采样到200Hz,并且应用了0-75Hz的带通频率滤波器。为了扩展数据量,我们将每个数据按照1s的数据窗口进行不重叠切割,最终获取3394个样本。在采集的62个通道中,本申请实施例选择与情感相关的32个通道,分别对应Fp1,AF3,F3,F7,FC5,FC1,C3,T7,CP5,CP1,P3,P7,PO3,O1,Oz,Pz,Fp2, AF4,Fz,F4,F8,FC6,FC2,Cz,C4,T8,CP6,CP2,P4,P8,PO4,O2。为此,每个样本的大小为32×200。需要说明的是,在本申请的一些实施例中,可将15名受试者中任一受试者的32个通道的数据作为一用户端的本地训练集。为提升脑电信号分类模型的精度,用户端每次均可利用本地训练集中的所有数据对脑电信号分类模型进行训练。需要进一步说明的是,每个用户端对应的本地训练集均不相同。In some embodiments of the present application, the local training set of the user terminal may be derived from Shanghai Jiaotong University Emotional EEG Dataset (SEED). In experiments on this dataset, 15 screened Chinese movie clips were selected as emotional stimuli in the experiment, and the labels included positive, neutral, and negative emotions. A total of 15 Chinese subjects (including 7 boys and 8 girls) were collected in this dataset, and each subject conducted 3 experiments. Each sample in this dataset contains 62 electrode channels, downsampled to 200 Hz, and a bandpass frequency filter of 0–75 Hz is applied. In order to expand the amount of data, we cut each data according to the data window of 1s without overlapping, and finally obtained 3394 samples. Among the 62 channels collected, the embodiment of the present application selects 32 channels related to emotion, corresponding to Fp1, AF3, F3, F7, FC5, FC1, C3, T7, CP5, CP1, P3, P7, PO3, O1 , Oz, Pz, Fp2, AF4, Fz, F4, F8, FC6, FC2, Cz, C4, T8, CP6, CP2, P4, P8, PO4, O2. For this, each sample has a size of 32×200. It should be noted that, in some embodiments of the present application, the data of 32 channels of any one of the 15 subjects can be used as a local training set of the client. In order to improve the accuracy of the EEG signal classification model, the user terminal can use all the data in the local training set to train the EEG signal classification model each time. It should be further explained that the local training set corresponding to each client is different.
作为一个优选的示例,根据输入的原始EEG信号的时空属性,上述脑电信号分类模型采用EEGNet模型,用于提取脑电信号的特征表示及分类。其中本申请中特征提取器和分类器模型参数如表1所示。当然可以理解的是,卷积层数目、卷积核大小、池化方法以及激活函数均可根据实际情况进行设定。As a preferred example, according to the spatio-temporal properties of the input original EEG signal, the EEG signal classification model adopts the EEGNet model to extract the feature representation and classification of the EEG signal. The parameters of the feature extractor and classifier model in this application are shown in Table 1. Of course, it can be understood that the number of convolutional layers, the size of the convolutional kernel, the pooling method, and the activation function can all be set according to the actual situation.
Figure PCTCN2021138013-appb-000021
Figure PCTCN2021138013-appb-000021
表1Table 1
其中,在用户端利用本地训练集对脑电信号分类模型进行训练时,可采用交叉熵(cross entropy)损失函数评估训练结果,其中第k个用户端的训练损失函数如下:
Figure PCTCN2021138013-appb-000022
其中,n k表示第k个用户端的本地训练集所包含的本地样本量,y i为训练样本(即本地训练集中的本地样本)的真实标签,
Figure PCTCN2021138013-appb-000023
为预测标签。需要说明的是,上述训练损失函数为常用损失函数,因此在此,不对该训练损失函数的原理进行过多赘述。
Wherein, when the user end utilizes the local training set to train the EEG signal classification model, the cross entropy (cross entropy) loss function can be used to evaluate the training result, wherein the training loss function of the kth user end is as follows:
Figure PCTCN2021138013-appb-000022
Among them, n k represents the local sample size contained in the local training set of the kth client, y i is the real label of the training sample (that is, the local sample in the local training set),
Figure PCTCN2021138013-appb-000023
is the predicted label. It should be noted that the above training loss function is a commonly used loss function, so here, the principle of the training loss function will not be described in detail.
接下来,结合具体实施例对获取重要性评估值以及更新网络参数的过程的进行示例性的说明。Next, the process of obtaining the importance evaluation value and updating the network parameters will be exemplarily described in conjunction with specific embodiments.
在本申请的一些实施例中,上述步骤13,根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的具体实现方式可以为:通过公式μ k=α k×β k,计算第k个用户端的重要性评估值。 In some embodiments of the present application, the above-mentioned step 13, according to the local model gradient of each user terminal, the specific implementation manner of obtaining the importance evaluation value of each user terminal may be: through the formula μ kk × β k , calculate the importance evaluation value of the kth client.
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000024
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000025
Figure PCTCN2021138013-appb-000026
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000027
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
Among them, μ k represents the importance evaluation value of the kth client, α k = n k /n, nk represents the local sample size contained in the local training set of the kth client,
Figure PCTCN2021138013-appb-000024
n represents the sum of the local samples contained in the local training set of K clients, K represents the number of clients,
Figure PCTCN2021138013-appb-000025
Figure PCTCN2021138013-appb-000026
Indicates the global gradient of the server side at the time of the t-1 round update,
Figure PCTCN2021138013-appb-000027
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
需要说明的是,当t的取值为1时,
Figure PCTCN2021138013-appb-000028
为服务器端初始化的脑电信号分类模型的模型梯度,公式中的更新指的是服务器端的脑电信号分类模型的网络参数的更新。
It should be noted that when the value of t is 1,
Figure PCTCN2021138013-appb-000028
is the model gradient of the EEG signal classification model initialized on the server side, and the update in the formula refers to the update of the network parameters of the EEG signal classification model on the server side.
在本申请的一些实施例中,除了通过上述公式计算用户端的重要性评估值外,还可以通过其他的相似性度量学习方法或者注意力机制算法度量用户端的重要性。In some embodiments of the present application, in addition to calculating the importance evaluation value of the user terminal by the above formula, the importance of the user terminal may also be measured by other similarity measurement learning methods or attention mechanism algorithms.
在本申请的一些实施例中,如图2所示,上述步骤15,根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的具体实现方式包括如下步骤:In some embodiments of the present application, as shown in FIG. 2, the above step 15, according to the local model gradients and importance evaluation values of the multiple target client terminals, updates the network parameters of the EEG signal classification model at the server end. The specific implementation method includes the following steps:
步骤21,对每个所述目标用户端的重要性评估值进行归一化处理。 Step 21, performing normalization processing on the importance evaluation value of each target client.
在本申请的一些实施例中,可通过公式
Figure PCTCN2021138013-appb-000029
对选择出的每个用户端的重要性评估值进行归一化处理。
In some embodiments of the present application, the formula
Figure PCTCN2021138013-appb-000029
Perform normalization processing on the importance evaluation value of each selected client.
其中,
Figure PCTCN2021138013-appb-000030
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
in,
Figure PCTCN2021138013-appb-000030
Indicates the importance evaluation value of the kth client after normalization processing, μ k represents the importance evaluation value of the kth client, C represents a preset ratio, and K represents the number of clients.
步骤22,根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度。 Step 22, update the global gradient of the server according to the normalized importance evaluation value and the local model gradients of all target clients.
在本申请的一些实施例中,可通过公式
Figure PCTCN2021138013-appb-000031
更新所述服务器端的全局梯度。
In some embodiments of the present application, the formula
Figure PCTCN2021138013-appb-000031
Update the global gradient on the server side.
其中,
Figure PCTCN2021138013-appb-000032
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000033
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000034
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
in,
Figure PCTCN2021138013-appb-000032
Indicates the global gradient obtained by the t-th round of update, C indicates the preset ratio, K indicates the number of clients,
Figure PCTCN2021138013-appb-000033
Indicates the importance evaluation value of the kth client after normalization processing,
Figure PCTCN2021138013-appb-000034
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
步骤23,根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。Step 23: Update network parameters of the server-side EEG signal classification model according to the updated global gradient.
在本申请的一些实施例中,可采用基于随机梯度下降(SGD,Stochastic Gradient Descent)的随机梯度下降法求解网络参数。需要说明的是,服务器端在初始化脑电信号分类模型的时候,服务器端的全局梯度也会初始化为0。In some embodiments of the present application, a stochastic gradient descent method based on stochastic gradient descent (SGD, Stochastic Gradient Descent) may be used to solve network parameters. It should be noted that when the server side initializes the EEG signal classification model, the global gradient of the server side will also be initialized to 0.
综上,本申请实施例提供的基于联邦学习的脑电信号分类模型训练方法具备如下效果:In summary, the federated learning-based EEG signal classification model training method provided by the embodiment of the present application has the following effects:
一、脑电信号分类模型采用EEGNet模型,将其应用到情感脑电信号的分类任务中,无需手工提取信号特征,能端到端进行情感脑电信号的特征提取和分类;1. The EEG signal classification model adopts the EEGNet model, which is applied to the classification task of emotional EEG signals. It does not need to manually extract signal features, and can perform end-to-end feature extraction and classification of emotional EEG signals;
二、将EEGNet模型应用到情感脑电识别网络,利用深度学习自动提取情感脑电信号的可判别性特征,提升用户端单个的脑电信号分类模型的准确率;2. Apply the EEGNet model to the emotional EEG recognition network, use deep learning to automatically extract the discriminative features of the emotional EEG signal, and improve the accuracy of a single EEG signal classification model at the user end;
三、无需对脑电信号做繁杂的预处理,直接利用脑电信号对脑电信号分类模型进行训练,便可有效进行脑电信号的特征提取和分类;3. There is no need to do complicated preprocessing on the EEG signal, and the EEG signal classification model can be directly used to train the EEG signal classification model, so that the feature extraction and classification of the EEG signal can be effectively performed;
四、能在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,实现联合训练及其分布式训练,达到充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果;4. Under the premise of satisfying data security and no need to share or exchange local data of each client, realize joint training and distributed training, and achieve the effect of making full use of the effective information of all users to improve the accuracy of the EEG signal classification model;
五、通过各用户端的重要性,选择对共享模型贡献大的目标用户端进行联 合训练,从而提升脑电信号分类模型的精度及收敛速度。5. Through the importance of each client, select the target client that contributes a lot to the shared model for joint training, thereby improving the accuracy and convergence speed of the EEG signal classification model.
下面结合具体实施例对本申请提供的基于联邦学习的脑电信号分类模型训练装置进行示例性的说明。The federated learning-based EEG classification model training device provided by the present application will be exemplarily described below in conjunction with specific embodiments.
如图3所示,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练装置,应用于服务器端,该脑电信号分类模型训练装置300包括:As shown in Figure 3, the embodiment of the present application provides a federated learning-based EEG signal classification model training device, which is applied to the server side. The EEG signal classification model training device 300 includes:
发送模块301,用于将所述服务器端的脑电信号分类模型发送给K个用户端;Sending module 301, for sending the EEG signal classification model of the server end to K client ends;
接收模块302,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;The receiving module 302 is configured to receive the local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG signal classification model;
获取模块303,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;An acquisition module 303, configured to acquire the importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals;
第一确定模块304,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;The first determination module 304 is configured to determine a plurality of target user terminals from the K user terminals according to the importance evaluation values of the K user terminals;
更新模块305,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;An update module 305, configured to update the network parameters of the EEG classification model on the server side according to the local model gradients and importance evaluation values of the multiple target client terminals;
第二确定模块306,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。The second determination module 306 is configured to return to execute the step of sending the server-side EEG signal classification model to K client terminals if the server-side EEG signal classification model does not converge until the server-side EEG signal classification model Electrical signal classification model converges.
其中,上述第一确定模块304,具体用于按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。Wherein, the above-mentioned first determination module 304 is specifically configured to select a preset proportion of user terminals from the K user terminals as target user terminals in descending order of importance evaluation values.
其中,上述更新模块305包括:Wherein, the above-mentioned update module 305 includes:
处理单元,用于对每个所述目标用户端的重要性评估值进行归一化处理;a processing unit, configured to perform normalization processing on the importance evaluation value of each target client;
第一更新单元,用于根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;The first update unit is configured to update the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
第二更新单元,用于根据更新后的全局梯度,更新所述服务器端的脑电信 号分类模型的网络参数。The second update unit is used to update the network parameters of the EEG signal classification model at the server end according to the updated global gradient.
其中,上述第一更新单元,具体用于通过公式
Figure PCTCN2021138013-appb-000035
更新所述服务器端的全局梯度;
Wherein, the above-mentioned first update unit is specifically used to pass the formula
Figure PCTCN2021138013-appb-000035
updating the global gradient on the server side;
其中,
Figure PCTCN2021138013-appb-000036
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000037
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000038
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
in,
Figure PCTCN2021138013-appb-000036
Indicates the global gradient obtained by the t-th round of update, C indicates the preset ratio, K indicates the number of clients,
Figure PCTCN2021138013-appb-000037
Indicates the importance evaluation value of the kth client after normalization processing,
Figure PCTCN2021138013-appb-000038
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
其中,上述处理单元,具体用于通过公式
Figure PCTCN2021138013-appb-000039
对选择出的每个用户端的重要性评估值进行归一化处理;
Among them, the above processing unit is specifically used to pass the formula
Figure PCTCN2021138013-appb-000039
Normalize the importance evaluation value of each selected client;
其中,
Figure PCTCN2021138013-appb-000040
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
in,
Figure PCTCN2021138013-appb-000040
Indicates the importance evaluation value of the kth client after normalization processing, μ k represents the importance evaluation value of the kth client, C represents a preset ratio, and K represents the number of clients.
其中,上述获取模块303,具体用于通过公式μ k=α k×β k,计算第k个用户端的重要性评估值; Wherein, the above acquisition module 303 is specifically used to calculate the importance evaluation value of the kth user terminal through the formula μ kk ×β k ;
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000041
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000042
Figure PCTCN2021138013-appb-000043
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000044
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
Among them, μ k represents the importance evaluation value of the kth client, α k = n k /n, nk represents the local sample size contained in the local training set of the kth client,
Figure PCTCN2021138013-appb-000041
n represents the sum of the local samples contained in the local training set of K clients, K represents the number of clients,
Figure PCTCN2021138013-appb-000042
Figure PCTCN2021138013-appb-000043
Indicates the global gradient of the server side at the time of the t-1 round update,
Figure PCTCN2021138013-appb-000044
Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
其中,上述脑电信号分类模型训练装置还包括:Wherein, the above-mentioned EEG signal classification model training device also includes:
下发模块,用于在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。The sending module is configured to send the EEG signal classification model at the server end to the K client terminals when the EEG signal classification model at the server end converges.
需要说明的是,上述装置/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。It should be noted that the information interaction and execution process between the above-mentioned devices/units are based on the same concept as the method embodiment of the present application, and its specific functions and technical effects can be found in the method embodiment section. I won't repeat them here.
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述***中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。Those skilled in the art can clearly understand that for the convenience and brevity of description, only the division of the above-mentioned functional units and modules is used for illustration. In practical applications, the above-mentioned functions can be assigned to different functional units, Completion of modules means that the internal structure of the device is divided into different functional units or modules to complete all or part of the functions described above. Each functional unit and module in the embodiment may be integrated into one processing unit, or each unit may exist separately physically, or two or more units may be integrated into one unit, and the above-mentioned integrated units may adopt hardware It can also be implemented in the form of software functional units. In addition, the specific names of the functional units and modules are only for the convenience of distinguishing each other, and are not used to limit the protection scope of the present application. For the specific working processes of the units and modules in the above system, reference may be made to the corresponding processes in the aforementioned method embodiments, and details will not be repeated here.
如图4所示,本申请的实施例提供了一种服务器,如图4所示,该实施例的服务器D10包括:至少一个处理器D100(图4中仅示出一个处理器)、存储器D101以及存储在所述存储器D101中并可在所述至少一个处理器D100上运行的计算机程序D102,所述处理器D100执行所述计算机程序D102时实现上述任意各个方法实施例中的步骤。As shown in Figure 4, an embodiment of the present application provides a server, as shown in Figure 4, the server D10 of this embodiment includes: at least one processor D100 (only one processor is shown in Figure 4), a memory D101 And a computer program D102 stored in the memory D101 and operable on the at least one processor D100, when the processor D100 executes the computer program D102, the steps in any of the above method embodiments are implemented.
所称处理器D100可以是中央处理单元(CPU,Central Processing Unit),该处理器D100还可以是其他通用处理器、数字信号处理器(DSP,Digital Signal Processor)、专用集成电路(ASIC,Application Specific Integrated Circuit)、现成可编程门阵列(FPGA,Field-Programmable Gate Array)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。The so-called processor D100 can be a central processing unit (CPU, Central Processing Unit), and the processor D100 can also be other general processors, digital signal processors (DSP, Digital Signal Processor), application specific integrated circuits (ASIC, Application Specific Integrated Circuit), off-the-shelf programmable gate array (FPGA, Field-Programmable Gate Array) or other programmable logic devices, discrete gate or transistor logic devices, discrete hardware components, etc. A general-purpose processor may be a microprocessor, or the processor may be any conventional processor, or the like.
所述存储器D101在一些实施例中可以是所述服务器D10的内部存储单元,例如服务器D10的硬盘或内存。所述存储器D101在另一些实施例中也可以是所述服务器D10的外部存储设备,例如所述服务器D10上配备的插接式硬盘,智能存储卡(SMC,Smart Media Card),安全数字(SD,Secure Digital)卡,闪存卡(Flash Card)等。进一步地,所述存储器D101还可以既包括所述服务 器D10的内部存储单元也包括外部存储设备。所述存储器D101用于存储操作***、应用程序、引导装载程序(BootLoader)、数据以及其他程序等,例如所述计算机程序的程序代码等。所述存储器D101还可以用于暂时地存储已经输出或者将要输出的数据。The storage D101 may be an internal storage unit of the server D10 in some embodiments, such as a hard disk or a memory of the server D10. The memory D101 can also be an external storage device of the server D10 in other embodiments, such as a plug-in hard disk equipped on the server D10, a smart memory card (SMC, Smart Media Card), a secure digital (SD , Secure Digital) card, flash memory card (Flash Card), etc. Further, the storage D101 may also include both an internal storage unit of the server D10 and an external storage device. The memory D101 is used to store operating systems, application programs, boot loaders (BootLoader), data and other programs, such as program codes of the computer programs. The memory D101 can also be used to temporarily store data that has been output or will be output.
需要说明的是,上述装置/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。It should be noted that the information interaction and execution process between the above-mentioned devices/units are based on the same concept as the method embodiment of the present application, and its specific functions and technical effects can be found in the method embodiment section. I won't repeat them here.
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述***中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。Those skilled in the art can clearly understand that for the convenience and brevity of description, only the division of the above-mentioned functional units and modules is used for illustration. In practical applications, the above-mentioned functions can be assigned to different functional units, Completion of modules means that the internal structure of the device is divided into different functional units or modules to complete all or part of the functions described above. Each functional unit and module in the embodiment may be integrated into one processing unit, or each unit may exist separately physically, or two or more units may be integrated into one unit, and the above-mentioned integrated units may adopt hardware It can also be implemented in the form of software functional units. In addition, the specific names of the functional units and modules are only for the convenience of distinguishing each other, and are not used to limit the protection scope of the present application. For the specific working processes of the units and modules in the above system, reference may be made to the corresponding processes in the aforementioned method embodiments, and details will not be repeated here.
本申请实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现可实现上述各个方法实施例中的步骤。The embodiment of the present application also provides a computer-readable storage medium, the computer-readable storage medium stores a computer program, and when the computer program is executed by a processor, the steps in each of the foregoing method embodiments can be realized.
本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得终端设备执行时实现可实现上述各个方法实施例中的步骤。An embodiment of the present application provides a computer program product. When the computer program product is run on a terminal device, the terminal device can implement the steps in the foregoing method embodiments when executed.
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机 程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质至少可以包括:能够将计算机程序代码携带到脑电信号分类模型训练装置/终端设备的任何实体或装置、记录介质、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质。例如U盘、移动硬盘、磁碟或者光盘等。在某些司法管辖区,根据立法和专利实践,计算机可读介质不可以是电载波信号和电信信号。If the integrated unit is realized in the form of a software function unit and sold or used as an independent product, it can be stored in a computer-readable storage medium. Based on this understanding, all or part of the procedures in the method of the above-mentioned embodiments in the present application can be completed by instructing related hardware through a computer program. The computer program can be stored in a computer-readable storage medium. The computer program When executed by a processor, the steps in the above-mentioned various method embodiments can be realized. Wherein, the computer program includes computer program code, and the computer program code may be in the form of source code, object code, executable file or some intermediate form. The computer-readable medium may at least include: any entity or device, recording medium, computer memory, read-only memory (ROM, Read-Only Memory) capable of carrying computer program codes to the EEG signal classification model training device/terminal equipment , Random Access Memory (RAM, Random Access Memory), electrical carrier signal, telecommunication signal and software distribution medium. Such as U disk, mobile hard disk, magnetic disk or optical disk, etc. In some jurisdictions, computer readable media may not be electrical carrier signals and telecommunication signals under legislation and patent practice.
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。In the above-mentioned embodiments, the descriptions of each embodiment have their own emphases, and for parts that are not detailed or recorded in a certain embodiment, refer to the relevant descriptions of other embodiments.
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。Those skilled in the art can appreciate that the units and algorithm steps of the examples described in conjunction with the embodiments disclosed herein can be implemented by electronic hardware, or a combination of computer software and electronic hardware. Whether these functions are executed by hardware or software depends on the specific application and design constraints of the technical solution. Skilled artisans may use different methods to implement the described functions for each specific application, but such implementation should not be regarded as exceeding the scope of the present application.
在本申请所提供的实施例中,应该理解到,所揭露的装置/网络设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/网络设备实施例仅仅是示意性的,例如,所述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个***,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。In the embodiments provided in this application, it should be understood that the disclosed device/network device and method may be implemented in other ways. For example, the device/network device embodiments described above are only illustrative. For example, the division of the modules or units is only a logical function division. In actual implementation, there may be other division methods, such as multiple units Or components may be combined or may be integrated into another system, or some features may be omitted, or not implemented. In another point, the mutual coupling or direct coupling or communication connection shown or discussed may be through some interfaces, and the indirect coupling or communication connection of devices or units may be in electrical, mechanical or other forms.
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。The units described as separate components may or may not be physically separated, and the components displayed as units may or may not be physical units, that is, they may be located in one place, or may be distributed to multiple network units. Part or all of the units can be selected according to actual needs to achieve the purpose of the solution of this embodiment.
以上所述实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。The above-described embodiments are only used to illustrate the technical solutions of the present application, rather than to limit them; although the present application has been described in detail with reference to the foregoing embodiments, those of ordinary skill in the art should understand that: it can still implement the foregoing embodiments Modifications to the technical solutions described in the examples, or equivalent replacements for some of the technical features; and these modifications or replacements do not make the essence of the corresponding technical solutions deviate from the spirit and scope of the technical solutions of the various embodiments of the application, and should be included in the Within the protection scope of this application.

Claims (10)

  1. 一种基于联邦学习的脑电信号分类模型训练方法,其特征在于,应用于服务器端,所述方法包括:A method for training an EEG classification model based on federated learning, characterized in that it is applied to a server, and the method includes:
    将所述服务器端的脑电信号分类模型发送给K个用户端;Send the EEG signal classification model of the server end to K client ends;
    接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;receiving a local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG classification model;
    根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;Acquiring an importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals;
    根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;determining a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals;
    根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;Updating the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target client ends;
    若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。If the EEG signal classification model at the server end does not converge, return to the step of sending the EEG signal classification model at the server end to K clients until the EEG signal classification model at the server end converges.
  2. 根据权利要求1所述的方法,其特征在于,所述根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端的步骤,包括:The method according to claim 1, wherein the step of determining a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals includes:
    按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。According to the descending order of the importance evaluation value, a preset proportion of user terminals is selected from the K said user terminals as target user terminals.
  3. 根据权利要求2所述的方法,其特征在于,所述根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的步骤,包括:The method according to claim 2, wherein the step of updating the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target client terminals includes:
    对每个所述目标用户端的重要性评估值进行归一化处理;performing normalization processing on the importance evaluation value of each target client;
    根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;Updating the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
    根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。The network parameters of the EEG signal classification model at the server end are updated according to the updated global gradient.
  4. 根据权利要求3所述的方法,其特征在于,所述根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度的步骤,包括:The method according to claim 3, wherein the step of updating the global gradient of the server according to the normalized importance evaluation value and the local model gradients of all target clients includes:
    通过公式
    Figure PCTCN2021138013-appb-100001
    更新所述服务器端的全局梯度;
    by formula
    Figure PCTCN2021138013-appb-100001
    updating the global gradient on the server side;
    其中,
    Figure PCTCN2021138013-appb-100002
    表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
    Figure PCTCN2021138013-appb-100003
    表示第k个用户端归一化处理后的重要性评估值,
    Figure PCTCN2021138013-appb-100004
    表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
    in,
    Figure PCTCN2021138013-appb-100002
    Indicates the global gradient obtained by the t-th round of update, C indicates the preset ratio, K indicates the number of clients,
    Figure PCTCN2021138013-appb-100003
    Indicates the importance evaluation value of the kth client after normalization processing,
    Figure PCTCN2021138013-appb-100004
    Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
  5. 根据权利要求3所述的方法,其特征在于,所述对每个所述目标用户端的重要性评估值进行归一化处理的步骤,包括:The method according to claim 3, wherein the step of normalizing the importance evaluation value of each target client includes:
    通过公式
    Figure PCTCN2021138013-appb-100005
    对选择出的每个用户端的重要性评估值进行归一化处理;
    by formula
    Figure PCTCN2021138013-appb-100005
    Normalize the importance evaluation value of each selected client;
    其中,
    Figure PCTCN2021138013-appb-100006
    表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
    in,
    Figure PCTCN2021138013-appb-100006
    Indicates the importance evaluation value of the kth client after normalization processing, μ k represents the importance evaluation value of the kth client, C represents a preset ratio, and K represents the number of clients.
  6. 根据权利要求1所述的方法,其特征在于,所述根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的步骤,包括:The method according to claim 1, wherein the step of obtaining the importance evaluation value of each client according to the local model gradient of each client includes:
    通过公式μ k=α k×β k,计算第k个用户端的重要性评估值; Calculate the importance evaluation value of the kth user terminal by the formula μ kk ×β k ;
    其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
    Figure PCTCN2021138013-appb-100007
    n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
    Figure PCTCN2021138013-appb-100008
    Figure PCTCN2021138013-appb-100009
    表示第t-1轮更新时所述服务器端的全局梯度,
    Figure PCTCN2021138013-appb-100010
    表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
    Among them, μ k represents the importance evaluation value of the kth client, α k = n k /n, nk represents the local sample size contained in the local training set of the kth client,
    Figure PCTCN2021138013-appb-100007
    n represents the sum of the local samples contained in the local training set of K clients, K represents the number of clients,
    Figure PCTCN2021138013-appb-100008
    Figure PCTCN2021138013-appb-100009
    Indicates the global gradient of the server side at the time of the t-1 round update,
    Figure PCTCN2021138013-appb-100010
    Indicates the local model gradient of the kth client at the tth round of update, t is an integer greater than 0.
  7. 根据权利要求1所述的方法,其特征在于,所述方法还包括:The method according to claim 1, further comprising:
    在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。When the EEG signal classification model at the server end converges, the EEG signal classification model at the server end is delivered to the K client ends.
  8. 一种基于联邦学习的脑电信号分类模型训练装置,其特征在于,应用于服务器端,所述装置包括:A federated learning-based EEG signal classification model training device is characterized in that it is applied to the server side, and the device includes:
    发送模块,用于将所述服务器端的脑电信号分类模型发送给K个用户端;A sending module, configured to send the EEG signal classification model at the server end to K client ends;
    接收模块,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;A receiving module, configured to receive a local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG classification model;
    获取模块,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;An acquisition module, configured to acquire the importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals;
    第一确定模块,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;A first determination module, configured to determine a plurality of target user terminals from the K user terminals according to the importance evaluation values of the K user terminals;
    更新模块,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;An update module, configured to update the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target client terminals;
    第二确定模块,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。The second determination module is used to return to execute the step of sending the server-side EEG signal classification model to K clients if the server-side EEG signal classification model does not converge until the server-side EEG signal classification model Signal classification model converges.
  9. 一种服务器,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述的方法。A server, comprising a memory, a processor, and a computer program stored in the memory and operable on the processor, wherein the computer program according to claims 1 to 7 is implemented when the processor executes the computer program. any one of the methods described.
  10. 一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的方法。A computer-readable storage medium storing a computer program, wherein the computer program implements the method according to any one of claims 1 to 7 when executed by a processor.
PCT/CN2021/138013 2021-11-15 2021-12-14 Federated learning-based electroencephalogram signal classification model training method and device WO2023082406A1 (en)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202111347340.8A CN114048780A (en) 2021-11-15 2021-11-15 Electroencephalogram classification model training method and device based on federal learning
CN202111347340.8 2021-11-15

Publications (1)

Publication Number Publication Date
WO2023082406A1 true WO2023082406A1 (en) 2023-05-19

Family

ID=80208990

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/138013 WO2023082406A1 (en) 2021-11-15 2021-12-14 Federated learning-based electroencephalogram signal classification model training method and device

Country Status (2)

Country Link
CN (1) CN114048780A (en)
WO (1) WO2023082406A1 (en)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114664434A (en) * 2022-03-28 2022-06-24 上海韶脑传感技术有限公司 Cerebral apoplexy rehabilitation training system for different medical institutions and training method thereof
CN117708681B (en) * 2024-02-06 2024-04-26 南京邮电大学 Personalized federal electroencephalogram signal classification method and system based on structural diagram guidance

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150324690A1 (en) * 2014-05-08 2015-11-12 Microsoft Corporation Deep Learning Training System
CN111814985A (en) * 2020-06-30 2020-10-23 平安科技(深圳)有限公司 Model training method under federated learning network and related equipment thereof
CN112181666A (en) * 2020-10-26 2021-01-05 华侨大学 Method, system, equipment and readable storage medium for equipment evaluation and federal learning importance aggregation based on edge intelligence
CN112633146A (en) * 2020-12-21 2021-04-09 杭州趣链科技有限公司 Multi-pose face gender detection training optimization method and device and related equipment
CN113158241A (en) * 2021-04-06 2021-07-23 深圳市洞见智慧科技有限公司 Post recommendation method and device based on federal learning

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150324690A1 (en) * 2014-05-08 2015-11-12 Microsoft Corporation Deep Learning Training System
CN111814985A (en) * 2020-06-30 2020-10-23 平安科技(深圳)有限公司 Model training method under federated learning network and related equipment thereof
CN112181666A (en) * 2020-10-26 2021-01-05 华侨大学 Method, system, equipment and readable storage medium for equipment evaluation and federal learning importance aggregation based on edge intelligence
CN112633146A (en) * 2020-12-21 2021-04-09 杭州趣链科技有限公司 Multi-pose face gender detection training optimization method and device and related equipment
CN113158241A (en) * 2021-04-06 2021-07-23 深圳市洞见智慧科技有限公司 Post recommendation method and device based on federal learning

Also Published As

Publication number Publication date
CN114048780A (en) 2022-02-15

Similar Documents

Publication Publication Date Title
WO2023082406A1 (en) Federated learning-based electroencephalogram signal classification model training method and device
CN109309630B (en) Network traffic classification method and system and electronic equipment
CN108595585B (en) Sample data classification method, model training method, electronic equipment and storage medium
Duggal et al. Prediction of thyroid disorders using advanced machine learning techniques
CN103399896B (en) The method and system of incidence relation between identification user
CN109497990B (en) Electrocardiosignal identity recognition method and system based on canonical correlation analysis
WO2019192118A1 (en) Edge computing-based health monitoring method and apparatus, device, and storage medium
Alqahtani et al. Breast cancer pathological image classification based on the multiscale CNN squeeze model
CN105446741B (en) A kind of mobile applications discrimination method compared based on API
Vhaduri et al. Biometric-based wearable user authentication during sedentary and non-sedentary periods
CN109817339A (en) Patient's group technology and device based on big data
CN110489659A (en) Data matching method and device
CN114973330A (en) Cross-scene robust personnel fatigue state wireless detection method and related equipment
WO2021114818A1 (en) Method, system, and device for oct image quality evaluation based on fourier transform
WO2023134060A1 (en) Information pushing method and apparatus based on drug molecule image classification
Chen et al. Patient emotion recognition in human computer interaction system based on machine learning method and interactive design theory
CN104679967A (en) Method for judging reliability of psychological test
Katti et al. Are you from North or South India? A hard face-classification task reveals systematic representational differences between humans and machines
CN110348326A (en) The family health care information processing method of the identification of identity-based card and the access of more equipment
CN111651755A (en) Intrusion detection method and device
Gururaj et al. Fundus image features extraction for exudate mining in coordination with content based image retrieval: A study
CN110351303A (en) A kind of DDoS feature extracting method and device
Kong et al. Task-free brainprint recognition based on degree of brain networks
Araújo et al. Generic biometry algorithm based on signal morphology information: Application in the electrocardiogram signal
CN107832690B (en) Face recognition method and related product