diff --git a/main.py b/main.py index 81bdc16..8cd366b 100644 --- a/main.py +++ b/main.py @@ -24,6 +24,7 @@ def main(stage, path_final_save=None, every_n_epochs=1, save_top_k=1, + version_info='无' ): """ 框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程 @@ -50,6 +51,7 @@ def main(stage, :param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0. 非重载训练的情况下, 可以通过调整该值控制训练的次数; :param k_fold: + :param version_info: 版本信息, 主要记录该版本的网络数据集等 """ # 处理输入数据 precision = 32 if (gpus is None and tpu_cores is None) else precision @@ -71,7 +73,7 @@ def main(stage, path_final_save=path_final_save, every_n_epochs=every_n_epochs, verbose=True, monitor='Validation acc', save_top_k=save_top_k, - mode='max') + mode='max', version_info=version_info) training_module = TrainModule(config=config) if kth_fold != kth_fold_start or load_checkpoint_path is None: print('进行初始训练') @@ -99,7 +101,7 @@ def main(stage, if __name__ == "__main__": - main('fit', max_epochs=200, batch_size=128, precision=16, seed=1234, dataset_path='./dataset', k_fold=10, - kth_fold_start=9, + main('fit', max_epochs=1, batch_size=128, precision=16, seed=1234, dataset_path='./dataset', k_fold=10, + kth_fold_start=9, version_info='ResNet-RuLe-CIFAR100', # version_nth=8, ) diff --git a/save_checkpoint.py b/save_checkpoint.py index 24c1a0c..3eda737 100644 --- a/save_checkpoint.py +++ b/save_checkpoint.py @@ -18,7 +18,8 @@ class SaveCheckpoint(ModelCheckpoint): save_top_k=None, verbose=False, mode='min', - no_save_before_epoch=0): + no_save_before_epoch=0, + version_info='无'): """ 通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能. @@ -31,6 +32,7 @@ class SaveCheckpoint(ModelCheckpoint): :param verbose: :param mode: :param no_save_before_epoch: + :param version_info: """ super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) self.mode = mode @@ -42,6 +44,7 @@ class SaveCheckpoint(ModelCheckpoint): self.save_top_k = save_top_k self.flag_sanity_check = 0 self.no_save_before_epoch = no_save_before_epoch + self.version_info = version_info def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """ @@ -80,20 +83,22 @@ class SaveCheckpoint(ModelCheckpoint): # 保存版本信息(准确率等)到txt中 if not os.path.exists('./logs/default/version_info.txt'): with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: - f.write(version_name + ' ' + str(best_model_value) + '\n') + f.write(version_name + ' ' + str(best_model_value) + ' ' + self.version_info + '\n') else: with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: info_list = f.readlines() info_list = [item.strip('\n').split(' ') for item in info_list] - # 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 + # 对list进行转置, 转置前行为版本号和其数据, 列为不同的版本 info_list = list(map(list, zip(*info_list))) if version_name in info_list[0]: for cou in range(len(info_list[0])): if version_name == info_list[0][cou]: info_list[1][cou] = str(best_model_value) + info_list[2][cou] = self.version_info else: info_list[0].append(version_name) info_list[1].append(str(best_model_value)) + info_list[2].append(self.version_info) # 对list进行转置 info_list = list(map(list, zip(*info_list))) with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: