From a3adafea3449fd1f63c6c36164db85cc0b896efe Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Wed, 6 Jul 2022 22:27:28 +0800 Subject: [PATCH] add dataset method --- fastNLP/core/dataset/dataset.py | 16 +++++++++++++++- tests/core/dataset/test_dataset.py | 7 +++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index f91bc930..d75d742c 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -1037,4 +1037,18 @@ class DataSet: self.collator.set_ignore(*field_names) return self.collator else: - raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") \ No newline at end of file + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") + + @classmethod + def from_datasets(cls, dataset): + """ + 将 Huggingface Dataset 转为 fastNLP 的 DataSet + + :param dataset 为实例化好的 huggingface Dataset 对象 + """ + from datasets import Dataset + if not isinstance(dataset, DataSet): + raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!") + + data_dict = dataset.to_dict() + return DataSet(data_dict) \ No newline at end of file diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index 8fd0e726..5f342a00 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -522,3 +522,10 @@ class TestCase: ins = Instance(**fields) # simple print, that is enough. print(ins) + + def test_dataset(self): + from datasets import Dataset as HuggingfaceDataset + # ds = DataSet({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100}) + ds = HuggingfaceDataset.from_dict({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100}) + print(DataSet.from_datasets(ds)) + # print(ds.from_datasets())