| @@ -0,0 +1,25 @@ | |||
| # 任务:无限遥感图像分割 | |||
| ### 环境 | |||
| 操作系统:linux | |||
| 显卡:1080ti | |||
| python3.7 | |||
| pytroch | |||
| cuda10.0 | |||
| ### 模型 | |||
| Unet | |||
| deeplabv3 | |||
| ### 损失函数 | |||
| dice ce | |||
| @@ -0,0 +1,44 @@ | |||
| import os | |||
| import torch | |||
| import numpy as np | |||
| import cv2 as cv | |||
| from os.path import join | |||
| from torch.utils.data.dataset import Dataset | |||
| class DataTrain(Dataset): | |||
| def __init__(self, data_path, transforms=None): | |||
| self.data_dir = data_path | |||
| self.image_list = os.listdir(join(data_path, 'train_ori_images')) | |||
| files_len = len(self.image_list) | |||
| try: | |||
| imgs = np.zeros(shape=(files_len, 256, 256, 3), dtype=np.uint8) | |||
| labels = np.zeros(shape=(files_len, 256, 256), dtype=np.uint8) | |||
| for idx, file in enumerate(self.image_list): | |||
| fname = file.split('.')[0] | |||
| img = cv.imread(join(self.data_dir, 'train_ori_images', fname + '.tif')) | |||
| img = np.asarray(img, dtype=np.uint8) | |||
| label = cv.imread( | |||
| join(self.data_dir, 'train_pupil_images', fname + '.png'), | |||
| cv.IMREAD_UNCHANGED) | |||
| label = np.asarray(label, dtype=np.uint8) % 100 | |||
| imgs[idx, :, :, :] = img | |||
| labels[idx, :, :] = label | |||
| self.images = imgs | |||
| self.labels = labels | |||
| self.transforms = transforms | |||
| except Exception: | |||
| raise Exception('read error') | |||
| def __getitem__(self, index): | |||
| img = self.images[index] | |||
| label = self.labels[index] | |||
| label[label > 0.5] = 1 | |||
| tx_sample = self.transforms({'img': img, 'label': label}) | |||
| img = tx_sample['img'] | |||
| label = tx_sample['label'] | |||
| return img, label | |||
| def __len__(self): | |||
| return len(self.image_list) | |||
| @@ -0,0 +1,103 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch.nn.modules.loss import _Loss | |||
| import torch.nn.functional as F | |||
| import numpy as np | |||
| class DiceLoss(_Loss): | |||
| def forward(self, output, target, weights=None, ignore_index=None): | |||
| eps = 0.001 | |||
| encoded_target = output.detach() * 0 # 将variable参数从网络中隔离开,不参与参数更新。 | |||
| if ignore_index is not None: | |||
| mask = target == ignore_index | |||
| target = target.clone() | |||
| target[mask] = 0 | |||
| encoded_target.scatter_(1, target.unsqueeze(1), | |||
| 1) # unsqueeze增加一个维度 | |||
| mask = mask.unsqueeze(1).expand_as(encoded_target) | |||
| encoded_target[mask] = 0 | |||
| else: | |||
| encoded_target.scatter_(1, target.unsqueeze(1), | |||
| 1) # unsqueeze增加一个维度 | |||
| # scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向输出。 | |||
| if weights is None: | |||
| weights = 1 | |||
| # print(output.min(),output.max()) | |||
| # print(output.shape,output[0,:,0,0]) | |||
| intersection = output * encoded_target | |||
| numerator = 2 * intersection.sum(0).sum(1).sum(1) | |||
| denominator = output + encoded_target | |||
| if ignore_index is not None: | |||
| denominator[mask] = 0 | |||
| # 计算无效的类别数量 | |||
| count1 = [] | |||
| for i in encoded_target.sum(0).sum(1).sum(1): | |||
| if i == 0: | |||
| count1.append(1) | |||
| else: | |||
| count1.append(0) | |||
| count2 = [] | |||
| for i in denominator.sum(0).sum(1).sum(1): | |||
| if i == 0: | |||
| count2.append(1) | |||
| else: | |||
| count2.append(0) | |||
| count = sum(np.array(count1) * np.array(count2)) | |||
| # print(count) | |||
| denominator = denominator.sum(0).sum(1).sum(1) + eps | |||
| loss_per_channel = weights * (1 - (numerator / denominator) | |||
| ) # Channel-wise weights | |||
| # print(loss_per_channel) # 每一个类别的平均dice | |||
| return (loss_per_channel.sum() - count) / (output.size(1) - count) | |||
| # return loss_per_channel.sum() / output.size(1) | |||
| class CrossEntropy2D(nn.Module): | |||
| """ | |||
| 2D Cross-entropy loss implemented as negative log likelihood | |||
| """ | |||
| def __init__(self, weight=None, reduction='none'): | |||
| super(CrossEntropy2D, self).__init__() | |||
| self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction) | |||
| def forward(self, inputs, targets): | |||
| return self.nll_loss(inputs, targets) | |||
| class CombinedLoss(nn.Module): | |||
| """ | |||
| For CrossEntropy the input has to be a long tensor | |||
| Args: | |||
| -- inputx N x C x H x W (其中N为batch_size) | |||
| -- target - N x H x W - int type | |||
| -- weight - N x H x W - float | |||
| """ | |||
| def __init__(self, weight_dice, weight_ce): | |||
| super(CombinedLoss, self).__init__() | |||
| self.cross_entropy_loss = CrossEntropy2D() | |||
| self.dice_loss = DiceLoss() | |||
| self.weight_dice = weight_dice | |||
| self.weight_ce = weight_ce | |||
| def forward(self, inputx, target): | |||
| target = target.type(torch.LongTensor) # Typecast to long tensor | |||
| if inputx.is_cuda: | |||
| target = target.cuda() | |||
| # print(inputx.min(),inputx.max()) | |||
| input_soft = F.softmax(inputx, dim=1) # Along Class Dimension | |||
| dice_val = torch.mean(self.dice_loss(input_soft, target)) | |||
| ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target)) | |||
| # ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target)) | |||
| total_loss = torch.add(torch.mul(dice_val, self.weight_dice), | |||
| torch.mul(ce_val, self.weight_ce)) | |||
| # print(weight.max()) | |||
| return total_loss, dice_val, ce_val | |||
| @@ -0,0 +1,336 @@ | |||
| import os | |||
| import torch | |||
| from torch import nn | |||
| import torch.nn.functional as F | |||
| from collections import OrderedDict | |||
| class ASPP(nn.Module): | |||
| # have bias and relu, no bn | |||
| def __init__(self, in_channel=512, depth=256): | |||
| super().__init__() | |||
| # global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True) | |||
| self.mean = nn.AdaptiveAvgPool2d((1, 1)) | |||
| self.conv = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1), | |||
| nn.ReLU(inplace=True)) | |||
| self.atrous_block1 = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1), | |||
| nn.ReLU(inplace=True)) | |||
| self.atrous_block6 = nn.Sequential( | |||
| nn.Conv2d(in_channel, depth, 3, 1, padding=3, dilation=3), | |||
| nn.ReLU(inplace=True)) | |||
| self.atrous_block12 = nn.Sequential( | |||
| nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6), | |||
| nn.ReLU(inplace=True)) | |||
| self.atrous_block18 = nn.Sequential( | |||
| nn.Conv2d(in_channel, depth, 3, 1, padding=9, dilation=9), | |||
| nn.ReLU(inplace=True)) | |||
| self.conv_1x1_output = nn.Sequential(nn.Conv2d(depth * 5, depth, 1, 1), | |||
| nn.ReLU(inplace=True)) | |||
| def forward(self, x): | |||
| size = x.shape[2:] | |||
| image_features = self.mean(x) | |||
| image_features = self.conv(image_features) | |||
| image_features = F.interpolate(image_features, | |||
| size=size, | |||
| mode='bilinear', | |||
| align_corners=True) | |||
| atrous_block1 = self.atrous_block1(x) | |||
| atrous_block6 = self.atrous_block6(x) | |||
| atrous_block12 = self.atrous_block12(x) | |||
| atrous_block18 = self.atrous_block18(x) | |||
| net = self.conv_1x1_output( | |||
| torch.cat([ | |||
| image_features, atrous_block1, atrous_block6, atrous_block12, | |||
| atrous_block18 | |||
| ], | |||
| dim=1)) | |||
| return net | |||
| class ResNet(nn.Module): | |||
| def __init__(self, block, layers, num_classes=18, zero_init_residual=False, | |||
| groups=1, width_per_group=64, replace_stride_with_dilation=None, | |||
| norm_layer=None): | |||
| super(ResNet, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| self._norm_layer = norm_layer | |||
| self.inplanes = 64 | |||
| self.dilation = 1 | |||
| if replace_stride_with_dilation is None: | |||
| # each element in the tuple indicates if we should replace | |||
| # the 2x2 stride with a dilated convolution instead | |||
| replace_stride_with_dilation = [False, False, False] | |||
| if len(replace_stride_with_dilation) != 3: | |||
| raise ValueError("replace_stride_with_dilation should be None " | |||
| "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | |||
| self.groups = groups | |||
| self.base_width = width_per_group | |||
| self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, | |||
| bias=False) | |||
| self.bn1 = norm_layer(self.inplanes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, | |||
| dilate=replace_stride_with_dilation[0]) | |||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, | |||
| dilate=replace_stride_with_dilation[1]) | |||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, | |||
| dilate=replace_stride_with_dilation[2]) | |||
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |||
| self.fc = nn.Linear(512 * block.expansion, num_classes) | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |||
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| # Zero-initialize the last BN in each residual branch, | |||
| # so that the residual branch starts with zeros, and each residual block behaves like an identity. | |||
| # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | |||
| if zero_init_residual: | |||
| for m in self.modules(): | |||
| if isinstance(m, Bottleneck): | |||
| nn.init.constant_(m.bn3.weight, 0) | |||
| elif isinstance(m, BasicBlock): | |||
| nn.init.constant_(m.bn2.weight, 0) | |||
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |||
| norm_layer = self._norm_layer | |||
| downsample = None | |||
| previous_dilation = self.dilation | |||
| if dilate: | |||
| self.dilation *= stride | |||
| stride = 1 | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.Sequential( | |||
| conv1x1(self.inplanes, planes * block.expansion, stride), | |||
| norm_layer(planes * block.expansion), | |||
| ) | |||
| layers = [] | |||
| layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | |||
| self.base_width, previous_dilation, norm_layer)) | |||
| self.inplanes = planes * block.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append(block(self.inplanes, planes, groups=self.groups, | |||
| base_width=self.base_width, dilation=self.dilation, | |||
| norm_layer=norm_layer)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| x = self.layer4(x) | |||
| x = self.avgpool(x) | |||
| x = torch.flatten(x, 1) | |||
| x = self.fc(x) | |||
| return x | |||
| def _resnet(arch, block, layers, pretrained, progress, **kwargs): | |||
| model = ResNet(block, layers, **kwargs) | |||
| return model | |||
| def resnet50(pretrained=False, progress=True, **kwargs): | |||
| return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, | |||
| **kwargs) | |||
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |||
| """3x3 convolution with padding""" | |||
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||
| padding=dilation, groups=groups, bias=False, dilation=dilation) | |||
| def conv1x1(in_planes, out_planes, stride=1): | |||
| """1x1 convolution""" | |||
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |||
| class BasicBlock(nn.Module): | |||
| expansion = 1 | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |||
| base_width=64, dilation=1, norm_layer=None): | |||
| super(BasicBlock, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| if groups != 1 or base_width != 64: | |||
| raise ValueError('BasicBlock only supports groups=1 and base_width=64') | |||
| if dilation > 1: | |||
| raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |||
| # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||
| self.bn1 = norm_layer(planes) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.conv2 = conv3x3(planes, planes) | |||
| self.bn2 = norm_layer(planes) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class Bottleneck(nn.Module): | |||
| expansion = 4 | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | |||
| base_width=64, dilation=1, norm_layer=None): | |||
| super(Bottleneck, self).__init__() | |||
| if norm_layer is None: | |||
| norm_layer = nn.BatchNorm2d | |||
| width = int(planes * (base_width / 64.)) * groups | |||
| # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |||
| self.conv1 = conv1x1(inplanes, width) | |||
| self.bn1 = norm_layer(width) | |||
| self.conv2 = conv3x3(width, width, stride, groups, dilation) | |||
| self.bn2 = norm_layer(width) | |||
| self.conv3 = conv1x1(width, planes * self.expansion) | |||
| self.bn3 = norm_layer(planes * self.expansion) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.downsample = downsample | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| out = self.relu(out) | |||
| out = self.conv3(out) | |||
| out = self.bn3(out) | |||
| if self.downsample is not None: | |||
| identity = self.downsample(x) | |||
| out += identity | |||
| out = self.relu(out) | |||
| return out | |||
| class Deeplab_v3(nn.Module): | |||
| # in_channel = 3 fine-tune | |||
| def __init__(self, class_number=18): | |||
| super().__init__() | |||
| encoder = resnet50() | |||
| self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) | |||
| self.maxpool = encoder.maxpool | |||
| self.low_feature1 = nn.Sequential(nn.Conv2d( | |||
| 64, 32, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) | |||
| self.low_feature3 = nn.Sequential(nn.Conv2d( | |||
| 256, 64, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
| self.low_feature4 = nn.Sequential(nn.Conv2d( | |||
| 512, 128, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) | |||
| self.layer1 = encoder.layer1 #256 | |||
| self.layer2 = encoder.layer2 #512 | |||
| self.layer3 = encoder.layer3 #1024 | |||
| self.layer4 = encoder.layer4 #2048 | |||
| self.aspp = ASPP(in_channel=2048, depth=256) | |||
| self.conv_cat4 = nn.Sequential(nn.Conv2d(256 + 128, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) | |||
| self.conv_cat3 = nn.Sequential(nn.Conv2d(256 + 64, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), | |||
| nn.Conv2d(256, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
| self.conv_cat1 = nn.Sequential(nn.Conv2d(64 + 32, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), | |||
| nn.Conv2d(64, 18, 3, 1, padding=1)) | |||
| def forward(self, x): | |||
| size0 = x.shape[2:] # need upsample input size | |||
| x1 = self.start(x) # 64, 128*128 | |||
| x2 = self.maxpool(x1) # 64, 64*64 | |||
| x3 = self.layer1(x2) # 256, 64*64 | |||
| x4 = self.layer2(x3) # 512, 32*32 | |||
| x5 = self.layer3(x4) # 1024,16*16 | |||
| x = self.layer4(x5) # 2048,8*8 | |||
| x = self.aspp(x) # 256, 8*8 | |||
| low_feature1 = self.low_feature1(x1) # 64, 128*128 | |||
| # low_feature2 = self.low_feature2(x2) # 64, 64*64 | |||
| low_feature3 = self.low_feature3(x3) # 256, 64*64 | |||
| low_feature4 = self.low_feature4(x4) # 512, 32*32 -> 128, 32*32 | |||
| # low_feature5 = self.low_feature5(x5) # 1024,16*16 | |||
| size1 = low_feature1.shape[2:] | |||
| # size2 = low_feature2.shape[2:] | |||
| size3 = low_feature3.shape[2:] | |||
| size4 = low_feature4.shape[2:] | |||
| # size5 = low_feature5.shape[2:] | |||
| decoder_feature4 = F.interpolate(x, size=size4, mode='bilinear', align_corners=True) | |||
| x = self.conv_cat4(torch.cat([low_feature4, decoder_feature4], dim=1)) | |||
| decoder_feature3 = F.interpolate(x, size=size3, mode='bilinear', align_corners=True) | |||
| x = self.conv_cat3(torch.cat([low_feature3, decoder_feature3], dim=1)) | |||
| decoder_feature1 = F.interpolate(x, size=size1, mode='bilinear', align_corners=True) | |||
| x = self.conv_cat1(torch.cat([low_feature1, decoder_feature1], dim=1)) | |||
| score = F.interpolate(x, | |||
| size=size0, | |||
| mode='bilinear', | |||
| align_corners=True) | |||
| return score | |||
| def init_model(): | |||
| model_path = os.path.join(os.path.dirname(__file__), 'model.pkl') | |||
| model = Deeplab_v3() | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| model.to(device) | |||
| model_state = torch.load(model_path, map_location=device) | |||
| new_state_dict = OrderedDict() | |||
| for k, v in model_state["model_state_dict"].items(): | |||
| if k[:7] == "module.": | |||
| new_state_dict[k[7:]] = v | |||
| else: | |||
| new_state_dict[k] = v | |||
| model.load_state_dict(new_state_dict) | |||
| model.eval() | |||
| return model | |||
| @@ -0,0 +1,53 @@ | |||
| import os | |||
| import torch | |||
| import numpy as np | |||
| from PIL import Image | |||
| def cutImage(file_name): | |||
| img = Image.open(file_name) | |||
| oring_size = img.size | |||
| n = img.size[0] // 256 | |||
| img = img.resize((n * 256, n * 256)) | |||
| res = [] | |||
| for i in range(n): | |||
| for j in range(n): | |||
| pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1)) | |||
| res.append(img.crop(pos)) | |||
| return n, res, oring_size, img.size | |||
| def mergeImage(n, imgs, o_size, n_size): | |||
| img = Image.new('L', n_size) | |||
| for i in range(n): | |||
| for j in range(n): | |||
| pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1)) | |||
| img.paste(imgs[i * n + j], pos) | |||
| img = img.resize(o_size) | |||
| a = np.array(img) | |||
| a[a > 17] = 17 | |||
| return Image.fromarray(a) | |||
| def predict(model, input_path, output_dir): | |||
| name, _ = os.path.splitext(input_path) | |||
| name = os.path.split(name)[-1] + ".png" | |||
| n, imgs, o_size, n_size = cutImage(input_path) | |||
| res = [] | |||
| for img in imgs: | |||
| img = torch.from_numpy(np.array(img)).float().unsqueeze(0) | |||
| img = img / 255 | |||
| if torch.cuda.is_available(): | |||
| img = img.cuda() | |||
| img = img.permute(0, 3, 1, 2) | |||
| label = model(img) | |||
| label = torch.argmax(label, | |||
| dim=1).cpu().squeeze().numpy().astype(np.uint8) | |||
| res.append(Image.fromarray(label)) | |||
| img = mergeImage(n, res, o_size, n_size) | |||
| img.save(os.path.join(output_dir, name)) | |||
| if __name__ == "__main__": | |||
| from model_define import init_model | |||
| model = init_model() | |||
| predict(model, './test.tif', '') | |||
| @@ -0,0 +1,148 @@ | |||
| import os | |||
| import torch | |||
| import time | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| import glob | |||
| from torch.autograd import Variable | |||
| from torch.optim import lr_scheduler | |||
| from torchvision import utils | |||
| from skimage import color | |||
| from losses import CombinedLoss | |||
| def create_exp_directory(exp_dir_name): | |||
| if not os.path.exists(exp_dir_name): | |||
| os.makedirs(exp_dir_name) | |||
| def plot_predictions(images_batch, labels_batch, batch_output, plt_title, | |||
| file_save_name): | |||
| f = plt.figure(figsize=(20, 20)) | |||
| # n, c, h, w = images_batch.shape | |||
| # mid_slice = c // 2 | |||
| # images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1) | |||
| grid = utils.make_grid(images_batch.cpu(), nrow=4) | |||
| plt.subplot(131) | |||
| plt.imshow(grid.numpy().transpose((1, 2, 0))) | |||
| plt.title('Slices') | |||
| grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0] | |||
| color_grid = color.label2rgb(grid.numpy(), bg_label=0) | |||
| plt.subplot(132) | |||
| plt.imshow(color_grid) | |||
| plt.title('Ground Truth') | |||
| grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0] | |||
| color_grid = color.label2rgb(grid.numpy(), bg_label=0) | |||
| plt.subplot(133) | |||
| plt.imshow(color_grid) | |||
| plt.title('Prediction') | |||
| plt.suptitle(plt_title) | |||
| plt.tight_layout() | |||
| f.savefig(file_save_name, bbox_inches='tight') | |||
| plt.close(f) | |||
| plt.gcf().clear() | |||
| class Solver(object): | |||
| def __init__(self, num_classes, optimizer, lr_args, optimizer_args): | |||
| self.lr_scheduler_args = lr_args | |||
| self.optimizer_args = optimizer_args | |||
| self.optimizer = optimizer | |||
| self.loss_func = CombinedLoss(weight_dice=0, weight_ce=100) | |||
| self.num_classes = num_classes | |||
| self.classes = list(range(self.num_classes)) | |||
| def train(self, | |||
| model, | |||
| train_loader, | |||
| num_epochs, | |||
| log_params, | |||
| expdir, | |||
| resume=True): | |||
| create_exp_directory(expdir) | |||
| create_exp_directory(log_params["logdir"]) | |||
| optimizer = self.optimizer(model.parameters(), **self.optimizer_args) | |||
| scheduler = lr_scheduler.StepLR( | |||
| optimizer, | |||
| step_size=self.lr_scheduler_args["step_size"], | |||
| gamma=self.lr_scheduler_args["gamma"]) | |||
| epoch = -1 | |||
| print('-------> Starting to train') | |||
| if resume: | |||
| try: | |||
| prior_model_paths = sorted(glob.glob( | |||
| os.path.join(expdir, 'Epoch_*')), | |||
| key=os.path.getmtime) | |||
| current_model = prior_model_paths.pop() | |||
| state = torch.load(current_model) | |||
| model.load_state_dict(state["model_state_dict"]) | |||
| epoch = state["epoch"] | |||
| print("Successfully Resuming from Epoch {}".format(epoch + 1)) | |||
| except Exception as e: | |||
| print("No model to restore. {}".format(e)) | |||
| log_params["logger"].info("{} parameters in total".format( | |||
| sum(x.numel() for x in model.parameters()))) | |||
| model.train() | |||
| while epoch < num_epochs: | |||
| epoch = epoch + 1 | |||
| epoch_start = time.time() | |||
| loss_batch = np.zeros(1) | |||
| loss_dice_batch = np.zeros(1) | |||
| loss_ce_batch = np.zeros(1) | |||
| for batch_idx, sample in enumerate(train_loader): | |||
| images, labels = sample | |||
| images, labels = Variable(images), Variable(labels) | |||
| if torch.cuda.is_available(): | |||
| images, labels = images.cuda(), labels.cuda() | |||
| optimizer.zero_grad() | |||
| predictions = model(images) | |||
| loss_total, loss_dice, loss_ce = self.loss_func( | |||
| inputx=predictions, target=labels) | |||
| loss_total.backward() | |||
| optimizer.step() | |||
| loss_batch += loss_total.item() | |||
| loss_dice_batch += loss_dice.item() | |||
| loss_ce_batch += loss_ce.item() | |||
| _, batch_output = torch.max(predictions, dim=1) | |||
| if batch_idx == len(train_loader) - 2: | |||
| plt_title = 'Trian Results Epoch ' + str(epoch) | |||
| file_save_name = os.path.join( | |||
| log_params["logdir"], | |||
| 'Epoch_{}_Trian_Predictions.pdf'.format(epoch)) | |||
| plot_predictions(images, labels, batch_output, plt_title, | |||
| file_save_name) | |||
| if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == (len(train_loader) - 1): | |||
| log_params["logger"].info( | |||
| "Epoch: {} lr:{} [{}/{}] ({:.0f}%)]" | |||
| "with loss: {},\ndice_loss:{},ce_loss:{}".format( | |||
| epoch, optimizer.param_groups[0]['lr'], batch_idx, | |||
| len(train_loader), | |||
| 100. * batch_idx / len(train_loader), | |||
| loss_batch / (batch_idx + 1), | |||
| loss_dice_batch / (batch_idx + 1), | |||
| loss_ce_batch / (batch_idx + 1))) | |||
| scheduler.step() | |||
| epoch_finish = time.time() - epoch_start | |||
| log_params["logger"].info( | |||
| "Train Epoch {} finished in {:.04f} seconds.\n".format( | |||
| epoch, epoch_finish)) | |||
| # Saving Models | |||
| if epoch % log_params["log_iter"] == 0: # 每log_iter次保存一次模型 | |||
| save_name = os.path.join( | |||
| expdir, | |||
| 'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl') | |||
| checkpoint = { | |||
| "model_state_dict": model.state_dict(), | |||
| "optimizer_state_dict": optimizer.state_dict(), | |||
| "epoch": epoch | |||
| } | |||
| if scheduler is not None: | |||
| checkpoint["scheduler_state_dict"] = scheduler.state_dict() | |||
| torch.save(checkpoint, save_name) | |||
| @@ -0,0 +1,193 @@ | |||
| import os | |||
| import logging | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch.utils.data.dataloader import DataLoader | |||
| from dataset import DataTrain | |||
| from model_define import Deeplab_v3 | |||
| # from model_define_unet import UNet | |||
| from solver import Solver | |||
| from torchvision import transforms, utils | |||
| import numpy as np | |||
| from networks.vit_seg_modeling import VisionTransformer as ViT_seg | |||
| from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg | |||
| args = { | |||
| 'batch_size': 2, | |||
| 'log_interval': 1, | |||
| 'log_dir': 'log', | |||
| 'num_classes': 2, | |||
| 'epochs': 1000, | |||
| 'lr': 1e-5, | |||
| 'resume': True, | |||
| 'data_dir': "../small_data", | |||
| 'gamma': 0.5, | |||
| 'step': 5, | |||
| 'vit_name': 'R50-ViT-B_16', | |||
| 'num_classes': 2, | |||
| 'n_skip': 3, | |||
| 'img_size': 256, | |||
| 'vit_patches_size': 16, | |||
| } | |||
| ''' | |||
| 文件目录: | |||
| data | |||
| images | |||
| *.tif | |||
| labels | |||
| *.png | |||
| code | |||
| train.py | |||
| ... | |||
| ''' | |||
| class ToTensor(object): | |||
| """ | |||
| Convert ndarrays in sample to Tensors. | |||
| """ | |||
| def __call__(self, sample): | |||
| img, label = sample['img'], sample['label'] | |||
| img = img.astype(np.float32) | |||
| img = img/255.0 | |||
| # swap color axis because | |||
| # numpy image: H x W x C | |||
| # torch image: C X H X W | |||
| img = img.transpose((2, 0, 1)) | |||
| return {'img': torch.from_numpy(img), 'label': label} | |||
| class AugmentationPadImage(object): | |||
| """ | |||
| Pad Image with either zero padding or reflection padding of img, label and weight | |||
| """ | |||
| def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="constant"): | |||
| assert isinstance(pad_size, (int, tuple)) | |||
| if isinstance(pad_size, int): | |||
| # Do not pad along the channel dimension | |||
| self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0)) | |||
| self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size)) | |||
| else: | |||
| self.pad_size = pad_size | |||
| self.pad_type = pad_type | |||
| def __call__(self, sample): | |||
| img, label = sample['img'], sample['label'] | |||
| img = np.pad(img, self.pad_size_image, self.pad_type) | |||
| label = np.pad(label, self.pad_size_mask, self.pad_type) | |||
| return {'img': img, 'label': label} | |||
| class AugmentationRandomCrop(object): | |||
| """ | |||
| Randomly Crop Image to given size | |||
| """ | |||
| def __init__(self, output_size, crop_type='Random'): | |||
| assert isinstance(output_size, (int, tuple)) | |||
| if isinstance(output_size, int): | |||
| self.output_size = (output_size, output_size) | |||
| else: | |||
| self.output_size = output_size | |||
| self.crop_type = crop_type | |||
| def __call__(self, sample): | |||
| img, label = sample['img'], sample['label'] | |||
| h, w, _ = img.shape | |||
| if self.crop_type == 'Center': | |||
| top = (h - self.output_size[0]) // 2 | |||
| left = (w - self.output_size[1]) // 2 | |||
| else: | |||
| top = np.random.randint(0, h - self.output_size[0]) | |||
| left = np.random.randint(0, w - self.output_size[1]) | |||
| bottom = top + self.output_size[0] | |||
| right = left + self.output_size[1] | |||
| # print(img.shape) | |||
| img = img[top:bottom, left:right, :] | |||
| label = label[top:bottom, left:right] | |||
| # weight = weight[top:bottom, left:right] | |||
| return {'img': img, 'label': label} | |||
| def log_init(): | |||
| if not os.path.exists(args['log_dir']): | |||
| os.makedirs(args['log_dir']) | |||
| logger = logging.getLogger("train") | |||
| logger.setLevel(logging.DEBUG) | |||
| logger.handlers = [] | |||
| logger.addHandler(logging.StreamHandler()) | |||
| logger.addHandler( | |||
| logging.FileHandler(os.path.join(args['log_dir'], "log.txt"))) | |||
| logger.info("%s", repr(args)) | |||
| return logger | |||
| def train(): | |||
| logger = log_init() | |||
| transform_train = transforms.Compose([AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor()]) | |||
| dataset_train = DataTrain(args['data_dir'], transforms = transform_train) | |||
| train_dataloader = DataLoader(dataset=dataset_train, | |||
| batch_size=args['batch_size'], | |||
| shuffle=True, num_workers=4) | |||
| # model = Deeplab_v3() | |||
| # model = UNet() | |||
| config_vit = CONFIGS_ViT_seg[args['vit_name']] | |||
| config_vit.n_classes = args['num_classes'] | |||
| config_vit.n_skip = args['n_skip'] | |||
| if args['vit_name'].find('R50') != -1: | |||
| config_vit.patches.grid = (int(args['img_size'] / args['vit_patches_size']), int(args['img_size'] / args['vit_patches_size'])) | |||
| model = ViT_seg(config_vit, img_size=args['img_size'], num_classes=config_vit.n_classes) | |||
| # model.load_from(weights=np.load(config_vit.pretrained_path)) | |||
| if torch.cuda.is_available(): | |||
| if torch.cuda.device_count() > 1: | |||
| model = nn.DataParallel(model) | |||
| model.cuda() | |||
| solver = Solver(num_classes=args['num_classes'], | |||
| lr_args={ | |||
| "gamma": args['gamma'], | |||
| "step_size": args['step'] | |||
| }, | |||
| optimizer_args={ | |||
| "lr": args['lr'], | |||
| "betas": (0.9, 0.999), | |||
| "eps": 1e-8, | |||
| "weight_decay": 0.01 | |||
| }, | |||
| optimizer=torch.optim.Adam) | |||
| solver.train(model, | |||
| train_dataloader, | |||
| num_epochs=args['epochs'], | |||
| log_params={ | |||
| 'logdir': args['log_dir'] + "/logs", | |||
| 'log_iter': args['log_interval'], | |||
| 'logger': logger | |||
| }, | |||
| expdir=args['log_dir'] + "/ckpts", | |||
| resume=args['resume']) | |||
| if __name__ == "__main__": | |||
| train() | |||