|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from ..builder import LOSSES
- from .utils import weight_reduce_loss
-
-
- def _expand_onehot_labels(labels, label_weights, label_channels):
- bin_labels = labels.new_full((labels.size(0), label_channels), 0)
- inds = torch.nonzero(
- (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
- if inds.numel() > 0:
- bin_labels[inds, labels[inds]] = 1
- bin_label_weights = label_weights.view(-1, 1).expand(
- label_weights.size(0), label_channels)
- return bin_labels, bin_label_weights
-
-
- # TODO: code refactoring to make it consistent with other losses
- @LOSSES.register_module()
- class GHMC(nn.Module):
- """GHM Classification Loss.
-
- Details of the theorem can be viewed in the paper
- `Gradient Harmonized Single-stage Detector
- <https://arxiv.org/abs/1811.05181>`_.
-
- Args:
- bins (int): Number of the unit regions for distribution calculation.
- momentum (float): The parameter for moving average.
- use_sigmoid (bool): Can only be true for BCE based loss now.
- loss_weight (float): The weight of the total GHM-C loss.
- reduction (str): Options are "none", "mean" and "sum".
- Defaults to "mean"
- """
-
- def __init__(self,
- bins=10,
- momentum=0,
- use_sigmoid=True,
- loss_weight=1.0,
- reduction='mean'):
- super(GHMC, self).__init__()
- self.bins = bins
- self.momentum = momentum
- edges = torch.arange(bins + 1).float() / bins
- self.register_buffer('edges', edges)
- self.edges[-1] += 1e-6
- if momentum > 0:
- acc_sum = torch.zeros(bins)
- self.register_buffer('acc_sum', acc_sum)
- self.use_sigmoid = use_sigmoid
- if not self.use_sigmoid:
- raise NotImplementedError
- self.loss_weight = loss_weight
- self.reduction = reduction
-
- def forward(self,
- pred,
- target,
- label_weight,
- reduction_override=None,
- **kwargs):
- """Calculate the GHM-C loss.
-
- Args:
- pred (float tensor of size [batch_num, class_num]):
- The direct prediction of classification fc layer.
- target (float tensor of size [batch_num, class_num]):
- Binary class target for each sample.
- label_weight (float tensor of size [batch_num, class_num]):
- the value is 1 if the sample is valid and 0 if ignored.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Defaults to None.
- Returns:
- The gradient harmonized loss.
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- # the target should be binary class label
- if pred.dim() != target.dim():
- target, label_weight = _expand_onehot_labels(
- target, label_weight, pred.size(-1))
- target, label_weight = target.float(), label_weight.float()
- edges = self.edges
- mmt = self.momentum
- weights = torch.zeros_like(pred)
-
- # gradient length
- g = torch.abs(pred.sigmoid().detach() - target)
-
- valid = label_weight > 0
- tot = max(valid.float().sum().item(), 1.0)
- n = 0 # n valid bins
- for i in range(self.bins):
- inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
- num_in_bin = inds.sum().item()
- if num_in_bin > 0:
- if mmt > 0:
- self.acc_sum[i] = mmt * self.acc_sum[i] \
- + (1 - mmt) * num_in_bin
- weights[inds] = tot / self.acc_sum[i]
- else:
- weights[inds] = tot / num_in_bin
- n += 1
- if n > 0:
- weights = weights / n
-
- loss = F.binary_cross_entropy_with_logits(
- pred, target, reduction='none')
- loss = weight_reduce_loss(
- loss, weights, reduction=reduction, avg_factor=tot)
- return loss * self.loss_weight
-
-
- # TODO: code refactoring to make it consistent with other losses
- @LOSSES.register_module()
- class GHMR(nn.Module):
- """GHM Regression Loss.
-
- Details of the theorem can be viewed in the paper
- `Gradient Harmonized Single-stage Detector
- <https://arxiv.org/abs/1811.05181>`_.
-
- Args:
- mu (float): The parameter for the Authentic Smooth L1 loss.
- bins (int): Number of the unit regions for distribution calculation.
- momentum (float): The parameter for moving average.
- loss_weight (float): The weight of the total GHM-R loss.
- reduction (str): Options are "none", "mean" and "sum".
- Defaults to "mean"
- """
-
- def __init__(self,
- mu=0.02,
- bins=10,
- momentum=0,
- loss_weight=1.0,
- reduction='mean'):
- super(GHMR, self).__init__()
- self.mu = mu
- self.bins = bins
- edges = torch.arange(bins + 1).float() / bins
- self.register_buffer('edges', edges)
- self.edges[-1] = 1e3
- self.momentum = momentum
- if momentum > 0:
- acc_sum = torch.zeros(bins)
- self.register_buffer('acc_sum', acc_sum)
- self.loss_weight = loss_weight
- self.reduction = reduction
-
- # TODO: support reduction parameter
- def forward(self,
- pred,
- target,
- label_weight,
- avg_factor=None,
- reduction_override=None):
- """Calculate the GHM-R loss.
-
- Args:
- pred (float tensor of size [batch_num, 4 (* class_num)]):
- The prediction of box regression layer. Channel number can be 4
- or 4 * class_num depending on whether it is class-agnostic.
- target (float tensor of size [batch_num, 4 (* class_num)]):
- The target regression values with the same size of pred.
- label_weight (float tensor of size [batch_num, 4 (* class_num)]):
- The weight of each sample, 0 if ignored.
- reduction_override (str, optional): The reduction method used to
- override the original reduction method of the loss.
- Defaults to None.
- Returns:
- The gradient harmonized loss.
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- mu = self.mu
- edges = self.edges
- mmt = self.momentum
-
- # ASL1 loss
- diff = pred - target
- loss = torch.sqrt(diff * diff + mu * mu) - mu
-
- # gradient length
- g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
- weights = torch.zeros_like(g)
-
- valid = label_weight > 0
- tot = max(label_weight.float().sum().item(), 1.0)
- n = 0 # n: valid bins
- for i in range(self.bins):
- inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
- num_in_bin = inds.sum().item()
- if num_in_bin > 0:
- n += 1
- if mmt > 0:
- self.acc_sum[i] = mmt * self.acc_sum[i] \
- + (1 - mmt) * num_in_bin
- weights[inds] = tot / self.acc_sum[i]
- else:
- weights[inds] = tot / num_in_bin
- if n > 0:
- weights /= n
- loss = weight_reduce_loss(
- loss, weights, reduction=reduction, avg_factor=tot)
- return loss * self.loss_weight
|