Visualizing an Image Classification Model

 February 05, 2018
SHARE ON

AI/ML

If you followed along with our last post, we developed a deep-learning model that achieves our goal of identifying Simpsons characters in an image. However, as with all software development tasks, getting a working program is only half the battle. In order to maintain a program and fix bugs, the developer must understand the system– in particular, they must understand how it fails as well as how it succeeds. This can be quite a difficult task for deep-learning models, as they are black-boxes by nature of their construction. However, there are some techniques we have at our disposal to open up the black box and get a view into what is happening in our trained model; these can help us to find “bugs” in the model’s learning and even indicate how to resolve them. Among the many techniques to visualize the internals of a deep learning model, we will be focusing on the use of class activation maps.

What are class activation maps? 🔗

Class activation maps, or CAMs, provide a way to visualize what pixels in an image contribute the most to its classification by the model– effectively, it’s a map of how “important” each pixel is in an input image for a given classification. This map is generated by looking at the last layer in our deep learning model that contains spatial information; at that layer, CAM methods calculate the gradient of every pixel for a particular classification. This is in contrast to other visualization methods such as saliency maps, which directly use the output and don’t refer back to layers with spatial content.

We can use CAMs to show how our model performs well on some images and less-than-stellar on others. It can also be a useful tool for visualizing how different types of models learn, which can help in deciding which to use for a given problem.

CAM plots for our classifier 🔗

We want to investigate how a model learns in order to identify whether we have problems with over- or under-fitting our data. To visualize this, we have calculated CAMs for various input images and weighted the pixel intensity by its relative importance in the model’s classification. We have performed this at every epoch of the model during training, providing us a view into what the model “looks” at as it learns how to identify images.

The good: paying attention to identifying features 🔗

So what exactly does a proper image classification look like? We can see an example here:

The model starts out without any idea of what to look at, but quickly converges onto Homer’s face as the identifying feature. You can see the probability that Homer is correct (illustrated by the histogram at the bottom of the image) increases as the model focuses on the face and assigns more importance to those pixels.

This learning behavior is probably the most familiar to us, since humans use faces for identifying individuals too. However, it’s not the only solution that a deep learning classifier can produce! Let’s compare between two different deep-learning models– Inception and VGG-19– to see how different techniques can be used to converge to the same correct answer.

Although VGG-19 seems to focus mostly on Homer’s face, Inception seems to be looking in a wider area and puts more emphasis on Homer’s body. It’s hard to say which technique is more effective. Inception had a higher accuracy measure than VGG-19 at the end of training; however, it’s not entirely clear if higher accuracy on the training data directly translates to a better model. Accuracy can be inflated if a model is overfitting the data, so it’s best to take it with a grain of salt.

Here are a few other interesting CAM plots from the dataset:

The bad: not knowing what to look at 🔗

Sometimes it can be even more informative to look at scenarios where the classifier gets the wrong answer. In this example, we can see that the model is looking for something to use as an identifying feature.

Although Kent Brockman (the actual character in this image) occasionally appears in the top three, it’s not very often. The model doesn’t lock on to any identifying features in the image, so there is a lot of rapid turnover in the top three and there isn’t any classification that rises to the top. This could indicate that we need to add more pictures of Kent Brockman into our input dataset, or that we need to vary the input images containing his character so that the model learns to identify him in a variety of surroundings.

The ugly: getting it right for the wrong reason 🔗

Even more interesting are cases where the classifier quickly learns what to look at– and it’s not at all what we want. Take a look at this CAM plot for Krusty the Clown:

Although the classifier starts by looking at Krusty’s gloves, it quickly focuses on the red background in a TV frame and sticks with it throughout training. This is an example of overfitting: our model learned how to successfully classify our data by looking at unrelated commonalities between our input images rather than looking for the character. Since Krusty is frequently shown on TV with a red background, that was considered more informative than Krusty himself on the left side of the image. This is understandable once we take a look at some of our input images more carefully and see how often Krusty is in this environment:

This CAM plot is potentially very useful to us, since it suggests ways to improve our classifier. We can try to get more images of Krusty to use in our input set, ones where he is specifically not on a TV or with a red background. Alternatively, we could synthesize some images where we remove Krusty from this background and add in other Simpsons characters. It is also possible that we could try applying stronger image augmentation to our input data in order to prevent the classifier from focusing so strongly on the straight edge of the TV frame. Regardless, the CAM plot suggests that in its current state, our classifier is likely to associate anybody on TV with a red background with Krusty.

Here are some more examples we saw in the dataset that appear to be overfitting issues:

The unknown: limitations of the CAM plot 🔗

While CAM plots can be useful for investigating how a deep learning model classifies images, it’s far from a perfect method. There are often results that don’t align with our conceptual model of highlighting what a model “looks at,” particularly with more complex model architectures such as Inception. In the following example, it’s unclear whether Inception is overfitting or if the class-activation map method isn’t perfectly capturing the decision-making process:

It is also important to note that CAMs have a dependence on the structure of the underlying deep learning model that can affect what is represented in the plot. Specifically, since class activation maps (and the grad-CAM technique in particular, which Keras uses) operate on the nearest 2D feature map to the layer you are wanting to visualize, the distance between the layer you are visualizing and the nearest layer retaining spatial information can affect the quality of the resulting map. The Keras documentation calculates this distance automatically if the layer’s index (here called the “penultimate layer”) is not specified:

penultimate_layer_idx: The pre-layer to layer_idx whose feature maps should be used to compute gradients wrt filter output. If not provided, it is set to the nearest penultimate Conv or Pooling layer.

Check out the original paper for a comprehensive overview of the grad-CAM visualization technique with some excellent visuals.

Conclusion 🔗

The uses of deep learning go far beyond identifying characters in an image; in fact, much of deep learning’s power comes from how general is. We chose a Simpson’s dataset because it is easily relatable, visually distinctive, and can be explained without diving into too much medical jargon; however, the same techniques described here can be used on medical images as well. For example, if we were training a classification network to detect melanoma, we would expect the CAMs to lock onto the characteristic features of melanoma such as asymmetry, irregular borders, heterogenous colors, and large diameter.

Deep learning models can help us solve challenging classification problems in medical imaging, but we need tools to analyze their performance and investigate errors. Class activation maps can serve as a useful tool for visualizing activity at the top layers of a model and can even indicate non-trivial learning errors such as overfitting, making them particularly useful for investigating error cases. With the use of tools like CAM plots, we can develop deep learning models with more confidence and understand their behavior in a more transparent way.

Footnote on adversarial interpretation (added 3/2/2019) 🔗

An article recently surfaced on Hacker News that provided some well-founded criticisms of using visualization techniques such as CAMs as a way to interpret neural networks. The paper cited in this article introduces the concept of adversarial interpretation, whereby input images can be modified such that they produce the same label, but dramatically different feature maps. In the context of our Simpsons classifier: a user could provide a modified image of Homer that is visually indistinguishable from the original, and the network would correctly identify the image as containing Homer. However, the feature map would show completely nonsense features as being important.

In the paper by Ghorbani et al., gradient-based interpretation techniques such as grad-CAM are among the list of techniques that are vulnerable to these kinds of interpretation distortions. The reason these techniques are vulnerable to distortion is due to the complex decision gradient contours in neural network models; even small movements along these contours can produce dramatically different gradients, and thus a dramatically different feature map. The paper does a wonderful job of explaining how this works, and is worth a read if you are doing work in this area of visualizing neural network predictions.

As machine learning algorithms are used more widely, and as their interpretations are used as the basis for critical decision-making in high-risk industries, it is important to keep the limitations of these techniques in mind. Visualization techniques such as grad-CAM can be useful and powerful tools for identifying model decision making, but they are not fool-proof, as we saw in our “Limitations of the CAM Plot” section. However, they often can be a good place to start when investigating model behavior.

SHARE ON
×

Get To Market Faster

Monthly Medtech Insider Insights

Our monthly Medtech tips will help you get safe and effective Medtech software on the market faster. We cover regulatory process, AI/ML, software, cybersecurity, interoperability and more.