Browse Source

fix some bug

tags/v1.0.0alpha
yh 2 years ago
parent
commit
b5b48d58e9
5 changed files with 13 additions and 23 deletions
  1. +1
    -1
      fastNLP/core/callbacks/load_best_model_callback.py
  2. +2
    -1
      fastNLP/core/controllers/trainer.py
  3. +1
    -1
      tests/core/callbacks/test_checkpoint_callback_torch.py
  4. +1
    -13
      tests/core/callbacks/test_load_best_model_callback_torch.py
  5. +8
    -7
      tests/core/callbacks/test_more_evaluate_callback.py

+ 1
- 1
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -72,7 +72,6 @@ class LoadBestModelCallback(HasMonitorCallback):
self.model_save_fn = model_save_fn self.model_save_fn = model_save_fn
self.model_load_fn = model_load_fn self.model_load_fn = model_load_fn
self.delete_after_after = delete_after_train self.delete_after_after = delete_after_train
self.encounter_exception = False


def on_after_trainer_initialized(self, trainer, driver): def on_after_trainer_initialized(self, trainer, driver):
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1:
@@ -85,6 +84,7 @@ class LoadBestModelCallback(HasMonitorCallback):
f"save best model when launch using module.") f"save best model when launch using module.")


super().on_after_trainer_initialized(trainer, driver) super().on_after_trainer_initialized(trainer, driver)
self.encounter_exception = False


def on_evaluate_end(self, trainer, results): def on_evaluate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True): if self.is_better_results(results, keep_if_better=True):


+ 2
- 1
fastNLP/core/controllers/trainer.py View File

@@ -429,7 +429,8 @@ class Trainer(TrainerEventTrigger):
self.driver.set_optimizers(optimizers=optimizers) self.driver.set_optimizers(optimizers=optimizers)


# 根据 progress_bar 参数选择 ProgressBarCallback # 根据 progress_bar 参数选择 ProgressBarCallback
callbacks = prepare_callbacks(callbacks, kwargs.get('progress_bar', 'auto'))
self.progress_bar = kwargs.get('progress_bar', 'auto')
callbacks = prepare_callbacks(callbacks, self.progress_bar)
# 初始化 callback manager; # 初始化 callback manager;
self.callback_manager = CallbackManager(callbacks) self.callback_manager = CallbackManager(callbacks)
# 添加所有的函数式 callbacks; # 添加所有的函数式 callbacks;


+ 1
- 1
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -272,7 +272,7 @@ def test_model_checkpoint_callback_2(
trainer = Trainer( trainer = Trainer(
model=model_and_optimizers.model, model=model_and_optimizers.model,
driver="torch", driver="torch",
device=4,
device=0,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,


+ 1
- 13
tests/core/callbacks/test_load_best_model_callback_torch.py View File

@@ -72,19 +72,6 @@ def model_and_optimizers(request):
return trainer_params return trainer_params




from fastNLP import Metric
class CountMetrc(Metric):
def __init__(self):
super().__init__()
self.register_element('count', 0, aggregate_method='sum')

def update(self, pred):
self.count += len(pred)

def get_metric(self) -> dict:
return {'cnt': self.count.item()}


@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context @magic_argv_env_context
@@ -122,6 +109,7 @@ def test_load_best_model_callback(
progress_bar='rich', use_dist_sampler=False) progress_bar='rich', use_dist_sampler=False)
results = evaluator.run() results = evaluator.run()
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1'])
trainer.driver.barrier()
if save_folder: if save_folder:
import shutil import shutil
shutil.rmtree(save_folder, ignore_errors=True) shutil.rmtree(save_folder, ignore_errors=True)


+ 8
- 7
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -92,12 +92,12 @@ def model_and_optimizers(request):




@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context @magic_argv_env_context
def test_model_more_evaluate_callback_1( def test_model_more_evaluate_callback_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
device,
device
): ):
for only_state_dict in [True, False]: for only_state_dict in [True, False]:
for version in [0, 1]: for version in [0, 1]:
@@ -110,7 +110,7 @@ def test_model_more_evaluate_callback_1(
MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders, MoreEvaluateCallback(dataloaders=model_and_optimizers.evaluate_dataloaders,
metrics=model_and_optimizers.more_metrics, metrics=model_and_optimizers.more_metrics,
evaluate_every=-1, evaluate_every=-1,
folder=path, topk=-1,
folder=path, topk=-1, progress_bar=None,
topk_monitor='acc', only_state_dict=only_state_dict, save_object='model') topk_monitor='acc', only_state_dict=only_state_dict, save_object='model')
] ]
elif version == 1: elif version == 1:
@@ -119,7 +119,7 @@ def test_model_more_evaluate_callback_1(
metrics=model_and_optimizers.more_metrics, metrics=model_and_optimizers.more_metrics,
evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False, evaluate_every=None, watch_monitor='loss', watch_monitor_larger_better=False,
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict,
save_object='model')
save_object='model', progress_bar=None)
] ]
n_epochs = 3 n_epochs = 3
trainer = Trainer( trainer = Trainer(
@@ -167,6 +167,7 @@ def test_model_more_evaluate_callback_1(


trainer.run() trainer.run()
trainer.driver.barrier() trainer.driver.barrier()
break
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


@@ -175,7 +176,7 @@ def test_model_more_evaluate_callback_1(




@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context @magic_argv_env_context
def test_trainer_checkpoint_callback_1( def test_trainer_checkpoint_callback_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
@@ -241,7 +242,7 @@ def test_trainer_checkpoint_callback_1(
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
n_epochs=5,
n_epochs=2,
output_from_new_proc="all", output_from_new_proc="all",
evaluate_fn='train_step' evaluate_fn='train_step'
) )
@@ -250,7 +251,7 @@ def test_trainer_checkpoint_callback_1(


trainer.run() trainer.run()
trainer.driver.barrier() trainer.driver.barrier()
break
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)




Loading…
Cancel
Save