You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solver.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import os
  2. import torch
  3. import time
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import glob
  7. from torch.autograd import Variable
  8. from torch.optim import lr_scheduler
  9. from torchvision import utils
  10. from skimage import color
  11. from losses import CombinedLoss
  12. def create_exp_directory(exp_dir_name):
  13. if not os.path.exists(exp_dir_name):
  14. os.makedirs(exp_dir_name)
  15. def plot_predictions(images_batch, labels_batch, batch_output, plt_title,
  16. file_save_name):
  17. f = plt.figure(figsize=(20, 20))
  18. # n, c, h, w = images_batch.shape
  19. # mid_slice = c // 2
  20. # images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1)
  21. grid = utils.make_grid(images_batch.cpu(), nrow=4)
  22. plt.subplot(131)
  23. plt.imshow(grid.numpy().transpose((1, 2, 0)))
  24. plt.title('Slices')
  25. grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0]
  26. color_grid = color.label2rgb(grid.numpy(), bg_label=0)
  27. plt.subplot(132)
  28. plt.imshow(color_grid)
  29. plt.title('Ground Truth')
  30. grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0]
  31. color_grid = color.label2rgb(grid.numpy(), bg_label=0)
  32. plt.subplot(133)
  33. plt.imshow(color_grid)
  34. plt.title('Prediction')
  35. plt.suptitle(plt_title)
  36. plt.tight_layout()
  37. f.savefig(file_save_name, bbox_inches='tight')
  38. plt.close(f)
  39. plt.gcf().clear()
  40. class Solver(object):
  41. def __init__(self, num_classes, optimizer, lr_args, optimizer_args):
  42. self.lr_scheduler_args = lr_args
  43. self.optimizer_args = optimizer_args
  44. self.optimizer = optimizer
  45. self.loss_func = CombinedLoss(weight_dice=0, weight_ce=100)
  46. self.num_classes = num_classes
  47. self.classes = list(range(self.num_classes))
  48. def train(self,
  49. model,
  50. train_loader,
  51. num_epochs,
  52. log_params,
  53. expdir,
  54. resume=True):
  55. create_exp_directory(expdir)
  56. create_exp_directory(log_params["logdir"])
  57. optimizer = self.optimizer(model.parameters(), **self.optimizer_args)
  58. scheduler = lr_scheduler.StepLR(
  59. optimizer,
  60. step_size=self.lr_scheduler_args["step_size"],
  61. gamma=self.lr_scheduler_args["gamma"])
  62. epoch = -1
  63. print('-------> Starting to train')
  64. if resume:
  65. try:
  66. prior_model_paths = sorted(glob.glob(
  67. os.path.join(expdir, 'Epoch_*')),
  68. key=os.path.getmtime)
  69. current_model = prior_model_paths.pop()
  70. state = torch.load(current_model)
  71. model.load_state_dict(state["model_state_dict"])
  72. epoch = state["epoch"]
  73. print("Successfully Resuming from Epoch {}".format(epoch + 1))
  74. except Exception as e:
  75. print("No model to restore. {}".format(e))
  76. log_params["logger"].info("{} parameters in total".format(
  77. sum(x.numel() for x in model.parameters())))
  78. model.train()
  79. while epoch < num_epochs:
  80. epoch = epoch + 1
  81. epoch_start = time.time()
  82. loss_batch = np.zeros(1)
  83. loss_dice_batch = np.zeros(1)
  84. loss_ce_batch = np.zeros(1)
  85. for batch_idx, sample in enumerate(train_loader):
  86. images, labels = sample
  87. images, labels = Variable(images), Variable(labels)
  88. if torch.cuda.is_available():
  89. images, labels = images.cuda(), labels.cuda()
  90. optimizer.zero_grad()
  91. predictions = model(images)
  92. loss_total, loss_dice, loss_ce = self.loss_func(
  93. inputx=predictions, target=labels)
  94. loss_total.backward()
  95. optimizer.step()
  96. loss_batch += loss_total.item()
  97. loss_dice_batch += loss_dice.item()
  98. loss_ce_batch += loss_ce.item()
  99. _, batch_output = torch.max(predictions, dim=1)
  100. if batch_idx == len(train_loader) - 2:
  101. plt_title = 'Trian Results Epoch ' + str(epoch)
  102. file_save_name = os.path.join(
  103. log_params["logdir"],
  104. 'Epoch_{}_Trian_Predictions.pdf'.format(epoch))
  105. plot_predictions(images, labels, batch_output, plt_title,
  106. file_save_name)
  107. if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == (len(train_loader) - 1):
  108. log_params["logger"].info(
  109. "Epoch: {} lr:{} [{}/{}] ({:.0f}%)]"
  110. "with loss: {},\ndice_loss:{},ce_loss:{}".format(
  111. epoch, optimizer.param_groups[0]['lr'], batch_idx,
  112. len(train_loader),
  113. 100. * batch_idx / len(train_loader),
  114. loss_batch / (batch_idx + 1),
  115. loss_dice_batch / (batch_idx + 1),
  116. loss_ce_batch / (batch_idx + 1)))
  117. scheduler.step()
  118. epoch_finish = time.time() - epoch_start
  119. log_params["logger"].info(
  120. "Train Epoch {} finished in {:.04f} seconds.\n".format(
  121. epoch, epoch_finish))
  122. # Saving Models
  123. if epoch % log_params["log_iter"] == 0: # 每log_iter次保存一次模型
  124. save_name = os.path.join(
  125. expdir,
  126. 'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl')
  127. checkpoint = {
  128. "model_state_dict": model.state_dict(),
  129. "optimizer_state_dict": optimizer.state_dict(),
  130. "epoch": epoch
  131. }
  132. if scheduler is not None:
  133. checkpoint["scheduler_state_dict"] = scheduler.state_dict()
  134. torch.save(checkpoint, save_name)