Integrated Gradients Method for Image Classification #
XAI Course Project | Anatoliy Pushkarev
Goal #
Develop a robust image classification model and analyze its behavior with the help of the integrated gradients method.
Integrated Gradients #
Integrated Gradients is a technique for attributing a classification model’s prediction to its input features. It is a model interpretability technique: you can use it to visualize the relationship between input features and model predictions. It finds the importance of each pixel or feature in input data for a particular prediction of the model.
Integrated gradients original paper can be found here.
Steps #
- Select baseline (uniform distribution for all classes).
- Evaluate the path from baseline to input data point by many iterations.
- Observe how changing input data affects gradients.
- Integrate all gradients.
This is the original integrated gradients formula:
This is the Riemann approximation of the original formula, which is always used.
A very good article about integrated gradients and Riemann Approximation can be found here:
This is a good picture which shows, why exactly we need to sum gradients. Basically if we just take the gradients of the model wrt the inputs, we will get a lot of noise, which is illustrated on the right side of the picture. But if we scale the inputs, at the some point we get interesting gradients, which the method uses.
MNIST Fashion Dataset #
The dataset which I am using for the test purposes is Fashion-MNIST. Fashion-MNIST is a dataset of Zalando’s article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. Zalando intends Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.
Model #
- Keras Sequential
- 3 Conv layers
- 2 Max pooling layers
- Dropouts
- 2 FC layers
Training #
Below you can find a screenshot of model training loop. It is pretty straightforward.
Results of Training #
As a result, I got a robust model with accuracy > 90%, which is good for explanation.
There are some classes with lower precision and recall, it will be interesting to check IGs of them (for example, Shirt classs).
Examples of correctly predicted classes:
Examples of incorrectly predicted classes:
Applying Integrated Gradients #
Below you can find maps for IGs, which I`ve got after the method execution. Basically, red means that these points are important for certain class and blue means that these are negative features for a particular class.
More Examples #
This is and example of how IGs work with regression tasks. Basically x-axis is features and y-axis is weights of the feature to the output.
Challenges #
- Lack of open-source stable solutions for any model.
- Most likely you have to implement your own solution for a complicated model.
- Lack of support and community – not so popular (SHAP and GradCAM paper cited 20K times, IG – 5K times).