You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iou.py 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """
  2. 和计算iou相关的函数和类, 包括计算iou loss
  3. """
  4. import torch
  5. from torch import nn
  6. from network_module.compute_utils import torch_nanmean
  7. def fast_hist(pred, label, n_classes):
  8. # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
  9. return torch.bincount(n_classes * label + pred, minlength=n_classes ** 2).reshape(n_classes, n_classes)
  10. def per_class_iu(hist):
  11. # 计算所有验证集图片的逐类别mIoU值
  12. # 分别为每个类别计算mIoU,hist的形状(n, n)
  13. # 矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,)
  14. # hist.sum(0)=按列相加 hist.sum(1)按行相加, 行表示标签, 列表示预测
  15. return (torch.diag(hist)) / (torch.sum(hist, 1) + torch.sum(hist, 0) - torch.diag(hist))
  16. def get_ious(pred, label, n_classes):
  17. hist = fast_hist(pred.flatten(), label.flatten(), n_classes)
  18. IoUs = per_class_iu(hist)
  19. mIoU = torch_nanmean(IoUs[1:n_classes])
  20. return mIoU, IoUs
  21. class IOU_loss(nn.Module):
  22. def __init__(self, n_classes):
  23. super(IOU_loss, self).__init__()
  24. self.n_classes = n_classes
  25. def forward(self, pred, label):
  26. mIoU, _ = get_ious(pred, label, self.n_classes)
  27. return 1 - mIoU
  28. class IOU:
  29. def __init__(self, n_classes):
  30. self.n_classes = n_classes
  31. self.hist = None
  32. def add_data(self, preds, label):
  33. self.hist += torch.zeros((self.n_classes, self.n_classes)).type_as(
  34. preds) if self.hist is None else self.hist + fast_hist(preds.int(), label, self.n_classes)
  35. def get_miou(self):
  36. IoUs = per_class_iu(self.hist)
  37. self.hist = None
  38. mIoU = torch_nanmean(IoUs[1:self.n_classes])
  39. return mIoU, IoUs

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)