Hacking Into FasterRcnn in Pytorch
- Brief Intro
- Custom Backone
- Custom Predictor
- Custom BoxHead
- CustomLoss Function
- Note on how to vary the anchor generator
- Credits
Brief Intro
In the post I will show how to tweak some of the internals of FaterRcnn in Pytorch. I am assuming the reader is someone who already have trained an object detection model using pytorch. If not there is and excellent tutorial in pytorch website.
Small Insight into the model
Basically Faster Rcnn is a two stage detector
- The first stage is the Region proposal network which is resposible for knowing the objectness and corresponding bounding boxes. So essentially the RegionProposalNetwork will give the proposals of whether and object is there or not
- These proposals will be used by the RoIHeads which outputs the detections .
- Inside the RoIHeads roi align is done
- There will be a box head and box predictor
- The losses for the predictions
- In this post i will try to show how we can add custom parts to the torchvision FasterRcnn
#collapse-hide
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch.nn as nn
import torch.nn.functional as F
print(f'torch version {torch.__version__}')
print(f'torchvision version {torchvision.__version__}')
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
#we need to specify an outchannel of this backone specifically because this outchannel will be
#used as an inchannel for the RPNHEAD which is producing the out of RegionProposalNetwork
#we can know the number of outchannels by looking into the backbone "backbone??"
backbone.out_channels = 1280
#by default the achor generator FasterRcnn assign will be for a FPN backone, so
#we need to specify a different anchor generator
anchor_generator = AnchorGenerator(sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),))
#here at each position in the grid there will be 3x3=9 anchors
#and if our backbone is not FPN then the forward method will assign the name '0' to feature map
#so we need to specify '0 as feature map name'
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
output_size=9,
sampling_ratio=2)
#the output size is the output shape of the roi pooled features which will be used by the box head
model = FasterRCNN(backbone,num_classes=2,rpn_anchor_generator=anchor_generator)
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 600)]
predictions = model(x)
The Resnet50Fpn available in torchvision
# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2 # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)
#hte returned layers are layer1,layer2,layer3,layer4 in returned_layers
backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone('resnet101',pretrained=True)
model = FasterRCNN(backbone,num_classes=2)
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)
The size of the last fature map in a Resnet50.Later i will show the sizes of the feature maps we use when we use FPN.
#collapse-hide
#just to show what will be out of of a normal resnet without fpn
res = torchvision.models.resnet50()
pure = nn.Sequential(*list(res.children())[:-2])
temp = torch.rand(1,3,400,400)
pure(temp).shape
The required layers can be obtained by specifying the returned layers parameters.Also the resnet backbone of different depth can be used.
#the returned layers are layer1,layer2,layer3,layer4 in returned_layers
backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone('resnet101',pretrained=True,
returned_layers=[2,3,4])
Here we are using feature maps of the following shapes.
#collapse-hide
out = backbone(temp)
for i in out.keys():
print(i,' ',out[i].shape)
#from the above we can see that the feature are feat maps should be 0,1,2,pool
#where pool comes from the default extra block
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0','1','2','pool'],
output_size=7,
sampling_ratio=2)
So essentially what we did was we selected the last three layers in FPN by specifying them in the returned layers, by default, the backbone will add a pool layer on top of the last layer. So we are left with four layers. Now the RoIAlign need to be done in these four layers. If we dnt specify the RoIAlign it will use the by default assume we have used all layers from FPN in torchvision. So we need to specifically give the feauture maps that we used. The usage of feature maps can be our application specific, some time you might need to detect small objects sometimes the object of interest will be large objects only.
#we will need to give anchor_generator because the deafault anchor generator assumes we use all layers in fpn
#since we have four layers in fpn here we need to specify 4 anchors
anchor_sizes = ((32), (64), (128),(256) )
aspect_ratios = ((0.5,1.0, 1.5,2.0,)) * len(anchor_sizes)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
Since we have four layers in our FPN we need to specify the anchors. So here each feature map will have 4 anchors at each position.So the first feature map will have anchor size 32 and four of them will be there at each position in the feature map of aspect_ratios (0.5,1.0, 1.5,2.0). Now we can pass these to the FasterRCNN class
model = FasterRCNN(backbone,num_classes=2,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler)
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)
Custom Predictor
The predictor is what that outputs the classes and the corresponding bboxes . By default these have two layers one for class and one for bboxes,but we can add more before it if we want to,so if you have a ton of data this might come handy,(remember there is already a box head before the predictor head, so you might not need this)
class Custom_predictor(nn.Module):
def __init__(self,in_channels,num_classes):
super(Custom_predictor,self).__init__()
self.additional_layer = nn.Linear(in_channels,in_channels) #this is the additional layer
self.cls_score = nn.Linear(in_channels, num_classes)
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
def forward(self,x):
if x.dim() == 4:
assert list(x.shape[2:]) == [1, 1]
x = x.flatten(start_dim=1)
x = self.additional_layer(x)
scores = self.cls_score(x)
bbox_deltas = self.bbox_pred(x)
return scores, bbox_deltas
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
#we need the out channels of the box head to pass tpp custom predictor
in_features = model.roi_heads.box_head.fc7.out_features
#now we can add the custom predictor to the model
num_classes =2
model.roi_heads.box_predictor = Custom_predictor(in_features,num_classes)
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
class CustomHead(nn.Module):
def __init__(self,in_channels,roi_outshape,representation_size):
super(CustomHead,self).__init__()
self.conv = nn.Conv2d(in_channels,in_channels,kernel_size=3,padding=1)#this is teh additional layer adde
#we will be sending a flattened layer, the size will eb in_channels*w*h, here roi_outshape represents it
self.fc6 = nn.Linear(in_channels*roi_outshape**2, representation_size)
self.fc7 = nn.Linear(representation_size, representation_size)
def forward(self,x):
# breakpoint()
x = self.conv(x)
x = x.flatten(start_dim=1)
import torch.nn.functional as F
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
return x
- We need in_channels and representation size, remember the output of this is the input of box_predictor, so we can get the representation size of box_head from the input of box_predictor.
- The in_channels can be got from the backbone out channels.
- After the flattening the width and height also need to be considered which we wil get from roi_pool output.
in_channels = model.backbone.out_channels
roi_outshape = model.roi_heads.box_roi_pool.output_size[0]
representation_size=model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_head = CustomHead(in_channels,roi_outshape,representation_size)
num_classes=2
model.roi_heads.box_predictor = FastRCNNPredictor(representation_size, num_classes)
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)
This is the modification for loss of FasterRcnn Predictor.
- You can modify the loss by defining the fastrcnn_loss and making chages where you want.
- Then pass as say model.roi_heads.fastrcnn_loss = Custom_loss
- Usually we replace the F.crossentropy loss by say Focal loss or label smoothing loss
import torchvision.models.detection._utils as det_utils
import torch.nn.functional as F
The below loss function is taken from Aman Aroras blog.
# Helper functions from fastai
def reduce_loss(loss, reduction='mean'):
return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss
# Implementation from fastai https://github.com/fastai/fastai2/blob/master/fastai2/layers.py#L338
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, ε:float=0.1, reduction='mean'):
super().__init__()
self.ε,self.reduction = ε,reduction
def forward(self, output, target):
# number of classes
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
# (1-ε)* H(q,p) + ε*H(u,p)
return (1-self.ε)*nll + self.ε*(loss/c)
custom_loss = LabelSmoothingCrossEntropy()
#torchvision.models.detection.roi_heads.fastrcnn_loss??
def custom_fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Arguments:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
classification_loss = custom_loss(class_logits, labels) #ADDING THE CUSTOM LOSS HERE
# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, -1, 4)
box_loss = det_utils.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
size_average=False,
)
box_loss = box_loss / labels.numel()
return classification_loss, box_loss
The way in which anchor generators are assigned when we use backbone with and without fpn is different. When we are not using FPN there will be only one feature map and for that feature map we need to specify anchors of different shapes.
anchor_generator = AnchorGenerator(sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),))
In the above case suppose we have a feature map of shape 7x7, then at each cell in it there will be 9 anchors,three each of shapes 128,256 and 512,with the corresponding aspect rations. But when we are using FPN we have different feature maps, so its more effective we use different feature maps for different layers. Small sized objects are deteted using the earlier feature maps and thus for those we can specify a small sized anchor say 32 and for the later layers we can specify larger anchors.
anchor_sizes = ((32), (64), (128),(256) )
aspect_ratios = ((0.5,1.0, 1.5,2.0,)) * len(anchor_sizes)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
In the above i am using the same aspect ratio for all the sizes so i am just multiplying by the lenght of the anchor_sizes, but if we want to specify different aspect ratios its totally possible. But be carefull to specifiy the same number of aspect ratios for each anchor sizes
All the above hacks are just modification of the existing wonderful torchvision library.