CN115471716A - Chest radiographic image disease classification model lightweight method based on knowledge distillation - Google Patents
Chest radiographic image disease classification model lightweight method based on knowledge distillation Download PDFInfo
- Publication number
- CN115471716A CN115471716A CN202211056063.XA CN202211056063A CN115471716A CN 115471716 A CN115471716 A CN 115471716A CN 202211056063 A CN202211056063 A CN 202211056063A CN 115471716 A CN115471716 A CN 115471716A
- Authority
- CN
- China
- Prior art keywords
- network
- label
- model
- student
- classification
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 201000010099 disease Diseases 0.000 title claims abstract description 46
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 title claims abstract description 46
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 28
- 238000013145 classification model Methods 0.000 title claims abstract description 24
- 238000000034 method Methods 0.000 title claims abstract description 19
- 230000006870 function Effects 0.000 claims abstract description 29
- 238000013528 artificial neural network Methods 0.000 claims abstract description 13
- 238000013527 convolutional neural network Methods 0.000 claims abstract description 11
- 238000012549 training Methods 0.000 claims abstract description 7
- 239000011159 matrix material Substances 0.000 claims description 26
- 238000007418 data mining Methods 0.000 claims description 9
- 238000011176 pooling Methods 0.000 claims description 8
- 238000004364 calculation method Methods 0.000 claims description 7
- 230000004927 fusion Effects 0.000 claims description 7
- 230000004913 activation Effects 0.000 claims description 6
- 230000007704 transition Effects 0.000 claims description 6
- 230000015556 catabolic process Effects 0.000 claims description 3
- 238000006731 degradation reaction Methods 0.000 claims description 3
- 229940050561 matrix product Drugs 0.000 claims description 3
- 238000010606 normalization Methods 0.000 claims description 3
- 238000007500 overflow downdraw method Methods 0.000 claims description 3
- 238000005070 sampling Methods 0.000 claims description 3
- 238000012546 transfer Methods 0.000 claims description 3
- 230000009466 transformation Effects 0.000 claims description 3
- 239000013585 weight reducing agent Substances 0.000 claims 6
- 238000010276 construction Methods 0.000 claims 1
- 230000006835 compression Effects 0.000 abstract 1
- 238000007906 compression Methods 0.000 abstract 1
- 210000000038 chest Anatomy 0.000 description 15
- 201000003144 pneumothorax Diseases 0.000 description 8
- 206010014561 Emphysema Diseases 0.000 description 6
- 230000000694 effects Effects 0.000 description 6
- 238000002474 experimental method Methods 0.000 description 5
- 238000013138 pruning Methods 0.000 description 4
- 238000012360 testing method Methods 0.000 description 4
- 238000002679 ablation Methods 0.000 description 3
- 238000004195 computer-aided diagnosis Methods 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 238000013139 quantization Methods 0.000 description 3
- 238000001514 detection method Methods 0.000 description 2
- 230000036541 health Effects 0.000 description 2
- 241000282414 Homo sapiens Species 0.000 description 1
- 230000003042 antagnostic effect Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 210000000481 breast Anatomy 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000011976 chest X-ray Methods 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000008034 disappearance Effects 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000004880 explosion Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000009191 jumping Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000013468 resource allocation Methods 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F17/00—Digital computing or data processing equipment or methods, specially adapted for specific functions
- G06F17/10—Complex mathematical operations
- G06F17/16—Matrix or vector computation, e.g. matrix-matrix or matrix-vector multiplication, matrix factorization
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/0002—Inspection of images, e.g. flaw detection
- G06T7/0012—Biomedical image inspection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/80—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
- G06V10/806—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/30—Subject of image; Context of image processing
- G06T2207/30004—Biomedical image processing
- G06T2207/30061—Lung
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Computation (AREA)
- Databases & Information Systems (AREA)
- General Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Multimedia (AREA)
- Mathematical Physics (AREA)
- Mathematical Optimization (AREA)
- Mathematical Analysis (AREA)
- Computational Mathematics (AREA)
- Pure & Applied Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Algebra (AREA)
- General Engineering & Computer Science (AREA)
- Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
- Radiology & Medical Imaging (AREA)
- Quality & Reliability (AREA)
- Image Analysis (AREA)
Abstract
The invention relates to a chest radiographic image disease classification model lightweight method based on knowledge distillation, which comprises the following steps: the chest radiographic image disease classification model is used as a teacher network, the teacher network is composed of a graph convolution neural network and a convolution neural network module, GCN and CNN are used for respectively obtaining disease label characteristics and chest medical image characteristics, and finally, the two modal data characteristics are fused to predict multi-label classification results; a ResNet18 network with high operation speed and low memory occupancy rate is used as a student network; and performing combined training on the teacher network and the student networks, performing regression and classification on the teacher network and the student networks by using the loss function, and guiding the loss index of the student networks by using the loss index of the teacher network to obtain the multi-label classification model student networks. The invention realizes the compression of the model on the premise of reducing the precision of the overall model as little as possible, greatly improves the running efficiency of the model and reduces the utilization rate of the memory.
Description
Technical Field
The invention relates to a chest radiographic image disease classification model lightweight method based on knowledge distillation, and belongs to the technical field of computer vision.
Background
Chest diseases are a great problem threatening the health of human beings, hundreds of millions of people suffer from the chest diseases every year around the world, and if the chest diseases are not treated in time, the chest diseases can bring great influence to patients and even endanger the lives of the patients. Computer Aided Diagnosis (CAD) is often used to assist radiologists in making efficient diagnoses, and has many advantages over manual radiographs, such as freedom from subjective influences, automatic extraction and screening of visual features, rapid improvement of model accuracy with the increase of learning data volume, and so on. The existing computer-aided diagnosis model automatically completes the classification, segmentation and detection tasks of medical images and obtains good results, and the problems faced by radiographic image processing are greatly solved, but the cost of high performance is continuous expansion of network scale, including continuous expansion of calculated amount and continuous expansion of memory occupied by parameter amount. Therefore, in the case of making the performance index of the model higher, how to effectively reduce the number of parameters and improve the operation efficiency become very important.
To improve model efficiency, many solutions have been proposed over the past few years based on deep learning, and these methods generally fall into four categories: (1) In the super-resolution network design, for example, a residual network model designed by He and the like is named as ResNet, and the result that normal training cannot be performed due to gradient disappearance or gradient explosion is made up in a jumping connection mode. DenseNet continues to add the image features of all previous layers as input into the next layers, maximizing the use of the feature map attributes. ResNeXt introduces new dimension, and the phase change reduces the calculation amount. (2) Network pruning, such as Han et al, compresses the model by more than ten times with uniform norm arrangement on the basis of hardly reducing the performance of the model. Li et al delete convolution kernels with smaller absolute values, and count the effect on model performance after deleting each layer of convolution kernels, reduce the pruning proportion of convolution kernels with larger effect, increase the pruning proportion of convolution kernels with smaller effect, and the result shows that the pruning of a conventional pair of models is not better than that of each layer of convolution considered independently. (3) Data quantization, such as the quantization technique in NVIDIA, uses the minimum divergence distance value of floating point numbers and integers as the threshold for quantization. The DFQ model can be quantized to 6-bit by weight equalization and correction of bias values while maintaining performance close to that of a floating point network. (4) Knowledge distillation, such as Zagoruyko et al, adds an attention mechanism to the knowledge distillation model that uses attention map sharing to enable a student network to focus on features that are of interest to a teacher network. Reducing the difference between the attention map of the student network and the attention map of the teacher network to allow the student network to learn the resource allocation focus of the teacher network. Xu et al used antagonistic training to learn loss and then passed the knowledge of the teacher's network to the student's network.
When the deep learning algorithm is applied to medical image detection and classification tasks, the problems of large parameter quantity, large memory occupation and the like often occur, and the deep learning algorithm is difficult to operate on a mobile terminal or embedded equipment. The method aims at the problems that the parameter quantity of the existing chest radiological image classification model is too large, the operation efficiency is low, a small network is directly applied to a medical image to train the result unstably, the effect is poor and the like. The invention provides a chest radiographic image disease classification model based on knowledge distillation, which uses knowledge distillation to take a large model as a teacher network and a small model as a student network, and migrates knowledge to the small model, thereby improving the precision and stability of the small model, reducing the parameter number of the disease classification model and improving the operation efficiency.
Disclosure of Invention
The invention relates to a chest radiographic image disease classification model lightweight method based on knowledge distillation, which is used for solving the problems that the parameter quantity of the existing chest radiographic image classification model is overlarge, the operation efficiency is low, a small network is directly applied to a medical image to carry out training, the effect is poor and the like; the invention can improve the efficiency, precision and stability of the model.
The technical scheme of the invention is as follows: a chest radiographic image disease classification model lightweight method based on knowledge distillation specifically comprises the following steps:
step1, a graph convolution neural network (GCN) module converts a medical image label into a GloVe word embedded representation by adopting a pre-training language model, then a label relation graph matrix is constructed in a data mining mode and input into the GCN module, and disease label characteristics are extracted through two layers of graph convolution operation;
step2, inputting the medical image from a Convolution Neural Network (CNN) module, and extracting the characteristics of the chest medical image after convolution operation and maximum pooling operation;
step3, fusing the medical image characteristics and the disease label characteristics to predict a multi-label classification result;
step4, selecting a ResNet18 network model as a student network, wherein the ResNet18 solves the performance degradation problem of the deep network through a residual error unit;
and Step5, selecting a loss function to carry out regression and classification, and guiding the loss index of the student network by using the loss index of the teacher network to obtain the multi-label classification model student network.
Further, the graph convolution neural network (GCN) module in Step1 specifically includes the following:
representing each disease label as a single node, setting the GCN to 2 layers, inputting a feature representation matrix H l ∈R n ×d (where n is the class of the tag and d is the word embedding dimension of the tag) and a tag correlation matrix A ∈ R n×n The object is in the figureLearning a function f (·,) to make the new node update represented as H l+1 ∈R n×n Each layer, such as the GCN, is represented by a nonlinear activation function as:H l+1 =f(H l a). The graph convolution operation of the present invention uses f (·, ·) in the form: h l+1 =h(AH l W l ) Where h (-) refers to non-linear operation, the invention uses the LeakyReLU activation function, W l ∈R d×d′ Refers to the transformation matrix to be learned.
Further, the constructing of the label relationship graph matrix in Step1 specifically includes the following steps:
a label relation graph matrix is constructed in a data mining mode and input into a GCN module, the total number of all disease types is counted, then the number of other diseases under the condition that each disease occurs is found out through data mining, namely the relation matrix is constructed in the form of conditional probability, P (La | Lb) is defined to represent the probability of Lb label occurrence under the condition that La label occurs, for example, la can be represented as Pneumothorax (Pneumothorax), lb is Emphysema (Emphyema), the probability of Emphysema occurrence under the condition that Pneumothorax occurs is assumed to be 0.3, and the probability of Pneumothorax occurrence under the condition that Emphysema occurs is 0.1. The medical image data set disease label categories used by the invention are 14, so the finally constructed label correlation matrix is a two-dimensional matrix of 14 multiplied by 14.
Further, the Convolutional Neural Network (CNN) module in Step2 specifically includes the following:
the Convolutional Neural Network (CNN) module model selects a DenseNet network model, four modules are arranged in the DenseNet network model, the naming modes of the modules are DenseBlock1 to DenseBlock4, and the difference between the modules is that the convolution operation and the number of the modules are different. Each DenseBlock block contains 1 × 1 and 3 × 3 convolution kernels and a batch normalization layer, transition layers for down-sampling operation are arranged between dense network blocks, denseNet-121 contains 3 transition layers in total, and in order to smoothly perform feature fusion and better obtain texture features, the last full connection layer of a DenseNet-121 network is removed and replaced by a maximum pooling layer.
Further, the fusion method in Step3 specifically comprises the following steps:
the invention adopts a matrix product mode to carry out characteristic fusion, as shown in a calculation formula:in the formulaRepresenting the overall characteristics, x is the medical image characteristics and y is the disease label characteristics. Then putting the overall characteristics into a multi-label classification loss function to solve loss, as shown in a calculation formula: where δ (·) is a sigmoid function, and C represents the number of iterations.
Further, the student network in Step4 specifically includes the following steps:
residual learning occurs once in every two layers of the ResNet18 network model, the network model is divided into five parts, namely Convolume 1, conv2_ x, conv3_ x, conv4_ x and Conv5_ x, and finally a pooling layer is connected.
Further, the loss function in Step5 further includes the following:
in order to enable the student network to learn soft target, the temperature parameter T in knowledge distillation is used to regulate knowledge transfer, defining the softmax function as:wherein p is i Probability, x, of the ith output of the teacher network i 、x j The input of Softmax is represented, T is a temperature coefficient, when the temperature is increased, the output distribution of Softmax is more and more gentle, the information entropy is more and more large, and the student network can pay more attention to the negative label; in order for the student network to better fit the classification results of the teacher network, an overall loss function is defined as: loss = (1-a) H (label, y) + α H (p, y) T 2 Wherein alpha represents a weight coefficient, H represents cross entropy, label is a real label result, y is a student network label result, and p represents the teacher network total probability.
The invention has the beneficial effects that: using the student network to make predictions reduces memory usage by 35 percent and increases operating speed by 34 percent over using the teacher network. In a further ablation experiment, the average AUC is 0.756 under the condition of not using knowledge distillation and is 0.817 after using teacher network guidance, 6 percentage points are improved, and the knowledge distillation is proved to be useful.
Drawings
FIG. 1 is a general model and key component structure (Teacher model, student model);
Detailed Description
Embodiment 1, as shown in fig. 1, a method for lightweight classification model of breast radiographic image diseases based on knowledge distillation includes the following steps:
step1, a graph convolution neural network (GCN) module converts medical image labels into GloVe word embedded representation by adopting a pre-training language model, then a label relation graph matrix is constructed in a data mining mode and input into the GCN module, and disease label features are extracted through two layers of graph convolution operation;
step2, inputting the medical image from a Convolutional Neural Network (CNN) module, and extracting the characteristics of the chest medical image after convolution operation and maximum pooling operation;
step3, fusing the medical image features and the disease label features to predict a multi-label classification result;
step4, selecting a ResNet18 network model as a student network, and solving the performance degradation problem of the deep network by the ResNet18 through a residual error unit;
and Step5, selecting a loss function to carry out regression and classification, and guiding the loss index of the student network by using the loss index of the teacher network to obtain the multi-label classification model student network.
Further, the graph convolution neural network (GCN) module in Step1 specifically includes the following:
labeling each diseaseRepresented as a single node, GCN set to 2 levels, input is a feature representation matrix H l ∈R n ×d (where n is the class of the tag and d is the word embedding dimension of the tag) and a tag correlation matrix A ∈ R n×n The object is in the figureLearning a function f (·,) to make the new node update represented as H l+1 ∈R n×n For example, each layer of the GCN is represented by a nonlinear activation function: h l+1 =f(H l And A). The graph convolution operation of the present invention uses f (·,) in the form: h l+1 =h(AH l W l ) Where h (-) refers to non-linear operation, the invention uses the LeakyReLU activation function, W l ∈R d×d′ Refers to the transformation matrix to be learned.
Further, the constructing of the label relationship graph matrix in Step1 specifically includes the following steps:
a label relation graph matrix is constructed in a data mining mode and input into a GCN module, the total number of all disease types is counted firstly, then the number of other diseases under the condition that each disease occurs is found out through data mining, namely the relation matrix is constructed in the form of conditional probability, P (La | Lb) is defined to represent the probability of Lb label occurrence under the condition that La labels occur, for example, la can be represented as Pneumothorax (Pneumothorax), lb is Emphysema (Emphyema), the probability of Emphysema occurrence under the condition that Pneumothorax occurs is assumed to be 0.3, and the probability of Pneumothorax occurrence under the condition that Emphysema occurs is assumed to be 0.1. The medical image data set disease label categories used by the invention are 14, so the finally constructed label correlation matrix is a two-dimensional matrix of 14 multiplied by 14.
Further, the Convolutional Neural Network (CNN) module in Step2 specifically includes the following:
the Convolutional Neural Network (CNN) module model selects a DenseNet network model, four modules are arranged in the DenseNet network model, the naming modes of the modules are DenseBlock1 to DenseBlock4, and the difference between the modules is that the convolution operation and the number of the modules are different. Each DenseBlock block comprises 1 × 1 and 3 × 3 convolution kernels and batch normalization layers, transition layers for down-sampling operation are arranged among dense network blocks, denseNet-121 comprises 3 transition layers in total, and in order to smoothly perform feature fusion and better obtain texture features, the last full connection layer of the DenseNet-121 network is removed and replaced by the largest pooling layer.
Further, the fusion method in Step3 specifically comprises the following steps:
the invention adopts a matrix product mode to carry out characteristic fusion, as shown in a calculation formula:in the formulaRepresenting the overall characteristics, x is the medical image characteristics and y is the disease label characteristics. Then putting the overall characteristics into a multi-label classification loss function to solve the loss, as shown in a calculation formula: where δ (·) is a sigmoid function, and C represents the number of iterations.
Further, the student network in Step4 specifically includes the following steps:
residual learning occurs once in every two layers of the ResNet18 network model, the network model is divided into five parts, namely Convolition 1, conv2_ x, conv3_ x, conv4_ x and Conv5_ x, and finally a pooling layer is connected.
Further, the loss function in Step5 further includes the following:
in order to enable the student network to learn soft target, the temperature parameter T in knowledge distillation is used to regulate knowledge transfer, defining the softmax function as:wherein p is i Representing teacher networksProbability of ith output, x i 、x j The input of Softmax is represented, T is a temperature coefficient, when the temperature is increased, the output distribution of Softmax is more and more gentle, the information entropy is more and more large, and the student network can pay more attention to the negative label; in order for the student network to better fit the classification results of the teacher network, an overall loss function is defined as: loss = (1-a) H (label, y) + α H (p, y) T 2 Wherein alpha represents a weight coefficient, H represents cross entropy, label is a real label result, y is a student network label result, and p represents the teacher network total probability.
Further, in order to verify the effect of the present invention, the model trained through the above steps is inputted into a large multi-label chest X-ray data set ChestX-ray14, which is organized and published by National Institutes of Health (NIH) of a certain country, and the model is tested. The environment configuration of the test is GPU NVIDIA RTX2080; 11GB for the memory; operating system Ubuntu 18.04.3; a machine learning framework: pyTorch.
The chest radiological image disease classification model based on the atlas neural network is used as a Teacher network (Teacher model), the Teacher network is composed of an atlas neural network (GCN) and a Convolution Neural Network (CNN) module, a ResNet18 network with high operation speed and low memory occupancy rate is used as a Student network (Student model), the Teacher network and the Student network are jointly trained, the Teacher network and the Student network are regressed and classified by using a loss function, and loss indexes of the Teacher network are used for guiding loss indexes of the Student network. Prediction results in order to objectively evaluate the experiment, area Understhetrocurves (AUC) was used as an evaluation index.
The Chest DR radiological image is classified by a knowledge distillation method, ML Chest-GCN is used as a teacher network, accuracy of the generated Resnet18 student network is tested, comparison with the teacher network is made, and experimental results of the teacher network and the student network are obtained through independent testing. The experimental results are compared as shown in table 1:
TABLE 1 teacher's network and student's network comparison experiment results
The results in Table 1 show that the disease classification student network of Resnet18 can improve the average AUC value to the achievement of 0.817 under the guidance of the teacher network ML-GCN. Although it drops by three percentage points compared to the teacher's network, it still exceeds Wang et al, 0.738, and Yao et al, 0.803. The experimental results prove that the knowledge distillation method indeed guides students to learn in a network.
In order to evaluate whether the trained student network can improve the efficiency, the running speed of the student network and the memory occupation condition are tested, the running speed is compared with the teacher network ML-GCN by taking the number of the students per second as a unit, the same test environment is set, and the test result is shown in Table 2:
TABLE 2 model efficiency comparison Table
The results from table 2 show that, in the same experimental environment, using the student network to make predictions reduces memory usage by 35 percent and increases operating speed by 34 percent over using the teacher network.
Further, the invention also carries out ablation experiments, trains the Resnet18 network model independently, completely removes the knowledge distillation link, does not add teacher network guidance, and the experimental results are shown in Resnet18 in Table 3:
table 3 shows the results of the ablation experiment
As shown in table 3, the mean AUC was 0.756 without knowledge distillation and 0.817 after teacher web guidance, which is 6 percentage points higher, proving that knowledge distillation is indeed useful.
While the present invention has been described in detail with reference to the embodiments shown in the drawings, the present invention is not limited to the embodiments, and various changes can be made without departing from the spirit and scope of the present invention.
Claims (7)
1. A chest radiographic image disease classification model lightweight method based on knowledge distillation is characterized by comprising the following steps:
step1, converting medical image labels into GloVe word embedded expression by a graph convolution neural network (GCN) module through a pre-training language model, constructing a label relation graph matrix in a data mining mode, inputting the label relation graph matrix into the GCN module, and extracting disease label characteristics through two layers of graph convolution operation;
step2, inputting the medical image from a convolution neural network CNN module, and extracting the characteristics of the chest medical image after convolution operation and maximum pooling operation;
step3, fusing the medical image characteristics and the disease label characteristics to predict a multi-label classification result;
step4, selecting a ResNet18 network model as a student network, wherein the ResNet18 solves the performance degradation problem of the deep network through a residual error unit;
and Step5, selecting a loss function to carry out regression and classification, and guiding the loss index of the student network by using the loss index of the teacher network to obtain the multi-label classification model student network.
2. The method for weight reduction of the classification model of the chest radiographic image diseases based on knowledge distillation as claimed in claim 1, wherein: the GCN module of the graph convolution neural network in Step1 specifically comprises the following steps:
representing each disease label as a single node, setting the GCN to 2 layers, inputting a feature representation matrix H l ∈R n×d And label correlation matrix A epsilon R n×n Where n is the category of the label, d is the word embedding dimension of the label, and the target is in the graphLearning a function f (·,) to make the new node update represented as H l+1 ∈R n×n For example, each layer of the GCN is represented by a nonlinear activation function: h l+1 =f(H l A), the graph convolution operation uses the representation of f (,) as: h l+1 =h(AH l W l ) Where h (-) means that the nonlinear operation uses the LeakyReLU activation function, W l ∈R d×d Refers to the transformation matrix to be learned.
3. The method for weight reduction of the classification model of the chest radiographic image diseases based on knowledge distillation as claimed in claim 1, wherein: the construction of the label relationship graph matrix in Step1 specifically comprises the following steps:
a label relation graph matrix is constructed in a data mining mode and is input into a GCN module, the total number of all disease types is counted, then the number of other diseases under the condition that each disease occurs is found out through data mining, namely the relation matrix is constructed in a conditional probability mode, and P (La | Lb) is defined to represent the probability of Lb labels under the condition that La labels occur.
4. The method for weight reduction of the classification model of the chest radiographic image diseases based on knowledge distillation as claimed in claim 1, wherein: the convolutional neural network CNN module in Step2 specifically includes the following: the convolutional neural network CNN module selects a DenseNet network model, the inside of the DenseNet network model has four modules, the naming modes of the modules are DenseBlock1 to DenseBlock4, the difference between the modules lies in the convolution operation and the number difference between each block, each DenseBlock contains 1 × 1 and 3 × 3 convolution kernels and a batch normalization layer, transition layers for down-sampling operation are arranged between dense network blocks, denseNet-121 contains 3 transition layers in total, in order to smoothly perform feature fusion and better obtain texture characteristics, the last full connection layer of the DenseNet-121 network is removed, and the maximum serialization layer is replaced.
5. The method for weight reduction of the classification model of the chest radiographic image diseases based on knowledge distillation as claimed in claim 1, wherein: the fusion method in Step3 specifically comprises the following steps:
and (3) performing feature fusion by adopting a matrix product mode, wherein the feature fusion is performed as shown in a calculation formula:in the formulaRepresenting the overall characteristics, wherein x is the medical image characteristics, and y is the disease label characteristics; then putting the overall characteristics into a multi-label classification loss function to solve the loss, as shown in a calculation formula: where δ (·) is a sigmoid function, and C represents the number of iterations.
6. The method for weight reduction of the classification model of the chest radiographic image diseases based on knowledge distillation as claimed in claim 1, wherein: the student network in Step4 specifically comprises the following steps:
residual learning occurs once in every two layers of the ResNet18 network model, the network model is divided into five parts, namely Convolition 1, conv2_ x, conv3_ x, conv4_ x and Conv5_ x, and finally a pooling layer is connected.
7. The method for weight reduction of the classification model of the chest radiographic image diseases based on knowledge distillation as claimed in claim 1, wherein: the loss function in Step5 specifically includes the following steps:
in order to enable the student network to learn soft target, the temperature parameter T in knowledge distillation is used to regulate knowledge transfer, defining the softmax function as:wherein p is i Probability, x, of the ith output of the teacher network i 、x j The input of Softmax is represented, T is a temperature coefficient, when the temperature is increased, the output distribution of Softmax is more and more gentle, the information entropy is more and more large, and the student network can pay more attention to the negative label; in order for the student network to better fit the classification results of the teacher network, the overall loss function is defined as: loss = (1-a) H (label, y) + α H (p, y) T 2 Wherein alpha represents a weight coefficient, H represents cross entropy, label is a real label result, y is a student network label result, and p represents the teacher network total probability.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211056063.XA CN115471716A (en) | 2022-08-31 | 2022-08-31 | Chest radiographic image disease classification model lightweight method based on knowledge distillation |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211056063.XA CN115471716A (en) | 2022-08-31 | 2022-08-31 | Chest radiographic image disease classification model lightweight method based on knowledge distillation |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115471716A true CN115471716A (en) | 2022-12-13 |
Family
ID=84371065
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211056063.XA Pending CN115471716A (en) | 2022-08-31 | 2022-08-31 | Chest radiographic image disease classification model lightweight method based on knowledge distillation |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115471716A (en) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116311102A (en) * | 2023-03-30 | 2023-06-23 | 哈尔滨市科佳通用机电股份有限公司 | Railway wagon fault detection method and system based on improved knowledge distillation |
CN117253611A (en) * | 2023-09-25 | 2023-12-19 | 四川大学 | Intelligent early cancer screening method and system based on multi-modal knowledge distillation |
-
2022
- 2022-08-31 CN CN202211056063.XA patent/CN115471716A/en active Pending
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116311102A (en) * | 2023-03-30 | 2023-06-23 | 哈尔滨市科佳通用机电股份有限公司 | Railway wagon fault detection method and system based on improved knowledge distillation |
CN116311102B (en) * | 2023-03-30 | 2023-12-15 | 哈尔滨市科佳通用机电股份有限公司 | Railway wagon fault detection method and system based on improved knowledge distillation |
CN117253611A (en) * | 2023-09-25 | 2023-12-19 | 四川大学 | Intelligent early cancer screening method and system based on multi-modal knowledge distillation |
CN117253611B (en) * | 2023-09-25 | 2024-04-30 | 四川大学 | Intelligent early cancer screening method and system based on multi-modal knowledge distillation |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Ueda et al. | Technical and clinical overview of deep learning in radiology | |
CN115471716A (en) | Chest radiographic image disease classification model lightweight method based on knowledge distillation | |
WO2016192612A1 (en) | Method for analysing medical treatment data based on deep learning, and intelligent analyser thereof | |
CN109670576B (en) | Multi-scale visual attention image description method | |
CN113314205B (en) | Efficient medical image labeling and learning system | |
CN110660478A (en) | Cancer image prediction and discrimination method and system based on transfer learning | |
Gao et al. | Bone age assessment based on deep convolution neural network incorporated with segmentation | |
CN109447096B (en) | Glance path prediction method and device based on machine learning | |
CN113662664B (en) | Instrument tracking-based objective and automatic evaluation method for surgical operation quality | |
CN111798439A (en) | Medical image quality interpretation method and system for online and offline fusion and storage medium | |
CN116779091B (en) | Automatic generation method of multi-mode network interconnection and fusion chest image diagnosis report | |
JP2022530868A (en) | Target object attribute prediction method based on machine learning, related equipment and computer programs | |
CN113821668A (en) | Data classification identification method, device, equipment and readable storage medium | |
Chen et al. | Binarized neural architecture search for efficient object recognition | |
CN115271033B (en) | Medical image processing model construction and processing method based on federal knowledge distillation | |
CN114898121A (en) | Concrete dam defect image description automatic generation method based on graph attention network | |
CN115293128A (en) | Model training method and system based on multi-modal contrast learning radiology report generation | |
CN113538334A (en) | Capsule endoscope image lesion recognition device and training method | |
CN115512110A (en) | Medical image tumor segmentation method related to cross-modal attention mechanism | |
CN112116685A (en) | Multi-attention fusion network image subtitle generating method based on multi-granularity reward mechanism | |
CN114300081A (en) | Prediction device, system and storage medium based on electronic medical record multi-modal data | |
CN110335160A (en) | A kind of medical treatment migratory behaviour prediction technique and system for improving Bi-GRU based on grouping and attention | |
CN113012133A (en) | New coronary pneumonia diagnosis system, medium and equipment | |
CN117523356A (en) | Multi-view semi-supervised characterization learning method for medical image group science | |
CN114764865A (en) | Data classification model training method, data classification method and device |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |