| @@ -72,7 +72,6 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
| self.model_save_fn = model_save_fn | |||
| self.model_load_fn = model_load_fn | |||
| self.delete_after_after = delete_after_train | |||
| self.encounter_exception = False | |||
| def on_after_trainer_initialized(self, trainer, driver): | |||
| if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | |||
| @@ -85,6 +84,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
| f"save best model when launch using module.") | |||
| super().on_after_trainer_initialized(trainer, driver) | |||
| self.encounter_exception = False | |||
| def on_evaluate_end(self, trainer, results): | |||
| if self.is_better_results(results, keep_if_better=True): | |||
| @@ -429,7 +429,8 @@ class Trainer(TrainerEventTrigger): | |||
| self.driver.set_optimizers(optimizers=optimizers) | |||
| # 根据 progress_bar 参数选择 ProgressBarCallback | |||
| callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto')) | |||
| self.progress_bar = kwargs.get('progress_bar', 'auto') | |||
| callbacks = prepare_callbacks(callbacks, self.progress_bar) | |||
| # 初始化 callback manager; | |||
| self.callback_manager = CallbackManager(callbacks) | |||
| # 添加所有的函数式 callbacks; | |||
| @@ -272,7 +272,7 @@ def test_model_checkpoint_callback_2( | |||
| trainer = Trainer( | |||
| model=model_and_optimizers.model, | |||
| driver="torch", | |||
| device=4, | |||
| device=0, | |||
| optimizers=model_and_optimizers.optimizers, | |||
| train_dataloader=model_and_optimizers.train_dataloader, | |||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
| @@ -72,19 +72,6 @@ def model_and_optimizers(request): | |||
| return trainer_params | |||
| from fastNLP import Metric | |||
| class CountMetrc(Metric): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.register_element('count', 0, aggregate_method='sum') | |||
| def update(self, pred): | |||
| self.count += len(pred) | |||
| def get_metric(self) -> dict: | |||
| return {'cnt': self.count.item()} | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
| @magic_argv_env_context | |||
| @@ -122,6 +109,7 @@ def test_load_best_model_callback( | |||
| progress_bar='rich', use_dist_sampler=False) | |||
| results = evaluator.run() | |||
| assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | |||
| trainer.driver.barrier() | |||
| if save_folder: | |||
| import shutil | |||
| shutil.rmtree(save_folder, ignore_errors=True) | |||
| @@ -92,12 +92,12 @@ def model_and_optimizers(request): | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
| @magic_argv_env_context | |||
| def test_model_more_evaluate_callback_1( | |||
| model_and_optimizers: TrainerParameters, | |||
| driver, | |||
| device, | |||
| device | |||
| ): | |||
| for only_state_dict in [True, False]: | |||
| for version in [0, 1]: | |||
| @@ -110,7 +110,7 @@ def test_model_more_evaluate_callback_1( | |||
| MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
| metrics=model_and_optimizers.more_metrics, | |||
| evaluate_every=-1, | |||
| folder=path, topk=-1, | |||
| folder=path, topk=-1, progress_bar=None, | |||
| topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | |||
| ] | |||
| elif version == 1: | |||
| @@ -119,7 +119,7 @@ def test_model_more_evaluate_callback_1( | |||
| metrics=model_and_optimizers.more_metrics, | |||
| evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||
| folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||
| save_object='model') | |||
| save_object='model', progress_bar=None) | |||
| ] | |||
| n_epochs = 3 | |||
| trainer = Trainer( | |||
| @@ -167,6 +167,7 @@ def test_model_more_evaluate_callback_1( | |||
| trainer.run() | |||
| trainer.driver.barrier() | |||
| break | |||
| finally: | |||
| rank_zero_rm(path) | |||
| @@ -175,7 +176,7 @@ def test_model_more_evaluate_callback_1( | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
| @magic_argv_env_context | |||
| def test_trainer_checkpoint_callback_1( | |||
| model_and_optimizers: TrainerParameters, | |||
| @@ -241,7 +242,7 @@ def test_trainer_checkpoint_callback_1( | |||
| input_mapping=model_and_optimizers.input_mapping, | |||
| output_mapping=model_and_optimizers.output_mapping, | |||
| metrics=model_and_optimizers.metrics, | |||
| n_epochs=5, | |||
| n_epochs=2, | |||
| output_from_new_proc="all", | |||
| evaluate_fn='train_step' | |||
| ) | |||
| @@ -250,7 +251,7 @@ def test_trainer_checkpoint_callback_1( | |||
| trainer.run() | |||
| trainer.driver.barrier() | |||
| break | |||
| finally: | |||
| rank_zero_rm(path) | |||