| @@ -7,6 +7,7 @@ from typing import Optional, Callable, Union | |||||
| from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
| from io import BytesIO | from io import BytesIO | ||||
| import shutil | import shutil | ||||
| import pickle | |||||
| from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| @@ -63,6 +64,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
| self.model_load_fn = model_load_fn | self.model_load_fn = model_load_fn | ||||
| self.delete_after_after = delete_after_train | self.delete_after_after = delete_after_train | ||||
| self.meta = {'epoch': -1, 'batch': -1} | |||||
| def prepare_save_folder(self, trainer): | def prepare_save_folder(self, trainer): | ||||
| if not hasattr(self, 'real_save_folder'): | if not hasattr(self, 'real_save_folder'): | ||||
| @@ -87,6 +89,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| else: # 创建出一个 stringio | else: # 创建出一个 stringio | ||||
| self.real_save_folder = None | self.real_save_folder = None | ||||
| self.buffer = BytesIO() | self.buffer = BytesIO() | ||||
| def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
| super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
| @@ -94,6 +97,8 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
| if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
| self.meta['epoch'] = trainer.cur_epoch_idx | |||||
| self.meta['batch'] = trainer.global_forward_batches | |||||
| self.prepare_save_folder(trainer) | self.prepare_save_folder(trainer) | ||||
| if self.real_save_folder: | if self.real_save_folder: | ||||
| trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| @@ -102,17 +107,17 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| self.buffer.seek(0) | self.buffer.seek(0) | ||||
| with all_rank_call_context(): | with all_rank_call_context(): | ||||
| trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
| def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
| if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | ||||
| # 如果是分布式且报错了,就不要加载了,防止barrier的问题 | # 如果是分布式且报错了,就不要加载了,防止barrier的问题 | ||||
| if not (trainer.driver.is_distributed() and self.encounter_exception): | if not (trainer.driver.is_distributed() and self.encounter_exception): | ||||
| if self.real_save_folder: | if self.real_save_folder: | ||||
| logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||||
| logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']})...") | |||||
| trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| model_load_fn=self.model_load_fn) | model_load_fn=self.model_load_fn) | ||||
| else: | else: | ||||
| logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||||
| logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']})...") | |||||
| self.buffer.seek(0) | self.buffer.seek(0) | ||||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
| if self.delete_after_after: | if self.delete_after_after: | ||||