You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loss.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import torch
  2. import torch.nn as nn
  3. class SegmentationLosses(object):
  4. def __init__(self, weight=None, size_average=True, batch_average=True, ignore_index=255, cuda=False): # ignore_index=255
  5. self.ignore_index = ignore_index
  6. self.weight = weight
  7. self.size_average = size_average
  8. self.batch_average = batch_average
  9. self.cuda = cuda
  10. def build_loss(self, mode='ce'):
  11. """Choices: ['ce' or 'focal']"""
  12. if mode == 'ce':
  13. return self.CrossEntropyLoss
  14. elif mode == 'focal':
  15. return self.FocalLoss
  16. else:
  17. raise NotImplementedError
  18. def CrossEntropyLoss(self, logit, target):
  19. n, c, h, w = logit.size()
  20. #criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index,
  21. #size_average=self.size_average)
  22. criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=self.ignore_index)
  23. if self.cuda:
  24. criterion = criterion.cuda()
  25. loss = criterion(logit, target.long())
  26. if self.batch_average:
  27. loss /= n
  28. return loss
  29. def FocalLoss(self, logit, target, gamma=2, alpha=0.5):
  30. n, c, h, w = logit.size()
  31. criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index,
  32. size_average=self.size_average)
  33. if self.cuda:
  34. criterion = criterion.cuda()
  35. logpt = -criterion(logit, target.long())
  36. pt = torch.exp(logpt)
  37. if alpha is not None:
  38. logpt *= alpha
  39. loss = -((1 - pt) ** gamma) * logpt
  40. if self.batch_average:
  41. loss /= n
  42. return loss
  43. if __name__ == "__main__":
  44. loss = SegmentationLosses(cuda=True)
  45. a = torch.rand(1, 3, 7, 7).cuda()
  46. b = torch.rand(1, 7, 7).cuda()
  47. print(loss.CrossEntropyLoss(a, b).item())
  48. print(loss.FocalLoss(a, b, gamma=0, alpha=None).item())
  49. print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item())