| @@ -1,5 +1,5 @@ | |||
| from .batch import Batch | |||
| from .dataset import DataSet | |||
| # from .dataset import DataSet | |||
| from .fieldarray import FieldArray | |||
| from .instance import Instance | |||
| from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
| @@ -8,4 +8,6 @@ from .optimizer import Optimizer, SGD, Adam | |||
| from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | |||
| from .tester import Tester | |||
| from .trainer import Trainer | |||
| from .vocabulary import Vocabulary | |||
| from .vocabulary import Vocabulary | |||
| from ..io.dataset_loader import DataSet | |||
| @@ -5,8 +5,7 @@ import numpy as np | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.utils import get_func_signature | |||
| _READERS = {} | |||
| from fastNLP.io.base_loader import DataLoaderRegister | |||
| class DataSet(object): | |||
| @@ -98,6 +97,24 @@ class DataSet(object): | |||
| else: | |||
| raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
| def __getattr__(self, item): | |||
| if item == "field_arrays": | |||
| raise AttributeError | |||
| # TODO dataset.x | |||
| if item in self.field_arrays: | |||
| return self.field_arrays[item] | |||
| try: | |||
| reader = DataLoaderRegister.get_reader(item) | |||
| return reader | |||
| except AttributeError: | |||
| raise | |||
| def __setstate__(self, state): | |||
| self.__dict__ = state | |||
| def __getstate__(self): | |||
| return self.__dict__ | |||
| def __len__(self): | |||
| """Fetch the length of the dataset. | |||
| @@ -226,16 +243,6 @@ class DataSet(object): | |||
| """ | |||
| return [name for name, field in self.field_arrays.items() if field.is_target] | |||
| @classmethod | |||
| def set_reader(cls, method_name): | |||
| assert isinstance(method_name, str) | |||
| def wrapper(read_cls): | |||
| _READERS[method_name] = read_cls | |||
| return read_cls | |||
| return wrapper | |||
| def apply(self, func, new_field_name=None, **kwargs): | |||
| """Apply a function to every instance of the DataSet. | |||
| @@ -347,6 +354,9 @@ class DataSet(object): | |||
| _dict[header].append(content) | |||
| return cls(_dict) | |||
| # def read_pos(self): | |||
| # return DataLoaderRegister.get_reader('read_pos') | |||
| def save(self, path): | |||
| """Save the DataSet object as pickle. | |||
| @@ -85,8 +85,8 @@ class Trainer(object): | |||
| if metric_key is not None: | |||
| self.increase_better = False if metric_key[0] == "-" else True | |||
| self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
| else: | |||
| self.metric_key = None | |||
| elif metrics is not None: | |||
| self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||
| # prepare loss | |||
| losser = _prepare_losser(loss) | |||
| @@ -147,7 +147,7 @@ class Trainer(object): | |||
| self._mode(self.model, is_test=False) | |||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | |||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||
| print("training epochs started " + self.start_time, flush=True) | |||
| if self.save_path is None: | |||
| class psudoSW: | |||
| @@ -260,7 +260,7 @@ class Trainer(object): | |||
| self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||
| global_step=self.step) | |||
| if self.save_path is not None and self._better_eval_result(res): | |||
| metric_key = self.metric_key if self.metric_key is not None else "None" | |||
| metric_key = self.metric_key if self.metric_key is not None else "" | |||
| self._save_model(self.model, | |||
| "best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | |||
| return res | |||
| @@ -29,3 +29,39 @@ class BaseLoader(object): | |||
| with open(cache_path, 'wb') as f: | |||
| pickle.dump(obj, f) | |||
| return obj | |||
| class ToyLoader0(BaseLoader): | |||
| """ | |||
| For CharLM | |||
| """ | |||
| def __init__(self, data_path): | |||
| super(ToyLoader0, self).__init__(data_path) | |||
| def load(self): | |||
| with open(self.data_path, 'r') as f: | |||
| corpus = f.read().lower() | |||
| import re | |||
| corpus = re.sub(r"<unk>", "unk", corpus) | |||
| return corpus.split() | |||
| class DataLoaderRegister: | |||
| """"register for data sets""" | |||
| _readers = {} | |||
| @classmethod | |||
| def set_reader(cls, reader_cls, read_fn_name): | |||
| # def wrapper(reader_cls): | |||
| if read_fn_name in cls._readers: | |||
| raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) | |||
| if hasattr(reader_cls, 'load'): | |||
| cls._readers[read_fn_name] = reader_cls().load | |||
| return reader_cls | |||
| @classmethod | |||
| def get_reader(cls, read_fn_name): | |||
| if read_fn_name in cls._readers: | |||
| return cls._readers[read_fn_name] | |||
| raise AttributeError('no read function: {}'.format(read_fn_name)) | |||
| @@ -2,7 +2,7 @@ import os | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.io.base_loader import BaseLoader | |||
| from fastNLP.io.base_loader import DataLoaderRegister | |||
| def convert_seq_dataset(data): | |||
| @@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data): | |||
| return dataset | |||
| class DataSetLoader(BaseLoader): | |||
| class DataSetLoader: | |||
| """"loader for data sets""" | |||
| def __init__(self): | |||
| super(DataSetLoader, self).__init__() | |||
| def load(self, path): | |||
| """ load data in `path` into a dataset | |||
| """ | |||
| @@ -104,9 +101,9 @@ class RawDataSetLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| return convert_seq_dataset(data) | |||
| DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||
| @DataSet.set_reader('read_pos') | |||
| class POSDataSetLoader(DataSetLoader): | |||
| """Dataset Loader for POS Tag datasets. | |||
| @@ -174,9 +171,9 @@ class POSDataSetLoader(DataSetLoader): | |||
| """Convert lists of strings into Instances with Fields. | |||
| """ | |||
| return convert_seq2seq_dataset(data) | |||
| DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||
| @DataSet.set_reader('read_tokenize') | |||
| class TokenizeDataSetLoader(DataSetLoader): | |||
| """ | |||
| Data set loader for tokenization data sets | |||
| @@ -236,7 +233,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
| return convert_seq2seq_dataset(data) | |||
| @DataSet.set_reader('read_class') | |||
| class ClassDataSetLoader(DataSetLoader): | |||
| """Loader for classification data sets""" | |||
| @@ -275,6 +271,83 @@ class ClassDataSetLoader(DataSetLoader): | |||
| return convert_seq2tag_dataset(data) | |||
| class ConllLoader(DataSetLoader): | |||
| """loader for conll format files""" | |||
| def __init__(self): | |||
| """ | |||
| :param str data_path: the path to the conll data set | |||
| """ | |||
| super(ConllLoader, self).__init__() | |||
| def load(self, data_path): | |||
| """ | |||
| :return: list lines: all lines in a conll file | |||
| """ | |||
| with open(data_path, "r", encoding="utf-8") as f: | |||
| lines = f.readlines() | |||
| data = self.parse(lines) | |||
| return self.convert(data) | |||
| @staticmethod | |||
| def parse(lines): | |||
| """ | |||
| :param list lines:a list containing all lines in a conll file. | |||
| :return: a 3D list | |||
| """ | |||
| sentences = list() | |||
| tokens = list() | |||
| for line in lines: | |||
| if line[0] == "#": | |||
| # skip the comments | |||
| continue | |||
| if line == "\n": | |||
| sentences.append(tokens) | |||
| tokens = [] | |||
| continue | |||
| tokens.append(line.split()) | |||
| return sentences | |||
| def convert(self, data): | |||
| pass | |||
| class LMDataSetLoader(DataSetLoader): | |||
| """Language Model Dataset Loader | |||
| This loader produces data for language model training in a supervised way. | |||
| That means it has X and Y. | |||
| """ | |||
| def __init__(self): | |||
| super(LMDataSetLoader, self).__init__() | |||
| def load(self, data_path): | |||
| if not os.path.exists(data_path): | |||
| raise FileNotFoundError("file {} not found.".format(data_path)) | |||
| with open(data_path, "r", encoding="utf=8") as f: | |||
| text = " ".join(f.readlines()) | |||
| tokens = text.strip().split() | |||
| data = self.sentence_cut(tokens) | |||
| return self.convert(data) | |||
| def sentence_cut(self, tokens, sentence_length=15): | |||
| start_idx = 0 | |||
| data_set = [] | |||
| for idx in range(len(tokens) // sentence_length): | |||
| x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||
| y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||
| if start_idx * idx + sentence_length + 1 >= len(tokens): | |||
| # ad hoc | |||
| y.extend(["<unk>"]) | |||
| data_set.append([x, y]) | |||
| return data_set | |||
| def convert(self, data): | |||
| pass | |||
| @DataSet.set_reader('read_people_daily') | |||
| class PeopleDailyCorpusLoader(DataSetLoader): | |||
| """ | |||