You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

saver.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. import time
  3. import shutil
  4. import tempfile
  5. import torch
  6. from collections import OrderedDict
  7. import glob
  8. class Saver(object):
  9. def __init__(self, args):
  10. self.args = args
  11. self.directory = os.path.join('/tmp', args.checkname)
  12. self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*')))
  13. run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0
  14. self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)))
  15. if not os.path.exists(self.experiment_dir):
  16. os.makedirs(self.experiment_dir)
  17. def save_checkpoint(self, state, is_best): # filename from .pth.tar change to .pth?
  18. """Saves checkpoint to disk"""
  19. filename = f'checkpoint_{time.time()}.pth'
  20. checkpoint_path = os.path.join(self.experiment_dir, filename)
  21. torch.save(state, checkpoint_path)
  22. if is_best:
  23. best_pred = state['best_pred']
  24. with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f:
  25. f.write(str(best_pred))
  26. if self.runs:
  27. previous_miou = [0.0]
  28. for run in self.runs:
  29. run_id = run.split('_')[-1]
  30. path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt')
  31. if os.path.exists(path):
  32. with open(path, 'r') as f:
  33. miou = float(f.readline())
  34. previous_miou.append(miou)
  35. else:
  36. continue
  37. max_miou = max(previous_miou)
  38. if best_pred > max_miou:
  39. checkpoint_path_best = os.path.join(self.directory, 'model_best.pth')
  40. shutil.copyfile(checkpoint_path, checkpoint_path_best)
  41. checkpoint_path = checkpoint_path_best
  42. else:
  43. checkpoint_path_best = os.path.join(self.directory, 'model_best.pth')
  44. shutil.copyfile(checkpoint_path, checkpoint_path_best)
  45. checkpoint_path = checkpoint_path_best
  46. return checkpoint_path
  47. def save_experiment_config(self):
  48. logfile = os.path.join(self.experiment_dir, 'parameters.txt')
  49. log_file = open(logfile, 'w')
  50. p = OrderedDict()
  51. # p['datset'] = self.args.dataset
  52. # p['out_stride'] = self.args.out_stride
  53. p['lr'] = self.args.lr
  54. p['lr_scheduler'] = self.args.lr_scheduler
  55. p['loss_type'] = self.args.loss_type
  56. p['epoch'] = self.args.epochs
  57. p['base_size'] = self.args.base_size
  58. p['crop_size'] = self.args.crop_size
  59. for key, val in p.items():
  60. log_file.write(key + ':' + str(val) + '\n')
  61. log_file.close()