diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 7eb5bbac..8a888c2e 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -591,7 +591,7 @@ class Trainer(TrainerEventTrigger): if model_load_fn is not None: if not callable(model_load_fn): raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") - rank_zero_call(model_load_fn)(folder) + model_load_fn(folder) else: if isinstance(folder, str): folder = Path(folder) @@ -668,7 +668,7 @@ class Trainer(TrainerEventTrigger): if model_load_fn is not None: if not callable(model_load_fn): raise ValueError("Parameter `model_save_fn` should be `Callable`.") - rank_zero_call(model_load_fn)(folder) + model_load_fn(folder) states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) else: states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)