Browse Source

考虑保存版本信息时分割字符为'/'的情况; 删除读取表格相关的内容;

master
shenyan 4 years ago
parent
commit
665bca9171
2 changed files with 4 additions and 21 deletions
  1. +4
    -2
      save_checkpoint.py
  2. +0
    -19
      utils.py

+ 4
- 2
save_checkpoint.py View File

@@ -5,6 +5,7 @@ import pytorch_lightning as pl
import shutil import shutil
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from utils import zip_dir from utils import zip_dir
import re




class SaveCheckpoint(ModelCheckpoint): class SaveCheckpoint(ModelCheckpoint):
@@ -67,14 +68,15 @@ class SaveCheckpoint(ModelCheckpoint):


if self.check_monitor_top_k(trainer, current): if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates) self._update_best_and_save(current, trainer, monitor_candidates)
if self.mode=='max':
if self.mode == 'max':
best_model_value = max([float(item) for item in list(self.best_k_models.values())]) best_model_value = max([float(item) for item in list(self.best_k_models.values())])
else: else:
best_model_value = min([float(item) for item in list(self.best_k_models.values())]) best_model_value = min([float(item) for item in list(self.best_k_models.values())])
version_name = 'version_unkown' version_name = 'version_unkown'
for item in self.dirpath.split('\\'):
for item in re.split(r'[/|\\]', self.dirpath):
if 'version_' in item: if 'version_' in item:
version_name = item version_name = item
break
# 保存版本信息(准确率等)到txt中 # 保存版本信息(准确率等)到txt中
if not os.path.exists('./logs/default/version_info.txt'): if not os.path.exists('./logs/default/version_info.txt'):
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:


+ 0
- 19
utils.py View File

@@ -106,24 +106,5 @@ def get_ckpt_path(version_nth: int, kth_fold: int):
return ckpt_path[0].replace('\\', '/') return ckpt_path[0].replace('\\', '/')




def rwxl():
# 写
# dataset_xl = xl.Workbook(write_only=True)
# dataset_sh = dataset_xl.create_sheet('dataset', 0)
# for row in range(self.x.shape[0]):
# for col in range(self.x.shape[1]):
# dataset_sh.cell(row + 1, col + 1).value = float(self.x[row, col])
# dataset_sh.cell(row + 1, self.x.shape[1] + 1).value = float(self.y[row])
# dataset_xl.save(dataset_path + '/dataset.xlsx')
# dataset_xl.close()
# 读
# dataset_xl = xl.load_workbook(dataset_path + '/dataset_list.xlsx', read_only=True)
# dataset_sh = dataset_xl.get_sheet_by_name('dataset_list')
# temp = [[dataset_sh[row + 1][col].value for col in range(config['dim_in'] + 1)] for row in
# range(config['dataset_len'])]
# dataset_xl.close()
pass


if __name__ == "__main__": if __name__ == "__main__":
get_ckpt_path('version_0') get_ckpt_path('version_0')

Loading…
Cancel
Save