Browse Source

add dataset read functions

tags/v0.2.0
yunfan 5 years ago
parent
commit
ebbfcb7829
4 changed files with 43 additions and 5 deletions
  1. +23
    -0
      fastNLP/core/dataset.py
  2. +1
    -0
      fastNLP/core/vocabulary.py
  3. +7
    -3
      fastNLP/loader/dataset_loader.py
  4. +12
    -2
      test/loader/test_dataset_loader.py

+ 23
- 0
fastNLP/core/dataset.py View File

@@ -7,6 +7,8 @@ from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary

_READERS = {}

class DataSet(list):
"""A DataSet object is a list of Instance objects.

@@ -125,3 +127,24 @@ class DataSet(list):
self.origin_len = (origin_field + "_origin_len", origin_field) \
if origin_len_name is None else (origin_len_name, origin_field)
return self

def __getattribute__(self, name):
if name in _READERS:
# add read_*data() support
def _read(*args, **kwargs):
data = _READERS[name]().load(*args, **kwargs)
self.extend(data)
return self
return _read
else:
return object.__getattribute__(self, name)

@classmethod
def set_reader(cls, method_name):
"""decorator to add dataloader support
"""
assert isinstance(method_name, str)
def wrapper(read_cls):
_READERS[method_name] = read_cls
return read_cls
return wrapper

+ 1
- 0
fastNLP/core/vocabulary.py View File

@@ -69,6 +69,7 @@ class Vocabulary(object):
else:
self.word_count[word] += 1
self.word2idx = None
return self


def build_vocab(self):


+ 7
- 3
fastNLP/loader/dataset_loader.py View File

@@ -84,6 +84,7 @@ class DataSetLoader(BaseLoader):
"""
raise NotImplementedError

@DataSet.set_reader('read_raw')
class RawDataSetLoader(DataSetLoader):
def __init__(self):
super(RawDataSetLoader, self).__init__()
@@ -98,6 +99,7 @@ class RawDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq_dataset(data)

@DataSet.set_reader('read_pos')
class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.

@@ -166,6 +168,7 @@ class POSDataSetLoader(DataSetLoader):
"""
return convert_seq2seq_dataset(data)

@DataSet.set_reader('read_tokenize')
class TokenizeDataSetLoader(DataSetLoader):
"""
Data set loader for tokenization data sets
@@ -224,7 +227,7 @@ class TokenizeDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq2seq_dataset(data)

@DataSet.set_reader('read_class')
class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""

@@ -262,7 +265,7 @@ class ClassDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq2tag_dataset(data)

@DataSet.set_reader('read_conll')
class ConllLoader(DataSetLoader):
"""loader for conll format files"""

@@ -303,7 +306,7 @@ class ConllLoader(DataSetLoader):
def convert(self, data):
pass

@DataSet.set_reader('read_lm')
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

@@ -339,6 +342,7 @@ class LMDataSetLoader(DataSetLoader):
def convert(self, data):
pass

@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader):
"""
People Daily Corpus: Chinese word segmentation, POS tag, NER


+ 12
- 2
test/loader/test_dataset_loader.py View File

@@ -3,7 +3,7 @@ import unittest

from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
PeopleDailyCorpusLoader, ConllLoader
from fastNLP.core.dataset import DataSet

class TestDatasetLoader(unittest.TestCase):
def test_case_1(self):
@@ -15,13 +15,23 @@ class TestDatasetLoader(unittest.TestCase):

def test_case_TokenizeDatasetLoader(self):
loader = TokenizeDataSetLoader()
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32)
filepath = "./test/data_for_tests/cws_pku_utf_8"
data = loader.load(filepath, max_seq_len=32)
assert len(data) > 0

data1 = DataSet()
data1.read_tokenize(filepath, max_seq_len=32)
assert len(data1) > 0
print("pass TokenizeDataSetLoader test!")

def test_case_POSDatasetLoader(self):
loader = POSDataSetLoader()
filepath = "./test/data_for_tests/people.txt"
data = loader.load("./test/data_for_tests/people.txt")
datas = loader.load_lines("./test/data_for_tests/people.txt")

data1 = DataSet().read_pos(filepath)
assert len(data1) > 0
print("pass POSDataSetLoader test!")

def test_case_LMDatasetLoader(self):


Loading…
Cancel
Save