Browse Source

保存版本训练结果时输出备注的版本信息, 主要记录该版本的网络数据集等.

master
shenyan 4 years ago
parent
commit
67ec993c27
2 changed files with 13 additions and 6 deletions
  1. +5
    -3
      main.py
  2. +8
    -3
      save_checkpoint.py

+ 5
- 3
main.py View File

@@ -24,6 +24,7 @@ def main(stage,
path_final_save=None,
every_n_epochs=1,
save_top_k=1,
version_info='无'
):
"""
框架的入口函数. 包含设置超参数, 划分数据集, 选择训练或测试等流程
@@ -50,6 +51,7 @@ def main(stage,
:param kth_fold_start: 从第几个fold开始, 若使用重载训练, 则kth_fold_start为重载第几个fold, 第一个值为0.
非重载训练的情况下, 可以通过调整该值控制训练的次数;
:param k_fold:
:param version_info: 版本信息, 主要记录该版本的网络数据集等
"""
# 处理输入数据
precision = 32 if (gpus is None and tpu_cores is None) else precision
@@ -71,7 +73,7 @@ def main(stage,
path_final_save=path_final_save,
every_n_epochs=every_n_epochs, verbose=True,
monitor='Validation acc', save_top_k=save_top_k,
mode='max')
mode='max', version_info=version_info)
training_module = TrainModule(config=config)
if kth_fold != kth_fold_start or load_checkpoint_path is None:
print('进行初始训练')
@@ -99,7 +101,7 @@ def main(stage,


if __name__ == "__main__":
main('fit', max_epochs=200, batch_size=128, precision=16, seed=1234, dataset_path='./dataset', k_fold=10,
kth_fold_start=9,
main('fit', max_epochs=1, batch_size=128, precision=16, seed=1234, dataset_path='./dataset', k_fold=10,
kth_fold_start=9, version_info='ResNet-RuLe-CIFAR100',
# version_nth=8,
)

+ 8
- 3
save_checkpoint.py View File

@@ -18,7 +18,8 @@ class SaveCheckpoint(ModelCheckpoint):
save_top_k=None,
verbose=False,
mode='min',
no_save_before_epoch=0):
no_save_before_epoch=0,
version_info='无'):
"""
通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能.

@@ -31,6 +32,7 @@ class SaveCheckpoint(ModelCheckpoint):
:param verbose:
:param mode:
:param no_save_before_epoch:
:param version_info:
"""
super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode)
self.mode = mode
@@ -42,6 +44,7 @@ class SaveCheckpoint(ModelCheckpoint):
self.save_top_k = save_top_k
self.flag_sanity_check = 0
self.no_save_before_epoch = no_save_before_epoch
self.version_info = version_info

def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""
@@ -80,20 +83,22 @@ class SaveCheckpoint(ModelCheckpoint):
# 保存版本信息(准确率等)到txt中
if not os.path.exists('./logs/default/version_info.txt'):
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
f.write(version_name + ' ' + str(best_model_value) + '\n')
f.write(version_name + ' ' + str(best_model_value) + ' ' + self.version_info + '\n')
else:
with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f:
info_list = f.readlines()
info_list = [item.strip('\n').split(' ') for item in info_list]
# 对list进行转置, 现在行为版本号和其数据, 列为不同的版本
# 对list进行转置, 转置前行为版本号和其数据, 列为不同的版本
info_list = list(map(list, zip(*info_list)))
if version_name in info_list[0]:
for cou in range(len(info_list[0])):
if version_name == info_list[0][cou]:
info_list[1][cou] = str(best_model_value)
info_list[2][cou] = self.version_info
else:
info_list[0].append(version_name)
info_list[1].append(str(best_model_value))
info_list[2].append(self.version_info)
# 对list进行转置
info_list = list(map(list, zip(*info_list)))
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:


Loading…
Cancel
Save