| @@ -32,6 +32,7 @@ class SaveCheckpoint(ModelCheckpoint): | |||
| :param no_save_before_epoch: | |||
| """ | |||
| super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) | |||
| self.mode = mode | |||
| numpy.random.seed(seed) | |||
| self.seeds = numpy.random.randint(0, 2000, max_epochs) | |||
| pl.seed_everything(seed) | |||
| @@ -66,23 +67,30 @@ class SaveCheckpoint(ModelCheckpoint): | |||
| if self.check_monitor_top_k(trainer, current): | |||
| self._update_best_and_save(current, trainer, monitor_candidates) | |||
| best_model_value = max([float(item) for item in list(self.best_k_models.values())]) | |||
| if self.mode=='max': | |||
| best_model_value = max([float(item) for item in list(self.best_k_models.values())]) | |||
| else: | |||
| best_model_value = min([float(item) for item in list(self.best_k_models.values())]) | |||
| version_name = 'version_unkown' | |||
| for item in self.dirpath.split('\\'): | |||
| if 'version_' in item: | |||
| version_name = item | |||
| # 保存版本信息(准确率等)到txt中 | |||
| if not os.path.exists('./logs/default/version_info.txt'): | |||
| with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: | |||
| f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n') | |||
| f.write(version_name + ' ' + str(best_model_value) + '\n') | |||
| else: | |||
| with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: | |||
| info_list = f.readlines() | |||
| info_list = [item.strip('\n').split(' ') for item in info_list] | |||
| # 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 | |||
| info_list = list(map(list, zip(*info_list))) | |||
| if self.dirpath.split('\\')[1] in info_list[0]: | |||
| if version_name in info_list[0]: | |||
| for cou in range(len(info_list[0])): | |||
| if self.dirpath.split('\\')[1] == info_list[0][cou]: | |||
| if version_name == info_list[0][cou]: | |||
| info_list[1][cou] = str(best_model_value) | |||
| else: | |||
| info_list[0].append(self.dirpath.split('\\')[1]) | |||
| info_list[0].append(version_name) | |||
| info_list[1].append(str(best_model_value)) | |||
| # 对list进行转置 | |||
| info_list = list(map(list, zip(*info_list))) | |||
| @@ -92,10 +100,10 @@ class SaveCheckpoint(ModelCheckpoint): | |||
| f.write(line + '\n') | |||
| # 每次更新ckpt文件后, 将其存放到另一个位置 | |||
| if self.path_final_save is not None: | |||
| zip_dir('./logs/default/' + self.dirpath.split('\\')[1], './' + self.dirpath.split('\\')[1] + '.zip') | |||
| if os.path.exists(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip'): | |||
| os.remove(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip') | |||
| shutil.move('./' + self.dirpath.split('\\')[1] + '.zip', self.path_final_save) | |||
| zip_dir('./logs/default/' + version_name, './' + version_name + '.zip') | |||
| if os.path.exists(self.path_final_save + '/' + version_name + '.zip'): | |||
| os.remove(self.path_final_save + '/' + version_name + '.zip') | |||
| shutil.move('./' + version_name + '.zip', self.path_final_save) | |||
| elif self.verbose: | |||
| epoch = monitor_candidates.get("epoch") | |||
| step = monitor_candidates.get("step") | |||