| @@ -0,0 +1,25 @@ | |||||
| # 任务:无限遥感图像分割 | |||||
| ### 环境 | |||||
| 操作系统:linux | |||||
| 显卡:1080ti | |||||
| python3.7 | |||||
| pytroch | |||||
| cuda10.0 | |||||
| ### 模型 | |||||
| Unet | |||||
| deeplabv3 | |||||
| ### 损失函数 | |||||
| dice ce | |||||
| @@ -0,0 +1,44 @@ | |||||
| import os | |||||
| import torch | |||||
| import numpy as np | |||||
| import cv2 as cv | |||||
| from os.path import join | |||||
| from torch.utils.data.dataset import Dataset | |||||
| class DataTrain(Dataset): | |||||
| def __init__(self, data_path, transforms=None): | |||||
| self.data_dir = data_path | |||||
| self.image_list = os.listdir(join(data_path, 'train_ori_images')) | |||||
| files_len = len(self.image_list) | |||||
| try: | |||||
| imgs = np.zeros(shape=(files_len, 256, 256, 3), dtype=np.uint8) | |||||
| labels = np.zeros(shape=(files_len, 256, 256), dtype=np.uint8) | |||||
| for idx, file in enumerate(self.image_list): | |||||
| fname = file.split('.')[0] | |||||
| img = cv.imread(join(self.data_dir, 'train_ori_images', fname + '.tif')) | |||||
| img = np.asarray(img, dtype=np.uint8) | |||||
| label = cv.imread( | |||||
| join(self.data_dir, 'train_pupil_images', fname + '.png'), | |||||
| cv.IMREAD_UNCHANGED) | |||||
| label = np.asarray(label, dtype=np.uint8) % 100 | |||||
| imgs[idx, :, :, :] = img | |||||
| labels[idx, :, :] = label | |||||
| self.images = imgs | |||||
| self.labels = labels | |||||
| self.transforms = transforms | |||||
| except Exception: | |||||
| raise Exception('read error') | |||||
| def __getitem__(self, index): | |||||
| img = self.images[index] | |||||
| label = self.labels[index] | |||||
| label[label > 0.5] = 1 | |||||
| tx_sample = self.transforms({'img': img, 'label': label}) | |||||
| img = tx_sample['img'] | |||||
| label = tx_sample['label'] | |||||
| return img, label | |||||
| def __len__(self): | |||||
| return len(self.image_list) | |||||
| @@ -0,0 +1,103 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch.nn.modules.loss import _Loss | |||||
| import torch.nn.functional as F | |||||
| import numpy as np | |||||
| class DiceLoss(_Loss): | |||||
| def forward(self, output, target, weights=None, ignore_index=None): | |||||
| eps = 0.001 | |||||
| encoded_target = output.detach() * 0 # 将variable参数从网络中隔离开,不参与参数更新。 | |||||
| if ignore_index is not None: | |||||
| mask = target == ignore_index | |||||
| target = target.clone() | |||||
| target[mask] = 0 | |||||
| encoded_target.scatter_(1, target.unsqueeze(1), | |||||
| 1) # unsqueeze增加一个维度 | |||||
| mask = mask.unsqueeze(1).expand_as(encoded_target) | |||||
| encoded_target[mask] = 0 | |||||
| else: | |||||
| encoded_target.scatter_(1, target.unsqueeze(1), | |||||
| 1) # unsqueeze增加一个维度 | |||||
| # scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向输出。 | |||||
| if weights is None: | |||||
| weights = 1 | |||||
| # print(output.min(),output.max()) | |||||
| # print(output.shape,output[0,:,0,0]) | |||||
| intersection = output * encoded_target | |||||
| numerator = 2 * intersection.sum(0).sum(1).sum(1) | |||||
| denominator = output + encoded_target | |||||
| if ignore_index is not None: | |||||
| denominator[mask] = 0 | |||||
| # 计算无效的类别数量 | |||||
| count1 = [] | |||||
| for i in encoded_target.sum(0).sum(1).sum(1): | |||||
| if i == 0: | |||||
| count1.append(1) | |||||
| else: | |||||
| count1.append(0) | |||||
| count2 = [] | |||||
| for i in denominator.sum(0).sum(1).sum(1): | |||||
| if i == 0: | |||||
| count2.append(1) | |||||
| else: | |||||
| count2.append(0) | |||||
| count = sum(np.array(count1) * np.array(count2)) | |||||
| # print(count) | |||||
| denominator = denominator.sum(0).sum(1).sum(1) + eps | |||||
| loss_per_channel = weights * (1 - (numerator / denominator) | |||||
| ) # Channel-wise weights | |||||
| # print(loss_per_channel) # 每一个类别的平均dice | |||||
| return (loss_per_channel.sum() - count) / (output.size(1) - count) | |||||
| # return loss_per_channel.sum() / output.size(1) | |||||
| class CrossEntropy2D(nn.Module): | |||||
| """ | |||||
| 2D Cross-entropy loss implemented as negative log likelihood | |||||
| """ | |||||
| def __init__(self, weight=None, reduction='none'): | |||||
| super(CrossEntropy2D, self).__init__() | |||||
| self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction) | |||||
| def forward(self, inputs, targets): | |||||
| return self.nll_loss(inputs, targets) | |||||
| class CombinedLoss(nn.Module): | |||||
| """ | |||||
| For CrossEntropy the input has to be a long tensor | |||||
| Args: | |||||
| -- inputx N x C x H x W (其中N为batch_size) | |||||
| -- target - N x H x W - int type | |||||
| -- weight - N x H x W - float | |||||
| """ | |||||
| def __init__(self, weight_dice, weight_ce): | |||||
| super(CombinedLoss, self).__init__() | |||||
| self.cross_entropy_loss = CrossEntropy2D() | |||||
| self.dice_loss = DiceLoss() | |||||
| self.weight_dice = weight_dice | |||||
| self.weight_ce = weight_ce | |||||
| def forward(self, inputx, target): | |||||
| target = target.type(torch.LongTensor) # Typecast to long tensor | |||||
| if inputx.is_cuda: | |||||
| target = target.cuda() | |||||
| # print(inputx.min(),inputx.max()) | |||||
| input_soft = F.softmax(inputx, dim=1) # Along Class Dimension | |||||
| dice_val = torch.mean(self.dice_loss(input_soft, target)) | |||||
| ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target)) | |||||
| # ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target)) | |||||
| total_loss = torch.add(torch.mul(dice_val, self.weight_dice), | |||||
| torch.mul(ce_val, self.weight_ce)) | |||||
| # print(weight.max()) | |||||
| return total_loss, dice_val, ce_val | |||||
| @@ -0,0 +1,336 @@ | |||||
| import os | |||||
| import torch | |||||
| from torch import nn | |||||
| import torch.nn.functional as F | |||||
| from collections import OrderedDict | |||||
| class ASPP(nn.Module): | |||||
| # have bias and relu, no bn | |||||
| def __init__(self, in_channel=512, depth=256): | |||||
| super().__init__() | |||||
| # global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True) | |||||
| self.mean = nn.AdaptiveAvgPool2d((1, 1)) | |||||
| self.conv = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1), | |||||
| nn.ReLU(inplace=True)) | |||||
| self.atrous_block1 = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1), | |||||
| nn.ReLU(inplace=True)) | |||||
| self.atrous_block6 = nn.Sequential( | |||||
| nn.Conv2d(in_channel, depth, 3, 1, padding=3, dilation=3), | |||||
| nn.ReLU(inplace=True)) | |||||
| self.atrous_block12 = nn.Sequential( | |||||
| nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6), | |||||
| nn.ReLU(inplace=True)) | |||||
| self.atrous_block18 = nn.Sequential( | |||||
| nn.Conv2d(in_channel, depth, 3, 1, padding=9, dilation=9), | |||||
| nn.ReLU(inplace=True)) | |||||
| self.conv_1x1_output = nn.Sequential(nn.Conv2d(depth * 5, depth, 1, 1), | |||||
| nn.ReLU(inplace=True)) | |||||
| def forward(self, x): | |||||
| size = x.shape[2:] | |||||
| image_features = self.mean(x) | |||||
| image_features = self.conv(image_features) | |||||
| image_features = F.interpolate(image_features, | |||||
| size=size, | |||||
| mode='bilinear', | |||||
| align_corners=True) | |||||
| atrous_block1 = self.atrous_block1(x) | |||||
| atrous_block6 = self.atrous_block6(x) | |||||
| atrous_block12 = self.atrous_block12(x) | |||||
| atrous_block18 = self.atrous_block18(x) | |||||
| net = self.conv_1x1_output( | |||||
| torch.cat([ | |||||
| image_features, atrous_block1, atrous_block6, atrous_block12, | |||||
| atrous_block18 | |||||
| ], | |||||
| dim=1)) | |||||
| return net | |||||
| class ResNet(nn.Module): | |||||
| def __init__(self, block, layers, num_classes=18, zero_init_residual=False, | |||||
| groups=1, width_per_group=64, replace_stride_with_dilation=None, | |||||
| norm_layer=None): | |||||
| super(ResNet, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| self._norm_layer = norm_layer | |||||
| self.inplanes = 64 | |||||
| self.dilation = 1 | |||||
| if replace_stride_with_dilation is None: | |||||
| # each element in the tuple indicates if we should replace | |||||
| # the 2x2 stride with a dilated convolution instead | |||||
| replace_stride_with_dilation = [False, False, False] | |||||
| if len(replace_stride_with_dilation) != 3: | |||||
| raise ValueError("replace_stride_with_dilation should be None " | |||||
| "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | |||||
| self.groups = groups | |||||
| self.base_width = width_per_group | |||||
| self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, | |||||
| bias=False) | |||||
| self.bn1 = norm_layer(self.inplanes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, | |||||
| dilate=replace_stride_with_dilation[0]) | |||||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, | |||||
| dilate=replace_stride_with_dilation[1]) | |||||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, | |||||
| dilate=replace_stride_with_dilation[2]) | |||||
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |||||
| self.fc = nn.Linear(512 * block.expansion, num_classes) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |||||
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||||
| nn.init.constant_(m.weight, 1) | |||||
| nn.init.constant_(m.bias, 0) | |||||
| # Zero-initialize the last BN in each residual branch, | |||||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||||
| if zero_init_residual: | |||||
| for m in self.modules(): | |||||
| if isinstance(m, Bottleneck): | |||||
| nn.init.constant_(m.bn3.weight, 0) | |||||
| elif isinstance(m, BasicBlock): | |||||
| nn.init.constant_(m.bn2.weight, 0) | |||||
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||||
| norm_layer = self._norm_layer | |||||
| downsample = None | |||||
| previous_dilation = self.dilation | |||||
| if dilate: | |||||
| self.dilation *= stride | |||||
| stride = 1 | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||||
| downsample = nn.Sequential( | |||||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||||
| norm_layer(planes * block.expansion), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | |||||
| self.base_width, previous_dilation, norm_layer)) | |||||
| self.inplanes = planes * block.expansion | |||||
| for _ in range(1, blocks): | |||||
| layers.append(block(self.inplanes, planes, groups=self.groups, | |||||
| base_width=self.base_width, dilation=self.dilation, | |||||
| norm_layer=norm_layer)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = self.relu(x) | |||||
| x = self.maxpool(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| x = self.layer4(x) | |||||
| x = self.avgpool(x) | |||||
| x = torch.flatten(x, 1) | |||||
| x = self.fc(x) | |||||
| return x | |||||
| def _resnet(arch, block, layers, pretrained, progress, **kwargs): | |||||
| model = ResNet(block, layers, **kwargs) | |||||
| return model | |||||
| def resnet50(pretrained=False, progress=True, **kwargs): | |||||
| return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, | |||||
| **kwargs) | |||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||||
| """3x3 convolution with padding""" | |||||
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||||
| padding=dilation, groups=groups, bias=False, dilation=dilation) | |||||
| def conv1x1(in_planes, out_planes, stride=1): | |||||
| """1x1 convolution""" | |||||
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||||
| class BasicBlock(nn.Module): | |||||
| expansion = 1 | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |||||
| base_width=64, dilation=1, norm_layer=None): | |||||
| super(BasicBlock, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| if groups != 1 or base_width != 64: | |||||
| raise ValueError('BasicBlock only supports groups=1 and base_width=64') | |||||
| if dilation > 1: | |||||
| raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||||
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||||
| self.bn1 = norm_layer(planes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.conv2 = conv3x3(planes, planes) | |||||
| self.bn2 = norm_layer(planes) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| def forward(self, x): | |||||
| identity = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out += identity | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |||||
| base_width=64, dilation=1, norm_layer=None): | |||||
| super(Bottleneck, self).__init__() | |||||
| if norm_layer is None: | |||||
| norm_layer = nn.BatchNorm2d | |||||
| width = int(planes * (base_width / 64.)) * groups | |||||
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||||
| self.conv1 = conv1x1(inplanes, width) | |||||
| self.bn1 = norm_layer(width) | |||||
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||||
| self.bn2 = norm_layer(width) | |||||
| self.conv3 = conv1x1(width, planes * self.expansion) | |||||
| self.bn3 = norm_layer(planes * self.expansion) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| def forward(self, x): | |||||
| identity = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.downsample is not None: | |||||
| identity = self.downsample(x) | |||||
| out += identity | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Deeplab_v3(nn.Module): | |||||
| # in_channel = 3 fine-tune | |||||
| def __init__(self, class_number=18): | |||||
| super().__init__() | |||||
| encoder = resnet50() | |||||
| self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) | |||||
| self.maxpool = encoder.maxpool | |||||
| self.low_feature1 = nn.Sequential(nn.Conv2d( | |||||
| 64, 32, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) | |||||
| self.low_feature3 = nn.Sequential(nn.Conv2d( | |||||
| 256, 64, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||||
| self.low_feature4 = nn.Sequential(nn.Conv2d( | |||||
| 512, 128, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) | |||||
| self.layer1 = encoder.layer1 #256 | |||||
| self.layer2 = encoder.layer2 #512 | |||||
| self.layer3 = encoder.layer3 #1024 | |||||
| self.layer4 = encoder.layer4 #2048 | |||||
| self.aspp = ASPP(in_channel=2048, depth=256) | |||||
| self.conv_cat4 = nn.Sequential(nn.Conv2d(256 + 128, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) | |||||
| self.conv_cat3 = nn.Sequential(nn.Conv2d(256 + 64, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), | |||||
| nn.Conv2d(256, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||||
| self.conv_cat1 = nn.Sequential(nn.Conv2d(64 + 32, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), | |||||
| nn.Conv2d(64, 18, 3, 1, padding=1)) | |||||
| def forward(self, x): | |||||
| size0 = x.shape[2:] # need upsample input size | |||||
| x1 = self.start(x) # 64, 128*128 | |||||
| x2 = self.maxpool(x1) # 64, 64*64 | |||||
| x3 = self.layer1(x2) # 256, 64*64 | |||||
| x4 = self.layer2(x3) # 512, 32*32 | |||||
| x5 = self.layer3(x4) # 1024,16*16 | |||||
| x = self.layer4(x5) # 2048,8*8 | |||||
| x = self.aspp(x) # 256, 8*8 | |||||
| low_feature1 = self.low_feature1(x1) # 64, 128*128 | |||||
| # low_feature2 = self.low_feature2(x2) # 64, 64*64 | |||||
| low_feature3 = self.low_feature3(x3) # 256, 64*64 | |||||
| low_feature4 = self.low_feature4(x4) # 512, 32*32 -> 128, 32*32 | |||||
| # low_feature5 = self.low_feature5(x5) # 1024,16*16 | |||||
| size1 = low_feature1.shape[2:] | |||||
| # size2 = low_feature2.shape[2:] | |||||
| size3 = low_feature3.shape[2:] | |||||
| size4 = low_feature4.shape[2:] | |||||
| # size5 = low_feature5.shape[2:] | |||||
| decoder_feature4 = F.interpolate(x, size=size4, mode='bilinear', align_corners=True) | |||||
| x = self.conv_cat4(torch.cat([low_feature4, decoder_feature4], dim=1)) | |||||
| decoder_feature3 = F.interpolate(x, size=size3, mode='bilinear', align_corners=True) | |||||
| x = self.conv_cat3(torch.cat([low_feature3, decoder_feature3], dim=1)) | |||||
| decoder_feature1 = F.interpolate(x, size=size1, mode='bilinear', align_corners=True) | |||||
| x = self.conv_cat1(torch.cat([low_feature1, decoder_feature1], dim=1)) | |||||
| score = F.interpolate(x, | |||||
| size=size0, | |||||
| mode='bilinear', | |||||
| align_corners=True) | |||||
| return score | |||||
| def init_model(): | |||||
| model_path = os.path.join(os.path.dirname(__file__), 'model.pkl') | |||||
| model = Deeplab_v3() | |||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
| model.to(device) | |||||
| model_state = torch.load(model_path, map_location=device) | |||||
| new_state_dict = OrderedDict() | |||||
| for k, v in model_state["model_state_dict"].items(): | |||||
| if k[:7] == "module.": | |||||
| new_state_dict[k[7:]] = v | |||||
| else: | |||||
| new_state_dict[k] = v | |||||
| model.load_state_dict(new_state_dict) | |||||
| model.eval() | |||||
| return model | |||||
| @@ -0,0 +1,53 @@ | |||||
| import os | |||||
| import torch | |||||
| import numpy as np | |||||
| from PIL import Image | |||||
| def cutImage(file_name): | |||||
| img = Image.open(file_name) | |||||
| oring_size = img.size | |||||
| n = img.size[0] // 256 | |||||
| img = img.resize((n * 256, n * 256)) | |||||
| res = [] | |||||
| for i in range(n): | |||||
| for j in range(n): | |||||
| pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1)) | |||||
| res.append(img.crop(pos)) | |||||
| return n, res, oring_size, img.size | |||||
| def mergeImage(n, imgs, o_size, n_size): | |||||
| img = Image.new('L', n_size) | |||||
| for i in range(n): | |||||
| for j in range(n): | |||||
| pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1)) | |||||
| img.paste(imgs[i * n + j], pos) | |||||
| img = img.resize(o_size) | |||||
| a = np.array(img) | |||||
| a[a > 17] = 17 | |||||
| return Image.fromarray(a) | |||||
| def predict(model, input_path, output_dir): | |||||
| name, _ = os.path.splitext(input_path) | |||||
| name = os.path.split(name)[-1] + ".png" | |||||
| n, imgs, o_size, n_size = cutImage(input_path) | |||||
| res = [] | |||||
| for img in imgs: | |||||
| img = torch.from_numpy(np.array(img)).float().unsqueeze(0) | |||||
| img = img / 255 | |||||
| if torch.cuda.is_available(): | |||||
| img = img.cuda() | |||||
| img = img.permute(0, 3, 1, 2) | |||||
| label = model(img) | |||||
| label = torch.argmax(label, | |||||
| dim=1).cpu().squeeze().numpy().astype(np.uint8) | |||||
| res.append(Image.fromarray(label)) | |||||
| img = mergeImage(n, res, o_size, n_size) | |||||
| img.save(os.path.join(output_dir, name)) | |||||
| if __name__ == "__main__": | |||||
| from model_define import init_model | |||||
| model = init_model() | |||||
| predict(model, './test.tif', '') | |||||
| @@ -0,0 +1,148 @@ | |||||
| import os | |||||
| import torch | |||||
| import time | |||||
| import matplotlib.pyplot as plt | |||||
| import numpy as np | |||||
| import glob | |||||
| from torch.autograd import Variable | |||||
| from torch.optim import lr_scheduler | |||||
| from torchvision import utils | |||||
| from skimage import color | |||||
| from losses import CombinedLoss | |||||
| def create_exp_directory(exp_dir_name): | |||||
| if not os.path.exists(exp_dir_name): | |||||
| os.makedirs(exp_dir_name) | |||||
| def plot_predictions(images_batch, labels_batch, batch_output, plt_title, | |||||
| file_save_name): | |||||
| f = plt.figure(figsize=(20, 20)) | |||||
| # n, c, h, w = images_batch.shape | |||||
| # mid_slice = c // 2 | |||||
| # images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1) | |||||
| grid = utils.make_grid(images_batch.cpu(), nrow=4) | |||||
| plt.subplot(131) | |||||
| plt.imshow(grid.numpy().transpose((1, 2, 0))) | |||||
| plt.title('Slices') | |||||
| grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0] | |||||
| color_grid = color.label2rgb(grid.numpy(), bg_label=0) | |||||
| plt.subplot(132) | |||||
| plt.imshow(color_grid) | |||||
| plt.title('Ground Truth') | |||||
| grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0] | |||||
| color_grid = color.label2rgb(grid.numpy(), bg_label=0) | |||||
| plt.subplot(133) | |||||
| plt.imshow(color_grid) | |||||
| plt.title('Prediction') | |||||
| plt.suptitle(plt_title) | |||||
| plt.tight_layout() | |||||
| f.savefig(file_save_name, bbox_inches='tight') | |||||
| plt.close(f) | |||||
| plt.gcf().clear() | |||||
| class Solver(object): | |||||
| def __init__(self, num_classes, optimizer, lr_args, optimizer_args): | |||||
| self.lr_scheduler_args = lr_args | |||||
| self.optimizer_args = optimizer_args | |||||
| self.optimizer = optimizer | |||||
| self.loss_func = CombinedLoss(weight_dice=0, weight_ce=100) | |||||
| self.num_classes = num_classes | |||||
| self.classes = list(range(self.num_classes)) | |||||
| def train(self, | |||||
| model, | |||||
| train_loader, | |||||
| num_epochs, | |||||
| log_params, | |||||
| expdir, | |||||
| resume=True): | |||||
| create_exp_directory(expdir) | |||||
| create_exp_directory(log_params["logdir"]) | |||||
| optimizer = self.optimizer(model.parameters(), **self.optimizer_args) | |||||
| scheduler = lr_scheduler.StepLR( | |||||
| optimizer, | |||||
| step_size=self.lr_scheduler_args["step_size"], | |||||
| gamma=self.lr_scheduler_args["gamma"]) | |||||
| epoch = -1 | |||||
| print('-------> Starting to train') | |||||
| if resume: | |||||
| try: | |||||
| prior_model_paths = sorted(glob.glob( | |||||
| os.path.join(expdir, 'Epoch_*')), | |||||
| key=os.path.getmtime) | |||||
| current_model = prior_model_paths.pop() | |||||
| state = torch.load(current_model) | |||||
| model.load_state_dict(state["model_state_dict"]) | |||||
| epoch = state["epoch"] | |||||
| print("Successfully Resuming from Epoch {}".format(epoch + 1)) | |||||
| except Exception as e: | |||||
| print("No model to restore. {}".format(e)) | |||||
| log_params["logger"].info("{} parameters in total".format( | |||||
| sum(x.numel() for x in model.parameters()))) | |||||
| model.train() | |||||
| while epoch < num_epochs: | |||||
| epoch = epoch + 1 | |||||
| epoch_start = time.time() | |||||
| loss_batch = np.zeros(1) | |||||
| loss_dice_batch = np.zeros(1) | |||||
| loss_ce_batch = np.zeros(1) | |||||
| for batch_idx, sample in enumerate(train_loader): | |||||
| images, labels = sample | |||||
| images, labels = Variable(images), Variable(labels) | |||||
| if torch.cuda.is_available(): | |||||
| images, labels = images.cuda(), labels.cuda() | |||||
| optimizer.zero_grad() | |||||
| predictions = model(images) | |||||
| loss_total, loss_dice, loss_ce = self.loss_func( | |||||
| inputx=predictions, target=labels) | |||||
| loss_total.backward() | |||||
| optimizer.step() | |||||
| loss_batch += loss_total.item() | |||||
| loss_dice_batch += loss_dice.item() | |||||
| loss_ce_batch += loss_ce.item() | |||||
| _, batch_output = torch.max(predictions, dim=1) | |||||
| if batch_idx == len(train_loader) - 2: | |||||
| plt_title = 'Trian Results Epoch ' + str(epoch) | |||||
| file_save_name = os.path.join( | |||||
| log_params["logdir"], | |||||
| 'Epoch_{}_Trian_Predictions.pdf'.format(epoch)) | |||||
| plot_predictions(images, labels, batch_output, plt_title, | |||||
| file_save_name) | |||||
| if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == (len(train_loader) - 1): | |||||
| log_params["logger"].info( | |||||
| "Epoch: {} lr:{} [{}/{}] ({:.0f}%)]" | |||||
| "with loss: {},\ndice_loss:{},ce_loss:{}".format( | |||||
| epoch, optimizer.param_groups[0]['lr'], batch_idx, | |||||
| len(train_loader), | |||||
| 100. * batch_idx / len(train_loader), | |||||
| loss_batch / (batch_idx + 1), | |||||
| loss_dice_batch / (batch_idx + 1), | |||||
| loss_ce_batch / (batch_idx + 1))) | |||||
| scheduler.step() | |||||
| epoch_finish = time.time() - epoch_start | |||||
| log_params["logger"].info( | |||||
| "Train Epoch {} finished in {:.04f} seconds.\n".format( | |||||
| epoch, epoch_finish)) | |||||
| # Saving Models | |||||
| if epoch % log_params["log_iter"] == 0: # 每log_iter次保存一次模型 | |||||
| save_name = os.path.join( | |||||
| expdir, | |||||
| 'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl') | |||||
| checkpoint = { | |||||
| "model_state_dict": model.state_dict(), | |||||
| "optimizer_state_dict": optimizer.state_dict(), | |||||
| "epoch": epoch | |||||
| } | |||||
| if scheduler is not None: | |||||
| checkpoint["scheduler_state_dict"] = scheduler.state_dict() | |||||
| torch.save(checkpoint, save_name) | |||||
| @@ -0,0 +1,193 @@ | |||||
| import os | |||||
| import logging | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch.utils.data.dataloader import DataLoader | |||||
| from dataset import DataTrain | |||||
| from model_define import Deeplab_v3 | |||||
| # from model_define_unet import UNet | |||||
| from solver import Solver | |||||
| from torchvision import transforms, utils | |||||
| import numpy as np | |||||
| from networks.vit_seg_modeling import VisionTransformer as ViT_seg | |||||
| from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg | |||||
| args = { | |||||
| 'batch_size': 2, | |||||
| 'log_interval': 1, | |||||
| 'log_dir': 'log', | |||||
| 'num_classes': 2, | |||||
| 'epochs': 1000, | |||||
| 'lr': 1e-5, | |||||
| 'resume': True, | |||||
| 'data_dir': "../small_data", | |||||
| 'gamma': 0.5, | |||||
| 'step': 5, | |||||
| 'vit_name': 'R50-ViT-B_16', | |||||
| 'num_classes': 2, | |||||
| 'n_skip': 3, | |||||
| 'img_size': 256, | |||||
| 'vit_patches_size': 16, | |||||
| } | |||||
| ''' | |||||
| 文件目录: | |||||
| data | |||||
| images | |||||
| *.tif | |||||
| labels | |||||
| *.png | |||||
| code | |||||
| train.py | |||||
| ... | |||||
| ''' | |||||
| class ToTensor(object): | |||||
| """ | |||||
| Convert ndarrays in sample to Tensors. | |||||
| """ | |||||
| def __call__(self, sample): | |||||
| img, label = sample['img'], sample['label'] | |||||
| img = img.astype(np.float32) | |||||
| img = img/255.0 | |||||
| # swap color axis because | |||||
| # numpy image: H x W x C | |||||
| # torch image: C X H X W | |||||
| img = img.transpose((2, 0, 1)) | |||||
| return {'img': torch.from_numpy(img), 'label': label} | |||||
| class AugmentationPadImage(object): | |||||
| """ | |||||
| Pad Image with either zero padding or reflection padding of img, label and weight | |||||
| """ | |||||
| def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="constant"): | |||||
| assert isinstance(pad_size, (int, tuple)) | |||||
| if isinstance(pad_size, int): | |||||
| # Do not pad along the channel dimension | |||||
| self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0)) | |||||
| self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size)) | |||||
| else: | |||||
| self.pad_size = pad_size | |||||
| self.pad_type = pad_type | |||||
| def __call__(self, sample): | |||||
| img, label = sample['img'], sample['label'] | |||||
| img = np.pad(img, self.pad_size_image, self.pad_type) | |||||
| label = np.pad(label, self.pad_size_mask, self.pad_type) | |||||
| return {'img': img, 'label': label} | |||||
| class AugmentationRandomCrop(object): | |||||
| """ | |||||
| Randomly Crop Image to given size | |||||
| """ | |||||
| def __init__(self, output_size, crop_type='Random'): | |||||
| assert isinstance(output_size, (int, tuple)) | |||||
| if isinstance(output_size, int): | |||||
| self.output_size = (output_size, output_size) | |||||
| else: | |||||
| self.output_size = output_size | |||||
| self.crop_type = crop_type | |||||
| def __call__(self, sample): | |||||
| img, label = sample['img'], sample['label'] | |||||
| h, w, _ = img.shape | |||||
| if self.crop_type == 'Center': | |||||
| top = (h - self.output_size[0]) // 2 | |||||
| left = (w - self.output_size[1]) // 2 | |||||
| else: | |||||
| top = np.random.randint(0, h - self.output_size[0]) | |||||
| left = np.random.randint(0, w - self.output_size[1]) | |||||
| bottom = top + self.output_size[0] | |||||
| right = left + self.output_size[1] | |||||
| # print(img.shape) | |||||
| img = img[top:bottom, left:right, :] | |||||
| label = label[top:bottom, left:right] | |||||
| # weight = weight[top:bottom, left:right] | |||||
| return {'img': img, 'label': label} | |||||
| def log_init(): | |||||
| if not os.path.exists(args['log_dir']): | |||||
| os.makedirs(args['log_dir']) | |||||
| logger = logging.getLogger("train") | |||||
| logger.setLevel(logging.DEBUG) | |||||
| logger.handlers = [] | |||||
| logger.addHandler(logging.StreamHandler()) | |||||
| logger.addHandler( | |||||
| logging.FileHandler(os.path.join(args['log_dir'], "log.txt"))) | |||||
| logger.info("%s", repr(args)) | |||||
| return logger | |||||
| def train(): | |||||
| logger = log_init() | |||||
| transform_train = transforms.Compose([AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor()]) | |||||
| dataset_train = DataTrain(args['data_dir'], transforms = transform_train) | |||||
| train_dataloader = DataLoader(dataset=dataset_train, | |||||
| batch_size=args['batch_size'], | |||||
| shuffle=True, num_workers=4) | |||||
| # model = Deeplab_v3() | |||||
| # model = UNet() | |||||
| config_vit = CONFIGS_ViT_seg[args['vit_name']] | |||||
| config_vit.n_classes = args['num_classes'] | |||||
| config_vit.n_skip = args['n_skip'] | |||||
| if args['vit_name'].find('R50') != -1: | |||||
| config_vit.patches.grid = (int(args['img_size'] / args['vit_patches_size']), int(args['img_size'] / args['vit_patches_size'])) | |||||
| model = ViT_seg(config_vit, img_size=args['img_size'], num_classes=config_vit.n_classes) | |||||
| # model.load_from(weights=np.load(config_vit.pretrained_path)) | |||||
| if torch.cuda.is_available(): | |||||
| if torch.cuda.device_count() > 1: | |||||
| model = nn.DataParallel(model) | |||||
| model.cuda() | |||||
| solver = Solver(num_classes=args['num_classes'], | |||||
| lr_args={ | |||||
| "gamma": args['gamma'], | |||||
| "step_size": args['step'] | |||||
| }, | |||||
| optimizer_args={ | |||||
| "lr": args['lr'], | |||||
| "betas": (0.9, 0.999), | |||||
| "eps": 1e-8, | |||||
| "weight_decay": 0.01 | |||||
| }, | |||||
| optimizer=torch.optim.Adam) | |||||
| solver.train(model, | |||||
| train_dataloader, | |||||
| num_epochs=args['epochs'], | |||||
| log_params={ | |||||
| 'logdir': args['log_dir'] + "/logs", | |||||
| 'log_iter': args['log_interval'], | |||||
| 'logger': logger | |||||
| }, | |||||
| expdir=args['log_dir'] + "/ckpts", | |||||
| resume=args['resume']) | |||||
| if __name__ == "__main__": | |||||
| train() | |||||