|
|
@@ -591,7 +591,7 @@ class Trainer(TrainerEventTrigger): |
|
|
|
if model_load_fn is not None: |
|
|
|
if not callable(model_load_fn): |
|
|
|
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") |
|
|
|
rank_zero_call(model_load_fn)(folder) |
|
|
|
model_load_fn(folder) |
|
|
|
else: |
|
|
|
if isinstance(folder, str): |
|
|
|
folder = Path(folder) |
|
|
@@ -668,7 +668,7 @@ class Trainer(TrainerEventTrigger): |
|
|
|
if model_load_fn is not None: |
|
|
|
if not callable(model_load_fn): |
|
|
|
raise ValueError("Parameter `model_save_fn` should be `Callable`.") |
|
|
|
rank_zero_call(model_load_fn)(folder) |
|
|
|
model_load_fn(folder) |
|
|
|
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) |
|
|
|
else: |
|
|
|
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) |
|
|
|