Browse Source

Add files via upload

main
weichonghit GitHub 4 years ago
commit
8032e4a7c1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 902 additions and 0 deletions
  1. +25
    -0
      Readme.md
  2. +44
    -0
      dataset.py
  3. +103
    -0
      losses.py
  4. +336
    -0
      model_define.py
  5. +53
    -0
      model_predict.py
  6. +148
    -0
      solver.py
  7. +193
    -0
      train.py

+ 25
- 0
Readme.md View File

@@ -0,0 +1,25 @@
# 任务:无限遥感图像分割

### 环境

操作系统:linux

显卡:1080ti

python3.7

pytroch

cuda10.0


### 模型

Unet

deeplabv3

### 损失函数

dice ce


+ 44
- 0
dataset.py View File

@@ -0,0 +1,44 @@
import os
import torch
import numpy as np
import cv2 as cv
from os.path import join
from torch.utils.data.dataset import Dataset


class DataTrain(Dataset):
def __init__(self, data_path, transforms=None):
self.data_dir = data_path
self.image_list = os.listdir(join(data_path, 'train_ori_images'))
files_len = len(self.image_list)
try:
imgs = np.zeros(shape=(files_len, 256, 256, 3), dtype=np.uint8)
labels = np.zeros(shape=(files_len, 256, 256), dtype=np.uint8)
for idx, file in enumerate(self.image_list):
fname = file.split('.')[0]
img = cv.imread(join(self.data_dir, 'train_ori_images', fname + '.tif'))
img = np.asarray(img, dtype=np.uint8)
label = cv.imread(
join(self.data_dir, 'train_pupil_images', fname + '.png'),
cv.IMREAD_UNCHANGED)
label = np.asarray(label, dtype=np.uint8) % 100
imgs[idx, :, :, :] = img
labels[idx, :, :] = label
self.images = imgs
self.labels = labels
self.transforms = transforms
except Exception:
raise Exception('read error')

def __getitem__(self, index):
img = self.images[index]
label = self.labels[index]
label[label > 0.5] = 1
tx_sample = self.transforms({'img': img, 'label': label})
img = tx_sample['img']
label = tx_sample['label']

return img, label

def __len__(self):
return len(self.image_list)

+ 103
- 0
losses.py View File

@@ -0,0 +1,103 @@
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
import numpy as np


class DiceLoss(_Loss):
def forward(self, output, target, weights=None, ignore_index=None):
eps = 0.001

encoded_target = output.detach() * 0 # 将variable参数从网络中隔离开,不参与参数更新。

if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1),
1) # unsqueeze增加一个维度
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0

else:
encoded_target.scatter_(1, target.unsqueeze(1),
1) # unsqueeze增加一个维度
# scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向输出。

if weights is None:
weights = 1
# print(output.min(),output.max())
# print(output.shape,output[0,:,0,0])
intersection = output * encoded_target
numerator = 2 * intersection.sum(0).sum(1).sum(1)
denominator = output + encoded_target

if ignore_index is not None:
denominator[mask] = 0

# 计算无效的类别数量
count1 = []
for i in encoded_target.sum(0).sum(1).sum(1):
if i == 0:
count1.append(1)
else:
count1.append(0)
count2 = []
for i in denominator.sum(0).sum(1).sum(1):
if i == 0:
count2.append(1)
else:
count2.append(0)
count = sum(np.array(count1) * np.array(count2))
# print(count)

denominator = denominator.sum(0).sum(1).sum(1) + eps
loss_per_channel = weights * (1 - (numerator / denominator)
) # Channel-wise weights
# print(loss_per_channel) # 每一个类别的平均dice
return (loss_per_channel.sum() - count) / (output.size(1) - count)
# return loss_per_channel.sum() / output.size(1)


class CrossEntropy2D(nn.Module):
"""
2D Cross-entropy loss implemented as negative log likelihood
"""
def __init__(self, weight=None, reduction='none'):
super(CrossEntropy2D, self).__init__()
self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction)

def forward(self, inputs, targets):
return self.nll_loss(inputs, targets)


class CombinedLoss(nn.Module):
"""
For CrossEntropy the input has to be a long tensor
Args:
-- inputx N x C x H x W (其中N为batch_size)
-- target - N x H x W - int type
-- weight - N x H x W - float
"""
def __init__(self, weight_dice, weight_ce):
super(CombinedLoss, self).__init__()
self.cross_entropy_loss = CrossEntropy2D()
self.dice_loss = DiceLoss()
self.weight_dice = weight_dice
self.weight_ce = weight_ce

def forward(self, inputx, target):
target = target.type(torch.LongTensor) # Typecast to long tensor
if inputx.is_cuda:
target = target.cuda()
# print(inputx.min(),inputx.max())

input_soft = F.softmax(inputx, dim=1) # Along Class Dimension
dice_val = torch.mean(self.dice_loss(input_soft, target))
ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target))
# ce_val = torch.mean(self.cross_entropy_loss.forward(inputx, target))
total_loss = torch.add(torch.mul(dice_val, self.weight_dice),
torch.mul(ce_val, self.weight_ce))
# print(weight.max())
return total_loss, dice_val, ce_val

+ 336
- 0
model_define.py View File

@@ -0,0 +1,336 @@
import os
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict


class ASPP(nn.Module):
# have bias and relu, no bn
def __init__(self, in_channel=512, depth=256):
super().__init__()
# global average pooling : init nn.AdaptiveAvgPool2d ;also forward torch.mean(,,keep_dim=True)
self.mean = nn.AdaptiveAvgPool2d((1, 1))
self.conv = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1),
nn.ReLU(inplace=True))

self.atrous_block1 = nn.Sequential(nn.Conv2d(in_channel, depth, 1, 1),
nn.ReLU(inplace=True))
self.atrous_block6 = nn.Sequential(
nn.Conv2d(in_channel, depth, 3, 1, padding=3, dilation=3),
nn.ReLU(inplace=True))
self.atrous_block12 = nn.Sequential(
nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6),
nn.ReLU(inplace=True))
self.atrous_block18 = nn.Sequential(
nn.Conv2d(in_channel, depth, 3, 1, padding=9, dilation=9),
nn.ReLU(inplace=True))

self.conv_1x1_output = nn.Sequential(nn.Conv2d(depth * 5, depth, 1, 1),
nn.ReLU(inplace=True))

def forward(self, x):
size = x.shape[2:]

image_features = self.mean(x)
image_features = self.conv(image_features)
image_features = F.interpolate(image_features,
size=size,
mode='bilinear',
align_corners=True)

atrous_block1 = self.atrous_block1(x)

atrous_block6 = self.atrous_block6(x)

atrous_block12 = self.atrous_block12(x)

atrous_block18 = self.atrous_block18(x)

net = self.conv_1x1_output(
torch.cat([
image_features, atrous_block1, atrous_block6, atrous_block12,
atrous_block18
],
dim=1))
return net

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=18, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)

return model

def resnet50(pretrained=False, progress=True, **kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out
class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out

class Deeplab_v3(nn.Module):
# in_channel = 3 fine-tune
def __init__(self, class_number=18):
super().__init__()
encoder = resnet50()
self.start = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)

self.maxpool = encoder.maxpool

self.low_feature1 = nn.Sequential(nn.Conv2d(
64, 32, 1, 1), nn.BatchNorm2d(32), nn.ReLU(inplace=True))
self.low_feature3 = nn.Sequential(nn.Conv2d(
256, 64, 1, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.low_feature4 = nn.Sequential(nn.Conv2d(
512, 128, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True))

self.layer1 = encoder.layer1 #256
self.layer2 = encoder.layer2 #512
self.layer3 = encoder.layer3 #1024
self.layer4 = encoder.layer4 #2048

self.aspp = ASPP(in_channel=2048, depth=256)

self.conv_cat4 = nn.Sequential(nn.Conv2d(256 + 128, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))

self.conv_cat3 = nn.Sequential(nn.Conv2d(256 + 64, 256, 3, 1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
nn.Conv2d(256, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))

self.conv_cat1 = nn.Sequential(nn.Conv2d(64 + 32, 64, 3, 1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.Conv2d(64, 18, 3, 1, padding=1))

def forward(self, x):
size0 = x.shape[2:] # need upsample input size
x1 = self.start(x) # 64, 128*128
x2 = self.maxpool(x1) # 64, 64*64
x3 = self.layer1(x2) # 256, 64*64
x4 = self.layer2(x3) # 512, 32*32
x5 = self.layer3(x4) # 1024,16*16
x = self.layer4(x5) # 2048,8*8
x = self.aspp(x) # 256, 8*8

low_feature1 = self.low_feature1(x1) # 64, 128*128
# low_feature2 = self.low_feature2(x2) # 64, 64*64
low_feature3 = self.low_feature3(x3) # 256, 64*64
low_feature4 = self.low_feature4(x4) # 512, 32*32 -> 128, 32*32
# low_feature5 = self.low_feature5(x5) # 1024,16*16

size1 = low_feature1.shape[2:]
# size2 = low_feature2.shape[2:]
size3 = low_feature3.shape[2:]
size4 = low_feature4.shape[2:]
# size5 = low_feature5.shape[2:]

decoder_feature4 = F.interpolate(x, size=size4, mode='bilinear', align_corners=True)
x = self.conv_cat4(torch.cat([low_feature4, decoder_feature4], dim=1))

decoder_feature3 = F.interpolate(x, size=size3, mode='bilinear', align_corners=True)
x = self.conv_cat3(torch.cat([low_feature3, decoder_feature3], dim=1))

decoder_feature1 = F.interpolate(x, size=size1, mode='bilinear', align_corners=True)
x = self.conv_cat1(torch.cat([low_feature1, decoder_feature1], dim=1))

score = F.interpolate(x,
size=size0,
mode='bilinear',
align_corners=True)

return score

def init_model():
model_path = os.path.join(os.path.dirname(__file__), 'model.pkl')
model = Deeplab_v3()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model_state = torch.load(model_path, map_location=device)
new_state_dict = OrderedDict()
for k, v in model_state["model_state_dict"].items():
if k[:7] == "module.":
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval()
return model

+ 53
- 0
model_predict.py View File

@@ -0,0 +1,53 @@
import os
import torch
import numpy as np
from PIL import Image
def cutImage(file_name):
img = Image.open(file_name)
oring_size = img.size
n = img.size[0] // 256
img = img.resize((n * 256, n * 256))
res = []
for i in range(n):
for j in range(n):
pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1))
res.append(img.crop(pos))
return n, res, oring_size, img.size
def mergeImage(n, imgs, o_size, n_size):
img = Image.new('L', n_size)
for i in range(n):
for j in range(n):
pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1))
img.paste(imgs[i * n + j], pos)
img = img.resize(o_size)
a = np.array(img)
a[a > 17] = 17
return Image.fromarray(a)
def predict(model, input_path, output_dir):
name, _ = os.path.splitext(input_path)
name = os.path.split(name)[-1] + ".png"
n, imgs, o_size, n_size = cutImage(input_path)
res = []
for img in imgs:
img = torch.from_numpy(np.array(img)).float().unsqueeze(0)
img = img / 255
if torch.cuda.is_available():
img = img.cuda()
img = img.permute(0, 3, 1, 2)
label = model(img)
label = torch.argmax(label,
dim=1).cpu().squeeze().numpy().astype(np.uint8)
res.append(Image.fromarray(label))
img = mergeImage(n, res, o_size, n_size)
img.save(os.path.join(output_dir, name))
if __name__ == "__main__":
from model_define import init_model
model = init_model()
predict(model, './test.tif', '')

+ 148
- 0
solver.py View File

@@ -0,0 +1,148 @@
import os
import torch
import time
import matplotlib.pyplot as plt
import numpy as np
import glob
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torchvision import utils
from skimage import color
from losses import CombinedLoss


def create_exp_directory(exp_dir_name):
if not os.path.exists(exp_dir_name):
os.makedirs(exp_dir_name)


def plot_predictions(images_batch, labels_batch, batch_output, plt_title,
file_save_name):
f = plt.figure(figsize=(20, 20))
# n, c, h, w = images_batch.shape
# mid_slice = c // 2
# images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1)
grid = utils.make_grid(images_batch.cpu(), nrow=4)
plt.subplot(131)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Slices')
grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0]
color_grid = color.label2rgb(grid.numpy(), bg_label=0)
plt.subplot(132)
plt.imshow(color_grid)
plt.title('Ground Truth')
grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0]
color_grid = color.label2rgb(grid.numpy(), bg_label=0)
plt.subplot(133)
plt.imshow(color_grid)
plt.title('Prediction')

plt.suptitle(plt_title)
plt.tight_layout()

f.savefig(file_save_name, bbox_inches='tight')
plt.close(f)
plt.gcf().clear()


class Solver(object):
def __init__(self, num_classes, optimizer, lr_args, optimizer_args):
self.lr_scheduler_args = lr_args
self.optimizer_args = optimizer_args
self.optimizer = optimizer
self.loss_func = CombinedLoss(weight_dice=0, weight_ce=100)
self.num_classes = num_classes
self.classes = list(range(self.num_classes))

def train(self,
model,
train_loader,
num_epochs,
log_params,
expdir,
resume=True):
create_exp_directory(expdir)
create_exp_directory(log_params["logdir"])
optimizer = self.optimizer(model.parameters(), **self.optimizer_args)
scheduler = lr_scheduler.StepLR(
optimizer,
step_size=self.lr_scheduler_args["step_size"],
gamma=self.lr_scheduler_args["gamma"])
epoch = -1
print('-------> Starting to train')
if resume:
try:
prior_model_paths = sorted(glob.glob(
os.path.join(expdir, 'Epoch_*')),
key=os.path.getmtime)
current_model = prior_model_paths.pop()
state = torch.load(current_model)
model.load_state_dict(state["model_state_dict"])
epoch = state["epoch"]
print("Successfully Resuming from Epoch {}".format(epoch + 1))
except Exception as e:
print("No model to restore. {}".format(e))
log_params["logger"].info("{} parameters in total".format(
sum(x.numel() for x in model.parameters())))

model.train()
while epoch < num_epochs:
epoch = epoch + 1
epoch_start = time.time()
loss_batch = np.zeros(1)
loss_dice_batch = np.zeros(1)
loss_ce_batch = np.zeros(1)
for batch_idx, sample in enumerate(train_loader):
images, labels = sample
images, labels = Variable(images), Variable(labels)
if torch.cuda.is_available():
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
predictions = model(images)
loss_total, loss_dice, loss_ce = self.loss_func(
inputx=predictions, target=labels)
loss_total.backward()
optimizer.step()

loss_batch += loss_total.item()
loss_dice_batch += loss_dice.item()
loss_ce_batch += loss_ce.item()
_, batch_output = torch.max(predictions, dim=1)
if batch_idx == len(train_loader) - 2:
plt_title = 'Trian Results Epoch ' + str(epoch)
file_save_name = os.path.join(
log_params["logdir"],
'Epoch_{}_Trian_Predictions.pdf'.format(epoch))
plot_predictions(images, labels, batch_output, plt_title,
file_save_name)

if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == (len(train_loader) - 1):
log_params["logger"].info(
"Epoch: {} lr:{} [{}/{}] ({:.0f}%)]"
"with loss: {},\ndice_loss:{},ce_loss:{}".format(
epoch, optimizer.param_groups[0]['lr'], batch_idx,
len(train_loader),
100. * batch_idx / len(train_loader),
loss_batch / (batch_idx + 1),
loss_dice_batch / (batch_idx + 1),
loss_ce_batch / (batch_idx + 1)))

scheduler.step()
epoch_finish = time.time() - epoch_start
log_params["logger"].info(
"Train Epoch {} finished in {:.04f} seconds.\n".format(
epoch, epoch_finish))

# Saving Models
if epoch % log_params["log_iter"] == 0: # 每log_iter次保存一次模型
save_name = os.path.join(
expdir,
'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl')
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch
}
if scheduler is not None:
checkpoint["scheduler_state_dict"] = scheduler.state_dict()
torch.save(checkpoint, save_name)

+ 193
- 0
train.py View File

@@ -0,0 +1,193 @@
import os
import logging
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from dataset import DataTrain
from model_define import Deeplab_v3
# from model_define_unet import UNet
from solver import Solver
from torchvision import transforms, utils
import numpy as np
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg


args = {
'batch_size': 2,
'log_interval': 1,
'log_dir': 'log',
'num_classes': 2,
'epochs': 1000,
'lr': 1e-5,
'resume': True,
'data_dir': "../small_data",
'gamma': 0.5,
'step': 5,
'vit_name': 'R50-ViT-B_16',
'num_classes': 2,
'n_skip': 3,
'img_size': 256,
'vit_patches_size': 16,
}
'''
文件目录:
data
images
*.tif
labels
*.png
code
train.py
...
'''
class ToTensor(object):
"""
Convert ndarrays in sample to Tensors.
"""

def __call__(self, sample):
img, label = sample['img'], sample['label']

img = img.astype(np.float32)
img = img/255.0
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = img.transpose((2, 0, 1))

return {'img': torch.from_numpy(img), 'label': label}


class AugmentationPadImage(object):
"""
Pad Image with either zero padding or reflection padding of img, label and weight
"""

def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="constant"):

assert isinstance(pad_size, (int, tuple))

if isinstance(pad_size, int):

# Do not pad along the channel dimension
self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0))
self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size))

else:
self.pad_size = pad_size

self.pad_type = pad_type

def __call__(self, sample):
img, label = sample['img'], sample['label']

img = np.pad(img, self.pad_size_image, self.pad_type)
label = np.pad(label, self.pad_size_mask, self.pad_type)

return {'img': img, 'label': label}


class AugmentationRandomCrop(object):
"""
Randomly Crop Image to given size
"""

def __init__(self, output_size, crop_type='Random'):

assert isinstance(output_size, (int, tuple))

if isinstance(output_size, int):
self.output_size = (output_size, output_size)

else:
self.output_size = output_size

self.crop_type = crop_type

def __call__(self, sample):
img, label = sample['img'], sample['label']

h, w, _ = img.shape

if self.crop_type == 'Center':
top = (h - self.output_size[0]) // 2
left = (w - self.output_size[1]) // 2

else:
top = np.random.randint(0, h - self.output_size[0])
left = np.random.randint(0, w - self.output_size[1])

bottom = top + self.output_size[0]
right = left + self.output_size[1]

# print(img.shape)
img = img[top:bottom, left:right, :]
label = label[top:bottom, left:right]
# weight = weight[top:bottom, left:right]

return {'img': img, 'label': label}

def log_init():
if not os.path.exists(args['log_dir']):
os.makedirs(args['log_dir'])
logger = logging.getLogger("train")
logger.setLevel(logging.DEBUG)
logger.handlers = []
logger.addHandler(logging.StreamHandler())
logger.addHandler(
logging.FileHandler(os.path.join(args['log_dir'], "log.txt")))
logger.info("%s", repr(args))
return logger



def train():
logger = log_init()
transform_train = transforms.Compose([AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor()])
dataset_train = DataTrain(args['data_dir'], transforms = transform_train)
train_dataloader = DataLoader(dataset=dataset_train,
batch_size=args['batch_size'],
shuffle=True, num_workers=4)
# model = Deeplab_v3()
# model = UNet()
config_vit = CONFIGS_ViT_seg[args['vit_name']]
config_vit.n_classes = args['num_classes']
config_vit.n_skip = args['n_skip']
if args['vit_name'].find('R50') != -1:
config_vit.patches.grid = (int(args['img_size'] / args['vit_patches_size']), int(args['img_size'] / args['vit_patches_size']))
model = ViT_seg(config_vit, img_size=args['img_size'], num_classes=config_vit.n_classes)
# model.load_from(weights=np.load(config_vit.pretrained_path))
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.cuda()

solver = Solver(num_classes=args['num_classes'],
lr_args={
"gamma": args['gamma'],
"step_size": args['step']
},
optimizer_args={
"lr": args['lr'],
"betas": (0.9, 0.999),
"eps": 1e-8,
"weight_decay": 0.01
},
optimizer=torch.optim.Adam)
solver.train(model,
train_dataloader,
num_epochs=args['epochs'],
log_params={
'logdir': args['log_dir'] + "/logs",
'log_iter': args['log_interval'],
'logger': logger
},
expdir=args['log_dir'] + "/ckpts",
resume=args['resume'])


if __name__ == "__main__":
train()

Loading…
Cancel
Save