|
- import os
- import torch
- import time
- import matplotlib.pyplot as plt
- import numpy as np
- import glob
- from torch.autograd import Variable
- from torch.optim import lr_scheduler
- from torchvision import utils
- from skimage import color
- from losses import CombinedLoss
-
-
- def create_exp_directory(exp_dir_name):
- if not os.path.exists(exp_dir_name):
- os.makedirs(exp_dir_name)
-
-
- def plot_predictions(images_batch, labels_batch, batch_output, plt_title,
- file_save_name):
- f = plt.figure(figsize=(20, 20))
- # n, c, h, w = images_batch.shape
- # mid_slice = c // 2
- # images_batch = torch.unsqueeze(images_batch[:, mid_slice, :, :], 1)
- grid = utils.make_grid(images_batch.cpu(), nrow=4)
- plt.subplot(131)
- plt.imshow(grid.numpy().transpose((1, 2, 0)))
- plt.title('Slices')
- grid = utils.make_grid(labels_batch.unsqueeze_(1).cpu(), nrow=4)[0]
- color_grid = color.label2rgb(grid.numpy(), bg_label=0)
- plt.subplot(132)
- plt.imshow(color_grid)
- plt.title('Ground Truth')
- grid = utils.make_grid(batch_output.unsqueeze_(1).cpu(), nrow=4)[0]
- color_grid = color.label2rgb(grid.numpy(), bg_label=0)
- plt.subplot(133)
- plt.imshow(color_grid)
- plt.title('Prediction')
-
- plt.suptitle(plt_title)
- plt.tight_layout()
-
- f.savefig(file_save_name, bbox_inches='tight')
- plt.close(f)
- plt.gcf().clear()
-
-
- class Solver(object):
- def __init__(self, num_classes, optimizer, lr_args, optimizer_args):
- self.lr_scheduler_args = lr_args
- self.optimizer_args = optimizer_args
- self.optimizer = optimizer
- self.loss_func = CombinedLoss(weight_dice=0, weight_ce=100)
- self.num_classes = num_classes
- self.classes = list(range(self.num_classes))
-
- def train(self,
- model,
- train_loader,
- num_epochs,
- log_params,
- expdir,
- resume=True):
- create_exp_directory(expdir)
- create_exp_directory(log_params["logdir"])
- optimizer = self.optimizer(model.parameters(), **self.optimizer_args)
- scheduler = lr_scheduler.StepLR(
- optimizer,
- step_size=self.lr_scheduler_args["step_size"],
- gamma=self.lr_scheduler_args["gamma"])
- epoch = -1
- print('-------> Starting to train')
- if resume:
- try:
- prior_model_paths = sorted(glob.glob(
- os.path.join(expdir, 'Epoch_*')),
- key=os.path.getmtime)
- current_model = prior_model_paths.pop()
- state = torch.load(current_model)
- model.load_state_dict(state["model_state_dict"])
- epoch = state["epoch"]
- print("Successfully Resuming from Epoch {}".format(epoch + 1))
- except Exception as e:
- print("No model to restore. {}".format(e))
- log_params["logger"].info("{} parameters in total".format(
- sum(x.numel() for x in model.parameters())))
-
- model.train()
- while epoch < num_epochs:
- epoch = epoch + 1
- epoch_start = time.time()
- loss_batch = np.zeros(1)
- loss_dice_batch = np.zeros(1)
- loss_ce_batch = np.zeros(1)
- for batch_idx, sample in enumerate(train_loader):
- images, labels = sample
- images, labels = Variable(images), Variable(labels)
- if torch.cuda.is_available():
- images, labels = images.cuda(), labels.cuda()
- optimizer.zero_grad()
- predictions = model(images)
- loss_total, loss_dice, loss_ce = self.loss_func(
- inputx=predictions, target=labels)
- loss_total.backward()
- optimizer.step()
-
- loss_batch += loss_total.item()
- loss_dice_batch += loss_dice.item()
- loss_ce_batch += loss_ce.item()
- _, batch_output = torch.max(predictions, dim=1)
- if batch_idx == len(train_loader) - 2:
- plt_title = 'Trian Results Epoch ' + str(epoch)
- file_save_name = os.path.join(
- log_params["logdir"],
- 'Epoch_{}_Trian_Predictions.pdf'.format(epoch))
- plot_predictions(images, labels, batch_output, plt_title,
- file_save_name)
-
- if batch_idx % (len(train_loader) // 2) == 0 or batch_idx == (len(train_loader) - 1):
- log_params["logger"].info(
- "Epoch: {} lr:{} [{}/{}] ({:.0f}%)]"
- "with loss: {},\ndice_loss:{},ce_loss:{}".format(
- epoch, optimizer.param_groups[0]['lr'], batch_idx,
- len(train_loader),
- 100. * batch_idx / len(train_loader),
- loss_batch / (batch_idx + 1),
- loss_dice_batch / (batch_idx + 1),
- loss_ce_batch / (batch_idx + 1)))
-
- scheduler.step()
- epoch_finish = time.time() - epoch_start
- log_params["logger"].info(
- "Train Epoch {} finished in {:.04f} seconds.\n".format(
- epoch, epoch_finish))
-
- # Saving Models
- if epoch % log_params["log_iter"] == 0: # 每log_iter次保存一次模型
- save_name = os.path.join(
- expdir,
- 'Epoch_' + str(epoch).zfill(2) + '_training_state.pkl')
- checkpoint = {
- "model_state_dict": model.state_dict(),
- "optimizer_state_dict": optimizer.state_dict(),
- "epoch": epoch
- }
- if scheduler is not None:
- checkpoint["scheduler_state_dict"] = scheduler.state_dict()
- torch.save(checkpoint, save_name)
|