@@ -7,6 +7,8 @@ from fastNLP.core.field import TextField, LabelField | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.vocabulary import Vocabulary | |||
_READERS = {} | |||
class DataSet(list): | |||
"""A DataSet object is a list of Instance objects. | |||
@@ -125,3 +127,24 @@ class DataSet(list): | |||
self.origin_len = (origin_field + "_origin_len", origin_field) \ | |||
if origin_len_name is None else (origin_len_name, origin_field) | |||
return self | |||
def __getattribute__(self, name): | |||
if name in _READERS: | |||
# add read_*data() support | |||
def _read(*args, **kwargs): | |||
data = _READERS[name]().load(*args, **kwargs) | |||
self.extend(data) | |||
return self | |||
return _read | |||
else: | |||
return object.__getattribute__(self, name) | |||
@classmethod | |||
def set_reader(cls, method_name): | |||
"""decorator to add dataloader support | |||
""" | |||
assert isinstance(method_name, str) | |||
def wrapper(read_cls): | |||
_READERS[method_name] = read_cls | |||
return read_cls | |||
return wrapper |
@@ -69,6 +69,7 @@ class Vocabulary(object): | |||
else: | |||
self.word_count[word] += 1 | |||
self.word2idx = None | |||
return self | |||
def build_vocab(self): | |||
@@ -84,6 +84,7 @@ class DataSetLoader(BaseLoader): | |||
""" | |||
raise NotImplementedError | |||
@DataSet.set_reader('read_raw') | |||
class RawDataSetLoader(DataSetLoader): | |||
def __init__(self): | |||
super(RawDataSetLoader, self).__init__() | |||
@@ -98,6 +99,7 @@ class RawDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq_dataset(data) | |||
@DataSet.set_reader('read_pos') | |||
class POSDataSetLoader(DataSetLoader): | |||
"""Dataset Loader for POS Tag datasets. | |||
@@ -166,6 +168,7 @@ class POSDataSetLoader(DataSetLoader): | |||
""" | |||
return convert_seq2seq_dataset(data) | |||
@DataSet.set_reader('read_tokenize') | |||
class TokenizeDataSetLoader(DataSetLoader): | |||
""" | |||
Data set loader for tokenization data sets | |||
@@ -224,7 +227,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq2seq_dataset(data) | |||
@DataSet.set_reader('read_class') | |||
class ClassDataSetLoader(DataSetLoader): | |||
"""Loader for classification data sets""" | |||
@@ -262,7 +265,7 @@ class ClassDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq2tag_dataset(data) | |||
@DataSet.set_reader('read_conll') | |||
class ConllLoader(DataSetLoader): | |||
"""loader for conll format files""" | |||
@@ -303,7 +306,7 @@ class ConllLoader(DataSetLoader): | |||
def convert(self, data): | |||
pass | |||
@DataSet.set_reader('read_lm') | |||
class LMDataSetLoader(DataSetLoader): | |||
"""Language Model Dataset Loader | |||
@@ -339,6 +342,7 @@ class LMDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
pass | |||
@DataSet.set_reader('read_people_daily') | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
""" | |||
People Daily Corpus: Chinese word segmentation, POS tag, NER | |||
@@ -3,7 +3,7 @@ import unittest | |||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | |||
PeopleDailyCorpusLoader, ConllLoader | |||
from fastNLP.core.dataset import DataSet | |||
class TestDatasetLoader(unittest.TestCase): | |||
def test_case_1(self): | |||
@@ -15,13 +15,23 @@ class TestDatasetLoader(unittest.TestCase): | |||
def test_case_TokenizeDatasetLoader(self): | |||
loader = TokenizeDataSetLoader() | |||
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32) | |||
filepath = "./test/data_for_tests/cws_pku_utf_8" | |||
data = loader.load(filepath, max_seq_len=32) | |||
assert len(data) > 0 | |||
data1 = DataSet() | |||
data1.read_tokenize(filepath, max_seq_len=32) | |||
assert len(data1) > 0 | |||
print("pass TokenizeDataSetLoader test!") | |||
def test_case_POSDatasetLoader(self): | |||
loader = POSDataSetLoader() | |||
filepath = "./test/data_for_tests/people.txt" | |||
data = loader.load("./test/data_for_tests/people.txt") | |||
datas = loader.load_lines("./test/data_for_tests/people.txt") | |||
data1 = DataSet().read_pos(filepath) | |||
assert len(data1) > 0 | |||
print("pass POSDataSetLoader test!") | |||
def test_case_LMDatasetLoader(self): | |||