Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
6ddcedaaeb
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      fastNLP/core/controllers/trainer.py

+ 2
- 2
fastNLP/core/controllers/trainer.py View File

@@ -591,7 +591,7 @@ class Trainer(TrainerEventTrigger):
if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
rank_zero_call(model_load_fn)(folder)
model_load_fn(folder)
else:
if isinstance(folder, str):
folder = Path(folder)
@@ -668,7 +668,7 @@ class Trainer(TrainerEventTrigger):
if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable`.")
rank_zero_call(model_load_fn)(folder)
model_load_fn(folder)
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
else:
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)


Loading…
Cancel
Save