From 267baec2244b1812fa3bdb01a66b7c05986352c2 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 7 Dec 2018 15:19:56 +0800 Subject: [PATCH] add dataloader register --- fastNLP/core/__init__.py | 6 ++- fastNLP/core/dataset.py | 34 +++++++++----- fastNLP/core/trainer.py | 8 ++-- fastNLP/io/base_loader.py | 36 +++++++++++++++ fastNLP/io/dataset_loader.py | 89 ++++++++++++++++++++++++++++++++---- 5 files changed, 147 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 44f30fad..038ca12f 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -1,5 +1,5 @@ from .batch import Batch -from .dataset import DataSet +# from .dataset import DataSet from .fieldarray import FieldArray from .instance import Instance from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward @@ -8,4 +8,6 @@ from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .tester import Tester from .trainer import Trainer -from .vocabulary import Vocabulary \ No newline at end of file +from .vocabulary import Vocabulary +from ..io.dataset_loader import DataSet + diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index d4d285d7..a08961fc 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -5,8 +5,7 @@ import numpy as np from fastNLP.core.fieldarray import FieldArray from fastNLP.core.instance import Instance from fastNLP.core.utils import get_func_signature - -_READERS = {} +from fastNLP.io.base_loader import DataLoaderRegister class DataSet(object): @@ -98,6 +97,24 @@ class DataSet(object): else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) + def __getattr__(self, item): + if item == "field_arrays": + raise AttributeError + # TODO dataset.x + if item in self.field_arrays: + return self.field_arrays[item] + try: + reader = DataLoaderRegister.get_reader(item) + return reader + except AttributeError: + raise + + def __setstate__(self, state): + self.__dict__ = state + + def __getstate__(self): + return self.__dict__ + def __len__(self): """Fetch the length of the dataset. @@ -226,16 +243,6 @@ class DataSet(object): """ return [name for name, field in self.field_arrays.items() if field.is_target] - @classmethod - def set_reader(cls, method_name): - assert isinstance(method_name, str) - - def wrapper(read_cls): - _READERS[method_name] = read_cls - return read_cls - - return wrapper - def apply(self, func, new_field_name=None, **kwargs): """Apply a function to every instance of the DataSet. @@ -347,6 +354,9 @@ class DataSet(object): _dict[header].append(content) return cls(_dict) + # def read_pos(self): + # return DataLoaderRegister.get_reader('read_pos') + def save(self, path): """Save the DataSet object as pickle. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index c2bca3a2..6cb6b560 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -85,8 +85,8 @@ class Trainer(object): if metric_key is not None: self.increase_better = False if metric_key[0] == "-" else True self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key - else: - self.metric_key = None + elif metrics is not None: + self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') # prepare loss losser = _prepare_losser(loss) @@ -147,7 +147,7 @@ class Trainer(object): self._mode(self.model, is_test=False) - self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) print("training epochs started " + self.start_time, flush=True) if self.save_path is None: class psudoSW: @@ -260,7 +260,7 @@ class Trainer(object): self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, global_step=self.step) if self.save_path is not None and self._better_eval_result(res): - metric_key = self.metric_key if self.metric_key is not None else "None" + metric_key = self.metric_key if self.metric_key is not None else "" self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) return res diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index b0b0d864..a3ce410b 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -29,3 +29,39 @@ class BaseLoader(object): with open(cache_path, 'wb') as f: pickle.dump(obj, f) return obj + + +class ToyLoader0(BaseLoader): + """ + For CharLM + """ + + def __init__(self, data_path): + super(ToyLoader0, self).__init__(data_path) + + def load(self): + with open(self.data_path, 'r') as f: + corpus = f.read().lower() + import re + corpus = re.sub(r"", "unk", corpus) + return corpus.split() + + +class DataLoaderRegister: + """"register for data sets""" + _readers = {} + + @classmethod + def set_reader(cls, reader_cls, read_fn_name): + # def wrapper(reader_cls): + if read_fn_name in cls._readers: + raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) + if hasattr(reader_cls, 'load'): + cls._readers[read_fn_name] = reader_cls().load + return reader_cls + + @classmethod + def get_reader(cls, read_fn_name): + if read_fn_name in cls._readers: + return cls._readers[read_fn_name] + raise AttributeError('no read function: {}'.format(read_fn_name)) diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 0d30c6e8..a1cfe33f 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -2,7 +2,7 @@ import os from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance -from fastNLP.io.base_loader import BaseLoader +from fastNLP.io.base_loader import DataLoaderRegister def convert_seq_dataset(data): @@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data): return dataset -class DataSetLoader(BaseLoader): +class DataSetLoader: """"loader for data sets""" - def __init__(self): - super(DataSetLoader, self).__init__() - def load(self, path): """ load data in `path` into a dataset """ @@ -104,9 +101,9 @@ class RawDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq_dataset(data) +DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') -@DataSet.set_reader('read_pos') class POSDataSetLoader(DataSetLoader): """Dataset Loader for POS Tag datasets. @@ -174,9 +171,9 @@ class POSDataSetLoader(DataSetLoader): """Convert lists of strings into Instances with Fields. """ return convert_seq2seq_dataset(data) +DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') -@DataSet.set_reader('read_tokenize') class TokenizeDataSetLoader(DataSetLoader): """ Data set loader for tokenization data sets @@ -236,7 +233,6 @@ class TokenizeDataSetLoader(DataSetLoader): return convert_seq2seq_dataset(data) -@DataSet.set_reader('read_class') class ClassDataSetLoader(DataSetLoader): """Loader for classification data sets""" @@ -275,6 +271,83 @@ class ClassDataSetLoader(DataSetLoader): return convert_seq2tag_dataset(data) +class ConllLoader(DataSetLoader): + """loader for conll format files""" + + def __init__(self): + """ + :param str data_path: the path to the conll data set + """ + super(ConllLoader, self).__init__() + + def load(self, data_path): + """ + :return: list lines: all lines in a conll file + """ + with open(data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + data = self.parse(lines) + return self.convert(data) + + @staticmethod + def parse(lines): + """ + :param list lines:a list containing all lines in a conll file. + :return: a 3D list + """ + sentences = list() + tokens = list() + for line in lines: + if line[0] == "#": + # skip the comments + continue + if line == "\n": + sentences.append(tokens) + tokens = [] + continue + tokens.append(line.split()) + return sentences + + def convert(self, data): + pass + + +class LMDataSetLoader(DataSetLoader): + """Language Model Dataset Loader + + This loader produces data for language model training in a supervised way. + That means it has X and Y. + + """ + + def __init__(self): + super(LMDataSetLoader, self).__init__() + + def load(self, data_path): + if not os.path.exists(data_path): + raise FileNotFoundError("file {} not found.".format(data_path)) + with open(data_path, "r", encoding="utf=8") as f: + text = " ".join(f.readlines()) + tokens = text.strip().split() + data = self.sentence_cut(tokens) + return self.convert(data) + + def sentence_cut(self, tokens, sentence_length=15): + start_idx = 0 + data_set = [] + for idx in range(len(tokens) // sentence_length): + x = tokens[start_idx * idx: start_idx * idx + sentence_length] + y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] + if start_idx * idx + sentence_length + 1 >= len(tokens): + # ad hoc + y.extend([""]) + data_set.append([x, y]) + return data_set + + def convert(self, data): + pass + + @DataSet.set_reader('read_people_daily') class PeopleDailyCorpusLoader(DataSetLoader): """