|
- import os
- import logging
- import torch
- import torch.nn as nn
- from torch.utils.data.dataloader import DataLoader
- from dataset import DataTrain
- from model_define import Deeplab_v3
- # from model_define_unet import UNet
- from solver import Solver
- from torchvision import transforms, utils
- import numpy as np
- from networks.vit_seg_modeling import VisionTransformer as ViT_seg
- from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
-
-
- args = {
- 'batch_size': 2,
- 'log_interval': 1,
- 'log_dir': 'log',
- 'num_classes': 2,
- 'epochs': 1000,
- 'lr': 1e-5,
- 'resume': True,
- 'data_dir': "../small_data",
- 'gamma': 0.5,
- 'step': 5,
- 'vit_name': 'R50-ViT-B_16',
- 'num_classes': 2,
- 'n_skip': 3,
- 'img_size': 256,
- 'vit_patches_size': 16,
-
- }
- '''
- 文件目录:
- data
- images
- *.tif
- labels
- *.png
- code
- train.py
- ...
- '''
- class ToTensor(object):
- """
- Convert ndarrays in sample to Tensors.
- """
-
- def __call__(self, sample):
- img, label = sample['img'], sample['label']
-
- img = img.astype(np.float32)
- img = img/255.0
- # swap color axis because
- # numpy image: H x W x C
- # torch image: C X H X W
- img = img.transpose((2, 0, 1))
-
- return {'img': torch.from_numpy(img), 'label': label}
-
-
- class AugmentationPadImage(object):
- """
- Pad Image with either zero padding or reflection padding of img, label and weight
- """
-
- def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="constant"):
-
- assert isinstance(pad_size, (int, tuple))
-
- if isinstance(pad_size, int):
-
- # Do not pad along the channel dimension
- self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0))
- self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size))
-
- else:
- self.pad_size = pad_size
-
- self.pad_type = pad_type
-
- def __call__(self, sample):
- img, label = sample['img'], sample['label']
-
- img = np.pad(img, self.pad_size_image, self.pad_type)
- label = np.pad(label, self.pad_size_mask, self.pad_type)
-
- return {'img': img, 'label': label}
-
-
- class AugmentationRandomCrop(object):
- """
- Randomly Crop Image to given size
- """
-
- def __init__(self, output_size, crop_type='Random'):
-
- assert isinstance(output_size, (int, tuple))
-
- if isinstance(output_size, int):
- self.output_size = (output_size, output_size)
-
- else:
- self.output_size = output_size
-
- self.crop_type = crop_type
-
- def __call__(self, sample):
- img, label = sample['img'], sample['label']
-
- h, w, _ = img.shape
-
- if self.crop_type == 'Center':
- top = (h - self.output_size[0]) // 2
- left = (w - self.output_size[1]) // 2
-
- else:
- top = np.random.randint(0, h - self.output_size[0])
- left = np.random.randint(0, w - self.output_size[1])
-
- bottom = top + self.output_size[0]
- right = left + self.output_size[1]
-
- # print(img.shape)
- img = img[top:bottom, left:right, :]
- label = label[top:bottom, left:right]
- # weight = weight[top:bottom, left:right]
-
- return {'img': img, 'label': label}
-
- def log_init():
- if not os.path.exists(args['log_dir']):
- os.makedirs(args['log_dir'])
- logger = logging.getLogger("train")
- logger.setLevel(logging.DEBUG)
- logger.handlers = []
- logger.addHandler(logging.StreamHandler())
- logger.addHandler(
- logging.FileHandler(os.path.join(args['log_dir'], "log.txt")))
- logger.info("%s", repr(args))
- return logger
-
-
-
- def train():
- logger = log_init()
- transform_train = transforms.Compose([AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor()])
- dataset_train = DataTrain(args['data_dir'], transforms = transform_train)
- train_dataloader = DataLoader(dataset=dataset_train,
- batch_size=args['batch_size'],
- shuffle=True, num_workers=4)
- # model = Deeplab_v3()
- # model = UNet()
- config_vit = CONFIGS_ViT_seg[args['vit_name']]
- config_vit.n_classes = args['num_classes']
- config_vit.n_skip = args['n_skip']
- if args['vit_name'].find('R50') != -1:
- config_vit.patches.grid = (int(args['img_size'] / args['vit_patches_size']), int(args['img_size'] / args['vit_patches_size']))
- model = ViT_seg(config_vit, img_size=args['img_size'], num_classes=config_vit.n_classes)
- # model.load_from(weights=np.load(config_vit.pretrained_path))
-
- if torch.cuda.is_available():
- if torch.cuda.device_count() > 1:
- model = nn.DataParallel(model)
- model.cuda()
-
- solver = Solver(num_classes=args['num_classes'],
- lr_args={
- "gamma": args['gamma'],
- "step_size": args['step']
- },
- optimizer_args={
- "lr": args['lr'],
- "betas": (0.9, 0.999),
- "eps": 1e-8,
- "weight_decay": 0.01
- },
- optimizer=torch.optim.Adam)
- solver.train(model,
- train_dataloader,
- num_epochs=args['epochs'],
- log_params={
- 'logdir': args['log_dir'] + "/logs",
- 'log_iter': args['log_interval'],
- 'logger': logger
- },
- expdir=args['log_dir'] + "/ckpts",
- resume=args['resume'])
-
-
- if __name__ == "__main__":
- train()
|