diff --git a/requirements.txt b/requirements.txt index a9d4deb..d9e3485 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/save_checkpoint.py b/save_checkpoint.py index 341b594..0ced157 100644 --- a/save_checkpoint.py +++ b/save_checkpoint.py @@ -32,6 +32,7 @@ class SaveCheckpoint(ModelCheckpoint): :param no_save_before_epoch: """ super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) + self.mode = mode numpy.random.seed(seed) self.seeds = numpy.random.randint(0, 2000, max_epochs) pl.seed_everything(seed) @@ -66,23 +67,30 @@ class SaveCheckpoint(ModelCheckpoint): if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, trainer, monitor_candidates) - best_model_value = max([float(item) for item in list(self.best_k_models.values())]) + if self.mode=='max': + best_model_value = max([float(item) for item in list(self.best_k_models.values())]) + else: + best_model_value = min([float(item) for item in list(self.best_k_models.values())]) + version_name = 'version_unkown' + for item in self.dirpath.split('\\'): + if 'version_' in item: + version_name = item # 保存版本信息(准确率等)到txt中 if not os.path.exists('./logs/default/version_info.txt'): with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: - f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n') + f.write(version_name + ' ' + str(best_model_value) + '\n') else: with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: info_list = f.readlines() info_list = [item.strip('\n').split(' ') for item in info_list] # 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 info_list = list(map(list, zip(*info_list))) - if self.dirpath.split('\\')[1] in info_list[0]: + if version_name in info_list[0]: for cou in range(len(info_list[0])): - if self.dirpath.split('\\')[1] == info_list[0][cou]: + if version_name == info_list[0][cou]: info_list[1][cou] = str(best_model_value) else: - info_list[0].append(self.dirpath.split('\\')[1]) + info_list[0].append(version_name) info_list[1].append(str(best_model_value)) # 对list进行转置 info_list = list(map(list, zip(*info_list))) @@ -92,10 +100,10 @@ class SaveCheckpoint(ModelCheckpoint): f.write(line + '\n') # 每次更新ckpt文件后, 将其存放到另一个位置 if self.path_final_save is not None: - zip_dir('./logs/default/' + self.dirpath.split('\\')[1], './' + self.dirpath.split('\\')[1] + '.zip') - if os.path.exists(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip'): - os.remove(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip') - shutil.move('./' + self.dirpath.split('\\')[1] + '.zip', self.path_final_save) + zip_dir('./logs/default/' + version_name, './' + version_name + '.zip') + if os.path.exists(self.path_final_save + '/' + version_name + '.zip'): + os.remove(self.path_final_save + '/' + version_name + '.zip') + shutil.move('./' + version_name + '.zip', self.path_final_save) elif self.verbose: epoch = monitor_candidates.get("epoch") step = monitor_candidates.get("step")