|
|
@@ -5,6 +5,7 @@ import pytorch_lightning as pl |
|
|
import shutil |
|
|
import shutil |
|
|
from pytorch_lightning.utilities import rank_zero_info |
|
|
from pytorch_lightning.utilities import rank_zero_info |
|
|
from utils import zip_dir |
|
|
from utils import zip_dir |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SaveCheckpoint(ModelCheckpoint): |
|
|
class SaveCheckpoint(ModelCheckpoint): |
|
|
@@ -67,14 +68,15 @@ class SaveCheckpoint(ModelCheckpoint): |
|
|
|
|
|
|
|
|
if self.check_monitor_top_k(trainer, current): |
|
|
if self.check_monitor_top_k(trainer, current): |
|
|
self._update_best_and_save(current, trainer, monitor_candidates) |
|
|
self._update_best_and_save(current, trainer, monitor_candidates) |
|
|
if self.mode=='max': |
|
|
|
|
|
|
|
|
if self.mode == 'max': |
|
|
best_model_value = max([float(item) for item in list(self.best_k_models.values())]) |
|
|
best_model_value = max([float(item) for item in list(self.best_k_models.values())]) |
|
|
else: |
|
|
else: |
|
|
best_model_value = min([float(item) for item in list(self.best_k_models.values())]) |
|
|
best_model_value = min([float(item) for item in list(self.best_k_models.values())]) |
|
|
version_name = 'version_unkown' |
|
|
version_name = 'version_unkown' |
|
|
for item in self.dirpath.split('\\'): |
|
|
|
|
|
|
|
|
for item in re.split(r'[/|\\]', self.dirpath): |
|
|
if 'version_' in item: |
|
|
if 'version_' in item: |
|
|
version_name = item |
|
|
version_name = item |
|
|
|
|
|
break |
|
|
# 保存版本信息(准确率等)到txt中 |
|
|
# 保存版本信息(准确率等)到txt中 |
|
|
if not os.path.exists('./logs/default/version_info.txt'): |
|
|
if not os.path.exists('./logs/default/version_info.txt'): |
|
|
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: |
|
|
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: |
|
|
|