|
- import torch
- import torch.nn as nn
- from torch.nn.modules.loss import _Loss
- import torch.nn.functional as F
- import numpy as np
-
-
- class DiceLoss(_Loss):
- def forward(self, output, target, weights=None, ignore_index=None):
- eps = 0.001
-
- encoded_target = output.detach() * 0 # 将variable参数从网络中隔离开,不参与参数更新。
-
- if ignore_index is not None:
- mask = target == ignore_index
- target = target.clone()
- target[mask] = 0
- encoded_target.scatter_(1, target.unsqueeze(1),
- 1) # unsqueeze增加一个维度
- mask = mask.unsqueeze(1).expand_as(encoded_target)
- encoded_target[mask] = 0
-
- else:
- encoded_target.scatter_(1, target.unsqueeze(1),
- 1) # unsqueeze增加一个维度
- # scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向输出。
-
- if weights is None:
- weights = 1
- # print(output.min(),output.max())
- # print(output.shape,output[0,:,0,0])
- intersection = output * encoded_target
- numerator = 2 * intersection.sum(0).sum(1).sum(1)
- denominator = output + encoded_target
-
- if ignore_index is not None:
- denominator[mask] = 0
-
- # 计算无效的类别数量
- count1 = []
- for i in encoded_target.sum(0).sum(1).sum(1):
- if i == 0:
- count1.append(1)
- else:
- count1.append(0)
- count2 = []
- for i in denominator.sum(0).sum(1).sum(1):
- if i == 0:
- count2.append(1)
- else:
- count2.append(0)
- count = sum(np.array(count1) * np.array(count2))
- # print(count)
-
- denominator = denominator.sum(0).sum(1).sum(1) + eps
- loss_per_channel = weights * (1 - (numerator / denominator)
- ) # Channel-wise weights
- # print(loss_per_channel) # 每一个类别的平均dice
- return (loss_per_channel.sum() - count) / (output.size(1) - count)
- # return loss_per_channel.sum() / output.size(1)
-
-
- class CrossEntropy2D(nn.Module):
- """
- 2D Cross-entropy loss implemented as negative log likelihood
- """
- def __init__(self, weight=None, reduction='none'):
- super(CrossEntropy2D, self).__init__()
- self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
-
- def forward(self, inputs, targets):
- return self.nll_loss(inputs, targets)
-
-
- class CombinedLoss(nn.Module):
- """
- For CrossEntropy the input has to be a long tensor
- Args:
- -- inputx N x C x H x W (其中N为batch_size)
- -- target - N x H x W - int type
- -- weight - N x H x W - float
- """
- def __init__(self, weight_dice, weight_ce):
- super(CombinedLoss, self).__init__()
- self.cross_entropy_loss = CrossEntropy2D()
- self.dice_loss = DiceLoss()
- self.weight_dice = weight_dice
- self.weight_ce = weight_ce
-
- def forward(self, inputx, target):
- target = target.type(torch.LongTensor) # Typecast to long tensor
- if inputx.is_cuda:
- target = target.cuda()
- # print(inputx.min(),inputx.max())
-
- input_soft = F.softmax(inputx, dim=1) # Along Class Dimension
- dice_val = torch.mean(self.dice_loss(input_soft, target))
- ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target))
- # ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target))
- total_loss = torch.add(torch.mul(dice_val, self.weight_dice),
- torch.mul(ce_val, self.weight_ce))
- # print(weight.max())
- return total_loss, dice_val, ce_val
|