|
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from ..builder import LOSSES
- from .utils import weight_reduce_loss
-
-
- def cross_entropy(pred,
- label,
- weight=None,
- reduction='mean',
- avg_factor=None,
- class_weight=None,
- ignore_index=-100):
- """Calculate the CrossEntropy loss.
-
- Args:
- pred (torch.Tensor): The prediction with shape (N, C), C is the number
- of classes.
- label (torch.Tensor): The learning label of the prediction.
- weight (torch.Tensor, optional): Sample-wise loss weight.
- reduction (str, optional): The method used to reduce the loss.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- class_weight (list[float], optional): The weight for each class.
- ignore_index (int | None): The label index to be ignored.
- If None, it will be set to default value. Default: -100.
-
- Returns:
- torch.Tensor: The calculated loss
- """
- # The default value of ignore_index is the same as F.cross_entropy
- ignore_index = -100 if ignore_index is None else ignore_index
- # element-wise losses
- loss = F.cross_entropy(
- pred,
- label,
- weight=class_weight,
- reduction='none',
- ignore_index=ignore_index)
-
- # apply weights and do the reduction
- if weight is not None:
- weight = weight.float()
- loss = weight_reduce_loss(
- loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
-
- return loss
-
-
- def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
- """Expand onehot labels to match the size of prediction."""
- bin_labels = labels.new_full((labels.size(0), label_channels), 0)
- valid_mask = (labels >= 0) & (labels != ignore_index)
- inds = torch.nonzero(
- valid_mask & (labels < label_channels), as_tuple=False)
-
- if inds.numel() > 0:
- bin_labels[inds, labels[inds]] = 1
-
- valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
- label_channels).float()
- if label_weights is None:
- bin_label_weights = valid_mask
- else:
- bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
- bin_label_weights *= valid_mask
-
- return bin_labels, bin_label_weights
-
-
- def binary_cross_entropy(pred,
- label,
- weight=None,
- reduction='mean',
- avg_factor=None,
- class_weight=None,
- ignore_index=-100):
- """Calculate the binary CrossEntropy loss.
-
- Args:
- pred (torch.Tensor): The prediction with shape (N, 1).
- label (torch.Tensor): The learning label of the prediction.
- weight (torch.Tensor, optional): Sample-wise loss weight.
- reduction (str, optional): The method used to reduce the loss.
- Options are "none", "mean" and "sum".
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- class_weight (list[float], optional): The weight for each class.
- ignore_index (int | None): The label index to be ignored.
- If None, it will be set to default value. Default: -100.
-
- Returns:
- torch.Tensor: The calculated loss.
- """
- # The default value of ignore_index is the same as F.cross_entropy
- ignore_index = -100 if ignore_index is None else ignore_index
- if pred.dim() != label.dim():
- label, weight = _expand_onehot_labels(label, weight, pred.size(-1),
- ignore_index)
-
- # weighted element-wise losses
- if weight is not None:
- weight = weight.float()
- loss = F.binary_cross_entropy_with_logits(
- pred, label.float(), pos_weight=class_weight, reduction='none')
- # do the reduction for the weighted loss
- loss = weight_reduce_loss(
- loss, weight, reduction=reduction, avg_factor=avg_factor)
-
- return loss
-
-
- def mask_cross_entropy(pred,
- target,
- label,
- reduction='mean',
- avg_factor=None,
- class_weight=None,
- ignore_index=None):
- """Calculate the CrossEntropy loss for masks.
-
- Args:
- pred (torch.Tensor): The prediction with shape (N, C, *), C is the
- number of classes. The trailing * indicates arbitrary shape.
- target (torch.Tensor): The learning label of the prediction.
- label (torch.Tensor): ``label`` indicates the class label of the mask
- corresponding object. This will be used to select the mask in the
- of the class which the object belongs to when the mask prediction
- if not class-agnostic.
- reduction (str, optional): The method used to reduce the loss.
- Options are "none", "mean" and "sum".
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- class_weight (list[float], optional): The weight for each class.
- ignore_index (None): Placeholder, to be consistent with other loss.
- Default: None.
-
- Returns:
- torch.Tensor: The calculated loss
-
- Example:
- >>> N, C = 3, 11
- >>> H, W = 2, 2
- >>> pred = torch.randn(N, C, H, W) * 1000
- >>> target = torch.rand(N, H, W)
- >>> label = torch.randint(0, C, size=(N,))
- >>> reduction = 'mean'
- >>> avg_factor = None
- >>> class_weights = None
- >>> loss = mask_cross_entropy(pred, target, label, reduction,
- >>> avg_factor, class_weights)
- >>> assert loss.shape == (1,)
- """
- assert ignore_index is None, 'BCE loss does not support ignore_index'
- # TODO: handle these two reserved arguments
- assert reduction == 'mean' and avg_factor is None
- num_rois = pred.size()[0]
- inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
- pred_slice = pred[inds, label].squeeze(1)
- return F.binary_cross_entropy_with_logits(
- pred_slice, target, weight=class_weight, reduction='mean')[None]
-
-
- @LOSSES.register_module()
- class CrossEntropyLoss(nn.Module):
-
- def __init__(self,
- use_sigmoid=False,
- use_mask=False,
- reduction='mean',
- class_weight=None,
- ignore_index=None,
- loss_weight=1.0):
- """CrossEntropyLoss.
-
- Args:
- use_sigmoid (bool, optional): Whether the prediction uses sigmoid
- of softmax. Defaults to False.
- use_mask (bool, optional): Whether to use mask cross entropy loss.
- Defaults to False.
- reduction (str, optional): . Defaults to 'mean'.
- Options are "none", "mean" and "sum".
- class_weight (list[float], optional): Weight of each class.
- Defaults to None.
- ignore_index (int | None): The label index to be ignored.
- Defaults to None.
- loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
- """
- super(CrossEntropyLoss, self).__init__()
- assert (use_sigmoid is False) or (use_mask is False)
- self.use_sigmoid = use_sigmoid
- self.use_mask = use_mask
- self.reduction = reduction
- self.loss_weight = loss_weight
- self.class_weight = class_weight
- self.ignore_index = ignore_index
-
- if self.use_sigmoid:
- self.cls_criterion = binary_cross_entropy
- elif self.use_mask:
- self.cls_criterion = mask_cross_entropy
- else:
- self.cls_criterion = cross_entropy
-
- def forward(self,
- cls_score,
- label,
- weight=None,
- avg_factor=None,
- reduction_override=None,
- ignore_index=None,
- **kwargs):
- """Forward function.
-
- Args:
- cls_score (torch.Tensor): The prediction.
- label (torch.Tensor): The learning label of the prediction.
- weight (torch.Tensor, optional): Sample-wise loss weight.
- avg_factor (int, optional): Average factor that is used to average
- the loss. Defaults to None.
- reduction_override (str, optional): The method used to reduce the
- loss. Options are "none", "mean" and "sum".
- ignore_index (int | None): The label index to be ignored.
- If not None, it will override the default value. Default: None.
- Returns:
- torch.Tensor: The calculated loss.
- """
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- if ignore_index is None:
- ignore_index = self.ignore_index
-
- if self.class_weight is not None:
- class_weight = cls_score.new_tensor(
- self.class_weight, device=cls_score.device)
- else:
- class_weight = None
- loss_cls = self.loss_weight * self.cls_criterion(
- cls_score,
- label,
- weight,
- class_weight=class_weight,
- reduction=reduction,
- avg_factor=avg_factor,
- ignore_index=ignore_index,
- **kwargs)
- return loss_cls
|