@@ -72,7 +72,6 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
self.model_load_fn = model_load_fn | self.model_load_fn = model_load_fn | ||||
self.delete_after_after = delete_after_train | self.delete_after_after = delete_after_train | ||||
self.encounter_exception = False | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | ||||
@@ -85,6 +84,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
f"save best model when launch using module.") | f"save best model when launch using module.") | ||||
super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
self.encounter_exception = False | |||||
def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
@@ -429,7 +429,8 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.set_optimizers(optimizers=optimizers) | self.driver.set_optimizers(optimizers=optimizers) | ||||
# 根据 progress_bar 参数选择 ProgressBarCallback | # 根据 progress_bar 参数选择 ProgressBarCallback | ||||
callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto')) | |||||
self.progress_bar = kwargs.get('progress_bar', 'auto') | |||||
callbacks = prepare_callbacks(callbacks, self.progress_bar) | |||||
# 初始化 callback manager; | # 初始化 callback manager; | ||||
self.callback_manager = CallbackManager(callbacks) | self.callback_manager = CallbackManager(callbacks) | ||||
# 添加所有的函数式 callbacks; | # 添加所有的函数式 callbacks; | ||||
@@ -272,7 +272,7 @@ def test_model_checkpoint_callback_2( | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver="torch", | driver="torch", | ||||
device=4, | |||||
device=0, | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | ||||
@@ -72,19 +72,6 @@ def model_and_optimizers(request): | |||||
return trainer_params | return trainer_params | ||||
from fastNLP import Metric | |||||
class CountMetrc(Metric): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.register_element('count', 0, aggregate_method='sum') | |||||
def update(self, pred): | |||||
self.count += len(pred) | |||||
def get_metric(self) -> dict: | |||||
return {'cnt': self.count.item()} | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -122,6 +109,7 @@ def test_load_best_model_callback( | |||||
progress_bar='rich', use_dist_sampler=False) | progress_bar='rich', use_dist_sampler=False) | ||||
results = evaluator.run() | results = evaluator.run() | ||||
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | ||||
trainer.driver.barrier() | |||||
if save_folder: | if save_folder: | ||||
import shutil | import shutil | ||||
shutil.rmtree(save_folder, ignore_errors=True) | shutil.rmtree(save_folder, ignore_errors=True) | ||||
@@ -92,12 +92,12 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_model_more_evaluate_callback_1( | def test_model_more_evaluate_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | |||||
device | |||||
): | ): | ||||
for only_state_dict in [True, False]: | for only_state_dict in [True, False]: | ||||
for version in [0, 1]: | for version in [0, 1]: | ||||
@@ -110,7 +110,7 @@ def test_model_more_evaluate_callback_1( | |||||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | ||||
metrics=model_and_optimizers.more_metrics, | metrics=model_and_optimizers.more_metrics, | ||||
evaluate_every=-1, | evaluate_every=-1, | ||||
folder=path, topk=-1, | |||||
folder=path, topk=-1, progress_bar=None, | |||||
topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | ||||
] | ] | ||||
elif version == 1: | elif version == 1: | ||||
@@ -119,7 +119,7 @@ def test_model_more_evaluate_callback_1( | |||||
metrics=model_and_optimizers.more_metrics, | metrics=model_and_optimizers.more_metrics, | ||||
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | ||||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | ||||
save_object='model') | |||||
save_object='model', progress_bar=None) | |||||
] | ] | ||||
n_epochs = 3 | n_epochs = 3 | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -167,6 +167,7 @@ def test_model_more_evaluate_callback_1( | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
break | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@@ -175,7 +176,7 @@ def test_model_more_evaluate_callback_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_checkpoint_callback_1( | def test_trainer_checkpoint_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -241,7 +242,7 @@ def test_trainer_checkpoint_callback_1( | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
n_epochs=5, | |||||
n_epochs=2, | |||||
output_from_new_proc="all", | output_from_new_proc="all", | ||||
evaluate_fn='train_step' | evaluate_fn='train_step' | ||||
) | ) | ||||
@@ -250,7 +251,7 @@ def test_trainer_checkpoint_callback_1( | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
break | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||