commit 8032e4a7c1a80d0b344919be4d722c9bba5a281a Author: weichonghit <53960867+weichonghit@users.noreply.github.com> Date: Thu Jul 8 11:04:18 2021 +0800 Add files via upload diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..8e03795 --- /dev/null +++ b/Readme.md @@ -0,0 +1,25 @@ +# 任务:无限遥感图像分割 + +### 环境 + +操作系统:linux + +显卡:1080ti + +python3.7 + +pytroch + +cuda10.0 + + +### 模型 + +Unet + +deeplabv3 + +### 损失函数 + +dice ce + diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..be69789 --- /dev/null +++ b/dataset.py @@ -0,0 +1,44 @@ +import os +import torch +import numpy as np +import cv2 as cv +from os.path import join +from torch.utils.data.dataset import Dataset + + +class DataTrain(Dataset): + def __init__(self, data_path, transforms=None): + self.data_dir = data_path + self.image_list = os.listdir(join(data_path, 'train_ori_images')) + files_len = len(self.image_list) + try: + imgs = np.zeros(shape=(files_len, 256, 256, 3), dtype=np.uint8) + labels = np.zeros(shape=(files_len, 256, 256), dtype=np.uint8) + for idx, file in enumerate(self.image_list): + fname = file.split('.')[0] + img = cv.imread(join(self.data_dir, 'train_ori_images', fname + '.tif')) + img = np.asarray(img, dtype=np.uint8) + label = cv.imread( + join(self.data_dir, 'train_pupil_images', fname + '.png'), + cv.IMREAD_UNCHANGED) + label = np.asarray(label, dtype=np.uint8) % 100 + imgs[idx, :, :, :] = img + labels[idx, :, :] = label + self.images = imgs + self.labels = labels + self.transforms = transforms + except Exception: + raise Exception('read error') + + def __getitem__(self, index): + img = self.images[index] + label = self.labels[index] + label[label > 0.5] = 1 + tx_sample = self.transforms({'img': img, 'label': label}) + img = tx_sample['img'] + label = tx_sample['label'] + + return img, label + + def __len__(self): + return len(self.image_list) diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..e90f3dc --- /dev/null +++ b/losses.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss +import torch.nn.functional as F +import numpy as np + + +class DiceLoss(_Loss): + def forward(self, output, target, weights=None, ignore_index=None): + eps = 0.001 + + encoded_target = output.detach() * 0 # 将variable参数从网络中隔离开,不参与参数更新。 + + if ignore_index is not None: + mask = target == ignore_index + target = target.clone() + target[mask] = 0 + encoded_target.scatter_(1, target.unsqueeze(1), + 1) # unsqueeze增加一个维度 + mask = mask.unsqueeze(1).expand_as(encoded_target) + encoded_target[mask] = 0 + + else: + encoded_target.scatter_(1, target.unsqueeze(1), + 1) # unsqueeze增加一个维度 + # scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向输出。 + + if weights is None: + weights = 1 + # print(output.min(),output.max()) + # print(output.shape,output[0,:,0,0]) + intersection = output * encoded_target + numerator = 2 * intersection.sum(0).sum(1).sum(1) + denominator = output + encoded_target + + if ignore_index is not None: + denominator[mask] = 0 + + # 计算无效的类别数量 + count1 = [] + for i in encoded_target.sum(0).sum(1).sum(1): + if i == 0: + count1.append(1) + else: + count1.append(0) + count2 = [] + for i in denominator.sum(0).sum(1).sum(1): + if i == 0: + count2.append(1) + else: + count2.append(0) + count = sum(np.array(count1) * np.array(count2)) + # print(count) + + denominator = denominator.sum(0).sum(1).sum(1) + eps + loss_per_channel = weights * (1 - (numerator / denominator) + ) # Channel-wise weights + # print(loss_per_channel) # 每一个类别的平均dice + return (loss_per_channel.sum() - count) / (output.size(1) - count) + # return loss_per_channel.sum() / output.size(1) + + +class CrossEntropy2D(nn.Module): + """ + 2D Cross-entropy loss implemented as negative log likelihood + """ + def __init__(self, weight=None, reduction='none'): + super(CrossEntropy2D, self).__init__() + self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction) + + def forward(self, inputs, targets): + return self.nll_loss(inputs, targets) + + +class CombinedLoss(nn.Module): + """ + For CrossEntropy the input has to be a long tensor + Args: + -- inputx N x C x H x W (其中N为batch_size) + -- target - N x H x W - int type + -- weight - N x H x W - float + """ + def __init__(self, weight_dice, weight_ce): + super(CombinedLoss, self).__init__() + self.cross_entropy_loss = CrossEntropy2D() + self.dice_loss = DiceLoss() + self.weight_dice = weight_dice + self.weight_ce = weight_ce + + def forward(self, inputx, target): + target = target.type(torch.LongTensor) # Typecast to long tensor + if inputx.is_cuda: + target = target.cuda() + # print(inputx.min(),inputx.max()) + + input_soft = F.softmax(inputx, dim=1) # Along Class Dimension + dice_val = torch.mean(self.dice_loss(input_soft, target)) + ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target)) + # ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target)) + total_loss = torch.add(torch.mul(dice_val, self.weight_dice), + torch.mul(ce_val, self.weight_ce)) + # print(weight.max()) + return total_loss, dice_val, ce_val diff --git a/model_define.py b/model_define.py new file mode 100644 index 0000000..3268fde --- /dev/null +++ b/model_define.py @@ -0,0 +1,336 @@ +import os +import torch +from torch import nn +import torch.nn.functional as F +from collections import OrderedDict + + +class ASPP(nn.Module): + # have bias and relu, no bn + def __init__(self, in_channel=512, depth=256): + super().__init__() + # global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True) + self.mean = nn.AdaptiveAvgPool2d((1, 1)) + self.conv = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1), + nn.ReLU(inplace=True)) + + self.atrous_block1 = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1), + nn.ReLU(inplace=True)) + self.atrous_block6 = nn.Sequential( + nn.Conv2d(in_channel, depth, 3, 1, padding=3, dilation=3), + nn.ReLU(inplace=True)) + self.atrous_block12 = nn.Sequential( + nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6), + nn.ReLU(inplace=True)) + self.atrous_block18 = nn.Sequential( + nn.Conv2d(in_channel, depth, 3, 1, padding=9, dilation=9), + nn.ReLU(inplace=True)) + + self.conv_1x1_output = nn.Sequential(nn.Conv2d(depth * 5, depth, 1, 1), + nn.ReLU(inplace=True)) + + def forward(self, x): + size = x.shape[2:] + + image_features = self.mean(x) + image_features = self.conv(image_features) + image_features = F.interpolate(image_features, + size=size, + mode='bilinear', + align_corners=True) + + atrous_block1 = self.atrous_block1(x) + + atrous_block6 = self.atrous_block6(x) + + atrous_block12 = self.atrous_block12(x) + + atrous_block18 = self.atrous_block18(x) + + net = self.conv_1x1_output( + torch.cat([ + image_features, atrous_block1, atrous_block6, atrous_block12, + atrous_block18 + ], + dim=1)) + return net + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=18, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + + return model + +def resnet50(pretrained=False, progress=True, **kwargs): + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class Deeplab_v3(nn.Module): + # in_channel = 3 fine-tune + def __init__(self, class_number=18): + super().__init__() + encoder = resnet50() + self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) + + self.maxpool = encoder.maxpool + + self.low_feature1 = nn.Sequential(nn.Conv2d( + 64, 32, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) + self.low_feature3 = nn.Sequential(nn.Conv2d( + 256, 64, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.low_feature4 = nn.Sequential(nn.Conv2d( + 512, 128, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) + + self.layer1 = encoder.layer1 #256 + self.layer2 = encoder.layer2 #512 + self.layer3 = encoder.layer3 #1024 + self.layer4 = encoder.layer4 #2048 + + self.aspp = ASPP(in_channel=2048, depth=256) + + self.conv_cat4 = nn.Sequential(nn.Conv2d(256 + 128, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) + + self.conv_cat3 = nn.Sequential(nn.Conv2d(256 + 64, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), + nn.Conv2d(256, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + + self.conv_cat1 = nn.Sequential(nn.Conv2d(64 + 32, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), + nn.Conv2d(64, 18, 3, 1, padding=1)) + + def forward(self, x): + size0 = x.shape[2:] # need upsample input size + x1 = self.start(x) # 64, 128*128 + x2 = self.maxpool(x1) # 64, 64*64 + x3 = self.layer1(x2) # 256, 64*64 + x4 = self.layer2(x3) # 512, 32*32 + x5 = self.layer3(x4) # 1024,16*16 + x = self.layer4(x5) # 2048,8*8 + x = self.aspp(x) # 256, 8*8 + + low_feature1 = self.low_feature1(x1) # 64, 128*128 + # low_feature2 = self.low_feature2(x2) # 64, 64*64 + low_feature3 = self.low_feature3(x3) # 256, 64*64 + low_feature4 = self.low_feature4(x4) # 512, 32*32 -> 128, 32*32 + # low_feature5 = self.low_feature5(x5) # 1024,16*16 + + size1 = low_feature1.shape[2:] + # size2 = low_feature2.shape[2:] + size3 = low_feature3.shape[2:] + size4 = low_feature4.shape[2:] + # size5 = low_feature5.shape[2:] + + decoder_feature4 = F.interpolate(x, size=size4, mode='bilinear', align_corners=True) + x = self.conv_cat4(torch.cat([low_feature4, decoder_feature4], dim=1)) + + decoder_feature3 = F.interpolate(x, size=size3, mode='bilinear', align_corners=True) + x = self.conv_cat3(torch.cat([low_feature3, decoder_feature3], dim=1)) + + decoder_feature1 = F.interpolate(x, size=size1, mode='bilinear', align_corners=True) + x = self.conv_cat1(torch.cat([low_feature1, decoder_feature1], dim=1)) + + score = F.interpolate(x, + size=size0, + mode='bilinear', + align_corners=True) + + return score + +def init_model(): + model_path = os.path.join(os.path.dirname(__file__), 'model.pkl') + model = Deeplab_v3() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model_state = torch.load(model_path, map_location=device) + new_state_dict = OrderedDict() + for k, v in model_state["model_state_dict"].items(): + if k[:7] == "module.": + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + model.load_state_dict(new_state_dict) + model.eval() + return model diff --git a/model_predict.py b/model_predict.py new file mode 100644 index 0000000..d1c3aed --- /dev/null +++ b/model_predict.py @@ -0,0 +1,53 @@ +import os +import torch +import numpy as np +from PIL import Image + +def cutImage(file_name): + img = Image.open(file_name) + oring_size = img.size + n = img.size[0] // 256 + img = img.resize((n * 256, n * 256)) + res = [] + for i in range(n): + for j in range(n): + pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1)) + res.append(img.crop(pos)) + return n, res, oring_size, img.size + + +def mergeImage(n, imgs, o_size, n_size): + img = Image.new('L', n_size) + for i in range(n): + for j in range(n): + pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1)) + img.paste(imgs[i * n + j], pos) + img = img.resize(o_size) + a = np.array(img) + a[a > 17] = 17 + return Image.fromarray(a) + + +def predict(model, input_path, output_dir): + name, _ = os.path.splitext(input_path) + name = os.path.split(name)[-1] + ".png" + n, imgs, o_size, n_size = cutImage(input_path) + res = [] + for img in imgs: + img = torch.from_numpy(np.array(img)).float().unsqueeze(0) + img = img / 255 + if torch.cuda.is_available(): + img = img.cuda() + img = img.permute(0, 3, 1, 2) + label = model(img) + label = torch.argmax(label, + dim=1).cpu().squeeze().numpy().astype(np.uint8) + res.append(Image.fromarray(label)) + img = mergeImage(n, res, o_size, n_size) + img.save(os.path.join(output_dir, name)) + + +if __name__ == "__main__": + from model_define import init_model + model = init_model() + predict(model, './test.tif', '') diff --git a/solver.py b/solver.py new file mode 100644 index 0000000..fc0102b --- /dev/null +++ b/solver.py @@ -0,0 +1,148 @@ +import os +import torch +import time +import matplotlib.pyplot as plt +import numpy as np +import glob +from torch.autograd import Variable +from torch.optim import lr_scheduler +from torchvision import utils +from skimage import color +from losses import CombinedLoss + + +def create_exp_directory(exp_dir_name): + if not os.path.exists(exp_dir_name): + os.makedirs(exp_dir_name) + + +def plot_predictions(images_batch, labels_batch, batch_output, plt_title, + file_save_name): + f = plt.figure(figsize=(20, 20)) + # n, c, h, w = images_batch.shape + # mid_slice = c // 2 + # images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1) + grid = utils.make_grid(images_batch.cpu(), nrow=4) + plt.subplot(131) + plt.imshow(grid.numpy().transpose((1, 2, 0))) + plt.title('Slices') + grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0] + color_grid = color.label2rgb(grid.numpy(), bg_label=0) + plt.subplot(132) + plt.imshow(color_grid) + plt.title('Ground Truth') + grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0] + color_grid = color.label2rgb(grid.numpy(), bg_label=0) + plt.subplot(133) + plt.imshow(color_grid) + plt.title('Prediction') + + plt.suptitle(plt_title) + plt.tight_layout() + + f.savefig(file_save_name, bbox_inches='tight') + plt.close(f) + plt.gcf().clear() + + +class Solver(object): + def __init__(self, num_classes, optimizer, lr_args, optimizer_args): + self.lr_scheduler_args = lr_args + self.optimizer_args = optimizer_args + self.optimizer = optimizer + self.loss_func = CombinedLoss(weight_dice=0, weight_ce=100) + self.num_classes = num_classes + self.classes = list(range(self.num_classes)) + + def train(self, + model, + train_loader, + num_epochs, + log_params, + expdir, + resume=True): + create_exp_directory(expdir) + create_exp_directory(log_params["logdir"]) + optimizer = self.optimizer(model.parameters(), **self.optimizer_args) + scheduler = lr_scheduler.StepLR( + optimizer, + step_size=self.lr_scheduler_args["step_size"], + gamma=self.lr_scheduler_args["gamma"]) + epoch = -1 + print('-------> Starting to train') + if resume: + try: + prior_model_paths = sorted(glob.glob( + os.path.join(expdir, 'Epoch_*')), + key=os.path.getmtime) + current_model = prior_model_paths.pop() + state = torch.load(current_model) + model.load_state_dict(state["model_state_dict"]) + epoch = state["epoch"] + print("Successfully Resuming from Epoch {}".format(epoch + 1)) + except Exception as e: + print("No model to restore. {}".format(e)) + log_params["logger"].info("{} parameters in total".format( + sum(x.numel() for x in model.parameters()))) + + model.train() + while epoch < num_epochs: + epoch = epoch + 1 + epoch_start = time.time() + loss_batch = np.zeros(1) + loss_dice_batch = np.zeros(1) + loss_ce_batch = np.zeros(1) + for batch_idx, sample in enumerate(train_loader): + images, labels = sample + images, labels = Variable(images), Variable(labels) + if torch.cuda.is_available(): + images, labels = images.cuda(), labels.cuda() + optimizer.zero_grad() + predictions = model(images) + loss_total, loss_dice, loss_ce = self.loss_func( + inputx=predictions, target=labels) + loss_total.backward() + optimizer.step() + + loss_batch += loss_total.item() + loss_dice_batch += loss_dice.item() + loss_ce_batch += loss_ce.item() + _, batch_output = torch.max(predictions, dim=1) + if batch_idx == len(train_loader) - 2: + plt_title = 'Trian Results Epoch ' + str(epoch) + file_save_name = os.path.join( + log_params["logdir"], + 'Epoch_{}_Trian_Predictions.pdf'.format(epoch)) + plot_predictions(images, labels, batch_output, plt_title, + file_save_name) + + if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == (len(train_loader) - 1): + log_params["logger"].info( + "Epoch: {} lr:{} [{}/{}] ({:.0f}%)]" + "with loss: {},\ndice_loss:{},ce_loss:{}".format( + epoch, optimizer.param_groups[0]['lr'], batch_idx, + len(train_loader), + 100. * batch_idx / len(train_loader), + loss_batch / (batch_idx + 1), + loss_dice_batch / (batch_idx + 1), + loss_ce_batch / (batch_idx + 1))) + + scheduler.step() + epoch_finish = time.time() - epoch_start + log_params["logger"].info( + "Train Epoch {} finished in {:.04f} seconds.\n".format( + epoch, epoch_finish)) + + # Saving Models + if epoch % log_params["log_iter"] == 0: # 每log_iter次保存一次模型 + save_name = os.path.join( + expdir, + 'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl') + checkpoint = { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "epoch": epoch + } + if scheduler is not None: + checkpoint["scheduler_state_dict"] = scheduler.state_dict() + torch.save(checkpoint, save_name) diff --git a/train.py b/train.py new file mode 100644 index 0000000..3426dcc --- /dev/null +++ b/train.py @@ -0,0 +1,193 @@ +import os +import logging +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +from dataset import DataTrain +from model_define import Deeplab_v3 +# from model_define_unet import UNet +from solver import Solver +from torchvision import transforms, utils +import numpy as np +from networks.vit_seg_modeling import VisionTransformer as ViT_seg +from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg + + +args = { + 'batch_size': 2, + 'log_interval': 1, + 'log_dir': 'log', + 'num_classes': 2, + 'epochs': 1000, + 'lr': 1e-5, + 'resume': True, + 'data_dir': "../small_data", + 'gamma': 0.5, + 'step': 5, + 'vit_name': 'R50-ViT-B_16', + 'num_classes': 2, + 'n_skip': 3, + 'img_size': 256, + 'vit_patches_size': 16, + +} +''' +文件目录: +data + images + *.tif + labels + *.png +code + train.py + ... +''' +class ToTensor(object): + """ + Convert ndarrays in sample to Tensors. + """ + + def __call__(self, sample): + img, label = sample['img'], sample['label'] + + img = img.astype(np.float32) + img = img/255.0 + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = img.transpose((2, 0, 1)) + + return {'img': torch.from_numpy(img), 'label': label} + + +class AugmentationPadImage(object): + """ + Pad Image with either zero padding or reflection padding of img, label and weight + """ + + def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="constant"): + + assert isinstance(pad_size, (int, tuple)) + + if isinstance(pad_size, int): + + # Do not pad along the channel dimension + self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0)) + self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size)) + + else: + self.pad_size = pad_size + + self.pad_type = pad_type + + def __call__(self, sample): + img, label = sample['img'], sample['label'] + + img = np.pad(img, self.pad_size_image, self.pad_type) + label = np.pad(label, self.pad_size_mask, self.pad_type) + + return {'img': img, 'label': label} + + +class AugmentationRandomCrop(object): + """ + Randomly Crop Image to given size + """ + + def __init__(self, output_size, crop_type='Random'): + + assert isinstance(output_size, (int, tuple)) + + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + + else: + self.output_size = output_size + + self.crop_type = crop_type + + def __call__(self, sample): + img, label = sample['img'], sample['label'] + + h, w, _ = img.shape + + if self.crop_type == 'Center': + top = (h - self.output_size[0]) // 2 + left = (w - self.output_size[1]) // 2 + + else: + top = np.random.randint(0, h - self.output_size[0]) + left = np.random.randint(0, w - self.output_size[1]) + + bottom = top + self.output_size[0] + right = left + self.output_size[1] + + # print(img.shape) + img = img[top:bottom, left:right, :] + label = label[top:bottom, left:right] + # weight = weight[top:bottom, left:right] + + return {'img': img, 'label': label} + +def log_init(): + if not os.path.exists(args['log_dir']): + os.makedirs(args['log_dir']) + logger = logging.getLogger("train") + logger.setLevel(logging.DEBUG) + logger.handlers = [] + logger.addHandler(logging.StreamHandler()) + logger.addHandler( + logging.FileHandler(os.path.join(args['log_dir'], "log.txt"))) + logger.info("%s", repr(args)) + return logger + + + +def train(): + logger = log_init() + transform_train = transforms.Compose([AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor()]) + dataset_train = DataTrain(args['data_dir'], transforms = transform_train) + train_dataloader = DataLoader(dataset=dataset_train, + batch_size=args['batch_size'], + shuffle=True, num_workers=4) + # model = Deeplab_v3() + # model = UNet() + config_vit = CONFIGS_ViT_seg[args['vit_name']] + config_vit.n_classes = args['num_classes'] + config_vit.n_skip = args['n_skip'] + if args['vit_name'].find('R50') != -1: + config_vit.patches.grid = (int(args['img_size'] / args['vit_patches_size']), int(args['img_size'] / args['vit_patches_size'])) + model = ViT_seg(config_vit, img_size=args['img_size'], num_classes=config_vit.n_classes) + # model.load_from(weights=np.load(config_vit.pretrained_path)) + + if torch.cuda.is_available(): + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + model.cuda() + + solver = Solver(num_classes=args['num_classes'], + lr_args={ + "gamma": args['gamma'], + "step_size": args['step'] + }, + optimizer_args={ + "lr": args['lr'], + "betas": (0.9, 0.999), + "eps": 1e-8, + "weight_decay": 0.01 + }, + optimizer=torch.optim.Adam) + solver.train(model, + train_dataloader, + num_epochs=args['epochs'], + log_params={ + 'logdir': args['log_dir'] + "/logs", + 'log_iter': args['log_interval'], + 'logger': logger + }, + expdir=args['log_dir'] + "/ckpts", + resume=args['resume']) + + +if __name__ == "__main__": + train()