@@ -7,6 +7,8 @@ from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
_READERS = {} | |||||
class DataSet(list): | class DataSet(list): | ||||
"""A DataSet object is a list of Instance objects. | """A DataSet object is a list of Instance objects. | ||||
@@ -125,3 +127,24 @@ class DataSet(list): | |||||
self.origin_len = (origin_field + "_origin_len", origin_field) \ | self.origin_len = (origin_field + "_origin_len", origin_field) \ | ||||
if origin_len_name is None else (origin_len_name, origin_field) | if origin_len_name is None else (origin_len_name, origin_field) | ||||
return self | return self | ||||
def __getattribute__(self, name): | |||||
if name in _READERS: | |||||
# add read_*data() support | |||||
def _read(*args, **kwargs): | |||||
data = _READERS[name]().load(*args, **kwargs) | |||||
self.extend(data) | |||||
return self | |||||
return _read | |||||
else: | |||||
return object.__getattribute__(self, name) | |||||
@classmethod | |||||
def set_reader(cls, method_name): | |||||
"""decorator to add dataloader support | |||||
""" | |||||
assert isinstance(method_name, str) | |||||
def wrapper(read_cls): | |||||
_READERS[method_name] = read_cls | |||||
return read_cls | |||||
return wrapper |
@@ -69,6 +69,7 @@ class Vocabulary(object): | |||||
else: | else: | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
self.word2idx = None | self.word2idx = None | ||||
return self | |||||
def build_vocab(self): | def build_vocab(self): | ||||
@@ -84,6 +84,7 @@ class DataSetLoader(BaseLoader): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@DataSet.set_reader('read_raw') | |||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -98,6 +99,7 @@ class RawDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
@DataSet.set_reader('read_pos') | |||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -166,6 +168,7 @@ class POSDataSetLoader(DataSetLoader): | |||||
""" | """ | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_tokenize') | |||||
class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
""" | """ | ||||
Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
@@ -224,7 +227,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_class') | |||||
class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
@@ -262,7 +265,7 @@ class ClassDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
@DataSet.set_reader('read_conll') | |||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
@@ -303,7 +306,7 @@ class ConllLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | pass | ||||
@DataSet.set_reader('read_lm') | |||||
class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
@@ -339,6 +342,7 @@ class LMDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
pass | pass | ||||
@DataSet.set_reader('read_people_daily') | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | People Daily Corpus: Chinese word segmentation, POS tag, NER | ||||
@@ -3,7 +3,7 @@ import unittest | |||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | ||||
PeopleDailyCorpusLoader, ConllLoader | PeopleDailyCorpusLoader, ConllLoader | ||||
from fastNLP.core.dataset import DataSet | |||||
class TestDatasetLoader(unittest.TestCase): | class TestDatasetLoader(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
@@ -15,13 +15,23 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_case_TokenizeDatasetLoader(self): | def test_case_TokenizeDatasetLoader(self): | ||||
loader = TokenizeDataSetLoader() | loader = TokenizeDataSetLoader() | ||||
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32) | |||||
filepath = "./test/data_for_tests/cws_pku_utf_8" | |||||
data = loader.load(filepath, max_seq_len=32) | |||||
assert len(data) > 0 | |||||
data1 = DataSet() | |||||
data1.read_tokenize(filepath, max_seq_len=32) | |||||
assert len(data1) > 0 | |||||
print("pass TokenizeDataSetLoader test!") | print("pass TokenizeDataSetLoader test!") | ||||
def test_case_POSDatasetLoader(self): | def test_case_POSDatasetLoader(self): | ||||
loader = POSDataSetLoader() | loader = POSDataSetLoader() | ||||
filepath = "./test/data_for_tests/people.txt" | |||||
data = loader.load("./test/data_for_tests/people.txt") | data = loader.load("./test/data_for_tests/people.txt") | ||||
datas = loader.load_lines("./test/data_for_tests/people.txt") | datas = loader.load_lines("./test/data_for_tests/people.txt") | ||||
data1 = DataSet().read_pos(filepath) | |||||
assert len(data1) > 0 | |||||
print("pass POSDataSetLoader test!") | print("pass POSDataSetLoader test!") | ||||
def test_case_LMDatasetLoader(self): | def test_case_LMDatasetLoader(self): | ||||