XAI - Explaining deep learning models

    The goal of explainability, also called eXplainable Artificial Intelligence (XAI), is to create explanations for a model’s predictions. These explanations are going to look different depending on the type of data we are working with.

    Here is an example of what explainability can look like when we use images. In this example, a model tries to classify if a picture contains a bell pepper. We can use explainability methods to highlight the parts of the image that is most important for making that decision.

    Interpretability and explainability are two terms that are often used interchangeably, and there is no generalized consensus over their formal definitions. We will refer to explainability as the methods and techniques used to explain Machine Learning or Deep Learning reasoning in human-understandable terms. Interpretability, on the other hand, is the extent to which cause and effect can be observed within a system.

    In this post, we will talk about the problems explainability is trying to solve. We will also touch upon some technical details of SHAP, and explainability algorithm out of the many that exist today.

    02/ Why do we need XAI?

    The reasons why explainability is important can vary, depending on the problem we are trying to solve.

    Problem 1: Performance-Interpretability trade-off

    Traditional ML algorithms (like linear regression or decision trees) are inherently interpretable and have been widely used but can often be outperformed by more complicated Machine Learning algorithms or Deep Learning models

    The higher the complexity of a model, the harder it gets to interpret it. As a result, we may end up having fairly good models with regards to the mainstream metrics used, but they work like black boxes.

    Problem 2: The right to an explanation

    Predictions delivered by a highly performing model can be impressive but might not inspire trust if there exists obscurity in how or why they are generated. Especially in cases where the models are being used in decision-making procedures that directly affect us (e.g. models predicting whether candidate clients are eligible for receiving a loan, or health care or self-driving cars etc). If a user is subject to an automated decision like this, she or he holds a right to explanation as defined by the GDPR. Thus, it is inevitable that we should integrate explainability to all Machine Learning or Deep Learning automated production pipelines. 

    Problem 3: our models are good, but they could be better

    If we could shed more light to the inner workings of the models, we could build better models by improving our methods in sensible ways.

    Problem 4: we have little understanding of the biases our models might be perpetuating

    Our datasets are reflections of our society. Stereotypical biases that might be present in the way we behave can be projected in them. Using those datasets to train our models can lead to biased predictions that may be harmful to specific groups of people that have already been marginalized in our communities.

    03/ What has been done to address this problem?

    Interpretability is more focused on model understanding techniques while explainability is a broader research field and aims to create explanations for models’ predictions. It also aims to develop an interface for delivering these explanations to users in human-understandable terms. 

    Interpretability and explainability constitute a key subfield that is part of a broader human-centric concept called Responsible AI, as shown in the picture below.

    Responsible AI is an umbrella term for methods and practices that ensure that AI benefits society.

    Another subfield of Responsible AI is Fairness. Bias and fairness can go hand in hand with explainability but the two research fields should not be confused.

    Fairness focuses on ensuring that models treat all target users fairly and do not reinforce bias. Explainability and interpretability can be used to ensure fairness as they could shed light into how and why a model makes predictions but they can also be used outside of the context of fairness. 

    Explainability’s definition can be considered fluid because it changes according to who is using it, their goals and the ways they could benefit from it. 

    There are three stakeholder groups of people who could be interested in using explainable AI. Each of these users will take different actions once they have the explanations:

    Data scientists, engineers, model builders (Internal stakeholders)

    This group consists of people who are managing the operations and productionizing models. They might want to understand why the models do not behave as they would expect. More complex models can be difficult to debug, understand and control, thereby impeding their adoption in production. Understanding how they work could help those stakeholders come up with ways of improving them in meaningful ways, and ensuring long term success of the Machine Learning and Deep learning pipelines. These users will use the explanations to go back and see which part in the pipeline should be improved, whether it is the training data, the way features are extracted or how the model architecture is created.

    End users (External stakeholders)

    This group consists of practitioners like doctors that need to verify that they can trust the models’ output, or other users that consume or use and can benefit from the predictions of machine learning models in any way. These users usually want to receive information about the features the predictions are based upon, rather than understanding how the model reached those predictions, aiming to build trust that a model decision is reliable and equitable.

    Regulators, executives (External stakeholders)

    These users want to ensure that a model’s decision is fit for the desired purpose, is in compliance with laws and regulations and does not amplify undesirable bias from underlying datasets. These stakeholders think of explainability as means for transparency and governance. 

    They might want to trace unexpected predictions back to their inputs to inform corrective actions. They also aim to understand how Machine Learning models make predictions in order to communicate the insights from that to other external stakeholders at a high level. Their ultimate goal is to guide industries towards integrating AI in a responsible way. 

    04/ Different methods for solving the problem

    There are different ways to tackle the problem of explainable AI but there exist two large categories: the post hoc explanations and the intrinsic explanations.

    The intrinsic explainability refers to methods that are interpretable by design. Those are simpler models like linear regression or short decision trees. More complicated models, like Deep Learning models, are rarely inherently interpetable. A bright exception to this regularity is TabNet, a Deep Learning model that is a product of recent research for tabular data. Attention based models, like Transformer models, are considered self-explaining models by some researchers. This is due to the attention weights being part of the prediction process and used as explanations. 

    The post hoc explanation methods look at relationships between features and model predictions without focusing on the model’s inner workings. Within the post-hoc explanations’ category, one can have individual predictions (local), or describe the average behaviour of a machine learning model (global). In this post we will focus on this category, and we will specifically look at a method called SHAP. SHAP is a model agnostic method that belongs to the family of XAI attribution methods. In the next section, we will see this method in action by going through some examples of using it on different types of data. Before we do that, let’s first delve into how SHAP works and build some intuition around it.

    05/ SHAP

    What is it and how does it work?

    SHAP is an abbreviation for Shapley Additive exPlanations. Back in 1951, Lloyd Shapley introduced the Shapley value and won a Nobel for it. Shapley values provide a fair solution to the following problem: 

    Suppose we have a group of cooperative members in a team (a coalition). All members in the team collaborate to produce some final value. What is the contribution of each member towards obtaining this final value? An example could be a group of friends who shared dinner together and are now trying  to figure out how to split the bill. Not all members contributed to the meal being consumed in an equitable way or ordered dishes of the same cost.

    This is not an easy question as the members could be interacting in ways that cause some to contribute more.  Those interactions should be taken into account in a fair way. 

    Even though Lloyd introduced the Shapley values for Game Theory problems like the one mentioned above, they have also been used in Machine Learning. Just think of the features of the samples as the members of the group and the model prediction as the final value. The features can be columns  if we are working with tabular data, pixels or superpixels of an image or words/tokens of text. The Shapley value assigned to each feature attribution would then determine how this feature contributed to the final prediction. 

    In tabular data, every column is a feature and each row is one sample. To obtain some intuition of how Shapley values work, we are going to be working with a generic schematic representation of a tabular data sample: 

    Think of the patterns 🔼,⏹,⏺ and *️⃣ as features constituting a data sample. We will from now on refer to this group as a coalition. Those features could be the members of that group of friends we mentioned earlier who are having dinner together and the predictions would then be the final bill in the restaurant, but the concept could apply to any tabular data case. To compute the Shapley values, we could start by looking at how much 🔼 contributes given the subset of the coalition that includes ⏹,⏺ and *️⃣. To this end, we form a subset of the coalition that does not include 🔼 and one that does include it (that would be the full coalition).

    We then compute the final values (predictions) derived by the two sets and compare them by subtraction. The difference between them is what we call the marginal contribution of 🔼 given the subset of the coalition that includes ⏹,⏺ and *️⃣. We are essentially answering this question: we have all of the features in the coalition but 🔼 and then we put it back. How does it change the prediction? 

    We then do the same for all possible pairs of subsets of coalitions that have one component including 🔼 and the other one excluding feature 🔼. There can be 8 such pairs:

    We repeat the same process by computing the marginal contribution P1-P2 for all pairs. The mean value of them is the Shapley Value of 🔼 which we denote by φ🔼. 

    Here is a generic mathematical formula for calculating the shapley value φi of any feature i:


    S is a subset of the coalition that does not include feature i,

    |S| is the size of this subset and 

    |F| is the size of the full coalition. 

    The general idea is to obtain different combinations of how features can appear together in subsets of the coalition and find out what the effect of having these combinations would have on the predictions made by the model. Those combinations are often called permutations of the data sample.  

    But what happens with the features that we remove? We cannot expect our model to just handle missing values at inference mode.

    What we do instead is that we define a background dataset.  How we define this background dataset can vary and depends on the type of the problem. One approach is to have the whole training dataset as the background dataset if its size is small. Another approach is to use some algorithms like k-means to summarize the dataset in a meaningful way so that the background dataset is representative of the data the model was trained on. 

    Once we have our background dataset, every time a feature is set to missing, we replace it with values (or the one single value) it takes in the background dataset. 

    It is probably obvious by now that if one wants to compute the Shapley value of a feature, she or he should have to sum all possible pairs of subsets of features. That is 2n-1  different pairs, where n is the number of features in our dataset. Now imagine having to run the model in inference mode for each one of those pairs of subsets of coalitions, which increases exponentially as we saw. 

    Is there a better way of doing that? The solution lies in approximation. Instead of taking all subsets of the coalition into account, we sample them. But another pair of problems arise now. The first is that to have a good estimation we need a large number of samples. The second problem is that the approximation method we use, which is probably going to be a Monte Carlo approximation, is naive in this case because it randomly samples combinations in a non principled way. That is probably not the best way to use our computational power.  We still need a better way. 

    This is when SHAP (SHapley Additive exPlanations) comes into play. The algorithm introduced in this paper, alongside having other properties, suggests improved sampling methods. For Kernel SHAP in particular, which is a model agnostic algorithm, the sampling method the authors suggest does not have a uniform prior like Monte Carlo. Rather, it tends to prefer the subsets that include fewer features. Think of it like this: a subset in which many features are missing, say for example that there is only one feature left in it, tells a lot about how that feature is going to affect the prediction. Intuitively it makes a lot of sense: the pair (🔼, ⏹, - , *️⃣) - (🔼, ⏹, ⏺, *️⃣) will lead to more information gain than the pair 

    (-, -, -, *️⃣) - (-, -,⏺, *️⃣) when studying the effect of feature ⏺, for example. Based on this concept, Kernel SHAP assigns larger weights to both subsets with small coalitions and subsets with large coalitions. This can be considered as a distance function between subsets that determines which ones get assigned to larger weights. The word kernel generally refers to a distance function, hence the name of the algorithm.

    If the kernel tells you how to weigh every single subset of the coalition, how does Kernel SHAP work? It is a beautiful combination of another algorithm called LIME and the Shapley values we introduced before. LIME algorithm conducts permutations of samples that are around the sample you want to explain, and it weights the samples that are closer to it more. Then, it fits a model locally on these permuted samples which are interpretable. 

    In Kernel SHAP, we pick a sample, we then sample the subsets of the coalitions using the method explained above and we run the model in inference mode on those subsets. Just like LIME does, we then fit a local model on them which aims to approximate our complex model. If the weights of the linear model are picked in the right way, the coefficients of the linear model are the Shapley values as computed by the Kernel SHAP algorithm.

    What does SHAP look like for different data types?

    This depends largely on the type of data that we used to train our model. Explanations are going to look different depending on whether we are working with text, image, or tabular data. In this section, we go through examples for images and text and discuss them on a high level. 


    The above figure shows the explanations for an image classifier - VGG16 trained on ImageNet. The image on the left is fed to this classification model, whose outputs include artichoke, bell pepper and grocery store among the classes with the highest logits, in decreasing order. Instead of highlighting specific pixels, the image here is divided into superpixels, i.e. regions in which color properties are similar. To obtain these regions in this case, we used k-means clustering. If a region contributed to predicting a specific class (positive SHAP values), it is colored in green, but regions that contributed negatively (negative SHAP values) are colored in red. The higher the contribution of a pixel (absolute SHAP value), the higher the intensity of the color. 

    By looking at the example above we can conclude that:

    • Superpixels containing artichokes are in general the ones responsible for the model picking the artichoke class, which is correct and as expected. Note how the superpixels containing bell pepper, lemon and aubergine do not contribute to the artichoke class.
    • For the bell pepper class, the unshadowed regions of the bell pepper are the ones positively contributing instead. Note, however, how the green area of the bell pepper is mistakenly not in favour of the bell pepper class, which reveals a weakness of the model.


    For text, we are going to look at a real-life example that was part of a research project conducted by Peltarion. For this project, local and global explanations were extracted for large language models. This work was done in collaboration with Folktandvården Västra Götaland, as part of the Swedish Medical Language Data Lab. The goal was to decrease the actual number of incorrect antibiotics prescriptions by creating a model capable of distinguishing when antibiotics should and shouldn’t be prescribed based on patient journal data, which is text data. Explaining the model’s predictions was essential, since medical experts often cannot trust a model’s predictions if unaccompanied by explanations. 

    As discussed, SHAP assigns a value to each feature. For text data the features are often tokens (parts of words), but in order to improve the clarity of explanations, the tokens can be joined to form the original words, which are then represented by the sum of the corresponding SHAP values. 

    Below you can see a sample of a patient’s journal text. The words in green and red are positively contributing towards prescribing and not prescribing antibiotics, respectively. The more intense the color, the more important the word is for the classification.

    A closer look reveals that the word with the largest contribution towards prescription is ”kåvepenin” (an antibiotic) for the given sample. This shows that the model is not behaving properly, since it is heavily relying on the presence of antibiotics in the text, instead of on the patients’ actual condition. Apart from the name of antibiotics, from SHAP it could also be seen that the model placed a large emphasis on the name of the doctor and the year a prescription was made. This observation triggered a data cleaning procedure, in which names of antibiotics, names of doctors, and years were removed from the dataset.

    After cleaning the data, SHAP highlighted words such as ”akut” (acute) and “febrig” (feverish), which indicates an improvement in model performance.

    This particular example shows how explainability and fairness/bias can go hand in hand. In an ideal scenario of a classification case, the model should not be picking up on any identity terms (e.g., doctor’s name) to determine a class of a prediction. 

    More information about this research project done by Omar Contreras can be found here.

    Important note: The algorithms that generate the explanations are models themselves. Being models, they are merely predicting what they “think” the model under study is basing its predictions on. Unlike interpretability, it is not the model’s inner workings that carry the explanation information. 

    06/ Looking ahead

    We foresee that explainability will become a standard component of the Machine Learning and Deep Learning automated production pipeline.

    Currently, the concept of explainability is distinct from the concept of causal inference but they are expected to converge as explainable AI is receiving more research interest. 

    Explanations are expected to better human understanding by translating predictions in a more human understandable form and reach more stakeholders, in the form of, perhaps, codeless explanation interactive interfaces.