|
- # Copyright (c) OpenMMLab. All rights reserved.
- from abc import ABCMeta, abstractmethod
-
- import torch.nn.functional as F
- from mmcv.runner import BaseModule, force_fp32
-
- from ..builder import build_loss
- from ..utils import interpolate_as
-
-
- class BaseSemanticHead(BaseModule, metaclass=ABCMeta):
- """Base module of Semantic Head.
-
- Args:
- num_classes (int): the number of classes.
- init_cfg (dict): the initialization config.
- loss_seg (dict): the loss of the semantic head.
- """
-
- def __init__(self,
- num_classes,
- init_cfg=None,
- loss_seg=dict(
- type='CrossEntropyLoss',
- ignore_index=255,
- loss_weight=1.0)):
- super(BaseSemanticHead, self).__init__(init_cfg)
- self.loss_seg = build_loss(loss_seg)
- self.num_classes = num_classes
-
- @force_fp32(apply_to=('seg_preds', ))
- def loss(self, seg_preds, gt_semantic_seg):
- """Get the loss of semantic head.
-
- Args:
- seg_preds (Tensor): The input logits with the shape (N, C, H, W).
- gt_semantic_seg: The ground truth of semantic segmentation with
- the shape (N, H, W).
- label_bias: The starting number of the semantic label.
- Default: 1.
-
- Returns:
- dict: the loss of semantic head.
- """
- if seg_preds.shape[-2:] != gt_semantic_seg.shape[-2:]:
- seg_preds = interpolate_as(seg_preds, gt_semantic_seg)
- seg_preds = seg_preds.permute((0, 2, 3, 1))
-
- loss_seg = self.loss_seg(
- seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
- gt_semantic_seg.reshape(-1).long())
- return dict(loss_seg=loss_seg)
-
- @abstractmethod
- def forward(self, x):
- """Placeholder of forward function.
-
- Returns:
- dict[str, Tensor]: A dictionary, including features
- and predicted scores. Required keys: 'seg_preds'
- and 'feats'.
- """
- pass
-
- def forward_train(self, x, gt_semantic_seg):
- output = self.forward(x)
- seg_preds = output['seg_preds']
- return self.loss(seg_preds, gt_semantic_seg)
-
- def simple_test(self, x, img_metas, rescale=False):
- output = self.forward(x)
- seg_preds = output['seg_preds']
- seg_preds = F.interpolate(
- seg_preds,
- size=img_metas[0]['pad_shape'][:2],
- mode='bilinear',
- align_corners=False)
-
- if rescale:
- h, w, _ = img_metas[0]['img_shape']
- seg_preds = seg_preds[:, :, :h, :w]
-
- h, w, _ = img_metas[0]['ori_shape']
- seg_preds = F.interpolate(
- seg_preds, size=(h, w), mode='bilinear', align_corners=False)
- return seg_preds
|