import albumentations as A
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os
Building and Training a RetinaNet
Overview and Setup
The goal of this post is to build a RetinaNet model and train it on the PASCAL VOC dataset. We’ll use a couple of tricks, including fine-tuning the RetinaNet’s backbone on a related task, to speed up training and get results in about 2 hours total. Our trained model will predict bounding boxes for 20 classes of objects in images. For example, the model we train in this post makes the following predictions for the image below. The decimal numbers in each box label is the model’s confidence score for that label.
Package | Version |
---|---|
python | 3.9.16 |
PyTorch | 2.2.1+cu121 |
torchvision | 0.17.1+cu121 |
fastai | 2.7.14 |
matplotlib | 3.8.3 |
numpy | 1.25.2 |
albumentations | 1.4.2 |
pycocotools | 2.0.7 |
Get and Format the Data
We’ll use fastai’s versions of the PASCAL VOC datasets. We’ll combine the 2007 and 2012 training and validation sets to use as our training set, and we’ll use the 2007 test set as our validation set.
from fastai.vision.all import untar_data, URLs
= untar_data(URLs.PASCAL_2007)
path_2007 os.listdir(path_2007)
['test.json',
'test.csv',
'segmentation',
'valid.json',
'train',
'test',
'train.csv',
'train.json']
= untar_data(URLs.PASCAL_2012)
path_2012 os.listdir(path_2012)
['segmentation', 'valid.json', 'train', 'test', 'train.csv', 'train.json']
= [
train_data_sources /'train.json',
path_2007/'valid.json',
path_2007/'train.json',
path_2012/'valid.json',
path_2012
]= path_2007/'test.json' valid_data_source
The annotation JSON files are in the COCO format, and fastai has a convenience get_annotations
function to extract the data relevant for our task. To speed up training the RetinaNet, we’ll first fine-tune the ResNet backbone on a multilabel classification task. For organizational efficiency, we’ll generate and store the targets for both the classification and object detection tasks together and extract the relevant targets for each task when needed with custom Dataset
s and DataLoader
collation functions.
from fastai.vision.all import get_annotations
from itertools import chain, starmap
def get_vocab_dicts(data_source):
= get_annotations(data_source)
_, targets = set().union(*(set(o[1]) for o in targets))
all_labels = sorted(all_labels)
all_labels = dict(enumerate(all_labels))
idx2label = {v: k for k, v in idx2label.items()}
label2idx return idx2label, label2idx
def organize_annotations(train_data_sources, valid_data_source):
= [
train_images, train_targets list(chain.from_iterable(o))
for o in zip(*[get_annotations(source, prefix=str(source.parent)+'/train/')
for source in train_data_sources])
]= get_annotations(
valid_images, valid_targets =str(valid_data_source.parent)+'/test/'
valid_data_source, prefix
)= get_vocab_dicts(train_data_sources[0])
_, label2idx return (
(train_images, train_targets, label2idx),
(valid_images, valid_targets, label2idx)
)
def reformat_data(imgs, targs, label2idx):
= [
result 'image': img,
{'bboxes': targ[0],
'labels': [label2idx[o] for o in targ[1]]}
for img, targ in zip(imgs, targs)
]for item in result:
= [0] * len(label2idx)
multilabel for i in set(item['labels']):
= 1
multilabel[i] |= {'multilabel': multilabel}
item return result
def organize_data(train_data_sources, valid_data_source):
= organize_annotations(
results
train_data_sources, valid_data_source
)return starmap(reformat_data, results)
= organize_data(
train_data, valid_data
train_data_sources, valid_data_source )
Now let’s take a look at how we’ve formatted the data.
1] train_data[
{'image': '/root/.fastai/data/pascal_2007/train/000017.jpg',
'bboxes': [[184, 61, 279, 199], [89, 77, 403, 336]],
'labels': [14, 12],
'multilabel': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0]}
For later use, we’ll generate dictionaries to go back and forth between integer and string class labels.
= get_vocab_dicts(train_data_sources[0])
idx2label, label2idx idx2label
{0: 'aeroplane',
1: 'bicycle',
2: 'bird',
3: 'boat',
4: 'bottle',
5: 'bus',
6: 'car',
7: 'cat',
8: 'chair',
9: 'cow',
10: 'diningtable',
11: 'dog',
12: 'horse',
13: 'motorbike',
14: 'person',
15: 'pottedplant',
16: 'sheep',
17: 'sofa',
18: 'train',
19: 'tvmonitor'}
Visualize the Data
Let’s take a look at some of the images in our dataset together with their bounding box annotations.
Image Plotting Functions
import matplotlib.patheffects as path_effects
from fastai.vision.all import PILImage, Path
from PIL import Image
= matplotlib.colormaps['tab20']
cmap = {
imagenet_stats 'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
}
def draw_bbox(ax, bbox, label, alpha=1.0):
= bbox
x0, y0, x1, y1 = matplotlib.patches.Rectangle(
box - x0, y1 - y0,
(x0, y0), x1 ='none',
facecolor=2,
linewidth=cmap(label),
edgecolor=alpha,
alpha
)= [
border_effects =5, foreground='black', alpha=alpha),
path_effects.Stroke(linewidth
path_effects.Normal()
]
box.set_path_effects(border_effects)
ax.add_patch(box)
def add_bbox_label(ax, bbox, label, alpha=1.0):
= dict(
props ='square',
boxstyle=cmap(label),
facecolor='black',
edgecolor=1.1,
linewidth=alpha
alpha
)= bbox
x0, y0, _, _ if isinstance(label, torch.Tensor):
= label.item()
label
ax.text(
x0, y0, idx2label[label],=props,
bbox='black',
color=8,
fontsize=True
in_layout
)
def decode(image):
if isinstance(image, Path):
= Image.open(image)
image if isinstance(image, torch.Tensor):
if image.shape[0] <= 3:
= image.permute(1, 2, 0)
image = image.cpu().numpy()
image if isinstance(image, np.ndarray):
if image.dtype == np.float32:
= np.array(imagenet_stats['mean'])[None, None]
mean = np.array(imagenet_stats['std'])[None, None]
std = (image*std + mean) * 255.0
image = image.astype(np.uint8)
image = Image.fromarray(image)
image return image
def plot_image_with_annotations(image, bboxes, labels, ax=None, alpha=1.0, **kwargs):
if isinstance(image, str):
= PILImage.create(image)
image = decode(image)
image if isinstance(bboxes, torch.Tensor):
= bboxes.cpu().numpy()
bboxes if isinstance(labels, torch.Tensor):
= labels.cpu().numpy()
labels
if ax is None:
= plt.subplots()
fig, ax
fig.tight_layout()
ax.imshow(image)'off')
ax.axis(
for bbox, label in zip(bboxes, labels):
=alpha)
draw_bbox(ax, bbox, label, alpha=alpha) add_bbox_label(ax, bbox, label, alpha
= plt.subplots(nrows=4, ncols=4, figsize=(16, 16))
fig, axs for idx, ax in enumerate(axs.flat):
**train_data[idx], ax=ax)
plot_image_with_annotations( fig.tight_layout()
Multi-Label Classification
Our first task is to fine-tune a ResNet to do multi-label classification on our dataset. As mentioned above, we’ll use a custom Dataset
to extract only the targets that we need for this task.
= ('cuda' if torch.cuda.is_available()
def_device else 'mps' if torch.backends.mps.is_available()
else 'cpu')
class MultilabelDataset(torch.utils.data.Dataset):
def __init__(self, items, tfms=None, device=def_device):
self.items = items
self.tfms = tfms
self.device = device
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
= self.items[idx]
item = {'image': self.open_as_array(item['image'])}
image_as_array = item | image_as_array
item = {k: item[k] for k in ['image', 'multilabel']}
item if self.tfms is not None:
= item | self.tfms(image=item['image'])
item return {k: torch.tensor(item[k]).float()
for k in ['image', 'multilabel']}
def open_as_array(self, image):
return np.array(Image.open(image))
Now we’ll set up our image transforms and create our training and validation datasets for this task. We’ll pad all of the images to be \(600 \times 600\) to make them a uniform size while maintaining the aspect ratios of the bounding boxes in our later object detection model. We’ll also use minimal data augmentation: just random flips, since that’s all we’ll use for the object detection model.
= 600
image_size
= {
imagenet_stats 'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
}
= {
pad_params 'min_height': image_size,
'min_width': image_size,
'position': 'top_left',
'border_mode': 0,
'value': 0,
}
= A.Compose(
multilabel_train_tfms =[
transforms=0.5),
A.Flip(p**pad_params),
A.PadIfNeeded(**imagenet_stats),
A.Normalize(
]
)
= A.Compose(
multilabel_valid_tfms =[
transforms**pad_params),
A.PadIfNeeded(**imagenet_stats),
A.Normalize(
]
)
= MultilabelDataset(
multilabel_train_ds =multilabel_train_tfms
train_data, tfms
)
= MultilabelDataset(
multilabel_valid_ds =multilabel_valid_tfms
valid_data, tfms )
Padding the images in this way was inspired by this KerasCV tutorial.
Now let’s visualize our transformed data.
= plt.subplots(nrows=4, ncols=4, figsize=(16, 20))
fig, axs for idx, ax in enumerate(axs.flat):
= multilabel_train_ds[idx]
item = item['image'], item['multilabel']
image, target = ', '.join([idx2label[i.item()] for i in target.nonzero()])
title
ax.imshow(decode(image))set(xticks=[], yticks=[], title=title)
ax. fig.tight_layout()
We’ll use a custom collation function put our images in channels-first format.
from fastai.vision.all import DataLoader, DataLoaders
def multilabel_collate(batch):
= torch.stack([o['image'].permute(2, 0, 1) for o in batch])
images = torch.stack([o['multilabel'] for o in batch])
multilabels return images, multilabels
= {
multilabel_config 'bs': 24,
'create_batch': multilabel_collate,
'device': def_device,
'num_workers': 8
}
= DataLoader(
multilabel_train_dl =True, **multilabel_config
multilabel_train_ds, shuffle
)= DataLoader(
multilabel_valid_dl =False, **multilabel_config
multilabel_valid_ds, shuffle
)
= DataLoaders(multilabel_train_dl, multilabel_valid_dl) multilabel_dls
Now we’ll define our multi-label image classification model. One of the tricks we’ll use to get better object detection results later on is to use Mish as our activation function instead of ReLU.
from fastai.vision.all import create_body, create_head, resnet101
class MultilabelModel(nn.Module):
def __init__(self, n_out):
super().__init__()
self.backbone = create_body(resnet101(weights='DEFAULT'))
self.head = create_head(2048, n_out)
self.backbone = self.swap_activation(self.backbone)
self.head = self.swap_activation(self.head)
def forward(self, x):
= self.backbone(x)
x return self.head(x)
def swap_activation(self, module, old_act=nn.ReLU, new_act=nn.Mish):
for name, submodule in module._modules.items():
if len(list(submodule.children())) > 0:
= self.swap_activation(submodule)
module._modules[name] if isinstance(submodule, old_act):
= new_act(inplace=True)
module._modules[name] return module
The code for swap_activation
was adapted from this notebook by Radek Osmulski.
Now we can find a good learning rate and fine-tune our multi-label classification model.
from fastai.vision.all import (
params, L, accuracy_multi, Learner,
minimum, steep, valley, slide
)
def multilabel_split(model):
return L(
6],
model.backbone[:6:],
model.backbone[
model.headmap(params)
).
= Learner(
multilabel_learn =multilabel_dls,
dls=MultilabelModel(20),
model=nn.BCEWithLogitsLoss(),
loss_func=accuracy_multi,
metrics=multilabel_split
splitter
)
multilabel_learn.freeze()= multilabel_learn.lr_find(
multilabel_lrs =(minimum, steep, valley, slide)
suggest_funcs )
3, multilabel_lrs.slide) multilabel_learn.fine_tune(
epoch | train_loss | valid_loss | accuracy_multi | time |
---|---|---|---|---|
0 | 0.121236 | 0.070566 | 0.975909 | 06:49 |
epoch | train_loss | valid_loss | accuracy_multi | time |
---|---|---|---|---|
0 | 0.107960 | 0.075465 | 0.974031 | 08:07 |
1 | 0.081165 | 0.059422 | 0.979069 | 08:07 |
2 | 0.051065 | 0.047460 | 0.983330 | 08:07 |
Object Detection
Now we’ll build a RetinaNet model using our fine-tuned ResNet as the backbone. We’ll need a new dataset class since our targets for object detection are different.
class ObjectDetectionDataset(torch.utils.data.Dataset):
def __init__(self, items, tfms=None, device=def_device):
self.items = items
self.tfms = tfms
self.device = device
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
= self.items[idx]
item = {'image': self.open_as_array(item['image'])}
image_as_array = item | image_as_array
item = {k: item[k] for k in ['image', 'bboxes', 'labels']}
item if self.tfms is not None:
= self.tfms(**item)
item return {
k: torch.tensor(item[k]).to(t)for k, t in [
'image', torch.float32],
['bboxes', torch.float32],
['labels', torch.int64]
[
]
}
def open_as_array(self, image):
return np.array(Image.open(image))
Now we’ll set up the data processing pipeline for our object detection model. We’ll re-use the image_size
, imagenet_stats
, and pad_params
from the multi-label classification model.
= {
bbox_params 'format': 'pascal_voc',
'min_visibility': 0.2,
'label_fields': ['labels']
}
= A.Compose(
object_detection_train_tfms =[
transforms
A.BBoxSafeRandomCrop(),=0.5),
A.Flip(p**pad_params),
A.PadIfNeeded(**imagenet_stats),
A.Normalize(=A.BboxParams(**bbox_params)
], bbox_params
)
= A.Compose(
object_detection_valid_tfms =[
transforms**pad_params),
A.PadIfNeeded(**imagenet_stats),
A.Normalize(=A.BboxParams(**bbox_params)
], bbox_params
)
= ObjectDetectionDataset(
train_ds =object_detection_train_tfms
train_data, tfms
)
= ObjectDetectionDataset(
valid_ds =object_detection_valid_tfms
valid_data, tfms )
Let’s visualize our object detection data with the transforms above.
= plt.subplots(nrows=4, ncols=4, figsize=(16, 16))
fig, axs for idx, ax in enumerate(axs.flat):
**train_ds[idx], ax=ax)
plot_image_with_annotations( fig.tight_layout()
The number of bounding boxes and labels varies from image to image, so we’ll need to pad those targets when collating each batch.
def pad_bboxes(bboxes, max_targets):
= max_targets - len(bboxes)
pad_size = torch.tensor([[-100] * 4] * pad_size, device=bboxes.device)
padding return torch.cat([bboxes, padding])
def pad_labels(labels, max_targets):
= max_targets - len(labels)
pad_size = torch.tensor([-100] * pad_size, device=labels.device)
padding return torch.cat([labels, padding])
def object_detection_collate(batch):
= torch.stack([o['image'].permute(2, 0, 1) for o in batch])
images = max(len(o['labels']) for o in batch)
max_targets = torch.stack([pad_bboxes(o['bboxes'], max_targets) for o in batch])
bboxes = torch.stack([pad_labels(o['labels'], max_targets) for o in batch])
labels return images, bboxes, labels
= {
object_detection_config 'bs': 8,
'create_batch': object_detection_collate,
'device': def_device,
'num_workers': 8
}
= DataLoader(
object_detection_train_dl =True, **object_detection_config
train_ds, shuffle
)
= DataLoader(
object_detection_valid_dl =False, **object_detection_config
valid_ds, shuffle )
Now we’ll define our object detection model.
from fastai.vision.all import hook_outputs
from functools import partial
import math
def conv_with_init(
n_in,
n_out,=3,
kernel_size=1,
stride=True,
bias=nn.init.kaiming_normal_,
weight_init=partial(nn.init.constant_, val=0.0)
bias_init
):= nn.Conv2d(n_in, n_out, kernel_size=kernel_size,
conv =stride, padding=kernel_size//2, bias=bias)
stride
weight_init(conv.weight)
bias_init(conv.bias)return conv
class ObjectDetectionHead(nn.Module):
def __init__(self, n_out, n_anchors, bias_init_val):
super().__init__()
self.n_out = n_out
= []
layers for _ in range(4):
+= [conv_with_init(256, 256),
layers
nn.Mish()]+= [conv_with_init(
layers 256, n_out * n_anchors,
=partial(nn.init.constant_, val=0.0),
weight_init=partial(nn.init.constant_, val=bias_init_val)
bias_init
)]self.head = nn.Sequential(*layers)
def forward(self, x):
return torch.cat(
self.reshape(self.head(o), self.n_out) for o in x], dim=1
[
)
def reshape(self, x, n_out):
return (x.permute(0, 2, 3, 1)
0], -1, n_out))
.reshape(x.shape[
class ObjectDetectionModel(nn.Module):
def __init__(self, n_anchors, backbone=None):
super().__init__()
self.backbone = create_body(resnet101()) if backbone is None else backbone
# FPN top path
self.c5top5 = conv_with_init(2048, 256, kernel_size=1)
self.c5top6 = conv_with_init(2048, 256, stride=2)
self.p6top7 = nn.Sequential(nn.Mish(),
256, 256, stride=2))
conv_with_init(
# FPN down path
self.c4_cross = conv_with_init(1024, 256, kernel_size=1)
self.c3_cross = conv_with_init(512, 256, kernel_size=1)
# smooth results of FPN down path
self.p3_out = conv_with_init(256, 256)
self.p4_out = conv_with_init(256, 256)
self.p5_out = conv_with_init(256, 256)
# bounding box regression head and image classification head
self.box_head = ObjectDetectionHead(4, n_anchors, 0.0)
= -math.log((1 - 0.01) / 0.01)
prior self.class_head = ObjectDetectionHead(20, n_anchors, prior)
def forward(self, x):
= self.backbone[-3:-1]
hook_layers with hook_outputs(hook_layers, detach=False) as h:
= self.backbone(x)
c5 = h.stored
c3, c4
# FPN top path
= self.c5top5(c5)
p5 = self.c5top6(c5)
p6 = self.p6top7(p6)
p7
# FPN down path
= self.c4_cross(c4) + F.interpolate(p5, size=38, mode='nearest-exact')
p4 = self.c3_cross(c3) + F.interpolate(p4, size=75, mode='nearest-exact')
p3
# smooth results of FPN down path
= self.p3_out(p3)
p3 = self.p4_out(p4)
p4 = self.p5_out(p5)
p5
= [p3, p4, p5, p6, p7]
fpn_out return self.box_head(fpn_out), self.class_head(fpn_out)
The code for the model is partially based on code from this Keras tutorial and this notebook from an old version of the fastai course.
Now we’ll define our loss function, starting with code to generate anchor boxes.
from torchvision.ops import box_convert, box_iou
from itertools import product
def generate_centers(grid_size):
= torch.arange(0, grid_size, dtype=torch.float32, device=def_device) + 0.5
coords = torch.stack(torch.meshgrid(coords, coords.clone(), indexing='xy'), dim=-1)
centers return centers
def generate_one_anchor_grid(
image_size,
grid_size,
area,
aspect_ratio,
scale
):= math.sqrt(area / aspect_ratio)
height = area / height
width = generate_centers(grid_size) * (image_size / grid_size)
centers = torch.full((grid_size, grid_size, 1), scale * height, device=def_device)
heights = torch.full((grid_size, grid_size, 1), scale * width, device=def_device)
widths return torch.cat([centers, widths, heights], dim=-1)
def generate_anchor_grids(
image_size,
grid_size,
area,=None,
aspect_ratios=None
scales
):if aspect_ratios is None:
= [1/2, 1, 2]
aspect_ratios if scales is None:
= [math.pow(2, i / 3) for i in range(3)]
scales = torch.empty(
anchors len(aspect_ratios) * len(scales), 4,
grid_size, grid_size, =torch.float32, device=def_device
dtype
)for i, (r, s) in enumerate(product(aspect_ratios, scales)):
= generate_one_anchor_grid(
anchors[..., i, :]
image_size, grid_size, area, r, s
)return anchors.view(-1, 4)
def generate_anchor_boxes(
image_size,=None,
areas=None,
grid_sizes=None,
aspect_ratios=None
scales
):if grid_sizes is None:
= [75, 38, 19, 10, 5]
grid_sizes if areas is None:
= [(4 * image_size / grid_size)**2 for grid_size in grid_sizes]
areas return torch.cat(
[generate_anchor_grids(image_size, grid_size, area, aspect_ratios, scales)for area, grid_size in zip(areas, grid_sizes)]
)
The code to generate anchor boxes is partially based on code from this Keras tutorial and this notebook from an old version of the fastai course.
For the loss function itself, we’ll use complete iou loss for the bounding box targets and focal loss for the classification targets.
from torchvision.ops import sigmoid_focal_loss, complete_box_iou_loss
class ObjectDetectionLoss(nn.Module):
def __init__(self):
super().__init__()
self.anchors = generate_anchor_boxes(600)
self.box_loss_func = complete_box_iou_loss
self.class_loss_func = sigmoid_focal_loss
def forward(self, pred, *targ):
= []
losses = pred[0].shape[0]
batch_size for pb, pc, tb, tc in zip(*pred, *targ):
= self.unpad(tb, tc)
tb, tc = self.assign_anchors(self.anchors, tb)
assignments = assignments >= 0
box_mask if box_mask.sum() > 0:
= tb[assignments[box_mask]]
tb = self.anchors[box_mask]
an = pb[box_mask]
pb = self.compute_pred_boxes(pb, an)
pred_boxes = box_convert(pred_boxes, 'cxcywh', 'xyxy')
pred_boxes = self.box_loss_func(pred_boxes, tb, reduction='mean')
box_loss else:
= 0.0
box_loss = assignments >= -1
class_mask if class_mask.sum() > 0:
= tc[assignments[box_mask]] + 1
class_assignments = self.compute_one_hot_targets(class_assignments, box_mask, class_mask)
tc = self.class_loss_func(pc[class_mask], tc.float(), reduction='sum')
class_loss = class_loss / box_mask.sum().clamp_min(1)
class_loss else:
= 0.0
class_loss + class_loss)
losses.append(box_loss return sum(losses) / batch_size
def unpad(self, targ_boxes, targ_classes):
= (targ_classes != -100)
mask return targ_boxes[mask], targ_classes[mask]
def assign_anchors(self, anchors, targ_boxes, foreground_thresh=0.5, background_thresh=0.4):
= box_convert(anchors, in_fmt='cxcywh', out_fmt='xyxy')
anchors = box_iou(anchors, targ_boxes)
iou_matrix = iou_matrix.max(dim=1)
max_iou, prelim_assignments = (max_iou > foreground_thresh)
foreground_mask = (max_iou < background_thresh)
background_mask = torch.full((anchors.shape[0],), fill_value=-2, device=def_device)
assignments = prelim_assignments[foreground_mask]
assignments[foreground_mask] = -1
assignments[background_mask] return assignments
def compute_pred_boxes(self, box_preds, anchors):
= torch.empty_like(anchors)
result 2] = box_preds[..., :2] * anchors[..., 2:] + anchors[..., :2]
result[..., :2:] = anchors[..., 2:] * torch.exp(box_preds[..., 2:])
result[..., return result
def compute_one_hot_targets(self, class_assignments, box_mask, class_mask):
= torch.zeros(
result 0],), device=class_assignments.device
(box_mask.shape[
)= class_assignments
result[box_mask] = result[class_mask]
result return F.one_hot(result.long(), num_classes=21)[:, 1:]
The code for the loss function is partially based on code from this Keras tutorial and this notebook from an old version of the fastai course.
Finally, we’ll create a fastai DataLoaders
and train our model.
= DataLoaders(
object_detection_dls
object_detection_train_dl, object_detection_valid_dl
)
# the Learner needs to know the number of inputs to the model
= 1
object_detection_dls.n_inp
= multilabel_learn.model.backbone
object_detection_backbone
= ObjectDetectionModel(
object_detection_model =9, backbone=object_detection_backbone
n_anchors
)
= Learner(
object_detection_learn =object_detection_dls,
dls=object_detection_model,
model=ObjectDetectionLoss(),
loss_func
).to_fp16()
= object_detection_learn.lr_find(
object_detection_lrs =(minimum, steep, valley, slide)
suggest_funcs )
10, 1e-4) object_detection_learn.fit_one_cycle(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 0.721214 | 0.723528 | 09:23 |
1 | 0.545930 | 0.582645 | 09:23 |
2 | 0.459409 | 0.476774 | 09:24 |
3 | 0.399870 | 0.447026 | 09:24 |
4 | 0.350284 | 0.405484 | 09:26 |
5 | 0.309018 | 0.382009 | 09:28 |
6 | 0.283841 | 0.362980 | 09:29 |
7 | 0.250841 | 0.359760 | 09:28 |
8 | 0.232529 | 0.359746 | 09:27 |
9 | 0.207503 | 0.358098 | 09:27 |
Now let’s visualize the outputs of our trained RetinaNet.
Inference Image Plotting Functions
def draw_inference_bbox(ax, bbox, label, score, alpha=1.0):
= bbox
x0, y0, x1, y1 = matplotlib.patches.Rectangle(
box - x0, y1 - y0,
(x0, y0), x1 ='none',
facecolor=2,
linewidth=cmap(label),
edgecolor=alpha,
alpha=100*score.item(),
zorder
)= [
border_effects =5, foreground='black', alpha=alpha),
path_effects.Stroke(linewidth
path_effects.Normal()
]
box.set_path_effects(border_effects)
ax.add_patch(box)
def add_bbox_inference_label(ax, bbox, label, score, alpha=1.0):
= dict(
props ='square',
boxstyle=cmap(label),
facecolor='black',
edgecolor=1.1,
linewidth=alpha,
alpha=100*score.item()
zorder
)= bbox
x0, y0, _, _ if isinstance(label, torch.Tensor):
= label.item()
label
ax.text(+ f' {score.item():.2f}',
x0, y0, idx2label[label] =props,
bbox='black',
color=8,
fontsize=True,
in_layout=100*score.item()
zorder
)
def plot_image_with_inference_annotations(
image,
bboxes,
labels,
scores,=None,
ax=1.0
alpha
):= decode(image)
image if isinstance(bboxes, torch.Tensor):
= bboxes.cpu().numpy()
bboxes if isinstance(labels, torch.Tensor):
= labels.cpu().numpy()
labels if ax is None:
= plt.subplots()
fig, ax
fig.tight_layout()
ax.imshow(image)'off')
ax.axis(
for bbox, label, score in zip(bboxes, labels, scores):
=alpha)
draw_inference_bbox(ax, bbox, label, score, alpha=alpha)
add_bbox_inference_label(ax, bbox, label, score, alpha
from torchvision.ops import clip_boxes_to_image, batched_nms
def compute_pred_boxes(boxes, anchors):
= torch.empty_like(anchors)
result 2] = boxes[..., :2] * anchors[..., 2:] + anchors[..., :2]
result[..., :2:] = anchors[..., 2:] * torch.exp(boxes[..., 2:])
result[..., return result
def postprocess(box_preds, class_preds, detect_thresh, iou_thresh):
= class_preds.max(dim=-1)
pred_scores, pred_classes = pred_scores.sigmoid()
pred_scores
= pred_scores > detect_thresh
idxs_to_keep = box_preds[idxs_to_keep]
pred_boxes = pred_classes[idxs_to_keep]
pred_classes = pred_scores[idxs_to_keep]
pred_scores
= clip_boxes_to_image(pred_boxes, (image_size, image_size))
pred_boxes
= batched_nms(
idxs_to_keep =iou_thresh
pred_boxes, pred_scores, pred_classes, iou_threshold
)= pred_boxes[idxs_to_keep]
pred_boxes = pred_classes[idxs_to_keep]
pred_classes = pred_scores[idxs_to_keep]
pred_scores
return pred_boxes, pred_classes, pred_scores
def inference(model, batch_num, detect_thresh=0.35, iou_thresh=0.5):
= iter(object_detection_valid_dl)
valid_iter for _ in range(batch_num + 1):
= next(valid_iter)
batch
= batch
images, targ_boxes, targ_classes = images.shape[0]
batch_size = images.shape[-1]
image_size
eval()
model.with torch.no_grad():
= model(images.to(def_device))
boxes, classes
= torch.stack([
anchors
generate_anchor_boxes(image_size).to(torch.float32)for _ in range(batch_size)
])= compute_pred_boxes(boxes, anchors)
pred_boxes = box_convert(pred_boxes, 'cxcywh', 'xyxy')
pred_boxes
= plt.subplots(nrows=4, ncols=2, figsize=(16, 32))
fig, axs
fig.tight_layout()= zip(images, pred_boxes, classes, targ_boxes, targ_classes, axs)
plot_data for image, pred_box, pred_class, targ_box, targ_class, ax_row in plot_data:
= postprocess(
pred_box, pred_class, pred_scores
pred_box, pred_class, detect_thresh, iou_thresh
)
= 1.0
targ_alpha = targ_class > -100
non_padding = targ_box[non_padding]
targ_box = targ_class[non_padding].long()
targ_class
0].imshow(decode(image))
ax_row[0].axis('off')
ax_row[0].set_title('Ground Truth')
ax_row[for box, label in zip(targ_box, targ_class):
0], box.cpu(), label.cpu(), alpha=targ_alpha)
draw_bbox(ax_row[0], box.cpu(), label.cpu(), alpha=targ_alpha)
add_bbox_label(ax_row[
plot_image_with_inference_annotations(=image,
image=pred_box,
bboxes=pred_class,
labels=pred_scores,
scores=ax_row[1],
ax
)1].set_title('Predictions') ax_row[
inference(=object_detection_learn.model,
model=10,
batch_num=0.5,
detect_thresh=0.2
iou_thresh )
COCO Metrics
We can get some quantitative estimates on how well our model performs using pycocotools.
COCO metric helper functions
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from fastai.vision.all import chunked
import json
with open(path_2007/'test.json') as f:
= json.load(f)
test_ground_truth
def record_predictions(model, detect_thresh=0.35, iou_thresh=0.5):
= []
results = 8
batch_size = 600
image_size = torch.stack([
anchors
generate_anchor_boxes(image_size).to(torch.float32)for _ in range(batch_size)
])= [o['id'] for o in test_ground_truth['images']]
image_ids = chunked(image_ids, chunk_sz=8)
image_id_batches eval()
model.
for batch, image_id_batch in zip(object_detection_valid_dl, image_id_batches):
*_ = batch
images,
with torch.no_grad():
= model(images.to(def_device))
boxes, classes
= compute_pred_boxes(boxes, anchors)
pred_boxes = box_convert(pred_boxes, 'cxcywh', 'xyxy')
pred_boxes
for pred_box, pred_class, image_id, in zip(pred_boxes, classes, image_id_batch):
= postprocess(
pred_box, pred_class, pred_scores
pred_box, pred_class, detect_thresh, iou_thresh
)= box_convert(pred_box, in_fmt='xyxy', out_fmt='xywh')
pred_box for cat_id, box, score in zip(pred_class, pred_box, pred_scores):
results.append({'image_id': image_id,
'category_id': cat_id.item() + 1,
'bbox': [round(o, 2) for o in box.tolist()],
'score': round(score.item(), 3),
})return results
def get_coco_scores(
model,=path_2007/'test.json',
ground_truth_path='bbox',
annotation_type=0.35,
detect_thresh=0.5
iou_thresh
):= record_predictions(
test_results
model,=detect_thresh,
detect_thresh=iou_thresh
iou_thresh
)with open('results/results.json', 'w') as f:
json.dump(test_results, f)= COCO(ground_truth_path)
coco_ground_truth = coco_ground_truth.loadRes('results/results.json')
coco_predictions = COCOeval(coco_ground_truth, coco_predictions, annotation_type)
coco_eval
coco_eval.evaluate()
coco_eval.accumulate() coco_eval.summarize()
get_coco_scores(=object_detection_learn.model,
model=0.5,
detect_thresh=0.2
iou_thresh )
loading annotations into memory...
Done (t=0.34s)
creating index...
index created!
Loading and preparing results...
DONE (t=0.03s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=3.28s).
Accumulating evaluation results...
DONE (t=0.53s).
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.440
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.640
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.485
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.036
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.254
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.560
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.388
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.488
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.489
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.043
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.300
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.618
Results on Example Image
Finally, we’ll do inference on the example image we saw at the beginning of this post and get the bounding boxes we saw there.
from fastai.vision.all import load_image
= load_image('example_image.png', mode='RGB')
img = np.array(img)
img = multilabel_valid_tfms(image=img)['image']
img = torch.from_numpy(img)
img = img.permute(2, 0, 1)[None]
img = img.to(def_device)
img
with torch.no_grad():
= object_detection_learn.model(img)
boxes, classes
= boxes[0], classes[0]
boxes, classes = generate_anchor_boxes(600)
anchors = compute_pred_boxes(boxes, anchors)
pred_boxes = box_convert(pred_boxes, 'cxcywh', 'xyxy')
pred_boxes = postprocess(
boxes, classes, scores =0.5, iou_thresh=0.2
pred_boxes, classes, detect_thresh
)
= plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig, axs 0].imshow(load_image('example_image.png'))
axs[0].axis('off')
axs[0].set_title('Input Image')
axs[
plot_image_with_inference_annotations('example_image.png'), boxes, classes, scores, ax=axs[1]
load_image(
)1].set_title('Model Predictions')
axs[
fig.tight_layout() plt.show()