WO2023082406A1 - 基于联邦学习的脑电信号分类模型训练方法及装置 - Google Patents

基于联邦学习的脑电信号分类模型训练方法及装置 Download PDF

Info

Publication number
WO2023082406A1
WO2023082406A1 PCT/CN2021/138013 CN2021138013W WO2023082406A1 WO 2023082406 A1 WO2023082406 A1 WO 2023082406A1 CN 2021138013 W CN2021138013 W CN 2021138013W WO 2023082406 A1 WO2023082406 A1 WO 2023082406A1
Authority
WO
WIPO (PCT)
Prior art keywords
client
classification model
signal classification
eeg signal
local
Prior art date
Application number
PCT/CN2021/138013
Other languages
English (en)
French (fr)
Inventor
郑青青
陈彦锋
王琼
Original Assignee
中国科学院深圳先进技术研究院
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 中国科学院深圳先进技术研究院 filed Critical 中国科学院深圳先进技术研究院
Publication of WO2023082406A1 publication Critical patent/WO2023082406A1/zh

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F2218/00Aspects of pattern recognition specially adapted for signal processing
    • G06F2218/12Classification; Matching
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F3/00Input arrangements for transferring data to be processed into a form capable of being handled by the computer; Output arrangements for transferring data from processing unit to output unit, e.g. interface arrangements
    • G06F3/01Input arrangements or combined input and output arrangements for interaction between user and computer
    • G06F3/011Arrangements for interaction with the human body, e.g. for user immersion in virtual reality
    • G06F3/015Input arrangements based on nervous system activity detection, e.g. brain waves [EEG] detection, electromyograms [EMG] detection, electrodermal response detection
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F2218/00Aspects of pattern recognition specially adapted for signal processing
    • G06F2218/08Feature extraction

Definitions

  • the present application belongs to the field of biological information technology, and in particular relates to a federated learning-based brain electrical signal classification model training method and device.
  • the brain-computer interface (BCI, Brain Computer Interface) based on emotion recognition can identify the user's true emotional state and intention by collecting the user's EEG signal in the emotional interaction experiment, and extracting and decoding the EEG signal's features.
  • Sentiment analysis based on EEG signals has a wide range of application scenarios, such as auxiliary diagnosis of affective disorders and psychotherapy interventions such as depression.
  • Emotion recognition models based on deep learning are often data-driven and require a large amount of training data.
  • EEG Electroencephalographic
  • EEG data are often scattered among various users in the form of multiple small data sets.
  • existing methods focus on sharing data between different users, using technologies such as knowledge transfer and domain adaptation to effectively use useful information from other users and improve the emotion recognition rate of target users.
  • the EEG signals contain private information such as people's identity characteristics, thoughts and emotions, once they are misused or illegally read and disseminated, personal privacy will be leaked.
  • the EEG signal classification models mainly include: EEG signal classification models based on EEGNet (EEGNet is a general and compact convolutional neural network designed for general EEG recognition tasks), and federated transfer learning (FTL, Federate Transfer Learning) EEG signal classification model.
  • EEGNet uses the original EEG signal as input to train an end-to-end competitive emotion recognition network for each user.
  • FTL Federate Transfer Learning
  • EEGNet uses the original EEG signal as input to train an end-to-end competitive emotion recognition network for each user.
  • FTL Federate Transfer Learning
  • the network based on EEGNet training can only use The local data of each user trains the emotion recognition network separately, ignoring the data of other users and the effective information that can be provided, resulting in the problem of data waste.
  • the FTL-based method utilizes federated learning to effectively utilize other users' data information, it also meets the needs of users' local data not to be shared.
  • this method uses the spatial covariance matrix of the EEG signal as input, and part of the effective information of the original EEG signal is lost.
  • FTL relies on the federated averaging algorithm, which randomly selects the gradients of some local models in the process of joint training, and updates the gradients of the server through indiscriminate simple average aggregation, ignoring the data quality and importance of different users. This will cause the gradient change of the server model to be updated each time to be unstable, which is not conducive to the accuracy of the shared model (that is, the EEG signal classification model), and the convergence speed is often slow, which makes the model training difficult.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training method and device, which can solve the problems of low precision and slow convergence speed of the EEG signal classification model.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training method, which is applied to the server side, and the method includes:
  • the local model gradient is obtained by the client using a local training set to train the EEG classification model
  • the step of determining a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals includes:
  • a preset proportion of user terminals is selected from the K said user terminals as target user terminals.
  • the step of updating the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target clients includes:
  • the network parameters of the EEG signal classification model at the server end are updated according to the updated global gradient.
  • the step of updating the global gradient of the server according to the normalized importance evaluation value and the local model gradients of all target clients includes:
  • the step of normalizing the importance evaluation value of each target client includes:
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • the step of obtaining the importance evaluation value of each client according to the local model gradient of each client includes:
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the method also includes:
  • the EEG signal classification model at the server end When the EEG signal classification model at the server end converges, the EEG signal classification model at the server end is delivered to the K client ends.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training device, which is applied to the server side, and the device includes:
  • a sending module configured to send the EEG signal classification model at the server end to K client ends;
  • a receiving module configured to receive a local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG classification model;
  • An acquisition module configured to acquire the importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals
  • a first determination module configured to determine a plurality of target user terminals from the K user terminals according to the importance evaluation values of the K user terminals
  • An update module configured to update the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the plurality of target client terminals;
  • the second determination module is used to return to execute the step of sending the server-side EEG signal classification model to K clients if the server-side EEG signal classification model does not converge until the server-side EEG signal classification model Signal classification model converges.
  • the above-mentioned first determination module 304 is specifically configured to select a preset proportion of user terminals from the K user terminals as target user terminals in descending order of importance evaluation values.
  • the above-mentioned update module 305 includes:
  • a processing unit configured to perform normalization processing on the importance evaluation value of each target client
  • the first update unit is configured to update the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
  • the second update unit is configured to update the network parameters of the EEG signal classification model at the server end according to the updated global gradient.
  • the above-mentioned first update unit is specifically used to pass the formula updating the global gradient on the server side;
  • the above processing unit is specifically used to pass the formula Normalize the importance evaluation value of each selected client;
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the above-mentioned EEG signal classification model training device also includes:
  • the sending module is configured to send the EEG signal classification model at the server end to the K client terminals when the EEG signal classification model at the server end converges.
  • an embodiment of the present application provides a server, including a memory, a processor, and a computer program stored in the memory and operable on the processor, and the computer program is implemented when the processor executes the computer program. the above method.
  • an embodiment of the present application provides a computer-readable storage medium, where the computer-readable storage medium stores a computer program, and when the computer program is executed by a processor, the foregoing method is implemented.
  • an embodiment of the present application provides a computer program product, which, when the computer program product is run on a terminal device, causes the terminal device to execute the method described in any one of the foregoing first aspects.
  • joint training and distributed training can be realized to fully utilize all users' effective The effect of information on improving the accuracy of EEG signal classification models.
  • the target client since the target client is not randomly selected, but through the importance evaluation value of each client, the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient of the target client and The importance evaluation value updates the network parameters of the EEG signal classification model on the server side, thereby improving the accuracy and convergence speed of the EEG signal classification model.
  • Fig. 1 is the flow chart of the federated learning-based EEG classification model training method provided by an embodiment of the present application
  • Fig. 2 is a flowchart of step 15 provided by an embodiment of the present application.
  • FIG. 3 is a schematic structural diagram of a federated learning-based EEG classification model training device provided by an embodiment of the present application
  • Fig. 4 is a schematic structural diagram of a server provided by an embodiment of the present application.
  • the term “if” may be construed, depending on the context, as “when” or “once” or “in response to determining” or “in response to detecting “.
  • the phrase “if determined” or “if [the described condition or event] is detected” may be construed, depending on the context, to mean “once determined” or “in response to the determination” or “once detected [the described condition or event] ]” or “in response to detection of [described condition or event]”.
  • references to "one embodiment” or “some embodiments” or the like in the specification of the present application means that a particular feature, structure, or characteristic described in connection with the embodiment is included in one or more embodiments of the present application.
  • appearances of the phrases “in one embodiment,” “in some embodiments,” “in other embodiments,” “in other embodiments,” etc. in various places in this specification are not necessarily All refer to the same embodiment, but mean “one or more but not all embodiments” unless specifically stated otherwise.
  • the terms “including”, “comprising”, “having” and variations thereof mean “including but not limited to”, unless specifically stated otherwise.
  • the EEG signal classification models mainly include the EEGNet-based EEG signal classification model and the FTL-based EEG signal classification model.
  • the accuracy of the EEG classification model based on EEGNet is low, while the convergence speed of the EEG classification model based on FTL is slow and the accuracy is not ideal.
  • the embodiment of the present application is based on the federated learning framework, by sending the server-side EEG signal classification model to K client terminals during distributed training, so that each client terminal can use the local training set to analyze the received EEG signal
  • the classification model is trained, and the local model gradient obtained by training is sent to the server for joint training, so that joint training and its distributed Training has achieved the effect of improving the accuracy of the EEG signal classification model while making full use of the effective information of all users.
  • the target client since the target client is not randomly selected, but through the importance evaluation value of each client, the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient of the target client and The importance evaluation value updates the network parameters of the EEG signal classification model on the server side, thereby improving the accuracy and convergence speed of the EEG signal classification model.
  • the federated learning-based EEG classification model training method provided by the present application will be exemplarily described below in conjunction with specific embodiments.
  • the embodiment of the present application provides a federated learning-based EEG signal classification model training method, which is applied to the server side, and the method includes the following steps:
  • Step 11 sending the EEG signal classification model at the server end to K client ends.
  • the above-mentioned K clients are clients that participate in federated learning with the above-mentioned server.
  • the server end can initialize an EEG signal classification model (ie, the above steps EEG classification model in 11).
  • the weight of the model can be initialized to 0, or other common initialization schemes can be used, such as Gaussian and Xavier initialization (Xavier initialization is a neural network initialization method).
  • the above-mentioned EEG signal classification model may be an EEGNet model, of course, it may also be other deep learning networks, such as a convolutional neural network (ConvNet) and other EEG signal classification neural networks.
  • ConvNet convolutional neural network
  • Step 12 receiving the local model gradient sent by each client.
  • the aforementioned local model gradient is obtained by the user terminal using a local training set to train the EEG signal classification model.
  • the server after receiving the EEG signal classification model issued by the server, it will use the local training set of the client to analyze the received EEG signal
  • the classification model is trained, and the local model gradient is obtained when the EEG classification model converges.
  • Step 13 according to the local model gradient of each client, obtain the importance evaluation value of each client.
  • the above-mentioned importance evaluation value is mainly used to characterize the importance degree of the client end, so that subsequently, according to the order of importance from high to low, select from the K End EEG signal classification model) to jointly train the target user end to improve the accuracy and convergence speed of the EEG signal classification model.
  • Step 14 Determine a plurality of target client terminals from the K client terminals according to the importance evaluation values of the K client terminals.
  • a preset proportion of user terminals can be selected from the K user terminals as the target user terminals, thereby screening from the K user terminals Target users with high importance.
  • the specific numerical value of the above preset ratio can be set according to the actual situation.
  • the importance of the above-mentioned target client is higher than that of other client in the K client, that is, the target client pairs the shared model (ie, the EEG signal classification model at the server end) The contribution is greater than the contribution of other clients to the shared model. Subsequent joint training using these target clients can improve the accuracy and convergence speed of the EEG signal classification model.
  • Step 15 Update the network parameters of the EEG signal classification model at the server end according to the local model gradients and importance evaluation values of the multiple target client ends.
  • the performance of the server-side EEG signal classification model in the joint training, by updating the network parameters of the server-side EEG signal classification model according to the target client's local model gradient and importance evaluation value, the performance of the server-side EEG signal classification model can be improved. accuracy and convergence speed.
  • Step 16 if the EEG signal classification model at the server end does not converge, return to the step of sending the EEG signal classification model at the server end to K clients until the EEG signal classification model at the server end converges .
  • the EEG signal classification model converged in step 16 above is a shared model, which can be used to classify EEG signals of any user.
  • step 15 if the server-side EEG signal classification model does not converge, then return to step 11 to update the network parameters of the server-side EEG signal classification model again until the server-side End EEG signal classification model converges.
  • the updated EEG signal classification model is Share the model, otherwise, send the EEG signal classification model after updating the network parameters to K clients, so that K clients use their own local training sets to train the received EEG signal classification model, and the obtained Local model gradients to update the network parameters of the server-side EEG signal classification model again.
  • the local training set data of the client is not directly used, but the local model gradient of the client is used to jointly train the EEG classification model of the server, so that The privacy and use security of the local data of the client are guaranteed.
  • joint training and distributed training can be realized, and the effective utilization of all users can be achieved. The effect of information on improving the accuracy of EEG signal classification models.
  • the target client that contributes the most to the shared model is selected from all the clients, and based on the local model gradient and importance of the target client.
  • the network parameters of the EEG signal classification model on the server side are updated, thereby improving the accuracy and convergence speed of the EEG signal classification model.
  • the above method further includes the following steps: when the server-side EEG signal classification model converges, sending the server-side EEG signal classification model to the The above K clients.
  • the client can use its own local training set to train the EEG signal classification model to fine-tune the model parameters of the EEG signal classification model, An EEG signal classification model that is more suitable for the client terminal is obtained, and the client terminal can then use the fine-tuned EEG signal classification model to classify the user data of the client terminal to improve classification accuracy.
  • the local training set of the user terminal may be derived from Shanghai Jiaotong University Emotional EEG Dataset (SEED).
  • SEED Shanghai Jiaotong University Emotional EEG Dataset
  • 15 screened Chinese movie clips were selected as emotional stimuli in the experiment, and the labels included positive, neutral, and negative emotions.
  • a total of 15 Chinese subjects (including 7 boys and 8 girls) were collected in this dataset, and each subject conducted 3 experiments.
  • Each sample in this dataset contains 62 electrode channels, downsampled to 200 Hz, and a bandpass frequency filter of 0–75 Hz is applied.
  • the embodiment of the present application selects 32 channels related to emotion, corresponding to Fp1, AF3, F3, F7, FC5, FC1, C3, T7, CP5, CP1, P3, P7, PO3, O1 , Oz, Pz, Fp2, AF4, Fz, F4, F8, FC6, FC2, Cz, C4, T8, CP6, CP2, P4, P8, PO4, O2.
  • each sample has a size of 32 ⁇ 200.
  • the data of 32 channels of any one of the 15 subjects can be used as a local training set of the client.
  • the user terminal can use all the data in the local training set to train the EEG signal classification model each time. It should be further explained that the local training set corresponding to each client is different.
  • the EEG signal classification model adopts the EEGNet model to extract the feature representation and classification of the EEG signal.
  • the parameters of the feature extractor and classifier model in this application are shown in Table 1.
  • Table 1 the number of convolutional layers, the size of the convolutional kernel, the pooling method, and the activation function can all be set according to the actual situation.
  • the cross entropy (cross entropy) loss function can be used to evaluate the training result, wherein the training loss function of the kth user end is as follows: Among them, n k represents the local sample size contained in the local training set of the kth client, y i is the real label of the training sample (that is, the local sample in the local training set), is the predicted label. It should be noted that the above training loss function is a commonly used loss function, so here, the principle of the training loss function will not be described in detail.
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the importance of the user terminal may also be measured by other similarity measurement learning methods or attention mechanism algorithms.
  • the above step 15 according to the local model gradients and importance evaluation values of the multiple target client terminals, updates the network parameters of the EEG signal classification model at the server end.
  • the specific implementation method includes the following steps:
  • Step 21 performing normalization processing on the importance evaluation value of each target client.
  • the formula Perform normalization processing on the importance evaluation value of each selected client.
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • Step 22 update the global gradient of the server according to the normalized importance evaluation value and the local model gradients of all target clients.
  • Step 23 Update network parameters of the server-side EEG signal classification model according to the updated global gradient.
  • a stochastic gradient descent method based on stochastic gradient descent may be used to solve network parameters.
  • SGD stochastic gradient descent
  • the global gradient of the server side will also be initialized to 0.
  • the federated learning-based EEG signal classification model training method provided by the embodiment of the present application has the following effects:
  • the EEG signal classification model adopts the EEGNet model, which is applied to the classification task of emotional EEG signals. It does not need to manually extract signal features, and can perform end-to-end feature extraction and classification of emotional EEG signals;
  • the federated learning-based EEG classification model training device provided by the present application will be exemplarily described below in conjunction with specific embodiments.
  • the EEG signal classification model training device 300 includes:
  • Sending module 301 for sending the EEG signal classification model of the server end to K client ends;
  • the receiving module 302 is configured to receive the local model gradient sent by each client; the local model gradient is obtained by the client using a local training set to train the EEG signal classification model;
  • An acquisition module 303 configured to acquire the importance evaluation value of each of the user terminals according to the local model gradient of each of the user terminals;
  • the first determination module 304 is configured to determine a plurality of target user terminals from the K user terminals according to the importance evaluation values of the K user terminals;
  • An update module 305 configured to update the network parameters of the EEG classification model on the server side according to the local model gradients and importance evaluation values of the multiple target client terminals;
  • the second determination module 306 is configured to return to execute the step of sending the server-side EEG signal classification model to K client terminals if the server-side EEG signal classification model does not converge until the server-side EEG signal classification model Electrical signal classification model converges.
  • the above-mentioned first determination module 304 is specifically configured to select a preset proportion of user terminals from the K user terminals as target user terminals in descending order of importance evaluation values.
  • the above-mentioned update module 305 includes:
  • a processing unit configured to perform normalization processing on the importance evaluation value of each target client
  • the first update unit is configured to update the global gradient of the server according to the importance evaluation value after normalization processing and the local model gradients of all target clients;
  • the second update unit is used to update the network parameters of the EEG signal classification model at the server end according to the updated global gradient.
  • the above-mentioned first update unit is specifically used to pass the formula updating the global gradient on the server side;
  • the above processing unit is specifically used to pass the formula Normalize the importance evaluation value of each selected client;
  • ⁇ k represents the importance evaluation value of the kth client
  • C represents a preset ratio
  • K represents the number of clients.
  • ⁇ k represents the importance evaluation value of the kth client
  • ⁇ k n k /n
  • nk the local sample size contained in the local training set of the kth client
  • n the sum of the local samples contained in the local training set of K clients
  • K represents the number of clients
  • t is an integer greater than 0.
  • the above-mentioned EEG signal classification model training device also includes:
  • the sending module is configured to send the EEG signal classification model at the server end to the K client terminals when the EEG signal classification model at the server end converges.
  • an embodiment of the present application provides a server, as shown in Figure 4, the server D10 of this embodiment includes: at least one processor D100 (only one processor is shown in Figure 4), a memory D101 And a computer program D102 stored in the memory D101 and operable on the at least one processor D100, when the processor D100 executes the computer program D102, the steps in any of the above method embodiments are implemented.
  • the so-called processor D100 can be a central processing unit (CPU, Central Processing Unit), and the processor D100 can also be other general processors, digital signal processors (DSP, Digital Signal Processor), application specific integrated circuits (ASIC, Application Specific Integrated Circuit), off-the-shelf programmable gate array (FPGA, Field-Programmable Gate Array) or other programmable logic devices, discrete gate or transistor logic devices, discrete hardware components, etc.
  • DSP digital signal processors
  • ASIC Application Specific Integrated Circuit
  • FPGA Field-Programmable Gate Array
  • a general-purpose processor may be a microprocessor, or the processor may be any conventional processor, or the like.
  • the storage D101 may be an internal storage unit of the server D10 in some embodiments, such as a hard disk or a memory of the server D10.
  • the memory D101 can also be an external storage device of the server D10 in other embodiments, such as a plug-in hard disk equipped on the server D10, a smart memory card (SMC, Smart Media Card), a secure digital (SD , Secure Digital) card, flash memory card (Flash Card), etc.
  • the storage D101 may also include both an internal storage unit of the server D10 and an external storage device.
  • the memory D101 is used to store operating systems, application programs, boot loaders (BootLoader), data and other programs, such as program codes of the computer programs.
  • the memory D101 can also be used to temporarily store data that has been output or will be output.
  • the embodiment of the present application also provides a computer-readable storage medium, the computer-readable storage medium stores a computer program, and when the computer program is executed by a processor, the steps in each of the foregoing method embodiments can be realized.
  • An embodiment of the present application provides a computer program product.
  • the terminal device can implement the steps in the foregoing method embodiments when executed.
  • the integrated unit is realized in the form of a software function unit and sold or used as an independent product, it can be stored in a computer-readable storage medium. Based on this understanding, all or part of the procedures in the method of the above-mentioned embodiments in the present application can be completed by instructing related hardware through a computer program.
  • the computer program can be stored in a computer-readable storage medium.
  • the computer program When executed by a processor, the steps in the above-mentioned various method embodiments can be realized.
  • the computer program includes computer program code, and the computer program code may be in the form of source code, object code, executable file or some intermediate form.
  • the computer-readable medium may at least include: any entity or device, recording medium, computer memory, read-only memory (ROM, Read-Only Memory) capable of carrying computer program codes to the EEG signal classification model training device/terminal equipment , Random Access Memory (RAM, Random Access Memory), electrical carrier signal, telecommunication signal and software distribution medium.
  • ROM read-only memory
  • RAM Random Access Memory
  • electrical carrier signal telecommunication signal and software distribution medium.
  • U disk mobile hard disk, magnetic disk or optical disk, etc.
  • computer readable media may not be electrical carrier signals and telecommunication signals under legislation and patent practice.
  • the disclosed device/network device and method may be implemented in other ways.
  • the device/network device embodiments described above are only illustrative.
  • the division of the modules or units is only a logical function division.
  • the mutual coupling or direct coupling or communication connection shown or discussed may be through some interfaces, and the indirect coupling or communication connection of devices or units may be in electrical, mechanical or other forms.
  • the units described as separate components may or may not be physically separated, and the components displayed as units may or may not be physical units, that is, they may be located in one place, or may be distributed to multiple network units. Part or all of the units can be selected according to actual needs to achieve the purpose of the solution of this embodiment.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Physics & Mathematics (AREA)
  • Health & Medical Sciences (AREA)
  • Signal Processing (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Artificial Intelligence (AREA)
  • Biomedical Technology (AREA)
  • Dermatology (AREA)
  • General Health & Medical Sciences (AREA)
  • Neurology (AREA)
  • Neurosurgery (AREA)
  • Human Computer Interaction (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Measurement And Recording Of Electrical Phenomena And Electrical Characteristics Of The Living Body (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请适用于生物信息技术领域,提供了一种基于联邦学习的脑电信号分类模型训练方法及装置,该方法包括:将服务器端的脑电信号分类模型发送给K个用户端;接收每个用户端发送的本地模型梯度;根据用户端的本地模型梯度,获取用户端的重要性评估值;根据K个用户端的重要性评估值,从K个用户端中确定多个目标用户端;根据目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数;若服务器端的脑电信号分类模型未收敛,则返回将服务器端的脑电信号分类模型发送给K个用户端的步骤,直至服务器端的脑电信号分类模型收敛。本申请能在充分利用所有用户的有效信息情况下,提升脑电信号分类模型的精度及收敛速度。

Description

基于联邦学习的脑电信号分类模型训练方法及装置 技术领域
本申请属于生物信息技术领域,尤其涉及一种基于联邦学习的脑电信号分类模型训练方法及装置。
背景技术
基于情感识别的脑机接口(BCI,Brain Computer Interface)通过在情感交互实验中采集用户的脑电信号,并对脑电信号进行特征提取和解码,可以识别用户真正的情感状态和意图,从而实现用户和设备间的友好通信及交互。基于脑电信号的情感分析有广泛的应用场景,例如情感障碍疾病的辅助诊断和抑郁症等心理治疗干预等。
基于深度学习的情感识别模型往往是数据驱动,要求有大量的训练数据。然而由于脑电图(EEG,Electroencephalographic)信号的采集过程繁琐及个体间差异性巨大的特点,EEG数据往往以多个小数据集的形式分散存在于各个用户。为了构建高精度的情感识别模型,现有方法致力于通过共享不同用户之间的数据,利用知识迁移和领域自适应等技术来有效利用其他用户的有用信息和提升目标用户的情感识别率。但在数据共享的过程中,如果包含了人的身份特征及思想情感等私密信息的脑电信号,一旦被滥用或者非法阅读传播,将造成个人隐私的泄露。
目前脑电信号分类模型主要有:基于EEGNet(EEGNet是为专门一般的脑电图识别任务而设计的通用紧凑的卷积神经网络)的脑电信号分类模型,和基于联邦迁移学***均算法,该算法在联合训练的过程中,随机选择部分本地模型的梯度,通过无区别的简单平均聚合来更新服务器的梯度,忽略了不同用户的数据质量和重要性,这将导致每次更新服务器模型的梯度变化不稳定,不利于共享模型(即脑电信号分类模型)的精度,而且往往收敛速度慢,给模型训练造成一定难度。
发明内容
本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练方法及装置,可以解决脑电信号分类模型的精度低、且收敛速度慢的问题。
第一方面,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练方法,应用于服务器端,该方法包括:
将所述服务器端的脑电信号分类模型发送给K个用户端;
接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务 器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
其中,所述根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端的步骤,包括:
按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
其中,所述根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的步骤,包括:
对每个所述目标用户端的重要性评估值进行归一化处理;
根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
其中,所述根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度的步骤,包括:
通过公式
Figure PCTCN2021138013-appb-000001
更新所述服务器端的全局梯度;
其中,
Figure PCTCN2021138013-appb-000002
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000003
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000004
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,所述对每个所述目标用户端的重要性评估值进行归一化处理的步骤,包括:
通过公式
Figure PCTCN2021138013-appb-000005
对选择出的每个用户端的重要性评估值进行归一化处理;
其中,
Figure PCTCN2021138013-appb-000006
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
其中,所述根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的步骤,包括:
通过公式μ k=α k×β k,计算第k个用户端的重要性评估值;
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000007
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000008
Figure PCTCN2021138013-appb-000009
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000010
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,所述方法还包括:
在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
第二方面,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练装置,应用于服务器端,该装置包括:
发送模块,用于将所述服务器端的脑电信号分类模型发送给K个用户端;
接收模块,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
获取模块,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
第一确定模块,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
更新模块,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
第二确定模块,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
其中,上述第一确定模块304,具体用于按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
其中,上述更新模块305包括:
处理单元,用于对每个所述目标用户端的重要性评估值进行归一化处理;
第一更新单元,用于根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
第二更新单元,用于根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
其中,上述第一更新单元,具体用于通过公式
Figure PCTCN2021138013-appb-000011
更新所述服务器端的全局梯度;
其中,
Figure PCTCN2021138013-appb-000012
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000013
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000014
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,上述处理单元,具体用于通过公式
Figure PCTCN2021138013-appb-000015
对选择出的每个用户端的重要性评估值进行归一化处理;
其中,
Figure PCTCN2021138013-appb-000016
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
其中,上述获取模块303,具体用于通过公式μ k=α k×β k,计算第k个用户端的重要性评估值;
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000017
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000018
Figure PCTCN2021138013-appb-000019
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000020
表示第t轮更新时 第k个用户端的本地模型梯度,t为大于0的整数。
其中,上述脑电信号分类模型训练装置还包括:
下发模块,用于在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
第三方面,本申请实施例提供了一种服务器,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述的方法。
第四方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述的方法。
第五方面,本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得终端设备执行上述第一方面中任一项所述的方法。
本申请实施例与现有技术相比存在的有益效果是:
在本申请的实施例中,通过基于联邦学习框架,在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,即可实现联合训练及其分布式训练,达到充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果。同时在联合训练中,由于不是随机挑选目标用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一实施例提供的基于联邦学习的脑电信号分类模型训练方法的流程图;
图2是本申请一实施例提供的步骤15的流程图;
图3是本申请一实施例提供的基于联邦学习的脑电信号分类模型训练装置的结构示意图;
图4是本申请一实施例提供的服务器的结构示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定***结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的***、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
应当理解,当在本申请说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本申请说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
另外,在本申请说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
在本申请说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本申请的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。
目前脑电信号分类模型主要有基于EEGNet的脑电信号分类模型和基于FTL的脑电信号分类模型。但基于EEGNet的脑电信号分类模型的精度低,而基于FTL的脑电信号分类模型的收敛速度慢、且精度不理想。
针对上述问题,本申请实施例基于联邦学习框架,通过在分布式训练中,将服务器端的脑电信号分类模型发送给K个用户端,使各用户端利用本地训练集对接收到的脑电信号分类模型进行训练,并将训练得到的本地模型梯度发送给服务器端,以进行联合训练,从而在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,实现了联合训练及其分布式训练,达到了在充分利用所有用户的有效信息的情况下,提升脑电信号分类模型的精度的效果。
同时在联合训练中,由于不是随机挑选目标用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。
下面结合具体实施例对本申请提供的基于联邦学习的脑电信号分类模型训练方法进行示例性的说明。
如图1所示,本申请的实施例提供了一种基于联邦学习的脑电信号分类模型训练方法,应用于服务器端,该方法包括如下步骤:
步骤11,将所述服务器端的脑电信号分类模型发送给K个用户端。
在本申请的一些实施例中,上述K个用户端为与上述服务器端参与联邦学 习的用户端。需要说明的是,为确保最终得到的脑电信号分类模型是基于用户端的有用户的有效信息得到的,在执行上述训练方法的步骤之前,服务器端可以初始化一个脑电信号分类模型(即上述步骤11中的脑电信号分类模型)。具体的,可以将模型权重初始化为0,也可以采用其他常见的初始化方案,例如高斯、Xavier初始化(Xavier初始化是一种神经网络初始化方法)。
其中,上述脑电信号分类模型可以为EEGNet模型,当然也可以是其他的深度学习网络,例如卷积神经网络(ConvNet)等脑电信号分类神经网络。
步骤12,接收每个所述用户端发送的本地模型梯度。
在本申请的一些实施例中,上述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的。
即,在本申请的一些实施例中,对于参与联邦学习的每个用户端,在接收到服务器端下发的脑电信号分类模型后,会利用用户端的本地训练集对接收到的脑电信号分类模型进行训练,并在该脑电信号分类模型收敛时得到本地模型梯度。
步骤13,根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值。
在本申请的一些实施例中,上述重要性评估值主要用于表征用户端的重要性程度,以便后续按照重要性从高至低的顺序,从K个用户端中选择出对共享模型(即服务器端的脑电信号分类模型)贡献大的目标用户端进行联合训练,从提升脑电信号分类模型的精度和收敛速度。
步骤14,根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端。
在本申请的一些实施例中,可按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端,从而从K个用户端中筛选出重要性程度高的目标用户端。其中,上述预设比例的具体数值可根据实际情况进行设定。
可见,在本申请的一些实施例中,上述目标用户端的重要性程度高于K个用户端中其他用户端的重要性程度,即,目标用户端对共享模型(即服务器端的脑电信号分类模型)贡献大于其他用户端对共享模型的贡献,后续利用这些目标用户端进行联合训练,能提升脑电信号分类模型的精度和收敛速度。
步骤15,根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数。
在本申请的一些实施例中,在联合训练中,通过根据目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,能提升服务器端的脑电信号分类模型的精度和收敛速度。
步骤16,若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
在本申请的一些实施例中,上述步骤16中收敛的脑电信号分类模型是共享模型,可用于对任一用户的脑电信号进行分类。
在本申请的一些实施例中,在执行完步骤15后,若服务器端的脑电信号分类模型未收敛,则返回步骤11,以再次更新服务器端的脑电信号分类模型的网络参数,直至所述服务器端的脑电信号分类模型收敛。
需要说明的是,每次更新服务器端的脑电信号分类模型的网络参数后,都需要判断更新网络参数后的脑电信号分类模型是否收敛,若收敛,则更新后的脑电信号分类模型即为共享模型,否则,将更新网络参数后的脑电信号分类模型下发给K个用户端,使K个用户端分别利用自身的本地训练集对接收到的脑电信号分类模型进行训练,得到的本地模型梯度,以再次更新服务器端的脑电信号分类模型的网络参数。
值得一提的是,在本申请的一些实施例中,在联合训练中,不直接使用用户端的本地训练集数据,而是使用用户端的本地模型梯度来共同训练服务器端的脑电信号分类模型,从而保障了用户端本地数据的隐私和使用安全性,在满 足数据安全、无需共享或者交换各个用户端本地数据的前提下,即可实现联合训练及其分布式训练,达到了充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果。
同时在联合训练中,由于不是随机挑选用户端,而是通过各用户端的重要性评估值,从所有用户端中选择对共享模型贡献大的目标用户端,并基于目标用户端的本地模型梯度和重要性评估值,更新服务器端的脑电信号分类模型的网络参数,从而提升了脑电信号分类模型的精度及收敛速度。
在本申请的实施例中,在执行完上述步骤16后,上述方法还包括如下步骤:在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
需要说明的是,用户端在接收到该脑电信号分类模型后,可利用自身的用本地训练集对该脑电信号分类模型进行训练,以对该脑电信号分类模型的模型参数进行微调,得到更适合该用户端的脑电信号分类模型,后续该用户端可利用微调后的脑电信号分类模型对该用户端的用户数据进行分类,提高分类准确性。
接下来,结合具体实施例对用户端利用本地训练集对脑电信号分类模型进行训练的过程的进行示例性的说明。
在本申请的一些实施例中,用户端的本地训练集可来源于上海交大情感脑电数据集(SEED)。在该数据集的实验中,15个筛选过的中国电影片段被选取为实验中的情感刺激源,标签包括正面、中性和负面情绪。该数据集一共采集了15名中国受试者(包括7名男生和8名女生),其中每个受试者分别进行3次实验。该数据集中的每个样本包含62个电极通道,下采样到200Hz,并且应用了0-75Hz的带通频率滤波器。为了扩展数据量,我们将每个数据按照1s的数据窗口进行不重叠切割,最终获取3394个样本。在采集的62个通道中,本申请实施例选择与情感相关的32个通道,分别对应Fp1,AF3,F3,F7,FC5,FC1,C3,T7,CP5,CP1,P3,P7,PO3,O1,Oz,Pz,Fp2, AF4,Fz,F4,F8,FC6,FC2,Cz,C4,T8,CP6,CP2,P4,P8,PO4,O2。为此,每个样本的大小为32×200。需要说明的是,在本申请的一些实施例中,可将15名受试者中任一受试者的32个通道的数据作为一用户端的本地训练集。为提升脑电信号分类模型的精度,用户端每次均可利用本地训练集中的所有数据对脑电信号分类模型进行训练。需要进一步说明的是,每个用户端对应的本地训练集均不相同。
作为一个优选的示例,根据输入的原始EEG信号的时空属性,上述脑电信号分类模型采用EEGNet模型,用于提取脑电信号的特征表示及分类。其中本申请中特征提取器和分类器模型参数如表1所示。当然可以理解的是,卷积层数目、卷积核大小、池化方法以及激活函数均可根据实际情况进行设定。
Figure PCTCN2021138013-appb-000021
表1
其中,在用户端利用本地训练集对脑电信号分类模型进行训练时,可采用交叉熵(cross entropy)损失函数评估训练结果,其中第k个用户端的训练损失函数如下:
Figure PCTCN2021138013-appb-000022
其中,n k表示第k个用户端的本地训练集所包含的本地样本量,y i为训练样本(即本地训练集中的本地样本)的真实标签,
Figure PCTCN2021138013-appb-000023
为预测标签。需要说明的是,上述训练损失函数为常用损失函数,因此在此,不对该训练损失函数的原理进行过多赘述。
接下来,结合具体实施例对获取重要性评估值以及更新网络参数的过程的进行示例性的说明。
在本申请的一些实施例中,上述步骤13,根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的具体实现方式可以为:通过公式μ k=α k×β k,计算第k个用户端的重要性评估值。
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000024
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000025
Figure PCTCN2021138013-appb-000026
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000027
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
需要说明的是,当t的取值为1时,
Figure PCTCN2021138013-appb-000028
为服务器端初始化的脑电信号分类模型的模型梯度,公式中的更新指的是服务器端的脑电信号分类模型的网络参数的更新。
在本申请的一些实施例中,除了通过上述公式计算用户端的重要性评估值外,还可以通过其他的相似性度量学习方法或者注意力机制算法度量用户端的重要性。
在本申请的一些实施例中,如图2所示,上述步骤15,根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的具体实现方式包括如下步骤:
步骤21,对每个所述目标用户端的重要性评估值进行归一化处理。
在本申请的一些实施例中,可通过公式
Figure PCTCN2021138013-appb-000029
对选择出的每个用户端的重要性评估值进行归一化处理。
其中,
Figure PCTCN2021138013-appb-000030
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
步骤22,根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度。
在本申请的一些实施例中,可通过公式
Figure PCTCN2021138013-appb-000031
更新所述服务器端的全局梯度。
其中,
Figure PCTCN2021138013-appb-000032
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000033
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000034
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
步骤23,根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
在本申请的一些实施例中,可采用基于随机梯度下降(SGD,Stochastic Gradient Descent)的随机梯度下降法求解网络参数。需要说明的是,服务器端在初始化脑电信号分类模型的时候,服务器端的全局梯度也会初始化为0。
综上,本申请实施例提供的基于联邦学习的脑电信号分类模型训练方法具备如下效果:
一、脑电信号分类模型采用EEGNet模型,将其应用到情感脑电信号的分类任务中,无需手工提取信号特征,能端到端进行情感脑电信号的特征提取和分类;
二、将EEGNet模型应用到情感脑电识别网络,利用深度学习自动提取情感脑电信号的可判别性特征,提升用户端单个的脑电信号分类模型的准确率;
三、无需对脑电信号做繁杂的预处理,直接利用脑电信号对脑电信号分类模型进行训练,便可有效进行脑电信号的特征提取和分类;
四、能在满足数据安全、无需共享或者交换各个用户端本地数据的前提下,实现联合训练及其分布式训练,达到充分利用所有用户的有效信息提升脑电信号分类模型的精度的效果;
五、通过各用户端的重要性,选择对共享模型贡献大的目标用户端进行联 合训练,从而提升脑电信号分类模型的精度及收敛速度。
下面结合具体实施例对本申请提供的基于联邦学习的脑电信号分类模型训练装置进行示例性的说明。
如图3所示,本申请实施例提供了一种基于联邦学习的脑电信号分类模型训练装置,应用于服务器端,该脑电信号分类模型训练装置300包括:
发送模块301,用于将所述服务器端的脑电信号分类模型发送给K个用户端;
接收模块302,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
获取模块303,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
第一确定模块304,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
更新模块305,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
第二确定模块306,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
其中,上述第一确定模块304,具体用于按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
其中,上述更新模块305包括:
处理单元,用于对每个所述目标用户端的重要性评估值进行归一化处理;
第一更新单元,用于根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
第二更新单元,用于根据更新后的全局梯度,更新所述服务器端的脑电信 号分类模型的网络参数。
其中,上述第一更新单元,具体用于通过公式
Figure PCTCN2021138013-appb-000035
更新所述服务器端的全局梯度;
其中,
Figure PCTCN2021138013-appb-000036
表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
Figure PCTCN2021138013-appb-000037
表示第k个用户端归一化处理后的重要性评估值,
Figure PCTCN2021138013-appb-000038
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,上述处理单元,具体用于通过公式
Figure PCTCN2021138013-appb-000039
对选择出的每个用户端的重要性评估值进行归一化处理;
其中,
Figure PCTCN2021138013-appb-000040
表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
其中,上述获取模块303,具体用于通过公式μ k=α k×β k,计算第k个用户端的重要性评估值;
其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
Figure PCTCN2021138013-appb-000041
n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
Figure PCTCN2021138013-appb-000042
Figure PCTCN2021138013-appb-000043
表示第t-1轮更新时所述服务器端的全局梯度,
Figure PCTCN2021138013-appb-000044
表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
其中,上述脑电信号分类模型训练装置还包括:
下发模块,用于在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
需要说明的是,上述装置/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述***中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
如图4所示,本申请的实施例提供了一种服务器,如图4所示,该实施例的服务器D10包括:至少一个处理器D100(图4中仅示出一个处理器)、存储器D101以及存储在所述存储器D101中并可在所述至少一个处理器D100上运行的计算机程序D102,所述处理器D100执行所述计算机程序D102时实现上述任意各个方法实施例中的步骤。
所称处理器D100可以是中央处理单元(CPU,Central Processing Unit),该处理器D100还可以是其他通用处理器、数字信号处理器(DSP,Digital Signal Processor)、专用集成电路(ASIC,Application Specific Integrated Circuit)、现成可编程门阵列(FPGA,Field-Programmable Gate Array)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
所述存储器D101在一些实施例中可以是所述服务器D10的内部存储单元,例如服务器D10的硬盘或内存。所述存储器D101在另一些实施例中也可以是所述服务器D10的外部存储设备,例如所述服务器D10上配备的插接式硬盘,智能存储卡(SMC,Smart Media Card),安全数字(SD,Secure Digital)卡,闪存卡(Flash Card)等。进一步地,所述存储器D101还可以既包括所述服务 器D10的内部存储单元也包括外部存储设备。所述存储器D101用于存储操作***、应用程序、引导装载程序(BootLoader)、数据以及其他程序等,例如所述计算机程序的程序代码等。所述存储器D101还可以用于暂时地存储已经输出或者将要输出的数据。
需要说明的是,上述装置/单元之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例部分,此处不再赘述。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述***中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
本申请实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现可实现上述各个方法实施例中的步骤。
本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得终端设备执行时实现可实现上述各个方法实施例中的步骤。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机 程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质至少可以包括:能够将计算机程序代码携带到脑电信号分类模型训练装置/终端设备的任何实体或装置、记录介质、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、电载波信号、电信信号以及软件分发介质。例如U盘、移动硬盘、磁碟或者光盘等。在某些司法管辖区,根据立法和专利实践,计算机可读介质不可以是电载波信号和电信信号。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的实施例中,应该理解到,所揭露的装置/网络设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/网络设备实施例仅仅是示意性的,例如,所述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个***,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
以上所述实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。

Claims (10)

  1. 一种基于联邦学习的脑电信号分类模型训练方法,其特征在于,应用于服务器端,所述方法包括:
    将所述服务器端的脑电信号分类模型发送给K个用户端;
    接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
    根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
    根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
    根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
    若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
  2. 根据权利要求1所述的方法,其特征在于,所述根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端的步骤,包括:
    按照重要性评估值由大至小的顺序,从K个所述用户端中选择预设比例的用户端作为目标用户端。
  3. 根据权利要求2所述的方法,其特征在于,所述根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数的步骤,包括:
    对每个所述目标用户端的重要性评估值进行归一化处理;
    根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度;
    根据更新后的全局梯度,更新所述服务器端的脑电信号分类模型的网络参数。
  4. 根据权利要求3所述的方法,其特征在于,所述根据归一化处理后的重要性评估值以及所有目标用户端的本地模型梯度,更新所述服务器端的全局梯度的步骤,包括:
    通过公式
    Figure PCTCN2021138013-appb-100001
    更新所述服务器端的全局梯度;
    其中,
    Figure PCTCN2021138013-appb-100002
    表示第t轮更新得到的全局梯度,C表示预设比例,K表示用户端的数量,
    Figure PCTCN2021138013-appb-100003
    表示第k个用户端归一化处理后的重要性评估值,
    Figure PCTCN2021138013-appb-100004
    表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
  5. 根据权利要求3所述的方法,其特征在于,所述对每个所述目标用户端的重要性评估值进行归一化处理的步骤,包括:
    通过公式
    Figure PCTCN2021138013-appb-100005
    对选择出的每个用户端的重要性评估值进行归一化处理;
    其中,
    Figure PCTCN2021138013-appb-100006
    表示第k个用户端归一化处理后的重要性评估值,μ k表示第k个用户端的重要性评估值,C表示预设比例,K表示用户端的数量。
  6. 根据权利要求1所述的方法,其特征在于,所述根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值的步骤,包括:
    通过公式μ k=α k×β k,计算第k个用户端的重要性评估值;
    其中,μ k表示第k个用户端的重要性评估值,α k=n k/n,n k表示第k个用户端的本地训练集所包含的本地样本量,
    Figure PCTCN2021138013-appb-100007
    n表示K个用户端的本地训练集所包含的本地样本量的总和,K表示用户端的数量,
    Figure PCTCN2021138013-appb-100008
    Figure PCTCN2021138013-appb-100009
    表示第t-1轮更新时所述服务器端的全局梯度,
    Figure PCTCN2021138013-appb-100010
    表示第t轮更新时第k个用户端的本地模型梯度,t为大于0的整数。
  7. 根据权利要求1所述的方法,其特征在于,所述方法还包括:
    在所述服务器端的脑电信号分类模型收敛时,将所述服务器端的脑电信号分类模型下发给所述K个用户端。
  8. 一种基于联邦学习的脑电信号分类模型训练装置,其特征在于,应用于服务器端,所述装置包括:
    发送模块,用于将所述服务器端的脑电信号分类模型发送给K个用户端;
    接收模块,用于接收每个所述用户端发送的本地模型梯度;所述本地模型梯度是所述用户端利用本地训练集对所述脑电信号分类模型进行训练得到的;
    获取模块,用于根据每个所述用户端的本地模型梯度,获取每个所述用户端的重要性评估值;
    第一确定模块,用于根据K个所述用户端的重要性评估值,从K个所述用户端中确定多个目标用户端;
    更新模块,用于根据所述多个目标用户端的本地模型梯度和重要性评估值,更新所述服务器端的脑电信号分类模型的网络参数;
    第二确定模块,用于若所述服务器端的脑电信号分类模型未收敛,则返回执行所述将所述服务器端的脑电信号分类模型发送给K个用户端的步骤,直至所述服务器端的脑电信号分类模型收敛。
  9. 一种服务器,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述的方法。
  10. 一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的方法。
PCT/CN2021/138013 2021-11-15 2021-12-14 基于联邦学习的脑电信号分类模型训练方法及装置 WO2023082406A1 (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202111347340.8A CN114048780A (zh) 2021-11-15 2021-11-15 基于联邦学习的脑电信号分类模型训练方法及装置
CN202111347340.8 2021-11-15

Publications (1)

Publication Number Publication Date
WO2023082406A1 true WO2023082406A1 (zh) 2023-05-19

Family

ID=80208990

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/138013 WO2023082406A1 (zh) 2021-11-15 2021-12-14 基于联邦学习的脑电信号分类模型训练方法及装置

Country Status (2)

Country Link
CN (1) CN114048780A (zh)
WO (1) WO2023082406A1 (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114664434A (zh) * 2022-03-28 2022-06-24 上海韶脑传感技术有限公司 面向不同医疗机构的脑卒中康复训练***及其训练方法
CN117708681B (zh) * 2024-02-06 2024-04-26 南京邮电大学 基于结构图指导的个性化联邦脑电信号分类方法及***

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150324690A1 (en) * 2014-05-08 2015-11-12 Microsoft Corporation Deep Learning Training System
CN111814985A (zh) * 2020-06-30 2020-10-23 平安科技(深圳)有限公司 联邦学习网络下的模型训练方法及其相关设备
CN112181666A (zh) * 2020-10-26 2021-01-05 华侨大学 一种基于边缘智能的设备评估和联邦学习重要性聚合方法、***、设备和可读存储介质
CN112633146A (zh) * 2020-12-21 2021-04-09 杭州趣链科技有限公司 多姿态人脸性别检测训练优化方法、装置及相关设备
CN113158241A (zh) * 2021-04-06 2021-07-23 深圳市洞见智慧科技有限公司 基于联邦学习的岗位推荐方法及装置

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150324690A1 (en) * 2014-05-08 2015-11-12 Microsoft Corporation Deep Learning Training System
CN111814985A (zh) * 2020-06-30 2020-10-23 平安科技(深圳)有限公司 联邦学习网络下的模型训练方法及其相关设备
CN112181666A (zh) * 2020-10-26 2021-01-05 华侨大学 一种基于边缘智能的设备评估和联邦学习重要性聚合方法、***、设备和可读存储介质
CN112633146A (zh) * 2020-12-21 2021-04-09 杭州趣链科技有限公司 多姿态人脸性别检测训练优化方法、装置及相关设备
CN113158241A (zh) * 2021-04-06 2021-07-23 深圳市洞见智慧科技有限公司 基于联邦学习的岗位推荐方法及装置

Also Published As

Publication number Publication date
CN114048780A (zh) 2022-02-15

Similar Documents

Publication Publication Date Title
WO2023082406A1 (zh) 基于联邦学习的脑电信号分类模型训练方法及装置
CN108595585B (zh) 样本数据分类方法、模型训练方法、电子设备及存储介质
Duggal et al. Prediction of thyroid disorders using advanced machine learning techniques
CN103399896B (zh) 识别用户间关联关系的方法及***
WO2019192118A1 (zh) 基于边缘计算的健康监测方法、装置、设备及存储介质
Alqahtani et al. Breast cancer pathological image classification based on the multiscale CNN squeeze model
WO2020192112A1 (zh) 人脸识别方法及装置
CN105446741B (zh) 一种基于api比对的移动应用程序辨识方法
US10885361B2 (en) Biometric method and device for identifying a person through an electrocardiogram (ECG) waveform
CN109817339A (zh) 基于大数据的患者分组方法和装置
CN110348326A (zh) 基于身份证识别和多设备访问的家庭健康信息处理方法
CN110489659A (zh) 数据匹配方法和装置
CN114973330A (zh) 一种跨场景鲁棒的人员疲劳状态无线检测方法及相关设备
WO2021114818A1 (zh) 基于傅里叶变换的oct图像质量评估方法、***及装置
CN111447081B (zh) 数据链生成方法、装置、服务器及存储介质
Liong et al. Automatic traditional Chinese painting classification: A benchmarking analysis
CN104679967A (zh) 一种判断心理测试可靠性的方法
Katti et al. Are you from North or South India? A hard face-classification task reveals systematic representational differences between humans and machines
Kong et al. Task-free brainprint recognition based on degree of brain networks
CN113014881A (zh) 一种神经外科患者日常监护方法及***
Araújo et al. Generic biometry algorithm based on signal morphology information: Application in the electrocardiogram signal
CN107832690B (zh) 人脸识别的方法及相关产品
CN110222622B (zh) 一种环境土壤检测方法及装置
CN106126758A (zh) 用于信息处理和信息评估的云***
Talukdar et al. Malaria detection in segmented blood cell using convolutional neural networks and canny edge detection