@@ -72,7 +72,6 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
self.model_save_fn = model_save_fn | |||
self.model_load_fn = model_load_fn | |||
self.delete_after_after = delete_after_train | |||
self.encounter_exception = False | |||
def on_after_trainer_initialized(self, trainer, driver): | |||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | |||
@@ -85,6 +84,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
f"save best model when launch using module.") | |||
super().on_after_trainer_initialized(trainer, driver) | |||
self.encounter_exception = False | |||
def on_evaluate_end(self, trainer, results): | |||
if self.is_better_results(results, keep_if_better=True): | |||
@@ -429,7 +429,8 @@ class Trainer(TrainerEventTrigger): | |||
self.driver.set_optimizers(optimizers=optimizers) | |||
# 根据 progress_bar 参数选择 ProgressBarCallback | |||
callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto')) | |||
self.progress_bar = kwargs.get('progress_bar', 'auto') | |||
callbacks = prepare_callbacks(callbacks, self.progress_bar) | |||
# 初始化 callback manager; | |||
self.callback_manager = CallbackManager(callbacks) | |||
# 添加所有的函数式 callbacks; | |||
@@ -272,7 +272,7 @@ def test_model_checkpoint_callback_2( | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver="torch", | |||
device=4, | |||
device=0, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
@@ -72,19 +72,6 @@ def model_and_optimizers(request): | |||
return trainer_params | |||
from fastNLP import Metric | |||
class CountMetrc(Metric): | |||
def __init__(self): | |||
super().__init__() | |||
self.register_element('count', 0, aggregate_method='sum') | |||
def update(self, pred): | |||
self.count += len(pred) | |||
def get_metric(self) -> dict: | |||
return {'cnt': self.count.item()} | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@magic_argv_env_context | |||
@@ -122,6 +109,7 @@ def test_load_best_model_callback( | |||
progress_bar='rich', use_dist_sampler=False) | |||
results = evaluator.run() | |||
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | |||
trainer.driver.barrier() | |||
if save_folder: | |||
import shutil | |||
shutil.rmtree(save_folder, ignore_errors=True) | |||
@@ -92,12 +92,12 @@ def model_and_optimizers(request): | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@magic_argv_env_context | |||
def test_model_more_evaluate_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
device | |||
): | |||
for only_state_dict in [True, False]: | |||
for version in [0, 1]: | |||
@@ -110,7 +110,7 @@ def test_model_more_evaluate_callback_1( | |||
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
metrics=model_and_optimizers.more_metrics, | |||
evaluate_every=-1, | |||
folder=path, topk=-1, | |||
folder=path, topk=-1, progress_bar=None, | |||
topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') | |||
] | |||
elif version == 1: | |||
@@ -119,7 +119,7 @@ def test_model_more_evaluate_callback_1( | |||
metrics=model_and_optimizers.more_metrics, | |||
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, | |||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | |||
save_object='model') | |||
save_object='model', progress_bar=None) | |||
] | |||
n_epochs = 3 | |||
trainer = Trainer( | |||
@@ -167,6 +167,7 @@ def test_model_more_evaluate_callback_1( | |||
trainer.run() | |||
trainer.driver.barrier() | |||
break | |||
finally: | |||
rank_zero_rm(path) | |||
@@ -175,7 +176,7 @@ def test_model_more_evaluate_callback_1( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@magic_argv_env_context | |||
def test_trainer_checkpoint_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
@@ -241,7 +242,7 @@ def test_trainer_checkpoint_callback_1( | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=5, | |||
n_epochs=2, | |||
output_from_new_proc="all", | |||
evaluate_fn='train_step' | |||
) | |||
@@ -250,7 +251,7 @@ def test_trainer_checkpoint_callback_1( | |||
trainer.run() | |||
trainer.driver.barrier() | |||
break | |||
finally: | |||
rank_zero_rm(path) | |||