CN115565021A - Neural network knowledge distillation method based on learnable feature transformation - Google Patents
Neural network knowledge distillation method based on learnable feature transformation Download PDFInfo
- Publication number
- CN115565021A CN115565021A CN202211196707.5A CN202211196707A CN115565021A CN 115565021 A CN115565021 A CN 115565021A CN 202211196707 A CN202211196707 A CN 202211196707A CN 115565021 A CN115565021 A CN 115565021A
- Authority
- CN
- China
- Prior art keywords
- knowledge distillation
- loss
- model
- feature map
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/7715—Feature extraction, e.g. by transforming the feature space, e.g. multi-dimensional scaling [MDS]; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Image Processing (AREA)
Abstract
The invention provides a neural network knowledge distillation method based on learnable feature transformation, and belongs to the technical field of computer vision. The invention aligns the intermediate characteristics and the output results of the student model and the teacher model, does not need to design complex characteristic transformation modules aiming at different tasks, does not introduce complex hyper-parameters, omits fussy parameter adjustment steps, can improve the universality of knowledge distillation on a plurality of tasks, improves the knowledge distillation effect, simultaneously avoids the fussy manual design of a structure, and realizes the performance improvement on a plurality of computer vision tasks (such as picture classification, target detection, semantic segmentation, and the like).
Description
Technical Field
The invention belongs to the technical field of computer vision, and relates to deep learning technologies such as computer vision, neural network model compression, neural network knowledge distillation based on intermediate features and the like.
Background
In recent years, with the continuous development of deep learning technology, deep convolutional neural networks are widely applied to computer vision tasks such as image classification, target detection, semantic segmentation and the like, and achieve better and better performance on the tasks. After better performance is achieved, the complexity of the deep convolutional neural network model is higher and higher, and the demands on computing resources and storage resources are increased gradually, so that the deep convolutional neural network model is difficult to deploy on resource-limited devices such as mobile devices and embedded platforms. To solve this problem, a neural network model compression technique needs to be used.
Knowledge distillation is an important method in the current neural network model compression technology, and the method takes a large-scale neural network as a teacher network and a small-scale neural network as a student network, and transmits the knowledge of the teacher network to the student network, so that a neural network with low complexity, good performance and easy deployment is obtained, and the purpose of model compression is achieved.
At present, the mainstream knowledge distillation method is divided into output response-based knowledge distillation and intermediate characteristic-based knowledge distillation, and the output response-based knowledge distillation method uses a prediction result of a teacher model tail layer as supervision information to guide a student model to simulate the behavior of the teacher model. The knowledge distillation method based on the intermediate features takes the features of the intermediate hidden layer of the teacher model as supervision signals to guide the training of the student model. In practical application, various knowledge distillation methods are derived aiming at different visual tasks, and the methods usually have a plurality of manually designed parts, such as loss functions and feature masks, and the manually designed parts reduce the universality of the distillation method and bring additional hyper-parameters so as to increase the parameter adjustment difficulty.
Disclosure of Invention
In order to solve the problems, the invention provides a knowledge distillation method based on learnable feature transformation, which aligns the intermediate features and output responses of student models with a teacher model, improves the knowledge distillation effect, avoids the complexity of manual design of structures, and realizes performance improvement on a plurality of computer vision tasks (such as picture classification, target detection, semantic segmentation, and the like).
The technical scheme provided by the invention is as follows:
a knowledge distillation method based on learnable feature transformation, as shown in fig. 1, comprising the steps of:
1) Inputting input data into a teacher model, wherein a middle layer of the teacher model outputs a first characteristic diagram, inputting the input data into a student model, and a middle layer of the student model outputs a second characteristic diagram;
2) Aligning the second feature map with the first feature map in space dimension and channel dimension, and obtaining a third feature map by the aligned feature maps through a multilayer perceptron module; meanwhile, the shape of the aligned feature diagram is unfolded and transposed, a transformed feature diagram is obtained through another multilayer perceptron module, and the shape of the transformed feature diagram is restored to the shape before transformation to obtain a fourth feature diagram;
3) Calculating the mean square error loss between the first characteristic diagram and the third characteristic diagram to serve as a spatial characteristic loss, calculating the mean square error loss between the first characteristic diagram and the fourth characteristic diagram to serve as a channel characteristic loss, and weighting and summing the spatial characteristic loss and the channel characteristic loss to serve as a knowledge distillation loss function between the teacher model and the student model;
4) And training a student model to realize knowledge distillation according to the knowledge distillation loss function.
Preferably, the multi-layered perceptron module is a multi-layered perceptron structure with 1 hidden layer number and a ReLU activation function.
Preferably, the second feature map is aligned with the first feature map in spatial and channel dimensions by bilinear interpolation and 1x1 convolution.
And further, acquiring a downstream task of the student model, matching an objective function of the model according to the type of the downstream task, and combining the objective function and the knowledge distillation loss function to train the student model.
Further, adjusting the hyperparameters of the distillation loss functions according to the teacher model, the student models and the downstream tasks, summing the regression loss functions, the classification loss functions and the knowledge distillation loss functions in the objective functions to obtain total loss functions of the student model training, and training the student models according to the total loss functions.
The invention has the beneficial effects that:
the invention provides a knowledge distillation method based on learnable feature transformation, which aligns the features of a teacher model and a student model, improves the distillation effect, simultaneously does not need to design complex feature transformation modules aiming at different tasks, does not introduce complex hyper-parameters, omits complex parameter adjustment steps, improves the universality of knowledge distillation on a plurality of tasks, and can obtain good effect on a plurality of computer vision tasks.
Drawings
FIG. 1 is a schematic flow diagram of a knowledge distillation method based on learnable feature transformation according to the present invention;
fig. 2 is a schematic diagram of a training process architecture of a student model according to an embodiment of the present invention.
Detailed Description
The invention will be further described, by way of example, with reference to the accompanying drawings, without in any way limiting the scope of the invention.
Taking a large-scale target detection data set COCO as an example, retineNet-rx 101 pre-trained on the data is taken as a teacher model, and RetinaNet-R50 is selected as a student model to explain how to perform knowledge distillation on a target detection task through a learnable transformation module, as shown in FIG. 2.
The method comprises the following steps of S1, inputting input data into a teacher model to obtain a first characteristic diagram output by a middle layer of the teacher model, and inputting the input data into a student model to obtain a second characteristic diagram output by the middle layer of the student model, and specifically comprises the following steps:
s11: inputting any batch of original training pictures into a teacher model Retinnet-rx101, and obtaining a first characteristic diagram output by the middle layer on the FPN part of the teacher model.
S12: and inputting the training picture into a student model RetinaNet-R50, and obtaining a second characteristic diagram output by the middle layer on the FPN part of the student model.
Step S2: obtaining a third feature map and a fourth feature map by using the multilayer perceptron module, specifically comprising:
s21: aligning the second feature map with the first feature map in space dimension and channel dimension by bilinear interpolation and 1x1 convolution to obtain an aligned feature map;
s22: and obtaining a third feature map by the aligned feature map through a multilayer perceptron module with the hidden layer number of 1 and the activation function of ReLU.
S23: and setting the shape of the aligned feature diagram as [ N, C, H, W ], adjusting the shape of the feature diagram into [ N, (H x W), C ] through unfolding and transposition operations, obtaining a transformed feature diagram through a multi-layer perceptron module with the number of hidden layers as 1 and the activation function as ReLU, and then adjusting the shape of the transformed feature diagram into [ N, C, H, W ] to obtain the fourth feature diagram.
And step S3: calculating the spatial characteristic loss and the channel characteristic loss between the teacher model and the student model according to the first characteristic diagram, the third characteristic diagram and the fourth characteristic diagram, and performing weighted summation of the spatial characteristic loss and the channel characteristic loss as a knowledge distillation loss function between the teacher model and the student model, wherein the method specifically comprises the following steps:
s31: calculating the mean square error loss between the first feature map and the third feature map as the spatial feature loss, wherein the expression is as follows:
wherein feat T In order to provide the first characteristic diagram,is the third characteristic diagram
S32: calculating the mean square error loss between the first feature map and the fourth feature map as the channel feature loss, wherein the expression is as follows:
S33: and weighting and summing the spatial characteristic loss and the channel characteristic loss to obtain the knowledge distillation loss function, wherein the expression of the knowledge distillation loss function is as follows:
L distill =αLoss Spatial +βLoss Channel
where α, β are hyper-parameters, which are set to 2e-5 and 1e-6, respectively, in this example.
And step S4: and training a student model to realize knowledge distillation according to the knowledge distillation loss function.
Further, a downstream task of the student model is obtained, and in this embodiment, the downstream task is a target detection task.
Step S5: according to the downstream task type matching model objective function, in this embodiment, the objective function of the model is divided into a regression loss function and a classification loss function, and the regression loss function expression is:
wherein t is i For predicting the deviation of each anchor from the Ground Truth (GT)For each anchor's true deviation from GT.
In this embodiment, the classification Loss function uses Focal local, and its expression is:
L cls =-α t (1-p t )γlog(p t )
wherein p is t Is the probability value, alpha, that the sample is correctly classified t γ is a hyper-parameter, and is set to 0.25 and 2.0 in this embodiment, respectively.
Step S6: adjusting the hyperparameter, the target function, the knowledge distillation loss function and the hyperparameter of the distillation loss function according to a teacher model, a student model and a downstream task to obtain a total loss function trained by the student model; training the student model according to the total loss function, wherein the expression of the total loss function is as follows:
L total =L reg +L cls +L distill 。
for image classification tasks, the result on the ImageNet data set shows that by using ResNet34 as a teacher model and ResNet18 as a student model and adopting the distillation method provided by the invention to carry out knowledge distillation, the Top-1 accuracy on the test set can be improved from 69.9% to 71.4%; for the target detection task, the result on the MSCOCO data set shows that by using RetinaNet-RX101 as a teacher model and RetinaNet-R50 as a student model, the knowledge distillation method provided by the invention can improve the mAP of the student model from 37.4% to 41.0%; as for the semantic segmentation task, the result on a CityScaps data set shows that PSPNet-ResNet34 is used as a teacher model, PSPNet-ResNet18 is used as a student model, and the knowledge distillation method provided by the invention can be used for increasing the mIoU of the student model from 69.9% to 74.2% (note: imageNet is a large-scale image classification data set, top1-accuracy is used for measuring the image classification accuracy, MSCOCO is a large-scale data set and comprises tasks such as target detection, mAP of bbox is an index for measuring the target detection performance, cityScaps is a semantic segmentation data set, and mIoU is an index for measuring the semantic segmentation performance). For example, for the image classification task, on the Cifar-100 dataset, using ResNet56 based on the convolutional neural network architecture as the teacher model and ViT-tiny based on the Transformer architecture as the student model, the Top 1-accuacy of the student model can be increased from 57.8% to 77.5% (note: cifar100 is a small-scale image classification dataset).
The present invention has been described in detail with reference to the embodiments, and those skilled in the art can make insubstantial changes in form or content from the steps described above without departing from the scope of the present invention. Therefore, the present invention is not limited to the disclosure in the above embodiments, and the scope of the present invention should be determined by the appended claims.
Claims (5)
1. A knowledge distillation method based on learnable feature transformation, comprising the steps of:
1) Inputting input data into a teacher model, wherein a middle layer of the teacher model outputs a first characteristic diagram, the input data is input into a student model, and a middle layer of the student model outputs a second characteristic diagram;
2) Aligning the second feature map with the first feature map in space dimension and channel dimension, and obtaining a third feature map through a multilayer perceptron module after the aligned feature maps are aligned; meanwhile, unfolding and transposing the shape of the aligned feature map, obtaining a transformed feature map through another multi-layer perceptron module, and restoring the shape of the transformed feature map into the shape before transformation to obtain a fourth feature map;
3) Calculating the mean square error loss between the first characteristic diagram and the third characteristic diagram to serve as a spatial characteristic loss, calculating the mean square error loss between the first characteristic diagram and the fourth characteristic diagram to serve as a channel characteristic loss, and weighting and summing the spatial characteristic loss and the channel characteristic loss to serve as a knowledge distillation loss function between the teacher model and the student model;
4) And training a student model to realize knowledge distillation according to the knowledge distillation loss function.
2. The method for distilling knowledge based on learnable feature transformation of claim 1, wherein the multi-layered perceptron module in step 2) employs a multi-layered perceptron with 1 hidden layer number and a ReLU activation function.
3. The method for knowledge distillation based on learnable feature transformation of claim 1, characterized in that in step 2) the second feature map is aligned with the first feature map in spatial dimension and channel dimension by bilinear interpolation and 1x1 convolution.
4. The learnable feature transform based knowledge distillation method of claim 1, wherein in step 4) downstream tasks of the student model are obtained, the objective function of the model is matched according to the downstream task type, and the student model is trained by combining the objective function and the knowledge distillation loss function.
5. The learnable feature transformation based knowledge distillation method of claim 4, wherein the regression loss function, the classification loss function and the knowledge distillation loss function in the objective function are summed in step 4) to obtain a total loss function of the student model training, and the student model is trained according to the total loss function.
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211196707.5A CN115565021A (en) | 2022-09-28 | 2022-09-28 | Neural network knowledge distillation method based on learnable feature transformation |
PCT/CN2022/143756 WO2024066111A1 (en) | 2022-09-28 | 2022-12-30 | Image processing model training method and apparatus, image processing method and apparatus, and device and medium |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211196707.5A CN115565021A (en) | 2022-09-28 | 2022-09-28 | Neural network knowledge distillation method based on learnable feature transformation |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115565021A true CN115565021A (en) | 2023-01-03 |
Family
ID=84743371
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211196707.5A Pending CN115565021A (en) | 2022-09-28 | 2022-09-28 | Neural network knowledge distillation method based on learnable feature transformation |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN115565021A (en) |
WO (1) | WO2024066111A1 (en) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118015316B (en) * | 2024-04-07 | 2024-06-11 | 之江实验室 | Image matching model training method, device, storage medium and equipment |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111160409A (en) * | 2019-12-11 | 2020-05-15 | 浙江大学 | Heterogeneous neural network knowledge reorganization method based on common feature learning |
CN112819050B (en) * | 2021-01-22 | 2023-10-27 | 北京市商汤科技开发有限公司 | Knowledge distillation and image processing method, apparatus, electronic device and storage medium |
CN114998694A (en) * | 2022-06-08 | 2022-09-02 | 上海商汤智能科技有限公司 | Method, apparatus, device, medium and program product for training image processing model |
-
2022
- 2022-09-28 CN CN202211196707.5A patent/CN115565021A/en active Pending
- 2022-12-30 WO PCT/CN2022/143756 patent/WO2024066111A1/en unknown
Also Published As
Publication number | Publication date |
---|---|
WO2024066111A1 (en) | 2024-04-04 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110349185B (en) | RGBT target tracking model training method and device | |
CN113128355A (en) | Unmanned aerial vehicle image real-time target detection method based on channel pruning | |
CN114419449B (en) | Self-attention multi-scale feature fusion remote sensing image semantic segmentation method | |
CN114943963A (en) | Remote sensing image cloud and cloud shadow segmentation method based on double-branch fusion network | |
US20230215166A1 (en) | Few-shot urban remote sensing image information extraction method based on meta learning and attention | |
CN115565021A (en) | Neural network knowledge distillation method based on learnable feature transformation | |
CN115965864A (en) | Lightweight attention mechanism network for crop disease identification | |
Lin et al. | Application of ConvLSTM network in numerical temperature prediction interpretation | |
CN111179272A (en) | Rapid semantic segmentation method for road scene | |
CN112989843B (en) | Intention recognition method, device, computing equipment and storage medium | |
CN117830788A (en) | Image target detection method for multi-source information fusion | |
Pan et al. | PMT-IQA: progressive multi-task learning for blind image quality assessment | |
Cui et al. | WetlandNet: Semantic segmentation for remote sensing images of coastal wetlands via improved UNet with deconvolution | |
CN112149496A (en) | Real-time road scene segmentation method based on convolutional neural network | |
CN116977872A (en) | CNN+ transducer remote sensing image detection method | |
Huan et al. | Remote sensing image reconstruction using an asymmetric multi-scale super-resolution network | |
CN115601745A (en) | Multi-view three-dimensional object identification method facing application end | |
Yang et al. | Amd: Adaptive masked distillation for object detection | |
CN114494893A (en) | Remote sensing image feature extraction method based on semantic reuse context feature pyramid | |
CN109726690B (en) | Multi-region description method for learner behavior image based on DenseCap network | |
Chen et al. | Multi-modal feature fusion based on variational autoencoder for visual question answering | |
Gan et al. | Image super-resolution reconstruction based on deep residual network | |
Takemoto et al. | Automatic Font Generation for Early-Modern Japanese Printed Books | |
CN114066831B (en) | Remote sensing image mosaic quality non-reference evaluation method based on two-stage training | |
CN116863032B (en) | Flood disaster scene generation method based on generation countermeasure network |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |