Example using fa_convnav to select modules for investigation with pytorch hooks.
#all_flag
# exclude whole notebook from testing. 
# ...Contains GPU code whch raises an exception on Colab CPU instance. 
# ...and testing on GPU instance works but is very time consuming.
# ...Include notebook in tests with'nbdev_test_nbs --flags all_flag'

Import fastai deep learning library including pretrained vision models.

from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *
from torch import torch
import PIL
import cv2

Import the fa_convnav.navigator module

from fa_convnav.navigator import *

Create a fastai datablock and dataloader using the Pets dataset (included with fastai2 install), and apply some simple image transforms in the process.

pets = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 splitter=RandomSplitter(),
                 get_y=RegexLabeller(pat = r'/([^/]+)_\d+.jpg$'),
                 item_tfms=Resize(460),
                 batch_tfms=[*aug_transforms(size=224, max_rotate=30, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])

dls = pets.dataloaders(untar_data(URLs.PETS)/"images",  bs=128)

Download the pretrained model we want to use.

model = resnet18

Create a fastai Learner object from the dataloader, chosen model, and an optimiser.

learn = cnn_learner(
    dls, 
    model, 
    opt_func=partial(Adam, lr=slice(3e-3), wd=0.01, eps=1e-8), 
    metrics=error_rate, 
    config=cnn_config(ps=0.33)).to_fp16()

Create a ConvNav instance.

cn = ConvNav(learn, learn.summary())

Check the CNDF dataframe.

cn.view(top=True)
Resnet: Resnet18
Input shape: [128 x 3 x 224 x 224] (bs, ch, h, w)
Output features: [128 x 37] (bs, classes)
Currently frozen to parameter group 3 out of 3

Module_name Model Division Container_child Container_block Layer_description Torch_class Output_dimensions Parameters Trainable Currently
Index
0 Sequential torch.nn.modules.container.Sequential
1 0 Sequential torch.nn.modules.container.Sequential
2 0.0 Conv2d Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) torch.nn.modules.conv.Conv2d [128 x 64 x 112 x 11] 9,408 False Frozen
3 0.1 BatchNorm2d BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.nn.modules.batchnorm.BatchNorm2d [128 x 64 x 112 x 11] 128 True
4 0.2 ReLU ReLU(inplace=True) torch.nn.modules.activation.ReLU [128 x 64 x 112 x 11] 0 False
5 0.3 MaxPool2d MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) torch.nn.modules.pooling.MaxPool2d [128 x 64 x 56 x 56] 0 False
6 0.4 Sequential torch.nn.modules.container.Sequential
7 0.4.0 BasicBlock torchvision.models.resnet.BasicBlock
8 0.4.0.conv1 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) torch.nn.modules.conv.Conv2d [128 x 64 x 56 x 56] 36,864 False Frozen
9 0.4.0.bn1 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.nn.modules.batchnorm.BatchNorm2d [128 x 64 x 56 x 56] 128 True
...69 more layers

Select a spread of equally spaced blocks from the model and store their module objects in spread.

spread = cn.spread('block', 4)
resnet18
Spread of block where n = 4

Module_name Model Division Container_child Container_block Num_layers Torch_class Output_dimensions Parameters Trainable Currently
Index
7 0.4.0 0 4 BasicBlock 5 torchvision.models.resnet.BasicBlock [128 x 64 x 56 x 56] Frozen
20 0.5.0 0 5 BasicBlock 8 torchvision.models.resnet.BasicBlock [128 x 128 x 28 x 28] Frozen
36 0.6.0 0 6 BasicBlock 8 torchvision.models.resnet.BasicBlock [128 x 256 x 14 x 14] Frozen
61 0.7.1 0 7 BasicBlock 5 torchvision.models.resnet.BasicBlock [128 x 512 x 7 x 7] Frozen

View the modules.

for b in spread:
  print(f'\n{b}')
BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

BasicBlock(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (downsample): Sequential(
    (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

BasicBlock(
  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (downsample): Sequential(
    (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

BasicBlock(
  (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

The following code registers forward and backward hooks to target_module then passes a single image from the dataloader, input_img, through the model. During the forward pass, the target_modules input and output activations are acquired and stored, while during the backward pass it's gradients are kept. For each target_module these are then converted into image representations of the input activations, gradients and output activations (or feature map) and into a gradient class activation map (gradcam).

If you are not familiar with Pytorch hooks don't worry they are just ways to 'hook' into a model at specific points (i.e. the target module or layer) to get and store activations and gradients. (https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/).

Class activation maps (CAM) and Gradiant CAMs (Gradcam) help visualize the image regions a model finds relevent for class prediction. The code here is adapted and abridged from a number of sources including https://github.com/anhquan0412/animation-classification/blob/master/gradcam.py and https://arxiv.org/pdf/1610.02391.pdf. The latter article and https://arxiv.org/pdf/1710.11063.pdf are a good place to learn more about gradcams.

class examine_modules():
  "Gets activation stats and gradients for a module"
  def __init__(self, model, target_module, input_img):
    self.model = model
    self.gradients = dict()
    self.activations_out = dict()
    self.activations_in = dict()
    self.input_img_resized = input_img.resize((224, 224), Image.BILINEAR)

    def forward_hook(module, inp, output):
        "Store the input and output activations from the forward hook"
        self.activations_in['value'] = inp[0]
        self.activations_out['value'] = output
        return None
        
    def backward_hook(module, grad_input, grad_output):
        "Store the gradients from the backward hook"
        self.gradients['value'] = grad_output[0]
        return None

    # forward_hook is called and passed the target_module activations during the forward pass, 
    # ... backward_hook the gradients during the backward pass 
    self.handle_fwd = target_module.register_forward_hook(forward_hook)
    self.handle_bck = target_module.register_backward_hook(backward_hook)

    def normalize(tensor):
      "Normalises an input tensor using imagenet mean and std deviation"
      mean, std = imagenet_stats
      mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
      std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
      return tensor.sub(mean).div(std)

    #convert the inpt image to a tensor, rearrange dims (h, w, c -> c, h, w) and add a 4th dim then convert to a float between 0 and 1 and place on the device!
    self.torch_img = torch.from_numpy(np.asarray(self.input_img_resized)).permute(2, 0, 1).unsqueeze(0).float().div(255).cuda()
    self.normed_torch_image = normalize(self.torch_img)

  def model_pass(self):
    "make a pass through the model with input image (shape: 1, 3, H, W)"

    preds = self.model(self.torch_img)                  # forward pass gives predictions (preds)
    loss = preds[:, preds.max(1)[-1]].squeeze()         # class prediction (highest pred) becomes the loss 

    self.model.zero_grad()
    loss.backward(retain_graph=False)                   # backward pass with highest prediction as loss

    self.acts_in = self.activations_in['value']         # input activations 
    self.acts_out = self.activations_out['value']       # output feature maps, shape (1, 512, 7, 7) for example
    self.grads = self.gradients['value']                # input gradients
    return None

  def gradcam(self):
    "Make gradcam"

    #create weight matrix shape (1,512,1,1) containing gradient means of each feature map for the class prediction
    b, c, h, w = self.grads.size()
    alpha = self.grads.view(b, c, -1).mean(2)                 
    weights = alpha.view(b, c, 1, 1)                    

    #create a class disciminatory localization map or saliency map from the gradient means (weights) and feature maps (acts_out).
    saliency_map = (weights*self.acts_out).sum(1, keepdim=True)    
    saliency_map = F.relu(saliency_map)
    saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear')
    saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
    saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data

    #create a heatmap (concatenated torch tensor of r,g,b single channel heatmaps) using the saliency map as a mask.
    mask = saliency_map
    heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze().cpu()), cv2.COLORMAP_JET)
    heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255)
    b, g, r = heatmap.split(1)
    heatmap = torch.cat([r, g, b])
    
    #combine the heatmaps and input image and merge the thre color channels into the final gradcam.
    cam_rgb = heatmap + self.torch_img.cpu()
    cam_rgb = cam_rgb.div(cam_rgb.max()).squeeze()
    self.gradcam = cv2.merge((cam_rgb.numpy())) 
    return None

  def display(self, n):
    "Display input image, input activations, output activations, gradient and gradcam images in a row"

    def display_image(x):
      "Denormalises, processea and displays a torch.image , `x`"
      x -= x.mean()
      x /= (x.std() + 1e-5)
      x *= 0.1
      x += 0.5
      x = np.clip(x, 0, 1)
      x *= 255
      x = np.clip(x, 0, 255).astype('uint8')
      plt.imshow(x)
      return None

    fig, ax = plt.subplots(1,5, figsize=(20, 4))

    plt.subplot(1, 5, 1)
    img = self.input_img_resized.resize(self.acts_in[0][1].shape, Image.BILINEAR)
    plt.imshow(img)
    if not n: plt.title("\nInput image \n(from dataloader)\n", fontsize=18)
    plt.ylabel(f'Block {n}', fontsize=20)

    plt.subplot(1, 5, 2)
    f_in = self.acts_in[0][1].cpu().detach().numpy()
    if not n: plt.title("\nInput activations\n", fontsize=20)
    display_image(f_in)
    
    plt.subplot(1, 5, 3)
    g = self.grads[0][1].cpu().detach().numpy()
    if not n: plt.title("\nGradients\n", fontsize=18)
    display_image(g)
    
    plt.subplot(1, 5, 4)
    f_out = self.acts_out[0][1].cpu().detach().numpy()
    if not n: plt.title("\nOutput activations \n(feature map)\n", fontsize=18)
    display_image(f_out)
    
    plt.subplot(1, 5, 5)
    if not n: plt.title("\nGradcam\n", fontsize=18)
    plt.imshow(self.gradcam)

    plt.show()
    return None
  
  def remove_hooks(self):
    self.handle_fwd.remove()
    self.handle_bck.remove()

By plotting the input activations, gradients, feature maps and gradcam in succession for each of the four modules previously selected using fa_convnav (and stored in spread) we can easily follow an image as it is processed by the model. By repeating the process at different stages during model training we can gain insight into the training process itself.

# place the model onto the GPU
model = learn.model.eval().cuda()

# select an image from the dataloader
fname = dls.valid_ds.items[0]
pil_img = PIL.Image.open(fname) 

#Loop through our selected modules in spread
def display_img_grid(model, modules, input_img):
  for n, m in enumerate(modules):
    x = examine_modules(model, m, input_img)
    x.model_pass()
    x.gradcam()
    x.display(n)
    x.remove_hooks()

display_img_grid(model, spread, pil_img)

Resnet18 has been pre-trained on the imagenet dataset and the gradcam shows that it already does a pretty good job of identifying the dog or cat (there were lots of dog and cat images in the imagenet dataset) even without additional training. Now lets train the model on our pets dataset and see if it can classify different breeds of dog and cat.

learn.fit_one_cycle(5, 1e-3)
display_img_grid(model, spread, pil_img)
epoch train_loss valid_loss error_rate time
0 2.758332 0.546939 0.161028 00:58
1 1.190638 0.355577 0.106225 00:58
2 0.689494 0.318071 0.091340 00:58
3 0.456033 0.291165 0.079161 00:58
4 0.364217 0.294581 0.083221 01:00

After 5 epochs the error rate for breed classification using te Pets dataset has fallen to about 0.08 and the gradcam shows that model increasingly localises those features of the dog or cat which distinguish it from other breeds. The mdel appear sto be training correctly.

Next we can unfreeze the body of the model and train for a further 5 epochs.

learn.unfreeze()
learn.fit_one_cycle(5, slice(1e-5, 1e-4, 1e-4))
display_img_grid(model, spread, pil_img)
epoch train_loss valid_loss error_rate time
0 0.303732 0.284637 0.079161 01:01
1 0.284223 0.285909 0.078484 01:00
2 0.251341 0.268647 0.081867 01:00
3 0.218321 0.265787 0.077808 01:00
4 0.208654 0.265001 0.077131 00:59

After a further 5 epochs training the unfrozen model, close examination of the gradcam image above shows an even greater localization of the distinguishing features of the dog or cat.

Now lets try something different and train the model again, bit this time at a much higher learning rate.

learn.fit_one_cycle(1, 1e-2)
display_img_grid(model, spread, pil_img)
epoch train_loss valid_loss error_rate time
0 4.272862 4.995016 0.947903 00:58

As expected, the high learning rate hinders rather than helps gradient descent and reduces model performance. The gradcam images become more noisy and less focused. This model is no longer training effectively.