import os import logging import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader from dataset import DataTrain from model_define import Deeplab_v3 # from model_define_unet import UNet from solver import Solver from torchvision import transforms, utils import numpy as np from networks.vit_seg_modeling import VisionTransformer as ViT_seg from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg args = { 'batch_size': 2, 'log_interval': 1, 'log_dir': 'log', 'num_classes': 2, 'epochs': 1000, 'lr': 1e-5, 'resume': True, 'data_dir': "../small_data", 'gamma': 0.5, 'step': 5, 'vit_name': 'R50-ViT-B_16', 'num_classes': 2, 'n_skip': 3, 'img_size': 256, 'vit_patches_size': 16, } ''' 文件目录: data images *.tif labels *.png code train.py ... ''' class ToTensor(object): """ Convert ndarrays in sample to Tensors. """ def __call__(self, sample): img, label = sample['img'], sample['label'] img = img.astype(np.float32) img = img/255.0 # swap color axis because # numpy image: H x W x C # torch image: C X H X W img = img.transpose((2, 0, 1)) return {'img': torch.from_numpy(img), 'label': label} class AugmentationPadImage(object): """ Pad Image with either zero padding or reflection padding of img, label and weight """ def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="constant"): assert isinstance(pad_size, (int, tuple)) if isinstance(pad_size, int): # Do not pad along the channel dimension self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0)) self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size)) else: self.pad_size = pad_size self.pad_type = pad_type def __call__(self, sample): img, label = sample['img'], sample['label'] img = np.pad(img, self.pad_size_image, self.pad_type) label = np.pad(label, self.pad_size_mask, self.pad_type) return {'img': img, 'label': label} class AugmentationRandomCrop(object): """ Randomly Crop Image to given size """ def __init__(self, output_size, crop_type='Random'): assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: self.output_size = output_size self.crop_type = crop_type def __call__(self, sample): img, label = sample['img'], sample['label'] h, w, _ = img.shape if self.crop_type == 'Center': top = (h - self.output_size[0]) // 2 left = (w - self.output_size[1]) // 2 else: top = np.random.randint(0, h - self.output_size[0]) left = np.random.randint(0, w - self.output_size[1]) bottom = top + self.output_size[0] right = left + self.output_size[1] # print(img.shape) img = img[top:bottom, left:right, :] label = label[top:bottom, left:right] # weight = weight[top:bottom, left:right] return {'img': img, 'label': label} def log_init(): if not os.path.exists(args['log_dir']): os.makedirs(args['log_dir']) logger = logging.getLogger("train") logger.setLevel(logging.DEBUG) logger.handlers = [] logger.addHandler(logging.StreamHandler()) logger.addHandler( logging.FileHandler(os.path.join(args['log_dir'], "log.txt"))) logger.info("%s", repr(args)) return logger def train(): logger = log_init() transform_train = transforms.Compose([AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor()]) dataset_train = DataTrain(args['data_dir'], transforms = transform_train) train_dataloader = DataLoader(dataset=dataset_train, batch_size=args['batch_size'], shuffle=True, num_workers=4) # model = Deeplab_v3() # model = UNet() config_vit = CONFIGS_ViT_seg[args['vit_name']] config_vit.n_classes = args['num_classes'] config_vit.n_skip = args['n_skip'] if args['vit_name'].find('R50') != -1: config_vit.patches.grid = (int(args['img_size'] / args['vit_patches_size']), int(args['img_size'] / args['vit_patches_size'])) model = ViT_seg(config_vit, img_size=args['img_size'], num_classes=config_vit.n_classes) # model.load_from(weights=np.load(config_vit.pretrained_path)) if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.cuda() solver = Solver(num_classes=args['num_classes'], lr_args={ "gamma": args['gamma'], "step_size": args['step'] }, optimizer_args={ "lr": args['lr'], "betas": (0.9, 0.999), "eps": 1e-8, "weight_decay": 0.01 }, optimizer=torch.optim.Adam) solver.train(model, train_dataloader, num_epochs=args['epochs'], log_params={ 'logdir': args['log_dir'] + "/logs", 'log_iter': args['log_interval'], 'logger': logger }, expdir=args['log_dir'] + "/ckpts", resume=args['resume']) if __name__ == "__main__": train()