You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import os
  2. import logging
  3. import torch
  4. import torch.nn as nn
  5. from torch.utils.data.dataloader import DataLoader
  6. from dataset import DataTrain
  7. from model_define import Deeplab_v3
  8. # from model_define_unet import UNet
  9. from solver import Solver
  10. from torchvision import transforms, utils
  11. import numpy as np
  12. from networks.vit_seg_modeling import VisionTransformer as ViT_seg
  13. from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
  14. args = {
  15. 'batch_size': 2,
  16. 'log_interval': 1,
  17. 'log_dir': 'log',
  18. 'num_classes': 2,
  19. 'epochs': 1000,
  20. 'lr': 1e-5,
  21. 'resume': True,
  22. 'data_dir': "../small_data",
  23. 'gamma': 0.5,
  24. 'step': 5,
  25. 'vit_name': 'R50-ViT-B_16',
  26. 'num_classes': 2,
  27. 'n_skip': 3,
  28. 'img_size': 256,
  29. 'vit_patches_size': 16,
  30. }
  31. '''
  32. 文件目录:
  33. data
  34. images
  35. *.tif
  36. labels
  37. *.png
  38. code
  39. train.py
  40. ...
  41. '''
  42. class ToTensor(object):
  43. """
  44. Convert ndarrays in sample to Tensors.
  45. """
  46. def __call__(self, sample):
  47. img, label = sample['img'], sample['label']
  48. img = img.astype(np.float32)
  49. img = img/255.0
  50. # swap color axis because
  51. # numpy image: H x W x C
  52. # torch image: C X H X W
  53. img = img.transpose((2, 0, 1))
  54. return {'img': torch.from_numpy(img), 'label': label}
  55. class AugmentationPadImage(object):
  56. """
  57. Pad Image with either zero padding or reflection padding of img, label and weight
  58. """
  59. def __init__(self, pad_size=((16, 16), (16, 16)), pad_type="constant"):
  60. assert isinstance(pad_size, (int, tuple))
  61. if isinstance(pad_size, int):
  62. # Do not pad along the channel dimension
  63. self.pad_size_image = ((pad_size, pad_size), (pad_size, pad_size), (0, 0))
  64. self.pad_size_mask = ((pad_size, pad_size), (pad_size, pad_size))
  65. else:
  66. self.pad_size = pad_size
  67. self.pad_type = pad_type
  68. def __call__(self, sample):
  69. img, label = sample['img'], sample['label']
  70. img = np.pad(img, self.pad_size_image, self.pad_type)
  71. label = np.pad(label, self.pad_size_mask, self.pad_type)
  72. return {'img': img, 'label': label}
  73. class AugmentationRandomCrop(object):
  74. """
  75. Randomly Crop Image to given size
  76. """
  77. def __init__(self, output_size, crop_type='Random'):
  78. assert isinstance(output_size, (int, tuple))
  79. if isinstance(output_size, int):
  80. self.output_size = (output_size, output_size)
  81. else:
  82. self.output_size = output_size
  83. self.crop_type = crop_type
  84. def __call__(self, sample):
  85. img, label = sample['img'], sample['label']
  86. h, w, _ = img.shape
  87. if self.crop_type == 'Center':
  88. top = (h - self.output_size[0]) // 2
  89. left = (w - self.output_size[1]) // 2
  90. else:
  91. top = np.random.randint(0, h - self.output_size[0])
  92. left = np.random.randint(0, w - self.output_size[1])
  93. bottom = top + self.output_size[0]
  94. right = left + self.output_size[1]
  95. # print(img.shape)
  96. img = img[top:bottom, left:right, :]
  97. label = label[top:bottom, left:right]
  98. # weight = weight[top:bottom, left:right]
  99. return {'img': img, 'label': label}
  100. def log_init():
  101. if not os.path.exists(args['log_dir']):
  102. os.makedirs(args['log_dir'])
  103. logger = logging.getLogger("train")
  104. logger.setLevel(logging.DEBUG)
  105. logger.handlers = []
  106. logger.addHandler(logging.StreamHandler())
  107. logger.addHandler(
  108. logging.FileHandler(os.path.join(args['log_dir'], "log.txt")))
  109. logger.info("%s", repr(args))
  110. return logger
  111. def train():
  112. logger = log_init()
  113. transform_train = transforms.Compose([AugmentationPadImage(pad_size=8), AugmentationRandomCrop(output_size=256), ToTensor()])
  114. dataset_train = DataTrain(args['data_dir'], transforms = transform_train)
  115. train_dataloader = DataLoader(dataset=dataset_train,
  116. batch_size=args['batch_size'],
  117. shuffle=True, num_workers=4)
  118. # model = Deeplab_v3()
  119. # model = UNet()
  120. config_vit = CONFIGS_ViT_seg[args['vit_name']]
  121. config_vit.n_classes = args['num_classes']
  122. config_vit.n_skip = args['n_skip']
  123. if args['vit_name'].find('R50') != -1:
  124. config_vit.patches.grid = (int(args['img_size'] / args['vit_patches_size']), int(args['img_size'] / args['vit_patches_size']))
  125. model = ViT_seg(config_vit, img_size=args['img_size'], num_classes=config_vit.n_classes)
  126. # model.load_from(weights=np.load(config_vit.pretrained_path))
  127. if torch.cuda.is_available():
  128. if torch.cuda.device_count() > 1:
  129. model = nn.DataParallel(model)
  130. model.cuda()
  131. solver = Solver(num_classes=args['num_classes'],
  132. lr_args={
  133. "gamma": args['gamma'],
  134. "step_size": args['step']
  135. },
  136. optimizer_args={
  137. "lr": args['lr'],
  138. "betas": (0.9, 0.999),
  139. "eps": 1e-8,
  140. "weight_decay": 0.01
  141. },
  142. optimizer=torch.optim.Adam)
  143. solver.train(model,
  144. train_dataloader,
  145. num_epochs=args['epochs'],
  146. log_params={
  147. 'logdir': args['log_dir'] + "/logs",
  148. 'log_iter': args['log_interval'],
  149. 'logger': logger
  150. },
  151. expdir=args['log_dir'] + "/ckpts",
  152. resume=args['resume'])
  153. if __name__ == "__main__":
  154. train()