Browse Source

更改requirements版本; 改变保存版本信息时的最优值的获取逻辑; 修复保存版本信息时可能产生的BUG

master
shenyan 4 years ago
parent
commit
76b29f1a8f
2 changed files with 17 additions and 9 deletions
  1. BIN
      requirements.txt
  2. +17
    -9
      save_checkpoint.py

BIN
requirements.txt View File


+ 17
- 9
save_checkpoint.py View File

@@ -32,6 +32,7 @@ class SaveCheckpoint(ModelCheckpoint):
:param no_save_before_epoch:
"""
super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode)
self.mode = mode
numpy.random.seed(seed)
self.seeds = numpy.random.randint(0, 2000, max_epochs)
pl.seed_everything(seed)
@@ -66,23 +67,30 @@ class SaveCheckpoint(ModelCheckpoint):

if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates)
best_model_value = max([float(item) for item in list(self.best_k_models.values())])
if self.mode=='max':
best_model_value = max([float(item) for item in list(self.best_k_models.values())])
else:
best_model_value = min([float(item) for item in list(self.best_k_models.values())])
version_name = 'version_unkown'
for item in self.dirpath.split('\\'):
if 'version_' in item:
version_name = item
# 保存版本信息(准确率等)到txt中
if not os.path.exists('./logs/default/version_info.txt'):
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n')
f.write(version_name + ' ' + str(best_model_value) + '\n')
else:
with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f:
info_list = f.readlines()
info_list = [item.strip('\n').split(' ') for item in info_list]
# 对list进行转置, 现在行为版本号和其数据, 列为不同的版本
info_list = list(map(list, zip(*info_list)))
if self.dirpath.split('\\')[1] in info_list[0]:
if version_name in info_list[0]:
for cou in range(len(info_list[0])):
if self.dirpath.split('\\')[1] == info_list[0][cou]:
if version_name == info_list[0][cou]:
info_list[1][cou] = str(best_model_value)
else:
info_list[0].append(self.dirpath.split('\\')[1])
info_list[0].append(version_name)
info_list[1].append(str(best_model_value))
# 对list进行转置
info_list = list(map(list, zip(*info_list)))
@@ -92,10 +100,10 @@ class SaveCheckpoint(ModelCheckpoint):
f.write(line + '\n')
# 每次更新ckpt文件后, 将其存放到另一个位置
if self.path_final_save is not None:
zip_dir('./logs/default/' + self.dirpath.split('\\')[1], './' + self.dirpath.split('\\')[1] + '.zip')
if os.path.exists(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip'):
os.remove(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip')
shutil.move('./' + self.dirpath.split('\\')[1] + '.zip', self.path_final_save)
zip_dir('./logs/default/' + version_name, './' + version_name + '.zip')
if os.path.exists(self.path_final_save + '/' + version_name + '.zip'):
os.remove(self.path_final_save + '/' + version_name + '.zip')
shutil.move('./' + version_name + '.zip', self.path_final_save)
elif self.verbose:
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")


Loading…
Cancel
Save