| @@ -32,6 +32,7 @@ class SaveCheckpoint(ModelCheckpoint): | |||||
| :param no_save_before_epoch: | :param no_save_before_epoch: | ||||
| """ | """ | ||||
| super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) | super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) | ||||
| self.mode = mode | |||||
| numpy.random.seed(seed) | numpy.random.seed(seed) | ||||
| self.seeds = numpy.random.randint(0, 2000, max_epochs) | self.seeds = numpy.random.randint(0, 2000, max_epochs) | ||||
| pl.seed_everything(seed) | pl.seed_everything(seed) | ||||
| @@ -66,23 +67,30 @@ class SaveCheckpoint(ModelCheckpoint): | |||||
| if self.check_monitor_top_k(trainer, current): | if self.check_monitor_top_k(trainer, current): | ||||
| self._update_best_and_save(current, trainer, monitor_candidates) | self._update_best_and_save(current, trainer, monitor_candidates) | ||||
| best_model_value = max([float(item) for item in list(self.best_k_models.values())]) | |||||
| if self.mode=='max': | |||||
| best_model_value = max([float(item) for item in list(self.best_k_models.values())]) | |||||
| else: | |||||
| best_model_value = min([float(item) for item in list(self.best_k_models.values())]) | |||||
| version_name = 'version_unkown' | |||||
| for item in self.dirpath.split('\\'): | |||||
| if 'version_' in item: | |||||
| version_name = item | |||||
| # 保存版本信息(准确率等)到txt中 | # 保存版本信息(准确率等)到txt中 | ||||
| if not os.path.exists('./logs/default/version_info.txt'): | if not os.path.exists('./logs/default/version_info.txt'): | ||||
| with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: | with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: | ||||
| f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n') | |||||
| f.write(version_name + ' ' + str(best_model_value) + '\n') | |||||
| else: | else: | ||||
| with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: | with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: | ||||
| info_list = f.readlines() | info_list = f.readlines() | ||||
| info_list = [item.strip('\n').split(' ') for item in info_list] | info_list = [item.strip('\n').split(' ') for item in info_list] | ||||
| # 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 | # 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 | ||||
| info_list = list(map(list, zip(*info_list))) | info_list = list(map(list, zip(*info_list))) | ||||
| if self.dirpath.split('\\')[1] in info_list[0]: | |||||
| if version_name in info_list[0]: | |||||
| for cou in range(len(info_list[0])): | for cou in range(len(info_list[0])): | ||||
| if self.dirpath.split('\\')[1] == info_list[0][cou]: | |||||
| if version_name == info_list[0][cou]: | |||||
| info_list[1][cou] = str(best_model_value) | info_list[1][cou] = str(best_model_value) | ||||
| else: | else: | ||||
| info_list[0].append(self.dirpath.split('\\')[1]) | |||||
| info_list[0].append(version_name) | |||||
| info_list[1].append(str(best_model_value)) | info_list[1].append(str(best_model_value)) | ||||
| # 对list进行转置 | # 对list进行转置 | ||||
| info_list = list(map(list, zip(*info_list))) | info_list = list(map(list, zip(*info_list))) | ||||
| @@ -92,10 +100,10 @@ class SaveCheckpoint(ModelCheckpoint): | |||||
| f.write(line + '\n') | f.write(line + '\n') | ||||
| # 每次更新ckpt文件后, 将其存放到另一个位置 | # 每次更新ckpt文件后, 将其存放到另一个位置 | ||||
| if self.path_final_save is not None: | if self.path_final_save is not None: | ||||
| zip_dir('./logs/default/' + self.dirpath.split('\\')[1], './' + self.dirpath.split('\\')[1] + '.zip') | |||||
| if os.path.exists(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip'): | |||||
| os.remove(self.path_final_save + '/' + self.dirpath.split('\\')[1] + '.zip') | |||||
| shutil.move('./' + self.dirpath.split('\\')[1] + '.zip', self.path_final_save) | |||||
| zip_dir('./logs/default/' + version_name, './' + version_name + '.zip') | |||||
| if os.path.exists(self.path_final_save + '/' + version_name + '.zip'): | |||||
| os.remove(self.path_final_save + '/' + version_name + '.zip') | |||||
| shutil.move('./' + version_name + '.zip', self.path_final_save) | |||||
| elif self.verbose: | elif self.verbose: | ||||
| epoch = monitor_candidates.get("epoch") | epoch = monitor_candidates.get("epoch") | ||||
| step = monitor_candidates.get("step") | step = monitor_candidates.get("step") | ||||