|
- """
- # -*- coding: utf-8 -*-
- -----------------------------------------------------------------------------------
- # Author: Nguyen Mau Dung
- # DoC: 2020.07.31
- # email: nguyenmaudung93.kstn@gmail.com
- -----------------------------------------------------------------------------------
- # Description: This script for logging
-
- """
-
- import os
- import torch
- import time
-
- def make_folder(folder_name):
- if not os.path.exists(folder_name):
- os.makedirs(folder_name)
- # or os.makedirs(folder_name, exist_ok=True)
-
-
- class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self, name, fmt=':f'):
- self.name = name
- self.fmt = fmt
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def __str__(self):
- fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
- return fmtstr.format(**self.__dict__)
-
-
- class ProgressMeter(object):
- def __init__(self, num_batches, meters, prefix=""):
- self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
- self.meters = meters
- self.prefix = prefix
-
- def display(self, batch):
- entries = [self.prefix + self.batch_fmtstr.format(batch)]
- entries += [str(meter) for meter in self.meters]
- print('\t'.join(entries))
-
- def get_message(self, batch):
- entries = [self.prefix + self.batch_fmtstr.format(batch)]
- entries += [str(meter) for meter in self.meters]
- return '\t'.join(entries)
-
- def _get_batch_fmtstr(self, num_batches):
- num_digits = len(str(num_batches // 1))
- fmt = '{:' + str(num_digits) + 'd}'
- return '[' + fmt + '/' + fmt.format(num_batches) + ']'
-
-
- def time_synchronized():
- torch.cuda.synchronize() if torch.cuda.is_available() else None
- return time.time()
|