|
|
@@ -192,7 +192,8 @@ class OneflowDriver(Driver): |
|
|
|
f"`only_state_dict=False`") |
|
|
|
if not isinstance(res, dict): |
|
|
|
res = res.state_dict() |
|
|
|
model.load_state_dict(res) |
|
|
|
_strict = kwargs.get("strict") |
|
|
|
model.load_state_dict(res, _strict) |
|
|
|
|
|
|
|
@rank_zero_call |
|
|
|
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): |
|
|
|