Browse Source

add dataloader register

tags/v0.2.0^2
yunfan 6 years ago
parent
commit
267baec224
5 changed files with 147 additions and 26 deletions
  1. +4
    -2
      fastNLP/core/__init__.py
  2. +22
    -12
      fastNLP/core/dataset.py
  3. +4
    -4
      fastNLP/core/trainer.py
  4. +36
    -0
      fastNLP/io/base_loader.py
  5. +81
    -8
      fastNLP/io/dataset_loader.py

+ 4
- 2
fastNLP/core/__init__.py View File

@@ -1,5 +1,5 @@
from .batch import Batch
from .dataset import DataSet
# from .dataset import DataSet
from .fieldarray import FieldArray
from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
@@ -8,4 +8,6 @@ from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester
from .trainer import Trainer
from .vocabulary import Vocabulary
from .vocabulary import Vocabulary
from ..io.dataset_loader import DataSet


+ 22
- 12
fastNLP/core/dataset.py View File

@@ -5,8 +5,7 @@ import numpy as np
from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance
from fastNLP.core.utils import get_func_signature

_READERS = {}
from fastNLP.io.base_loader import DataLoaderRegister


class DataSet(object):
@@ -98,6 +97,24 @@ class DataSet(object):
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __getattr__(self, item):
if item == "field_arrays":
raise AttributeError
# TODO dataset.x
if item in self.field_arrays:
return self.field_arrays[item]
try:
reader = DataLoaderRegister.get_reader(item)
return reader
except AttributeError:
raise

def __setstate__(self, state):
self.__dict__ = state

def __getstate__(self):
return self.__dict__

def __len__(self):
"""Fetch the length of the dataset.

@@ -226,16 +243,6 @@ class DataSet(object):
"""
return [name for name, field in self.field_arrays.items() if field.is_target]

@classmethod
def set_reader(cls, method_name):
assert isinstance(method_name, str)

def wrapper(read_cls):
_READERS[method_name] = read_cls
return read_cls

return wrapper

def apply(self, func, new_field_name=None, **kwargs):
"""Apply a function to every instance of the DataSet.

@@ -347,6 +354,9 @@ class DataSet(object):
_dict[header].append(content)
return cls(_dict)

# def read_pos(self):
# return DataLoaderRegister.get_reader('read_pos')

def save(self, path):
"""Save the DataSet object as pickle.



+ 4
- 4
fastNLP/core/trainer.py View File

@@ -85,8 +85,8 @@ class Trainer(object):
if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None
elif metrics is not None:
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')

# prepare loss
losser = _prepare_losser(loss)
@@ -147,7 +147,7 @@ class Trainer(object):

self._mode(self.model, is_test=False)

self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
print("training epochs started " + self.start_time, flush=True)
if self.save_path is None:
class psudoSW:
@@ -260,7 +260,7 @@ class Trainer(object):
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.step)
if self.save_path is not None and self._better_eval_result(res):
metric_key = self.metric_key if self.metric_key is not None else "None"
metric_key = self.metric_key if self.metric_key is not None else ""
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time]))
return res


+ 36
- 0
fastNLP/io/base_loader.py View File

@@ -29,3 +29,39 @@ class BaseLoader(object):
with open(cache_path, 'wb') as f:
pickle.dump(obj, f)
return obj


class ToyLoader0(BaseLoader):
"""
For CharLM
"""

def __init__(self, data_path):
super(ToyLoader0, self).__init__(data_path)

def load(self):
with open(self.data_path, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()


class DataLoaderRegister:
""""register for data sets"""
_readers = {}

@classmethod
def set_reader(cls, reader_cls, read_fn_name):
# def wrapper(reader_cls):
if read_fn_name in cls._readers:
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name))
if hasattr(reader_cls, 'load'):
cls._readers[read_fn_name] = reader_cls().load
return reader_cls

@classmethod
def get_reader(cls, read_fn_name):
if read_fn_name in cls._readers:
return cls._readers[read_fn_name]
raise AttributeError('no read function: {}'.format(read_fn_name))

+ 81
- 8
fastNLP/io/dataset_loader.py View File

@@ -2,7 +2,7 @@ import os

from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.io.base_loader import BaseLoader
from fastNLP.io.base_loader import DataLoaderRegister


def convert_seq_dataset(data):
@@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data):
return dataset


class DataSetLoader(BaseLoader):
class DataSetLoader:
""""loader for data sets"""

def __init__(self):
super(DataSetLoader, self).__init__()

def load(self, path):
""" load data in `path` into a dataset
"""
@@ -104,9 +101,9 @@ class RawDataSetLoader(DataSetLoader):

def convert(self, data):
return convert_seq_dataset(data)
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')


@DataSet.set_reader('read_pos')
class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.

@@ -174,9 +171,9 @@ class POSDataSetLoader(DataSetLoader):
"""Convert lists of strings into Instances with Fields.
"""
return convert_seq2seq_dataset(data)
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')


@DataSet.set_reader('read_tokenize')
class TokenizeDataSetLoader(DataSetLoader):
"""
Data set loader for tokenization data sets
@@ -236,7 +233,6 @@ class TokenizeDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data)


@DataSet.set_reader('read_class')
class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""

@@ -275,6 +271,83 @@ class ClassDataSetLoader(DataSetLoader):
return convert_seq2tag_dataset(data)


class ConllLoader(DataSetLoader):
"""loader for conll format files"""

def __init__(self):
"""
:param str data_path: the path to the conll data set
"""
super(ConllLoader, self).__init__()

def load(self, data_path):
"""
:return: list lines: all lines in a conll file
"""
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return self.convert(data)

@staticmethod
def parse(lines):
"""
:param list lines:a list containing all lines in a conll file.
:return: a 3D list
"""
sentences = list()
tokens = list()
for line in lines:
if line[0] == "#":
# skip the comments
continue
if line == "\n":
sentences.append(tokens)
tokens = []
continue
tokens.append(line.split())
return sentences

def convert(self, data):
pass


class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

This loader produces data for language model training in a supervised way.
That means it has X and Y.

"""

def __init__(self):
super(LMDataSetLoader, self).__init__()

def load(self, data_path):
if not os.path.exists(data_path):
raise FileNotFoundError("file {} not found.".format(data_path))
with open(data_path, "r", encoding="utf=8") as f:
text = " ".join(f.readlines())
tokens = text.strip().split()
data = self.sentence_cut(tokens)
return self.convert(data)

def sentence_cut(self, tokens, sentence_length=15):
start_idx = 0
data_set = []
for idx in range(len(tokens) // sentence_length):
x = tokens[start_idx * idx: start_idx * idx + sentence_length]
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1]
if start_idx * idx + sentence_length + 1 >= len(tokens):
# ad hoc
y.extend(["<unk>"])
data_set.append([x, y])
return data_set

def convert(self, data):
pass


@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader):
"""


Loading…
Cancel
Save