diff --git a/save_checkpoint.py b/save_checkpoint.py index 0ced157..24c1a0c 100644 --- a/save_checkpoint.py +++ b/save_checkpoint.py @@ -5,6 +5,7 @@ import pytorch_lightning as pl import shutil from pytorch_lightning.utilities import rank_zero_info from utils import zip_dir +import re class SaveCheckpoint(ModelCheckpoint): @@ -67,14 +68,15 @@ class SaveCheckpoint(ModelCheckpoint): if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, trainer, monitor_candidates) - if self.mode=='max': + if self.mode == 'max': best_model_value = max([float(item) for item in list(self.best_k_models.values())]) else: best_model_value = min([float(item) for item in list(self.best_k_models.values())]) version_name = 'version_unkown' - for item in self.dirpath.split('\\'): + for item in re.split(r'[/|\\]', self.dirpath): if 'version_' in item: version_name = item + break # 保存版本信息(准确率等)到txt中 if not os.path.exists('./logs/default/version_info.txt'): with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: diff --git a/utils.py b/utils.py index e34a8de..7e9a000 100644 --- a/utils.py +++ b/utils.py @@ -106,24 +106,5 @@ def get_ckpt_path(version_nth: int, kth_fold: int): return ckpt_path[0].replace('\\', '/') -def rwxl(): - # 写 - # dataset_xl = xl.Workbook(write_only=True) - # dataset_sh = dataset_xl.create_sheet('dataset', 0) - # for row in range(self.x.shape[0]): - # for col in range(self.x.shape[1]): - # dataset_sh.cell(row + 1, col + 1).value = float(self.x[row, col]) - # dataset_sh.cell(row + 1, self.x.shape[1] + 1).value = float(self.y[row]) - # dataset_xl.save(dataset_path + '/dataset.xlsx') - # dataset_xl.close() - # 读 - # dataset_xl = xl.load_workbook(dataset_path + '/dataset_list.xlsx', read_only=True) - # dataset_sh = dataset_xl.get_sheet_by_name('dataset_list') - # temp = [[dataset_sh[row + 1][col].value for col in range(config['dim_in'] + 1)] for row in - # range(config['dataset_len'])] - # dataset_xl.close() - pass - - if __name__ == "__main__": get_ckpt_path('version_0')