Browse Source

add dataloader register

tags/v0.2.0^2
yunfan 6 years ago
parent
commit
267baec224
5 changed files with 147 additions and 26 deletions
  1. +4
    -2
      fastNLP/core/__init__.py
  2. +22
    -12
      fastNLP/core/dataset.py
  3. +4
    -4
      fastNLP/core/trainer.py
  4. +36
    -0
      fastNLP/io/base_loader.py
  5. +81
    -8
      fastNLP/io/dataset_loader.py

+ 4
- 2
fastNLP/core/__init__.py View File

@@ -1,5 +1,5 @@
from .batch import Batch from .batch import Batch
from .dataset import DataSet
# from .dataset import DataSet
from .fieldarray import FieldArray from .fieldarray import FieldArray
from .instance import Instance from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
@@ -8,4 +8,6 @@ from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .vocabulary import Vocabulary
from .vocabulary import Vocabulary
from ..io.dataset_loader import DataSet


+ 22
- 12
fastNLP/core/dataset.py View File

@@ -5,8 +5,7 @@ import numpy as np
from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature

_READERS = {}
from fastNLP.io.base_loader import DataLoaderRegister




class DataSet(object): class DataSet(object):
@@ -98,6 +97,24 @@ class DataSet(object):
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))


def __getattr__(self, item):
if item == "field_arrays":
raise AttributeError
# TODO dataset.x
if item in self.field_arrays:
return self.field_arrays[item]
try:
reader = DataLoaderRegister.get_reader(item)
return reader
except AttributeError:
raise

def __setstate__(self, state):
self.__dict__ = state

def __getstate__(self):
return self.__dict__

def __len__(self): def __len__(self):
"""Fetch the length of the dataset. """Fetch the length of the dataset.


@@ -226,16 +243,6 @@ class DataSet(object):
""" """
return [name for name, field in self.field_arrays.items() if field.is_target] return [name for name, field in self.field_arrays.items() if field.is_target]


@classmethod
def set_reader(cls, method_name):
assert isinstance(method_name, str)

def wrapper(read_cls):
_READERS[method_name] = read_cls
return read_cls

return wrapper

def apply(self, func, new_field_name=None, **kwargs): def apply(self, func, new_field_name=None, **kwargs):
"""Apply a function to every instance of the DataSet. """Apply a function to every instance of the DataSet.


@@ -347,6 +354,9 @@ class DataSet(object):
_dict[header].append(content) _dict[header].append(content)
return cls(_dict) return cls(_dict)


# def read_pos(self):
# return DataLoaderRegister.get_reader('read_pos')

def save(self, path): def save(self, path):
"""Save the DataSet object as pickle. """Save the DataSet object as pickle.




+ 4
- 4
fastNLP/core/trainer.py View File

@@ -85,8 +85,8 @@ class Trainer(object):
if metric_key is not None: if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None
elif metrics is not None:
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')


# prepare loss # prepare loss
losser = _prepare_losser(loss) losser = _prepare_losser(loss)
@@ -147,7 +147,7 @@ class Trainer(object):


self._mode(self.model, is_test=False) self._mode(self.model, is_test=False)


self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
print("training epochs started " + self.start_time, flush=True) print("training epochs started " + self.start_time, flush=True)
if self.save_path is None: if self.save_path is None:
class psudoSW: class psudoSW:
@@ -260,7 +260,7 @@ class Trainer(object):
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.step) global_step=self.step)
if self.save_path is not None and self._better_eval_result(res): if self.save_path is not None and self._better_eval_result(res):
metric_key = self.metric_key if self.metric_key is not None else "None"
metric_key = self.metric_key if self.metric_key is not None else ""
self._save_model(self.model, self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) "best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time]))
return res return res


+ 36
- 0
fastNLP/io/base_loader.py View File

@@ -29,3 +29,39 @@ class BaseLoader(object):
with open(cache_path, 'wb') as f: with open(cache_path, 'wb') as f:
pickle.dump(obj, f) pickle.dump(obj, f)
return obj return obj


class ToyLoader0(BaseLoader):
"""
For CharLM
"""

def __init__(self, data_path):
super(ToyLoader0, self).__init__(data_path)

def load(self):
with open(self.data_path, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()


class DataLoaderRegister:
""""register for data sets"""
_readers = {}

@classmethod
def set_reader(cls, reader_cls, read_fn_name):
# def wrapper(reader_cls):
if read_fn_name in cls._readers:
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name))
if hasattr(reader_cls, 'load'):
cls._readers[read_fn_name] = reader_cls().load
return reader_cls

@classmethod
def get_reader(cls, read_fn_name):
if read_fn_name in cls._readers:
return cls._readers[read_fn_name]
raise AttributeError('no read function: {}'.format(read_fn_name))

+ 81
- 8
fastNLP/io/dataset_loader.py View File

@@ -2,7 +2,7 @@ import os


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.io.base_loader import BaseLoader
from fastNLP.io.base_loader import DataLoaderRegister




def convert_seq_dataset(data): def convert_seq_dataset(data):
@@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data):
return dataset return dataset




class DataSetLoader(BaseLoader):
class DataSetLoader:
""""loader for data sets""" """"loader for data sets"""


def __init__(self):
super(DataSetLoader, self).__init__()

def load(self, path): def load(self, path):
""" load data in `path` into a dataset """ load data in `path` into a dataset
""" """
@@ -104,9 +101,9 @@ class RawDataSetLoader(DataSetLoader):


def convert(self, data): def convert(self, data):
return convert_seq_dataset(data) return convert_seq_dataset(data)
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')




@DataSet.set_reader('read_pos')
class POSDataSetLoader(DataSetLoader): class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets. """Dataset Loader for POS Tag datasets.


@@ -174,9 +171,9 @@ class POSDataSetLoader(DataSetLoader):
"""Convert lists of strings into Instances with Fields. """Convert lists of strings into Instances with Fields.
""" """
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')




@DataSet.set_reader('read_tokenize')
class TokenizeDataSetLoader(DataSetLoader): class TokenizeDataSetLoader(DataSetLoader):
""" """
Data set loader for tokenization data sets Data set loader for tokenization data sets
@@ -236,7 +233,6 @@ class TokenizeDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)




@DataSet.set_reader('read_class')
class ClassDataSetLoader(DataSetLoader): class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets""" """Loader for classification data sets"""


@@ -275,6 +271,83 @@ class ClassDataSetLoader(DataSetLoader):
return convert_seq2tag_dataset(data) return convert_seq2tag_dataset(data)




class ConllLoader(DataSetLoader):
"""loader for conll format files"""

def __init__(self):
"""
:param str data_path: the path to the conll data set
"""
super(ConllLoader, self).__init__()

def load(self, data_path):
"""
:return: list lines: all lines in a conll file
"""
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return self.convert(data)

@staticmethod
def parse(lines):
"""
:param list lines:a list containing all lines in a conll file.
:return: a 3D list
"""
sentences = list()
tokens = list()
for line in lines:
if line[0] == "#":
# skip the comments
continue
if line == "\n":
sentences.append(tokens)
tokens = []
continue
tokens.append(line.split())
return sentences

def convert(self, data):
pass


class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

This loader produces data for language model training in a supervised way.
That means it has X and Y.

"""

def __init__(self):
super(LMDataSetLoader, self).__init__()

def load(self, data_path):
if not os.path.exists(data_path):
raise FileNotFoundError("file {} not found.".format(data_path))
with open(data_path, "r", encoding="utf=8") as f:
text = " ".join(f.readlines())
tokens = text.strip().split()
data = self.sentence_cut(tokens)
return self.convert(data)

def sentence_cut(self, tokens, sentence_length=15):
start_idx = 0
data_set = []
for idx in range(len(tokens) // sentence_length):
x = tokens[start_idx * idx: start_idx * idx + sentence_length]
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1]
if start_idx * idx + sentence_length + 1 >= len(tokens):
# ad hoc
y.extend(["<unk>"])
data_set.append([x, y])
return data_set

def convert(self, data):
pass


@DataSet.set_reader('read_people_daily') @DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
""" """


Loading…
Cancel
Save