You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

losses.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.modules.loss import _Loss
  4. import torch.nn.functional as F
  5. import numpy as np
  6. class DiceLoss(_Loss):
  7. def forward(self, output, target, weights=None, ignore_index=None):
  8. eps = 0.001
  9. encoded_target = output.detach() * 0 # 将variable参数从网络中隔离开,不参与参数更新。
  10. if ignore_index is not None:
  11. mask = target == ignore_index
  12. target = target.clone()
  13. target[mask] = 0
  14. encoded_target.scatter_(1, target.unsqueeze(1),
  15. 1) # unsqueeze增加一个维度
  16. mask = mask.unsqueeze(1).expand_as(encoded_target)
  17. encoded_target[mask] = 0
  18. else:
  19. encoded_target.scatter_(1, target.unsqueeze(1),
  20. 1) # unsqueeze增加一个维度
  21. # scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向输出。
  22. if weights is None:
  23. weights = 1
  24. # print(output.min(),output.max())
  25. # print(output.shape,output[0,:,0,0])
  26. intersection = output * encoded_target
  27. numerator = 2 * intersection.sum(0).sum(1).sum(1)
  28. denominator = output + encoded_target
  29. if ignore_index is not None:
  30. denominator[mask] = 0
  31. # 计算无效的类别数量
  32. count1 = []
  33. for i in encoded_target.sum(0).sum(1).sum(1):
  34. if i == 0:
  35. count1.append(1)
  36. else:
  37. count1.append(0)
  38. count2 = []
  39. for i in denominator.sum(0).sum(1).sum(1):
  40. if i == 0:
  41. count2.append(1)
  42. else:
  43. count2.append(0)
  44. count = sum(np.array(count1) * np.array(count2))
  45. # print(count)
  46. denominator = denominator.sum(0).sum(1).sum(1) + eps
  47. loss_per_channel = weights * (1 - (numerator / denominator)
  48. ) # Channel-wise weights
  49. # print(loss_per_channel) # 每一个类别的平均dice
  50. return (loss_per_channel.sum() - count) / (output.size(1) - count)
  51. # return loss_per_channel.sum() / output.size(1)
  52. class CrossEntropy2D(nn.Module):
  53. """
  54. 2D Cross-entropy loss implemented as negative log likelihood
  55. """
  56. def __init__(self, weight=None, reduction='none'):
  57. super(CrossEntropy2D, self).__init__()
  58. self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
  59. def forward(self, inputs, targets):
  60. return self.nll_loss(inputs, targets)
  61. class CombinedLoss(nn.Module):
  62. """
  63. For CrossEntropy the input has to be a long tensor
  64. Args:
  65. -- inputx N x C x H x W (其中N为batch_size)
  66. -- target - N x H x W - int type
  67. -- weight - N x H x W - float
  68. """
  69. def __init__(self, weight_dice, weight_ce):
  70. super(CombinedLoss, self).__init__()
  71. self.cross_entropy_loss = CrossEntropy2D()
  72. self.dice_loss = DiceLoss()
  73. self.weight_dice = weight_dice
  74. self.weight_ce = weight_ce
  75. def forward(self, inputx, target):
  76. target = target.type(torch.LongTensor) # Typecast to long tensor
  77. if inputx.is_cuda:
  78. target = target.cuda()
  79. # print(inputx.min(),inputx.max())
  80. input_soft = F.softmax(inputx, dim=1) # Along Class Dimension
  81. dice_val = torch.mean(self.dice_loss(input_soft, target))
  82. ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target))
  83. # ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target))
  84. total_loss = torch.add(torch.mul(dice_val, self.weight_dice),
  85. torch.mul(ce_val, self.weight_ce))
  86. # print(weight.max())
  87. return total_loss, dice_val, ce_val

网络代码复现