|
- """
- 和计算iou相关的函数和类, 包括计算iou loss
- """
- import torch
- from torch import nn
- from network_module.compute_utils import torch_nanmean
-
-
- def fast_hist(pred, label, n_classes):
- # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
- return torch.bincount(n_classes * label + pred, minlength=n_classes ** 2).reshape(n_classes, n_classes)
-
-
- def per_class_iu(hist):
- # 计算所有验证集图片的逐类别mIoU值
- # 分别为每个类别计算mIoU,hist的形状(n, n)
- # 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,)
- # hist.sum(0)=按列相加 hist.sum(1)按行相加, 行表示标签, 列表示预测
- return (torch.diag(hist)) / (torch.sum(hist, 1) + torch.sum(hist, 0) - torch.diag(hist))
-
-
- def get_ious(pred, label, n_classes):
- hist = fast_hist(pred.flatten(), label.flatten(), n_classes)
- IoUs = per_class_iu(hist)
- mIoU = torch_nanmean(IoUs[1:n_classes])
- return mIoU, IoUs
-
-
- class IOU_loss(nn.Module):
- def __init__(self, n_classes):
- super(IOU_loss, self).__init__()
- self.n_classes = n_classes
-
- def forward(self, pred, label):
- mIoU, _ = get_ious(pred, label, self.n_classes)
- return 1 - mIoU
-
-
- class IOU:
- def __init__(self, n_classes):
- self.n_classes = n_classes
- self.hist = None
-
- def add_data(self, preds, label):
- self.hist += torch.zeros((self.n_classes, self.n_classes)).type_as(
- preds) if self.hist is None else self.hist + fast_hist(preds.int(), label, self.n_classes)
-
- def get_miou(self):
- IoUs = per_class_iu(self.hist)
- self.hist = None
- mIoU = torch_nanmean(IoUs[1:self.n_classes])
- return mIoU, IoUs
|