CN111753995B - Local interpretable method based on gradient lifting tree - Google Patents

Local interpretable method based on gradient lifting tree Download PDF

Info

Publication number
CN111753995B
CN111753995B CN202010580912.6A CN202010580912A CN111753995B CN 111753995 B CN111753995 B CN 111753995B CN 202010580912 A CN202010580912 A CN 202010580912A CN 111753995 B CN111753995 B CN 111753995B
Authority
CN
China
Prior art keywords
model
gradient lifting
tree model
interpretation
importance
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010580912.6A
Other languages
Chinese (zh)
Other versions
CN111753995A (en
Inventor
仇鑫
李鑫
张瑞
徐宏刚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
East China Normal University
Original Assignee
East China Normal University
Filing date
Publication date
Application filed by East China Normal University filed Critical East China Normal University
Priority to CN202010580912.6A priority Critical patent/CN111753995B/en
Publication of CN111753995A publication Critical patent/CN111753995A/en
Application granted granted Critical
Publication of CN111753995B publication Critical patent/CN111753995B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Abstract

The invention discloses a local interpretable method based on gradient lifting trees, which comprises the steps of distilling a complex model by using knowledge to obtain a gradient lifting tree model, improving the traditional method for calculating the importance of average degree of uncertainty reduction (MDI) to be the weighted average of the contribution of each gradient lifting tree to node information gain, and sequencing the weighted average to obtain the importance sequencing of input features so as to obtain local interpretable, thereby realizing interpretation of the complex model. The invention is a general interpretable method, which can extract and interpret data sets in various fields, such as natural language processing data sets, image data sets and table data sets. Meanwhile, the method can be popularized and applied to global interpretation of the acquired model by using a sub-module selection method through local interpretation.

Description

Local interpretable method based on gradient lifting tree
Technical Field
The invention relates to the field of artificial intelligence, in particular to a local interpretable method based on a gradient lifting tree, which is applied to extraction interpretation of various artificial intelligent models.
Background
As machine learning models are increasingly used in critical areas such as automated driving automobiles, healthcare, financial markets, and legal systems, it becomes critical for humans to understand predictions made by machine learning algorithms. Many complex models (e.g., deep neural networks and ensemble learning) are fine-tuned to optimize prediction accuracy, which makes it difficult to interpret predictions. Interpretable machine learning solves this problem from two directions. The first approach attempts to build inherently interpretable models based on decision trees (sets or rules), GAMs (generalized additive models), logistic regression, etc., which often face the need to reduce prediction accuracy. Another approach provides a global understanding of the entire model or a local interpretation of individual predictions. Some interpretation methods are model independent and can be applied to any classifier or regressor, while others are designed for a particular model. The form of interpretation varies from functional importance to decision sets or rules.
The field of interpretable machine learning has recently attracted more and more researchers. With the resumption of deep learning, understanding complex neural networks becomes increasingly difficult. Deep neural networks remain challenging because they typically contain a large number of hidden layers and parameters, as well as associated active features residing on the hidden layers. Meanwhile, GBM (gradient hoisting machine) is a powerful whole learning algorithm, which can prove its competitive performance on many tasks (e.g. online advertising). Boosting is a powerful supervised learning method that enhances the predictive performance of a model by iteratively refining and combining multiple weak learners (typically decision trees). Gradient enhancement generalizes the enhancement method to any differentiable loss function, which can be used for regression and classification problems. In practice, GBM works well in many application areas and is supported by many publicly available implementations. Learning games like Kaggle are all tree-based gradient boosting methods, especially LightGBM, XGBoost, etc. One of the most popular base learners for GBM may be a fixed-size CART (classification tree and regression tree) resulting in GBDT (gradient enhanced decision tree, also called gradient tree enhancement). In the present invention, attention is paid to explaining the single prediction of tree-based GBM in terms of functional importance. For tree-based integration methods, although decision trees are relatively easy to understand, the final additive model becomes less transparent after model integration.
Recent developments in model-agnostic interpretation methods can be used to interpret integrated methods. Model-independent interpretation methods treat the target model as a black box, enabling interpretation of any classifier or regressor. The existing work of model agnostic methods is typically post hoc analysis of a given black box model of fit data. One common approach is to learn another model that approximates the predictions of the original model and is relatively easy to interpret. Earlier work on global approximation of the original predictions, while more recently methods have been proposed, such as LIME and Anchor, to be able to obtain interpretable models for individual samples. Most model-independent methods perturb the input instance for interpretation according to certain perturbation distributions, by which the most likely important features can be specified for the prediction. For complex models, it is often difficult to globally interpret the behavior of the model using a simple interpretable set or rules formed by selected important features. Also, the interpretation of the entire model may not perfectly interpret a single prediction. Therefore, in this case, it is preferable to use a local interpretation method having a concise interpretable explanation. To further evaluate the entire model, an interpretation generated by a subset of inputs may be chosen to be applied to the unknown instance.
Disclosure of Invention
The invention aims to provide a local interpretable method based on a gradient lifting tree, which can improve the interpretable capability of a model by improving a feature importance calculation method of an integrated model, and can simultaneously interpret an original complex model by utilizing a knowledge distillation technology.
The specific technical scheme for realizing the aim of the invention is as follows:
a local interpretable method based on gradient lifting tree is characterized in that: the method comprises the following specific steps:
Step 1: performing parameter training on the initial complex model by using a training data set, and extracting input features;
Step 2: carrying out knowledge distillation on the trained model to obtain soft tag output of input characteristics;
Step 3: training a gradient lifting tree model by using the input characteristics obtained in the step 1 and the output soft labels obtained in the step 2 to obtain a trained gradient lifting tree model;
Step 4: extracting feature importance from the trained gradient lifting tree model, sequencing the feature importance, and selecting features with higher feature importance as interpretation of the initial complex model.
The training data set in the step 1 is a natural language data set, an image data set and a table data set; the initial model is a long-term memory network, a convolutional neural network and a multi-layer perceptron based on an attention mechanism; the parameter training is carried out: the natural language data set uses a long-term and short-term memory network based on an attention mechanism; the image dataset uses a convolutional neural network; the tabular dataset uses a multi-layer perceptron.
And step 2, carrying out knowledge distillation to obtain soft tag output of input characteristics, wherein a soft tag output formula is as follows:
Wherein Label soft refers to the soft Label output, z i refers to the final output of the initial model, T refers to the temperature parameter, i refers to the predicted i-th class, and j refers to the total predicted class of the predicted tasks.
The trained gradient lifting tree model obtained in the step 3 includes M weak discriminators, each weak discriminator being a decision tree model, where M is a parameter of the gradient lifting tree model.
And step 4, extracting feature importance from the trained gradient lifting tree model, sequencing the feature importance, and selecting features with higher feature importance as interpretation of the initial complex model, wherein the method specifically comprises the following steps of:
the calculation formula of the feature importance is as follows:
Wherein, The importance expectation of the characteristic P is represented, wherein the characteristic P consists of K data, and P k is the kth data of the characteristic; imp (P k) is the feature importance of the kth data of the feature, wherein Each weight gamma mhm (x) in Imp (P k) is the contribution degree of the mth weak discriminator in the trained gradient-lifted tree model to the whole model,The method is defined as the reduction rate of the impure degree of the m weak discriminator when the input is P k, wherein the reduction rate of the impure degree refers to the ratio of the reduction amount of the impure degree of P k used in node segmentation to the total reduction amount of the impure degree when the weak discriminator predicts the characteristic P k; the calculation of the unrepeacy is performed by dividing the node n through which the feature P k passes in the decision tree model, namely Gain (P k,n)=i(n)-pLi(nL)-pRi(nR), wherein i (n) represents the unrepeacy of the node splitting, and P L and P R represent the parts of the sample after splitting reaching n L and n R respectively; in the gradient lifting tree model obtained through training, T m represents an mth weak discriminator, namely an mth decision tree model, and when an input sample is x, T m (x) is used for representing that the sample x comprises a plurality of characteristics P, and the decision tree model T m corresponds to a path in prediction; the higher the importance expectation of the feature P indicates that the feature is more important to model decisions; all the features to be obtainedThe order from big to small is used as an interpretation extracted from the gradient lifting tree model and also used as an interpretation of the initial complex model.
The invention is a general interpretable method, which can extract and interpret data sets in various fields, such as natural language processing data sets, image data sets and table data sets. Meanwhile, the method can be popularized and applied to global interpretation of the acquired model by using a sub-module selection method through local interpretation.
Drawings
FIG. 1 is a flowchart showing an embodiment 1 of the present invention;
FIG. 2 is a diagram of an initial model framework for image processing according to embodiment 2 of the present invention;
FIG. 3 is a diagram of an initial model framework for natural language processing according to embodiment 1 of the present invention;
FIG. 4 is a diagram of an initial model framework for a form task according to embodiment 3 of the present invention;
Fig. 5 is a flow chart of the present invention.
Detailed Description
The present invention will be described in further detail with reference to the drawings and examples, in order to make the objects, technical solutions and advantages of the present invention more apparent. It should be understood that the particular embodiments described herein are illustrative only and should not be taken as limiting the invention.
The invention provides a local interpretation algorithm based on a gradient lifting tree model, which calculates relative importance through nodes passing in a sample prediction process, and sequences the relative importance to obtain importance sequences of input features so as to obtain local interpretable. The invention is a general interpretable method, which can extract and interpret data sets in various fields, such as natural language processing data sets, image data sets and table data sets.
The process of the invention is as shown in fig. 5, and comprises the steps of initial complex model training, extracting input features and outputting soft labels, gradient lifting tree model training, extracting feature importance, and sequencing feature importance to generate and explain.
Firstly, dividing original sample data into a training set and a testing set; secondly, training an original model on a training set, and extracting soft label output of the original model by using a knowledge distillation method; then, training a gradient lifting tree model by using the input of the training set and the output of the soft label; and then calculating the feature importance of the sample by using the feature calculation method of the invention on the test set for the single sample, and sequencing to obtain the interpretation of the sample.
The invention provides a formula for calculating feature importance, which is as follows:
Wherein, The importance expectation of the characteristic P is represented, wherein the characteristic P consists of K data, and P k is the kth data of the characteristic; imp (P k) is the feature importance of the kth data of the feature, wherein Each weight gamma mhm (x) in Imp (P k) is the contribution degree of the mth weak discriminator in the trained gradient-lifted tree model to the whole model,The method is defined as the reduction rate of the impure degree of the m weak discriminator when the input is P k, wherein the reduction rate of the impure degree refers to the ratio of the reduction amount of the impure degree of P k used in node segmentation to the total reduction amount of the impure degree when the weak discriminator predicts X; the calculation of the unrepeacy is performed by dividing the node n through which the feature P k passes in the decision tree model, namely Gain (P k,n)=i(n)-pLi(nL)-pRi(nR), wherein i (n) represents the unrepeacy of the node splitting, and P L and P R represent the parts of the sample after splitting reaching n L and n R respectively; in the gradient lifting tree model obtained through training, T m represents an mth weak discriminator, namely an mth decision tree model, and when an input sample is x (the sample x comprises a plurality of characteristics P) is represented by T m (x), a path corresponding to the decision tree model T m in prediction is represented by T m (x); the higher the importance expectation of the feature P indicates that the feature is more important to model decisions; all the features to be obtainedThe order from big to small is used as an interpretation extracted from the gradient lifting tree model and also used as an interpretation of the initial complex model.
Example 1
The main parts and implementation strategies of the invention are set forth below:
Fig. 1 is a specific flow of embodiment 1, which includes training of an initial complex model, extracting input features and output soft labels, training a gradient-lifting tree model, extracting feature importance from the gradient-lifting tree model, and arranging for interpretation. Step one: parameter training of the initial complex model by using the training data set, and extracting input features
Firstly, a training data set is constructed by using a natural language processing data set SST2, and then an initial complex model corresponding to a natural language processing task is designed and constructed, wherein the structure of the initial complex model is shown in figure 3 and comprises a word embedding layer, a long-short-term memory network layer, an attention layer, a random inactivation layer and a full-connection network layer. Training of the training dataset is performed using the initial complex model. After training the model, the output of the word embedding layer is extracted as an input feature.
Step two: performing knowledge distillation on the trained model to obtain soft tag output with input characteristics
Knowledge distillation is a technique commonly used in model compression and migration learning that can transfer knowledge of a complex network into another, simpler model. For complex models, it is actually very difficult to directly explain them, knowledge distillation is a very useful technique for enabling interpretation, and a model with low interpretation is distilled onto a model with higher interpretation, and the interpretation of the latter is obtained by interpreting the former. Therefore, the invention uses the knowledge distillation method to extract the soft label corresponding to the input characteristic from the last output layer of the initial model, and the formula of the soft label is as follows:
Wherein z i is logits output of the model, T is the temperature of knowledge distillation, and i corresponds to the number of types of prediction tasks. For natural language processing tasks, T is set to 2. And calculating to obtain an output soft Label soft corresponding to the input characteristic. Step three: training of gradient-lifted trees using input features and output soft labels
Constructing the input features and the soft labels obtained in the second step into a data set, training by using a gradient lifting tree model, and correspondingly adjusting parameters of the constructed gradient lifting tree for different tasks to ensure that the precision of the trained gradient lifting tree is high enough, wherein the acquisition of the high-precision gradient lifting tree model is a precondition for better interpretation of the extracted model. For a natural language processing task, setting the parameter M of the gradient lifting tree as 100, namely, setting the gradient lifting tree model to comprise 100 weak discriminators, namely, setting the gradient lifting tree model to be 100 decision tree models.
Step four: extracting feature importance from the trained gradient lifting tree model, sequencing the feature importance, and selecting features with higher feature importance as interpretation of the initial complex model
By utilizing the calculation method provided by the invention, any sample is selected from the test set, the sample is predicted by using the gradient lifting tree model trained in the step three, and the importance of the feature is calculated for each feature of the sample while the sample is predicted, and the higher the importance expectation of the feature P is, the more important the feature is for model decision; all the features to be obtainedThe method is characterized by sequencing from large to small, taking the sequence as an interpretation extracted from the gradient lifting tree model, and taking the sequence as an interpretation of an initial complex model, namely a local interpretation. In the natural language processing task, the input sample is a sentence, each word in the sentence is the characteristic of the sample, the characteristic importance of each word can be obtained through the calculation, the importance ranking of the words can be obtained through the ranking, and the important words are used as the interpretation of the sample.
Example 2
The main parts and implementation strategies of the invention are set forth below:
Step one: parameter training of the initial complex model by using the training data set, and extracting input features
Firstly, a training data set is constructed by using an image processing data set MNIST, and then an initial complex model corresponding to an image processing task is designed and constructed, wherein the structure of the initial complex model is shown in figure 2 and comprises a convolution layer, an activation layer, a pooling layer, a random inactivation layer and a fully connected network layer. Training of the training dataset is performed using the initial complex model. Image two-dimensional pixel data is directly used as an input feature.
Step two: performing knowledge distillation on the trained model to obtain soft tag output with input characteristics
For the image processing task, T is set to 1. And calculating to obtain an output soft Label soft corresponding to the input characteristic.
The subsequent steps are the same as those of example 1.
Example 3
The main parts and implementation strategies of the invention are set forth below:
Step one: parameter training of the initial complex model by using the training data set, and extracting input features
Firstly, a training data set is constructed by using a form processing data set add, and then an initial complex model corresponding to an image processing task is designed and constructed, and the structure of the initial complex model is shown in fig. 4 and comprises a fully connected network layer. Training of the training dataset is performed using the initial complex model. The form data is directly used as an input feature.
Step two: performing knowledge distillation on the trained model to obtain soft tag output with input characteristics
For the table processing task, T is set to 1. And calculating to obtain an output soft Label soft corresponding to the input characteristic.
The subsequent steps are the same as those of example 1.
The foregoing description of specific embodiments of the invention has been presented in terms of the drawings. It will be understood by those skilled in the art that various changes may be made and equivalents may be substituted without departing from the spirit and principles of the invention. The invention claims are modified and equivalent to fall within the scope of the invention.

Claims (1)

1. A local interpretable method based on a gradient lifting tree, which is characterized by comprising the following specific steps:
Step 1: performing parameter training on the initial complex model by using a training data set, and extracting input features;
Step 2: carrying out knowledge distillation on the trained model to obtain soft tag output of input characteristics;
Step 3: training a gradient lifting tree model by using the input characteristics obtained in the step 1 and the output soft labels obtained in the step 2 to obtain a trained gradient lifting tree model;
Step 4: extracting feature importance from the trained gradient lifting tree model, sequencing the feature importance, and selecting features with higher feature importance as interpretation of an initial complex model; wherein:
The training data set in the step 1 is a natural language data set, an image data set and a table data set; the initial model is a long-term memory network, a convolutional neural network and a multi-layer perceptron based on an attention mechanism; the parameter training is carried out: the natural language data set uses a long-term and short-term memory network based on an attention mechanism; the image dataset uses a convolutional neural network; the table dataset uses a multi-layer perceptron;
and step 2, carrying out knowledge distillation to obtain soft tag output of input characteristics, wherein a soft tag output formula is as follows:
Wherein Label soft refers to soft Label output, z i refers to final output of the initial model, T refers to temperature parameter, i refers to i-th type predicted, j refers to total predicted type of predicted task;
the trained gradient lifting tree model obtained in the step 3 comprises M weak discriminators, wherein each weak discriminator is a decision tree model, and M is a parameter of the gradient lifting tree model;
And step 4, extracting feature importance from the trained gradient lifting tree model, sequencing the feature importance, and selecting features with higher feature importance as interpretation of the initial complex model, wherein the method specifically comprises the following steps of:
the calculation formula of the feature importance is as follows:
Wherein, The importance expectation of the characteristic P is represented, wherein the characteristic P consists of K data, and P k is the kth data of the characteristic; imp (P k) is the feature importance of the kth data of the feature, wherein Each weight gamma mhm (x) in Imp (P k) is the contribution degree of the mth weak discriminator in the trained gradient-lifted tree model to the whole model,The method is defined as the reduction rate of the impure degree of the m weak discriminator when the input is P k, wherein the reduction rate of the impure degree refers to the ratio of the reduction amount of the impure degree of P k used in node segmentation to the total reduction amount of the impure degree when the weak discriminator predicts the characteristic P k; the calculation of the unrepeacy is performed by dividing the node n through which the feature P k passes in the decision tree model, namely Gain (P k,n)=i(n)-pLi(nL)-pRi(nR), wherein i (n) represents the unrepeacy of the node splitting, and P L and P R represent the parts of the sample after splitting reaching n L and n R respectively; in the gradient lifting tree model obtained by training, T m represents the mth weak discriminator, namely the mth decision tree model, and T m (x) is used for representing that an input sample is x, wherein
Sample x is a path corresponding to decision tree model T m in prediction, including a plurality of features P; the higher the importance expectation of the feature P indicates that the feature is more important to model decisions; all the features to be obtainedThe order from big to small is used as an interpretation extracted from the gradient lifting tree model and also used as an interpretation of the initial complex model.
CN202010580912.6A 2020-06-23 Local interpretable method based on gradient lifting tree Active CN111753995B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010580912.6A CN111753995B (en) 2020-06-23 Local interpretable method based on gradient lifting tree

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010580912.6A CN111753995B (en) 2020-06-23 Local interpretable method based on gradient lifting tree

Publications (2)

Publication Number Publication Date
CN111753995A CN111753995A (en) 2020-10-09
CN111753995B true CN111753995B (en) 2024-06-28

Family

ID=

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110443346A (en) * 2019-08-12 2019-11-12 腾讯科技(深圳)有限公司 A kind of model explanation method and device based on input feature vector importance
CN111027060A (en) * 2019-12-17 2020-04-17 电子科技大学 Knowledge distillation-based neural network black box attack type defense method

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110443346A (en) * 2019-08-12 2019-11-12 腾讯科技(深圳)有限公司 A kind of model explanation method and device based on input feature vector importance
CN111027060A (en) * 2019-12-17 2020-04-17 电子科技大学 Knowledge distillation-based neural network black box attack type defense method

Similar Documents

Publication Publication Date Title
CN107480261B (en) Fine-grained face image fast retrieval method based on deep learning
CN110163234B (en) Model training method and device and storage medium
US8239336B2 (en) Data processing using restricted boltzmann machines
CN109272332B (en) Client loss prediction method based on recurrent neural network
Barman et al. Transfer learning for small dataset
CN111581368A (en) Intelligent expert recommendation-oriented user image drawing method based on convolutional neural network
CN111882042B (en) Neural network architecture automatic search method, system and medium for liquid state machine
CN111898704B (en) Method and device for clustering content samples
Dai et al. Hybrid deep model for human behavior understanding on industrial internet of video things
Liu et al. EACP: An effective automatic channel pruning for neural networks
CN111783688B (en) Remote sensing image scene classification method based on convolutional neural network
CN111259938B (en) Manifold learning and gradient lifting model-based image multi-label classification method
CN112989843A (en) Intention recognition method and device, computing equipment and storage medium
CN111753995B (en) Local interpretable method based on gradient lifting tree
Xie et al. Scalenet: Searching for the model to scale
US11468352B2 (en) Method and system for predictive modeling of geographic income distribution
CN115100543A (en) Self-supervision self-distillation element learning method for small sample remote sensing image scene classification
Haw et al. Improving the prediction resolution time for customer support ticket system
Yang et al. iCausalOSR: invertible Causal Disentanglement for Open-set Recognition
CN114610953A (en) Data classification method, device, equipment and storage medium
CN111753995A (en) Local interpretable method based on gradient lifting tree
Wang et al. Improved fine-grained object retrieval with Hard Global Softmin Loss objective
CN116579722B (en) Commodity distribution warehouse-in and warehouse-out management method based on deep learning
CN116737607B (en) Sample data caching method, system, computer device and storage medium
CN113723456B (en) Automatic astronomical image classification method and system based on unsupervised machine learning

Legal Events

Date Code Title Description
PB01 Publication
SE01 Entry into force of request for substantive examination
GR01 Patent grant