|
|
|
@@ -18,7 +18,8 @@ class SaveCheckpoint(ModelCheckpoint): |
|
|
|
save_top_k=None, |
|
|
|
verbose=False, |
|
|
|
mode='min', |
|
|
|
no_save_before_epoch=0): |
|
|
|
no_save_before_epoch=0, |
|
|
|
version_info='无'): |
|
|
|
""" |
|
|
|
通过回调实现checkpoint的保存逻辑, 同时具有回调函数中定义on_validation_end等功能. |
|
|
|
|
|
|
|
@@ -31,6 +32,7 @@ class SaveCheckpoint(ModelCheckpoint): |
|
|
|
:param verbose: |
|
|
|
:param mode: |
|
|
|
:param no_save_before_epoch: |
|
|
|
:param version_info: |
|
|
|
""" |
|
|
|
super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) |
|
|
|
self.mode = mode |
|
|
|
@@ -42,6 +44,7 @@ class SaveCheckpoint(ModelCheckpoint): |
|
|
|
self.save_top_k = save_top_k |
|
|
|
self.flag_sanity_check = 0 |
|
|
|
self.no_save_before_epoch = no_save_before_epoch |
|
|
|
self.version_info = version_info |
|
|
|
|
|
|
|
def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: |
|
|
|
""" |
|
|
|
@@ -80,20 +83,22 @@ class SaveCheckpoint(ModelCheckpoint): |
|
|
|
# 保存版本信息(准确率等)到txt中 |
|
|
|
if not os.path.exists('./logs/default/version_info.txt'): |
|
|
|
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: |
|
|
|
f.write(version_name + ' ' + str(best_model_value) + '\n') |
|
|
|
f.write(version_name + ' ' + str(best_model_value) + ' ' + self.version_info + '\n') |
|
|
|
else: |
|
|
|
with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: |
|
|
|
info_list = f.readlines() |
|
|
|
info_list = [item.strip('\n').split(' ') for item in info_list] |
|
|
|
# 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 |
|
|
|
# 对list进行转置, 转置前行为版本号和其数据, 列为不同的版本 |
|
|
|
info_list = list(map(list, zip(*info_list))) |
|
|
|
if version_name in info_list[0]: |
|
|
|
for cou in range(len(info_list[0])): |
|
|
|
if version_name == info_list[0][cou]: |
|
|
|
info_list[1][cou] = str(best_model_value) |
|
|
|
info_list[2][cou] = self.version_info |
|
|
|
else: |
|
|
|
info_list[0].append(version_name) |
|
|
|
info_list[1].append(str(best_model_value)) |
|
|
|
info_list[2].append(self.version_info) |
|
|
|
# 对list进行转置 |
|
|
|
info_list = list(map(list, zip(*info_list))) |
|
|
|
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: |
|
|
|
|