| @@ -72,7 +72,6 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
| self.model_load_fn = model_load_fn | self.model_load_fn = model_load_fn | ||||
| self.delete_after_after = delete_after_train | self.delete_after_after = delete_after_train | ||||
| self.encounter_exception = False | |||||
| def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
| if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | ||||
| @@ -85,6 +84,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| f"save best model when launch using module.") | f"save best model when launch using module.") | ||||
| super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
| self.encounter_exception = False | |||||
| def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
| if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
| @@ -429,7 +429,8 @@ class Trainer(TrainerEventTrigger): | |||||
| self.driver.set_optimizers(optimizers=optimizers) | self.driver.set_optimizers(optimizers=optimizers) | ||||
| # 根据 progress_bar 参数选择 ProgressBarCallback | # 根据 progress_bar 参数选择 ProgressBarCallback | ||||
| callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto')) | |||||
| self.progress_bar = kwargs.get('progress_bar', 'auto') | |||||
| callbacks = prepare_callbacks(callbacks, self.progress_bar) | |||||
| # 初始化 callback manager; | # 初始化 callback manager; | ||||
| self.callback_manager = CallbackManager(callbacks) | self.callback_manager = CallbackManager(callbacks) | ||||
| # 添加所有的函数式 callbacks; | # 添加所有的函数式 callbacks; | ||||
| @@ -272,7 +272,7 @@ def test_model_checkpoint_callback_2( | |||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
| driver="torch", | driver="torch", | ||||
| device=4, | |||||
| device=0, | |||||
| optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
| train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | ||||
| @@ -72,19 +72,6 @@ def model_and_optimizers(request): | |||||
| return trainer_params | return trainer_params | ||||
| from fastNLP import Metric | |||||
| class CountMetrc(Metric): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.register_element('count', 0, aggregate_method='sum') | |||||
| def update(self, pred): | |||||
| self.count += len(pred) | |||||
| def get_metric(self) -> dict: | |||||
| return {'cnt': self.count.item()} | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -122,6 +109,7 @@ def test_load_best_model_callback( | |||||
| progress_bar='rich', use_dist_sampler=False) | progress_bar='rich', use_dist_sampler=False) | ||||
| results = evaluator.run() | results = evaluator.run() | ||||
| assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | ||||
| trainer.driver.barrier() | |||||
| if save_folder: | if save_folder: | ||||
| import shutil | import shutil | ||||
| shutil.rmtree(save_folder, ignore_errors=True) | shutil.rmtree(save_folder, ignore_errors=True) | ||||
| @@ -92,12 +92,12 @@ def model_and_optimizers(request): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_model_more_evaluate_callback_1( | def test_model_more_evaluate_callback_1( | ||||
| model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
| driver, | driver, | ||||
| device, | |||||
| device | |||||
| ): | ): | ||||
| for only_state_dict in [True, False]: | for only_state_dict in [True, False]: | ||||
| for version in [0, 1]: | for version in [0, 1]: | ||||
| @@ -110,7 +110,7 @@ def test_model_more_evaluate_callback_1( | |||||
| MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | ||||
| metrics=model_and_optimizers.more_metrics, | metrics=model_and_optimizers.more_metrics, | ||||
| evaluate_every=-1, | evaluate_every=-1, | ||||
| folder=path, topk=-1, | |||||
| folder=path, topk=-1, progress_bar=None, | |||||
| topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | ||||
| ] | ] | ||||
| elif version == 1: | elif version == 1: | ||||
| @@ -119,7 +119,7 @@ def test_model_more_evaluate_callback_1( | |||||
| metrics=model_and_optimizers.more_metrics, | metrics=model_and_optimizers.more_metrics, | ||||
| evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | ||||
| folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | ||||
| save_object='model') | |||||
| save_object='model', progress_bar=None) | |||||
| ] | ] | ||||
| n_epochs = 3 | n_epochs = 3 | ||||
| trainer = Trainer( | trainer = Trainer( | ||||
| @@ -167,6 +167,7 @@ def test_model_more_evaluate_callback_1( | |||||
| trainer.run() | trainer.run() | ||||
| trainer.driver.barrier() | trainer.driver.barrier() | ||||
| break | |||||
| finally: | finally: | ||||
| rank_zero_rm(path) | rank_zero_rm(path) | ||||
| @@ -175,7 +176,7 @@ def test_model_more_evaluate_callback_1( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_checkpoint_callback_1( | def test_trainer_checkpoint_callback_1( | ||||
| model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
| @@ -241,7 +242,7 @@ def test_trainer_checkpoint_callback_1( | |||||
| input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
| output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
| metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
| n_epochs=5, | |||||
| n_epochs=2, | |||||
| output_from_new_proc="all", | output_from_new_proc="all", | ||||
| evaluate_fn='train_step' | evaluate_fn='train_step' | ||||
| ) | ) | ||||
| @@ -250,7 +251,7 @@ def test_trainer_checkpoint_callback_1( | |||||
| trainer.run() | trainer.run() | ||||
| trainer.driver.barrier() | trainer.driver.barrier() | ||||
| break | |||||
| finally: | finally: | ||||
| rank_zero_rm(path) | rank_zero_rm(path) | ||||