diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index f1f3f2a8..b63cb878 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -7,6 +7,8 @@ from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance from fastNLP.core.vocabulary import Vocabulary +_READERS = {} + class DataSet(list): """A DataSet object is a list of Instance objects. @@ -125,3 +127,24 @@ class DataSet(list): self.origin_len = (origin_field + "_origin_len", origin_field) \ if origin_len_name is None else (origin_len_name, origin_field) return self + + def __getattribute__(self, name): + if name in _READERS: + # add read_*data() support + def _read(*args, **kwargs): + data = _READERS[name]().load(*args, **kwargs) + self.extend(data) + return self + return _read + else: + return object.__getattribute__(self, name) + + @classmethod + def set_reader(cls, method_name): + """decorator to add dataloader support + """ + assert isinstance(method_name, str) + def wrapper(read_cls): + _READERS[method_name] = read_cls + return read_cls + return wrapper diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 77b27b92..e4b17bf4 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -69,6 +69,7 @@ class Vocabulary(object): else: self.word_count[word] += 1 self.word2idx = None + return self def build_vocab(self): diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 5feb62a6..a76eb960 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -84,6 +84,7 @@ class DataSetLoader(BaseLoader): """ raise NotImplementedError +@DataSet.set_reader('read_raw') class RawDataSetLoader(DataSetLoader): def __init__(self): super(RawDataSetLoader, self).__init__() @@ -98,6 +99,7 @@ class RawDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq_dataset(data) +@DataSet.set_reader('read_pos') class POSDataSetLoader(DataSetLoader): """Dataset Loader for POS Tag datasets. @@ -166,6 +168,7 @@ class POSDataSetLoader(DataSetLoader): """ return convert_seq2seq_dataset(data) +@DataSet.set_reader('read_tokenize') class TokenizeDataSetLoader(DataSetLoader): """ Data set loader for tokenization data sets @@ -224,7 +227,7 @@ class TokenizeDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq2seq_dataset(data) - +@DataSet.set_reader('read_class') class ClassDataSetLoader(DataSetLoader): """Loader for classification data sets""" @@ -262,7 +265,7 @@ class ClassDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq2tag_dataset(data) - +@DataSet.set_reader('read_conll') class ConllLoader(DataSetLoader): """loader for conll format files""" @@ -303,7 +306,7 @@ class ConllLoader(DataSetLoader): def convert(self, data): pass - +@DataSet.set_reader('read_lm') class LMDataSetLoader(DataSetLoader): """Language Model Dataset Loader @@ -339,6 +342,7 @@ class LMDataSetLoader(DataSetLoader): def convert(self, data): pass +@DataSet.set_reader('read_people_daily') class PeopleDailyCorpusLoader(DataSetLoader): """ People Daily Corpus: Chinese word segmentation, POS tag, NER diff --git a/test/loader/test_dataset_loader.py b/test/loader/test_dataset_loader.py index 94a7fa71..1914bce9 100644 --- a/test/loader/test_dataset_loader.py +++ b/test/loader/test_dataset_loader.py @@ -3,7 +3,7 @@ import unittest from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ PeopleDailyCorpusLoader, ConllLoader - +from fastNLP.core.dataset import DataSet class TestDatasetLoader(unittest.TestCase): def test_case_1(self): @@ -15,13 +15,23 @@ class TestDatasetLoader(unittest.TestCase): def test_case_TokenizeDatasetLoader(self): loader = TokenizeDataSetLoader() - data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32) + filepath = "./test/data_for_tests/cws_pku_utf_8" + data = loader.load(filepath, max_seq_len=32) + assert len(data) > 0 + + data1 = DataSet() + data1.read_tokenize(filepath, max_seq_len=32) + assert len(data1) > 0 print("pass TokenizeDataSetLoader test!") def test_case_POSDatasetLoader(self): loader = POSDataSetLoader() + filepath = "./test/data_for_tests/people.txt" data = loader.load("./test/data_for_tests/people.txt") datas = loader.load_lines("./test/data_for_tests/people.txt") + + data1 = DataSet().read_pos(filepath) + assert len(data1) > 0 print("pass POSDataSetLoader test!") def test_case_LMDatasetLoader(self):