You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

summaries.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import os
  2. import torch
  3. from torchvision.utils import make_grid
  4. # from tensorboardX import SummaryWriter
  5. from torch.utils.tensorboard import SummaryWriter
  6. from dataloaders.utils import decode_seg_map_sequence
  7. class TensorboardSummary(object):
  8. def __init__(self, directory):
  9. self.directory = directory
  10. def create_summary(self):
  11. writer = SummaryWriter(log_dir=os.path.join(self.directory))
  12. return writer
  13. def visualize_image(self, writer, dataset, image, target, output, global_step, depth=None):
  14. if depth is None:
  15. grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
  16. writer.add_image('Image', grid_image, global_step)
  17. grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(),
  18. dataset=dataset), 3, normalize=False, range=(0, 255))
  19. writer.add_image('Predicted label', grid_image, global_step)
  20. grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(),
  21. dataset=dataset), 3, normalize=False, range=(0, 255))
  22. writer.add_image('Groundtruth label', grid_image, global_step)
  23. else:
  24. grid_image = make_grid(image[:3].clone().cpu().data, 4, normalize=True)
  25. writer.add_image('Image', grid_image, global_step)
  26. grid_image = make_grid(depth[:3].clone().cpu().data, 4, normalize=True) # normalize=False?
  27. writer.add_image('Depth', grid_image, global_step)
  28. grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(),
  29. dataset=dataset), 4, normalize=False, range=(0, 255))
  30. writer.add_image('Predicted label', grid_image, global_step)
  31. grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(),
  32. dataset=dataset), 4, normalize=False, range=(0, 255))
  33. writer.add_image('Groundtruth label', grid_image, global_step)