From 749415970ec256b2d0e7e2ea7748ee5afee95c29 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 13 Sep 2022 14:41:22 +0800 Subject: [PATCH] 1. fix bugs of DataSet.from_dataset 2. fix bugs of some tests 3. add to OneflowDriver.load_model --- fastNLP/core/dataset/dataset.py | 2 +- fastNLP/core/drivers/oneflow_driver/oneflow_driver.py | 3 ++- tests/core/collators/test_pakcer_unpacker.py | 2 +- tests/core/drivers/torch_driver/test_deepspeed.py | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 1a3afd6d..438d84b6 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -1048,7 +1048,7 @@ class DataSet: :param dataset 为实例化好的 huggingface Dataset 对象 """ from datasets import Dataset - if not isinstance(dataset, DataSet): + if not isinstance(dataset, Dataset): raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!") data_dict = dataset.to_dict() diff --git a/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py b/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py index 7e3f8e4a..35d8d8bf 100644 --- a/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py +++ b/fastNLP/core/drivers/oneflow_driver/oneflow_driver.py @@ -192,7 +192,8 @@ class OneflowDriver(Driver): f"`only_state_dict=False`") if not isinstance(res, dict): res = res.state_dict() - model.load_state_dict(res) + _strict = kwargs.get("strict") + model.load_state_dict(res, _strict) @rank_zero_call def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): diff --git a/tests/core/collators/test_pakcer_unpacker.py b/tests/core/collators/test_pakcer_unpacker.py index 8e5b5e09..84756a00 100644 --- a/tests/core/collators/test_pakcer_unpacker.py +++ b/tests/core/collators/test_pakcer_unpacker.py @@ -1,5 +1,5 @@ -from fastNLP.core.collators.packer_unpacker import * +from fastNLP.core.collators.packer_unpacker import MappingPackerUnpacker, NestedMappingPackerUnpacker, SequencePackerUnpacker def test_unpack_batch_mapping(): diff --git a/tests/core/drivers/torch_driver/test_deepspeed.py b/tests/core/drivers/torch_driver/test_deepspeed.py index 462648bd..41a0f796 100644 --- a/tests/core/drivers/torch_driver/test_deepspeed.py +++ b/tests/core/drivers/torch_driver/test_deepspeed.py @@ -97,6 +97,7 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= # if dist.is_initialized(): # dist.destroy_process_group() +@pytest.mark.deepspeed @magic_argv_env_context def test_multi_optimizers(): torch_model = TorchNormalModel_Classification_1(10, 10)