CN115565021A - Neural network knowledge distillation method based on learnable feature transformation - Google Patents

Neural network knowledge distillation method based on learnable feature transformation Download PDF

Info

Publication number
CN115565021A
CN115565021A CN202211196707.5A CN202211196707A CN115565021A CN 115565021 A CN115565021 A CN 115565021A CN 202211196707 A CN202211196707 A CN 202211196707A CN 115565021 A CN115565021 A CN 115565021A
Authority
CN
China
Prior art keywords
knowledge distillation
loss
model
feature map
feature
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211196707.5A
Other languages
Chinese (zh)
Inventor
王勇涛
刘子炜
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peking University
Original Assignee
Peking University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peking University filed Critical Peking University
Priority to CN202211196707.5A priority Critical patent/CN115565021A/en
Priority to PCT/CN2022/143756 priority patent/WO2024066111A1/en
Publication of CN115565021A publication Critical patent/CN115565021A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/7715Feature extraction, e.g. by transforming the feature space, e.g. multi-dimensional scaling [MDS]; Mappings, e.g. subspace methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)
  • Image Processing (AREA)

Abstract

The invention provides a neural network knowledge distillation method based on learnable feature transformation, and belongs to the technical field of computer vision. The invention aligns the intermediate characteristics and the output results of the student model and the teacher model, does not need to design complex characteristic transformation modules aiming at different tasks, does not introduce complex hyper-parameters, omits fussy parameter adjustment steps, can improve the universality of knowledge distillation on a plurality of tasks, improves the knowledge distillation effect, simultaneously avoids the fussy manual design of a structure, and realizes the performance improvement on a plurality of computer vision tasks (such as picture classification, target detection, semantic segmentation, and the like).

Description

Neural network knowledge distillation method based on learnable feature transformation
Technical Field
The invention belongs to the technical field of computer vision, and relates to deep learning technologies such as computer vision, neural network model compression, neural network knowledge distillation based on intermediate features and the like.
Background
In recent years, with the continuous development of deep learning technology, deep convolutional neural networks are widely applied to computer vision tasks such as image classification, target detection, semantic segmentation and the like, and achieve better and better performance on the tasks. After better performance is achieved, the complexity of the deep convolutional neural network model is higher and higher, and the demands on computing resources and storage resources are increased gradually, so that the deep convolutional neural network model is difficult to deploy on resource-limited devices such as mobile devices and embedded platforms. To solve this problem, a neural network model compression technique needs to be used.
Knowledge distillation is an important method in the current neural network model compression technology, and the method takes a large-scale neural network as a teacher network and a small-scale neural network as a student network, and transmits the knowledge of the teacher network to the student network, so that a neural network with low complexity, good performance and easy deployment is obtained, and the purpose of model compression is achieved.
At present, the mainstream knowledge distillation method is divided into output response-based knowledge distillation and intermediate characteristic-based knowledge distillation, and the output response-based knowledge distillation method uses a prediction result of a teacher model tail layer as supervision information to guide a student model to simulate the behavior of the teacher model. The knowledge distillation method based on the intermediate features takes the features of the intermediate hidden layer of the teacher model as supervision signals to guide the training of the student model. In practical application, various knowledge distillation methods are derived aiming at different visual tasks, and the methods usually have a plurality of manually designed parts, such as loss functions and feature masks, and the manually designed parts reduce the universality of the distillation method and bring additional hyper-parameters so as to increase the parameter adjustment difficulty.
Disclosure of Invention
In order to solve the problems, the invention provides a knowledge distillation method based on learnable feature transformation, which aligns the intermediate features and output responses of student models with a teacher model, improves the knowledge distillation effect, avoids the complexity of manual design of structures, and realizes performance improvement on a plurality of computer vision tasks (such as picture classification, target detection, semantic segmentation, and the like).
The technical scheme provided by the invention is as follows:
a knowledge distillation method based on learnable feature transformation, as shown in fig. 1, comprising the steps of:
1) Inputting input data into a teacher model, wherein a middle layer of the teacher model outputs a first characteristic diagram, inputting the input data into a student model, and a middle layer of the student model outputs a second characteristic diagram;
2) Aligning the second feature map with the first feature map in space dimension and channel dimension, and obtaining a third feature map by the aligned feature maps through a multilayer perceptron module; meanwhile, the shape of the aligned feature diagram is unfolded and transposed, a transformed feature diagram is obtained through another multilayer perceptron module, and the shape of the transformed feature diagram is restored to the shape before transformation to obtain a fourth feature diagram;
3) Calculating the mean square error loss between the first characteristic diagram and the third characteristic diagram to serve as a spatial characteristic loss, calculating the mean square error loss between the first characteristic diagram and the fourth characteristic diagram to serve as a channel characteristic loss, and weighting and summing the spatial characteristic loss and the channel characteristic loss to serve as a knowledge distillation loss function between the teacher model and the student model;
4) And training a student model to realize knowledge distillation according to the knowledge distillation loss function.
Preferably, the multi-layered perceptron module is a multi-layered perceptron structure with 1 hidden layer number and a ReLU activation function.
Preferably, the second feature map is aligned with the first feature map in spatial and channel dimensions by bilinear interpolation and 1x1 convolution.
And further, acquiring a downstream task of the student model, matching an objective function of the model according to the type of the downstream task, and combining the objective function and the knowledge distillation loss function to train the student model.
Further, adjusting the hyperparameters of the distillation loss functions according to the teacher model, the student models and the downstream tasks, summing the regression loss functions, the classification loss functions and the knowledge distillation loss functions in the objective functions to obtain total loss functions of the student model training, and training the student models according to the total loss functions.
The invention has the beneficial effects that:
the invention provides a knowledge distillation method based on learnable feature transformation, which aligns the features of a teacher model and a student model, improves the distillation effect, simultaneously does not need to design complex feature transformation modules aiming at different tasks, does not introduce complex hyper-parameters, omits complex parameter adjustment steps, improves the universality of knowledge distillation on a plurality of tasks, and can obtain good effect on a plurality of computer vision tasks.
Drawings
FIG. 1 is a schematic flow diagram of a knowledge distillation method based on learnable feature transformation according to the present invention;
fig. 2 is a schematic diagram of a training process architecture of a student model according to an embodiment of the present invention.
Detailed Description
The invention will be further described, by way of example, with reference to the accompanying drawings, without in any way limiting the scope of the invention.
Taking a large-scale target detection data set COCO as an example, retineNet-rx 101 pre-trained on the data is taken as a teacher model, and RetinaNet-R50 is selected as a student model to explain how to perform knowledge distillation on a target detection task through a learnable transformation module, as shown in FIG. 2.
The method comprises the following steps of S1, inputting input data into a teacher model to obtain a first characteristic diagram output by a middle layer of the teacher model, and inputting the input data into a student model to obtain a second characteristic diagram output by the middle layer of the student model, and specifically comprises the following steps:
s11: inputting any batch of original training pictures into a teacher model Retinnet-rx101, and obtaining a first characteristic diagram output by the middle layer on the FPN part of the teacher model.
S12: and inputting the training picture into a student model RetinaNet-R50, and obtaining a second characteristic diagram output by the middle layer on the FPN part of the student model.
Step S2: obtaining a third feature map and a fourth feature map by using the multilayer perceptron module, specifically comprising:
s21: aligning the second feature map with the first feature map in space dimension and channel dimension by bilinear interpolation and 1x1 convolution to obtain an aligned feature map;
s22: and obtaining a third feature map by the aligned feature map through a multilayer perceptron module with the hidden layer number of 1 and the activation function of ReLU.
S23: and setting the shape of the aligned feature diagram as [ N, C, H, W ], adjusting the shape of the feature diagram into [ N, (H x W), C ] through unfolding and transposition operations, obtaining a transformed feature diagram through a multi-layer perceptron module with the number of hidden layers as 1 and the activation function as ReLU, and then adjusting the shape of the transformed feature diagram into [ N, C, H, W ] to obtain the fourth feature diagram.
And step S3: calculating the spatial characteristic loss and the channel characteristic loss between the teacher model and the student model according to the first characteristic diagram, the third characteristic diagram and the fourth characteristic diagram, and performing weighted summation of the spatial characteristic loss and the channel characteristic loss as a knowledge distillation loss function between the teacher model and the student model, wherein the method specifically comprises the following steps:
s31: calculating the mean square error loss between the first feature map and the third feature map as the spatial feature loss, wherein the expression is as follows:
Figure BDA0003869106480000031
wherein feat T In order to provide the first characteristic diagram,
Figure BDA0003869106480000032
is the third characteristic diagram
S32: calculating the mean square error loss between the first feature map and the fourth feature map as the channel feature loss, wherein the expression is as follows:
Figure BDA0003869106480000033
wherein feat T In order to be the first characteristic diagram,
Figure BDA0003869106480000034
is the fourth characteristic diagram
S33: and weighting and summing the spatial characteristic loss and the channel characteristic loss to obtain the knowledge distillation loss function, wherein the expression of the knowledge distillation loss function is as follows:
L distill =αLoss Spatial +βLoss Channel
where α, β are hyper-parameters, which are set to 2e-5 and 1e-6, respectively, in this example.
And step S4: and training a student model to realize knowledge distillation according to the knowledge distillation loss function.
Further, a downstream task of the student model is obtained, and in this embodiment, the downstream task is a target detection task.
Step S5: according to the downstream task type matching model objective function, in this embodiment, the objective function of the model is divided into a regression loss function and a classification loss function, and the regression loss function expression is:
Figure BDA0003869106480000041
wherein t is i For predicting the deviation of each anchor from the Ground Truth (GT)
Figure BDA0003869106480000042
For each anchor's true deviation from GT.
In this embodiment, the classification Loss function uses Focal local, and its expression is:
L cls =-α t (1-p t )γlog(p t )
wherein p is t Is the probability value, alpha, that the sample is correctly classified t γ is a hyper-parameter, and is set to 0.25 and 2.0 in this embodiment, respectively.
Step S6: adjusting the hyperparameter, the target function, the knowledge distillation loss function and the hyperparameter of the distillation loss function according to a teacher model, a student model and a downstream task to obtain a total loss function trained by the student model; training the student model according to the total loss function, wherein the expression of the total loss function is as follows:
L total =L reg +L cls +L distill
for image classification tasks, the result on the ImageNet data set shows that by using ResNet34 as a teacher model and ResNet18 as a student model and adopting the distillation method provided by the invention to carry out knowledge distillation, the Top-1 accuracy on the test set can be improved from 69.9% to 71.4%; for the target detection task, the result on the MSCOCO data set shows that by using RetinaNet-RX101 as a teacher model and RetinaNet-R50 as a student model, the knowledge distillation method provided by the invention can improve the mAP of the student model from 37.4% to 41.0%; as for the semantic segmentation task, the result on a CityScaps data set shows that PSPNet-ResNet34 is used as a teacher model, PSPNet-ResNet18 is used as a student model, and the knowledge distillation method provided by the invention can be used for increasing the mIoU of the student model from 69.9% to 74.2% (note: imageNet is a large-scale image classification data set, top1-accuracy is used for measuring the image classification accuracy, MSCOCO is a large-scale data set and comprises tasks such as target detection, mAP of bbox is an index for measuring the target detection performance, cityScaps is a semantic segmentation data set, and mIoU is an index for measuring the semantic segmentation performance). For example, for the image classification task, on the Cifar-100 dataset, using ResNet56 based on the convolutional neural network architecture as the teacher model and ViT-tiny based on the Transformer architecture as the student model, the Top 1-accuacy of the student model can be increased from 57.8% to 77.5% (note: cifar100 is a small-scale image classification dataset).
The present invention has been described in detail with reference to the embodiments, and those skilled in the art can make insubstantial changes in form or content from the steps described above without departing from the scope of the present invention. Therefore, the present invention is not limited to the disclosure in the above embodiments, and the scope of the present invention should be determined by the appended claims.

Claims (5)

1. A knowledge distillation method based on learnable feature transformation, comprising the steps of:
1) Inputting input data into a teacher model, wherein a middle layer of the teacher model outputs a first characteristic diagram, the input data is input into a student model, and a middle layer of the student model outputs a second characteristic diagram;
2) Aligning the second feature map with the first feature map in space dimension and channel dimension, and obtaining a third feature map through a multilayer perceptron module after the aligned feature maps are aligned; meanwhile, unfolding and transposing the shape of the aligned feature map, obtaining a transformed feature map through another multi-layer perceptron module, and restoring the shape of the transformed feature map into the shape before transformation to obtain a fourth feature map;
3) Calculating the mean square error loss between the first characteristic diagram and the third characteristic diagram to serve as a spatial characteristic loss, calculating the mean square error loss between the first characteristic diagram and the fourth characteristic diagram to serve as a channel characteristic loss, and weighting and summing the spatial characteristic loss and the channel characteristic loss to serve as a knowledge distillation loss function between the teacher model and the student model;
4) And training a student model to realize knowledge distillation according to the knowledge distillation loss function.
2. The method for distilling knowledge based on learnable feature transformation of claim 1, wherein the multi-layered perceptron module in step 2) employs a multi-layered perceptron with 1 hidden layer number and a ReLU activation function.
3. The method for knowledge distillation based on learnable feature transformation of claim 1, characterized in that in step 2) the second feature map is aligned with the first feature map in spatial dimension and channel dimension by bilinear interpolation and 1x1 convolution.
4. The learnable feature transform based knowledge distillation method of claim 1, wherein in step 4) downstream tasks of the student model are obtained, the objective function of the model is matched according to the downstream task type, and the student model is trained by combining the objective function and the knowledge distillation loss function.
5. The learnable feature transformation based knowledge distillation method of claim 4, wherein the regression loss function, the classification loss function and the knowledge distillation loss function in the objective function are summed in step 4) to obtain a total loss function of the student model training, and the student model is trained according to the total loss function.
CN202211196707.5A 2022-09-28 2022-09-28 Neural network knowledge distillation method based on learnable feature transformation Pending CN115565021A (en)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202211196707.5A CN115565021A (en) 2022-09-28 2022-09-28 Neural network knowledge distillation method based on learnable feature transformation
PCT/CN2022/143756 WO2024066111A1 (en) 2022-09-28 2022-12-30 Image processing model training method and apparatus, image processing method and apparatus, and device and medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211196707.5A CN115565021A (en) 2022-09-28 2022-09-28 Neural network knowledge distillation method based on learnable feature transformation

Publications (1)

Publication Number Publication Date
CN115565021A true CN115565021A (en) 2023-01-03

Family

ID=84743371

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211196707.5A Pending CN115565021A (en) 2022-09-28 2022-09-28 Neural network knowledge distillation method based on learnable feature transformation

Country Status (2)

Country Link
CN (1) CN115565021A (en)
WO (1) WO2024066111A1 (en)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN118015316B (en) * 2024-04-07 2024-06-11 之江实验室 Image matching model training method, device, storage medium and equipment

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111160409A (en) * 2019-12-11 2020-05-15 浙江大学 Heterogeneous neural network knowledge reorganization method based on common feature learning
CN112819050B (en) * 2021-01-22 2023-10-27 北京市商汤科技开发有限公司 Knowledge distillation and image processing method, apparatus, electronic device and storage medium
CN114998694A (en) * 2022-06-08 2022-09-02 上海商汤智能科技有限公司 Method, apparatus, device, medium and program product for training image processing model

Also Published As

Publication number Publication date
WO2024066111A1 (en) 2024-04-04

Similar Documents

Publication Publication Date Title
CN110349185B (en) RGBT target tracking model training method and device
CN113128355A (en) Unmanned aerial vehicle image real-time target detection method based on channel pruning
CN114419449B (en) Self-attention multi-scale feature fusion remote sensing image semantic segmentation method
CN114943963A (en) Remote sensing image cloud and cloud shadow segmentation method based on double-branch fusion network
US20230215166A1 (en) Few-shot urban remote sensing image information extraction method based on meta learning and attention
CN115565021A (en) Neural network knowledge distillation method based on learnable feature transformation
CN115965864A (en) Lightweight attention mechanism network for crop disease identification
Lin et al. Application of ConvLSTM network in numerical temperature prediction interpretation
CN111179272A (en) Rapid semantic segmentation method for road scene
CN112989843B (en) Intention recognition method, device, computing equipment and storage medium
CN117830788A (en) Image target detection method for multi-source information fusion
Pan et al. PMT-IQA: progressive multi-task learning for blind image quality assessment
Cui et al. WetlandNet: Semantic segmentation for remote sensing images of coastal wetlands via improved UNet with deconvolution
CN112149496A (en) Real-time road scene segmentation method based on convolutional neural network
CN116977872A (en) CNN+ transducer remote sensing image detection method
Huan et al. Remote sensing image reconstruction using an asymmetric multi-scale super-resolution network
CN115601745A (en) Multi-view three-dimensional object identification method facing application end
Yang et al. Amd: Adaptive masked distillation for object detection
CN114494893A (en) Remote sensing image feature extraction method based on semantic reuse context feature pyramid
CN109726690B (en) Multi-region description method for learner behavior image based on DenseCap network
Chen et al. Multi-modal feature fusion based on variational autoencoder for visual question answering
Gan et al. Image super-resolution reconstruction based on deep residual network
Takemoto et al. Automatic Font Generation for Early-Modern Japanese Printed Books
CN114066831B (en) Remote sensing image mosaic quality non-reference evaluation method based on two-stage training
CN116863032B (en) Flood disaster scene generation method based on generation countermeasure network

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination