Browse Source

1. fix bugs of DataSet.from_dataset 2. fix bugs of some tests 3. add to OneflowDriver.load_model

dev0.8.0
x54-729 2 years ago
parent
commit
749415970e
4 changed files with 5 additions and 3 deletions
  1. +1
    -1
      fastNLP/core/dataset/dataset.py
  2. +2
    -1
      fastNLP/core/drivers/oneflow_driver/oneflow_driver.py
  3. +1
    -1
      tests/core/collators/test_pakcer_unpacker.py
  4. +1
    -0
      tests/core/drivers/torch_driver/test_deepspeed.py

+ 1
- 1
fastNLP/core/dataset/dataset.py View File

@@ -1048,7 +1048,7 @@ class DataSet:
:param dataset 为实例化好的 huggingface Dataset 对象 :param dataset 为实例化好的 huggingface Dataset 对象
""" """
from datasets import Dataset from datasets import Dataset
if not isinstance(dataset, DataSet):
if not isinstance(dataset, Dataset):
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!") raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")


data_dict = dataset.to_dict() data_dict = dataset.to_dict()

+ 2
- 1
fastNLP/core/drivers/oneflow_driver/oneflow_driver.py View File

@@ -192,7 +192,8 @@ class OneflowDriver(Driver):
f"`only_state_dict=False`") f"`only_state_dict=False`")
if not isinstance(res, dict): if not isinstance(res, dict):
res = res.state_dict() res = res.state_dict()
model.load_state_dict(res)
_strict = kwargs.get("strict")
model.load_state_dict(res, _strict)


@rank_zero_call @rank_zero_call
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):


+ 1
- 1
tests/core/collators/test_pakcer_unpacker.py View File

@@ -1,5 +1,5 @@


from fastNLP.core.collators.packer_unpacker import *
from fastNLP.core.collators.packer_unpacker import MappingPackerUnpacker, NestedMappingPackerUnpacker, SequencePackerUnpacker




def test_unpack_batch_mapping(): def test_unpack_batch_mapping():


+ 1
- 0
tests/core/drivers/torch_driver/test_deepspeed.py View File

@@ -97,6 +97,7 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
# if dist.is_initialized(): # if dist.is_initialized():
# dist.destroy_process_group() # dist.destroy_process_group()


@pytest.mark.deepspeed
@magic_argv_env_context @magic_argv_env_context
def test_multi_optimizers(): def test_multi_optimizers():
torch_model = TorchNormalModel_Classification_1(10, 10) torch_model = TorchNormalModel_Classification_1(10, 10)


Loading…
Cancel
Save