@@ -1048,7 +1048,7 @@ class DataSet: | |||||
:param dataset 为实例化好的 huggingface Dataset 对象 | :param dataset 为实例化好的 huggingface Dataset 对象 | ||||
""" | """ | ||||
from datasets import Dataset | from datasets import Dataset | ||||
if not isinstance(dataset, DataSet): | |||||
if not isinstance(dataset, Dataset): | |||||
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!") | raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!") | ||||
data_dict = dataset.to_dict() | data_dict = dataset.to_dict() |
@@ -192,7 +192,8 @@ class OneflowDriver(Driver): | |||||
f"`only_state_dict=False`") | f"`only_state_dict=False`") | ||||
if not isinstance(res, dict): | if not isinstance(res, dict): | ||||
res = res.state_dict() | res = res.state_dict() | ||||
model.load_state_dict(res) | |||||
_strict = kwargs.get("strict") | |||||
model.load_state_dict(res, _strict) | |||||
@rank_zero_call | @rank_zero_call | ||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | ||||
@@ -1,5 +1,5 @@ | |||||
from fastNLP.core.collators.packer_unpacker import * | |||||
from fastNLP.core.collators.packer_unpacker import MappingPackerUnpacker, NestedMappingPackerUnpacker, SequencePackerUnpacker | |||||
def test_unpack_batch_mapping(): | def test_unpack_batch_mapping(): | ||||
@@ -97,6 +97,7 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= | |||||
# if dist.is_initialized(): | # if dist.is_initialized(): | ||||
# dist.destroy_process_group() | # dist.destroy_process_group() | ||||
@pytest.mark.deepspeed | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_multi_optimizers(): | def test_multi_optimizers(): | ||||
torch_model = TorchNormalModel_Classification_1(10, 10) | torch_model = TorchNormalModel_Classification_1(10, 10) | ||||