add dataset read functionstags/v0.2.0
@@ -7,6 +7,8 @@ from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
_READERS = {} | |||||
class DataSet(list): | class DataSet(list): | ||||
"""A DataSet object is a list of Instance objects. | """A DataSet object is a list of Instance objects. | ||||
@@ -135,3 +137,24 @@ class DataSet(list): | |||||
self.origin_len = (origin_field + "_origin_len", origin_field) \ | self.origin_len = (origin_field + "_origin_len", origin_field) \ | ||||
if origin_len_name is None else (origin_len_name, origin_field) | if origin_len_name is None else (origin_len_name, origin_field) | ||||
return self | return self | ||||
def __getattribute__(self, name): | |||||
if name in _READERS: | |||||
# add read_*data() support | |||||
def _read(*args, **kwargs): | |||||
data = _READERS[name]().load(*args, **kwargs) | |||||
self.extend(data) | |||||
return self | |||||
return _read | |||||
else: | |||||
return object.__getattribute__(self, name) | |||||
@classmethod | |||||
def set_reader(cls, method_name): | |||||
"""decorator to add dataloader support | |||||
""" | |||||
assert isinstance(method_name, str) | |||||
def wrapper(read_cls): | |||||
_READERS[method_name] = read_cls | |||||
return read_cls | |||||
return wrapper |
@@ -70,6 +70,7 @@ class Vocabulary(object): | |||||
else: | else: | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
self.word2idx = None | self.word2idx = None | ||||
return self | |||||
def build_vocab(self): | def build_vocab(self): | ||||
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` | """build 'word to index' dict, and filter the word using `max_size` and `min_freq` | ||||
@@ -88,6 +88,7 @@ class DataSetLoader(BaseLoader): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@DataSet.set_reader('read_raw') | |||||
class RawDataSetLoader(DataSetLoader): | class RawDataSetLoader(DataSetLoader): | ||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -103,6 +104,7 @@ class RawDataSetLoader(DataSetLoader): | |||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
@DataSet.set_reader('read_pos') | |||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -172,6 +174,7 @@ class POSDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_tokenize') | |||||
class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
""" | """ | ||||
Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
@@ -231,6 +234,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_class') | |||||
class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
@@ -269,6 +273,7 @@ class ClassDataSetLoader(DataSetLoader): | |||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
@DataSet.set_reader('read_conll') | |||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
@@ -310,6 +315,7 @@ class ConllLoader(DataSetLoader): | |||||
pass | pass | ||||
@DataSet.set_reader('read_lm') | |||||
class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
@@ -346,6 +352,7 @@ class LMDataSetLoader(DataSetLoader): | |||||
pass | pass | ||||
@DataSet.set_reader('read_people_daily') | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | People Daily Corpus: Chinese word segmentation, POS tag, NER | ||||
@@ -3,7 +3,7 @@ import unittest | |||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | ||||
PeopleDailyCorpusLoader, ConllLoader | PeopleDailyCorpusLoader, ConllLoader | ||||
from fastNLP.core.dataset import DataSet | |||||
class TestDatasetLoader(unittest.TestCase): | class TestDatasetLoader(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
@@ -15,13 +15,23 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_case_TokenizeDatasetLoader(self): | def test_case_TokenizeDatasetLoader(self): | ||||
loader = TokenizeDataSetLoader() | loader = TokenizeDataSetLoader() | ||||
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32) | |||||
filepath = "./test/data_for_tests/cws_pku_utf_8" | |||||
data = loader.load(filepath, max_seq_len=32) | |||||
assert len(data) > 0 | |||||
data1 = DataSet() | |||||
data1.read_tokenize(filepath, max_seq_len=32) | |||||
assert len(data1) > 0 | |||||
print("pass TokenizeDataSetLoader test!") | print("pass TokenizeDataSetLoader test!") | ||||
def test_case_POSDatasetLoader(self): | def test_case_POSDatasetLoader(self): | ||||
loader = POSDataSetLoader() | loader = POSDataSetLoader() | ||||
filepath = "./test/data_for_tests/people.txt" | |||||
data = loader.load("./test/data_for_tests/people.txt") | data = loader.load("./test/data_for_tests/people.txt") | ||||
datas = loader.load_lines("./test/data_for_tests/people.txt") | datas = loader.load_lines("./test/data_for_tests/people.txt") | ||||
data1 = DataSet().read_pos(filepath) | |||||
assert len(data1) > 0 | |||||
print("pass POSDataSetLoader test!") | print("pass POSDataSetLoader test!") | ||||
def test_case_LMDatasetLoader(self): | def test_case_LMDatasetLoader(self): | ||||