|
- import os
- import time
- import shutil
- import tempfile
- import torch
- from collections import OrderedDict
- import glob
-
- class Saver(object):
-
- def __init__(self, args):
- self.args = args
- self.directory = os.path.join('/tmp', args.checkname)
- self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*')))
- run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0
-
- self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)))
- if not os.path.exists(self.experiment_dir):
- os.makedirs(self.experiment_dir)
-
- def save_checkpoint(self, state, is_best): # filename from .pth.tar change to .pth?
- """Saves checkpoint to disk"""
- filename = f'checkpoint_{time.time()}.pth'
- checkpoint_path = os.path.join(self.experiment_dir, filename)
- torch.save(state, checkpoint_path)
- if is_best:
- best_pred = state['best_pred']
- with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f:
- f.write(str(best_pred))
- if self.runs:
- previous_miou = [0.0]
- for run in self.runs:
- run_id = run.split('_')[-1]
- path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt')
- if os.path.exists(path):
- with open(path, 'r') as f:
- miou = float(f.readline())
- previous_miou.append(miou)
- else:
- continue
- max_miou = max(previous_miou)
- if best_pred > max_miou:
- checkpoint_path_best = os.path.join(self.directory, 'model_best.pth')
- shutil.copyfile(checkpoint_path, checkpoint_path_best)
- checkpoint_path = checkpoint_path_best
- else:
- checkpoint_path_best = os.path.join(self.directory, 'model_best.pth')
- shutil.copyfile(checkpoint_path, checkpoint_path_best)
- checkpoint_path = checkpoint_path_best
-
- return checkpoint_path
-
- def save_experiment_config(self):
- logfile = os.path.join(self.experiment_dir, 'parameters.txt')
- log_file = open(logfile, 'w')
- p = OrderedDict()
- # p['datset'] = self.args.dataset
- # p['out_stride'] = self.args.out_stride
- p['lr'] = self.args.lr
- p['lr_scheduler'] = self.args.lr_scheduler
- p['loss_type'] = self.args.loss_type
- p['epoch'] = self.args.epochs
- p['base_size'] = self.args.base_size
- p['crop_size'] = self.args.crop_size
-
- for key, val in p.items():
- log_file.write(key + ':' + str(val) + '\n')
- log_file.close()
|