| @@ -1,5 +1,5 @@ | |||||
| from .batch import Batch | from .batch import Batch | ||||
| from .dataset import DataSet | |||||
| # from .dataset import DataSet | |||||
| from .fieldarray import FieldArray | from .fieldarray import FieldArray | ||||
| from .instance import Instance | from .instance import Instance | ||||
| from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
| @@ -8,4 +8,6 @@ from .optimizer import Optimizer, SGD, Adam | |||||
| from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | ||||
| from .tester import Tester | from .tester import Tester | ||||
| from .trainer import Trainer | from .trainer import Trainer | ||||
| from .vocabulary import Vocabulary | |||||
| from .vocabulary import Vocabulary | |||||
| from ..io.dataset_loader import DataSet | |||||
| @@ -5,8 +5,7 @@ import numpy as np | |||||
| from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
| _READERS = {} | |||||
| from fastNLP.io.base_loader import DataLoaderRegister | |||||
| class DataSet(object): | class DataSet(object): | ||||
| @@ -98,6 +97,24 @@ class DataSet(object): | |||||
| else: | else: | ||||
| raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
| def __getattr__(self, item): | |||||
| if item == "field_arrays": | |||||
| raise AttributeError | |||||
| # TODO dataset.x | |||||
| if item in self.field_arrays: | |||||
| return self.field_arrays[item] | |||||
| try: | |||||
| reader = DataLoaderRegister.get_reader(item) | |||||
| return reader | |||||
| except AttributeError: | |||||
| raise | |||||
| def __setstate__(self, state): | |||||
| self.__dict__ = state | |||||
| def __getstate__(self): | |||||
| return self.__dict__ | |||||
| def __len__(self): | def __len__(self): | ||||
| """Fetch the length of the dataset. | """Fetch the length of the dataset. | ||||
| @@ -226,16 +243,6 @@ class DataSet(object): | |||||
| """ | """ | ||||
| return [name for name, field in self.field_arrays.items() if field.is_target] | return [name for name, field in self.field_arrays.items() if field.is_target] | ||||
| @classmethod | |||||
| def set_reader(cls, method_name): | |||||
| assert isinstance(method_name, str) | |||||
| def wrapper(read_cls): | |||||
| _READERS[method_name] = read_cls | |||||
| return read_cls | |||||
| return wrapper | |||||
| def apply(self, func, new_field_name=None, **kwargs): | def apply(self, func, new_field_name=None, **kwargs): | ||||
| """Apply a function to every instance of the DataSet. | """Apply a function to every instance of the DataSet. | ||||
| @@ -347,6 +354,9 @@ class DataSet(object): | |||||
| _dict[header].append(content) | _dict[header].append(content) | ||||
| return cls(_dict) | return cls(_dict) | ||||
| # def read_pos(self): | |||||
| # return DataLoaderRegister.get_reader('read_pos') | |||||
| def save(self, path): | def save(self, path): | ||||
| """Save the DataSet object as pickle. | """Save the DataSet object as pickle. | ||||
| @@ -85,8 +85,8 @@ class Trainer(object): | |||||
| if metric_key is not None: | if metric_key is not None: | ||||
| self.increase_better = False if metric_key[0] == "-" else True | self.increase_better = False if metric_key[0] == "-" else True | ||||
| self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | ||||
| else: | |||||
| self.metric_key = None | |||||
| elif metrics is not None: | |||||
| self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||||
| # prepare loss | # prepare loss | ||||
| losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
| @@ -147,7 +147,7 @@ class Trainer(object): | |||||
| self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | |||||
| self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||||
| print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
| if self.save_path is None: | if self.save_path is None: | ||||
| class psudoSW: | class psudoSW: | ||||
| @@ -260,7 +260,7 @@ class Trainer(object): | |||||
| self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | ||||
| global_step=self.step) | global_step=self.step) | ||||
| if self.save_path is not None and self._better_eval_result(res): | if self.save_path is not None and self._better_eval_result(res): | ||||
| metric_key = self.metric_key if self.metric_key is not None else "None" | |||||
| metric_key = self.metric_key if self.metric_key is not None else "" | |||||
| self._save_model(self.model, | self._save_model(self.model, | ||||
| "best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | "best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | ||||
| return res | return res | ||||
| @@ -29,3 +29,39 @@ class BaseLoader(object): | |||||
| with open(cache_path, 'wb') as f: | with open(cache_path, 'wb') as f: | ||||
| pickle.dump(obj, f) | pickle.dump(obj, f) | ||||
| return obj | return obj | ||||
| class ToyLoader0(BaseLoader): | |||||
| """ | |||||
| For CharLM | |||||
| """ | |||||
| def __init__(self, data_path): | |||||
| super(ToyLoader0, self).__init__(data_path) | |||||
| def load(self): | |||||
| with open(self.data_path, 'r') as f: | |||||
| corpus = f.read().lower() | |||||
| import re | |||||
| corpus = re.sub(r"<unk>", "unk", corpus) | |||||
| return corpus.split() | |||||
| class DataLoaderRegister: | |||||
| """"register for data sets""" | |||||
| _readers = {} | |||||
| @classmethod | |||||
| def set_reader(cls, reader_cls, read_fn_name): | |||||
| # def wrapper(reader_cls): | |||||
| if read_fn_name in cls._readers: | |||||
| raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) | |||||
| if hasattr(reader_cls, 'load'): | |||||
| cls._readers[read_fn_name] = reader_cls().load | |||||
| return reader_cls | |||||
| @classmethod | |||||
| def get_reader(cls, read_fn_name): | |||||
| if read_fn_name in cls._readers: | |||||
| return cls._readers[read_fn_name] | |||||
| raise AttributeError('no read function: {}'.format(read_fn_name)) | |||||
| @@ -2,7 +2,7 @@ import os | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.io.base_loader import BaseLoader | |||||
| from fastNLP.io.base_loader import DataLoaderRegister | |||||
| def convert_seq_dataset(data): | def convert_seq_dataset(data): | ||||
| @@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data): | |||||
| return dataset | return dataset | ||||
| class DataSetLoader(BaseLoader): | |||||
| class DataSetLoader: | |||||
| """"loader for data sets""" | """"loader for data sets""" | ||||
| def __init__(self): | |||||
| super(DataSetLoader, self).__init__() | |||||
| def load(self, path): | def load(self, path): | ||||
| """ load data in `path` into a dataset | """ load data in `path` into a dataset | ||||
| """ | """ | ||||
| @@ -104,9 +101,9 @@ class RawDataSetLoader(DataSetLoader): | |||||
| def convert(self, data): | def convert(self, data): | ||||
| return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
| DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||||
| @DataSet.set_reader('read_pos') | |||||
| class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
| """Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
| @@ -174,9 +171,9 @@ class POSDataSetLoader(DataSetLoader): | |||||
| """Convert lists of strings into Instances with Fields. | """Convert lists of strings into Instances with Fields. | ||||
| """ | """ | ||||
| return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
| DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||||
| @DataSet.set_reader('read_tokenize') | |||||
| class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
| """ | """ | ||||
| Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
| @@ -236,7 +233,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
| return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
| @DataSet.set_reader('read_class') | |||||
| class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
| """Loader for classification data sets""" | """Loader for classification data sets""" | ||||
| @@ -275,6 +271,83 @@ class ClassDataSetLoader(DataSetLoader): | |||||
| return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
| class ConllLoader(DataSetLoader): | |||||
| """loader for conll format files""" | |||||
| def __init__(self): | |||||
| """ | |||||
| :param str data_path: the path to the conll data set | |||||
| """ | |||||
| super(ConllLoader, self).__init__() | |||||
| def load(self, data_path): | |||||
| """ | |||||
| :return: list lines: all lines in a conll file | |||||
| """ | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| lines = f.readlines() | |||||
| data = self.parse(lines) | |||||
| return self.convert(data) | |||||
| @staticmethod | |||||
| def parse(lines): | |||||
| """ | |||||
| :param list lines:a list containing all lines in a conll file. | |||||
| :return: a 3D list | |||||
| """ | |||||
| sentences = list() | |||||
| tokens = list() | |||||
| for line in lines: | |||||
| if line[0] == "#": | |||||
| # skip the comments | |||||
| continue | |||||
| if line == "\n": | |||||
| sentences.append(tokens) | |||||
| tokens = [] | |||||
| continue | |||||
| tokens.append(line.split()) | |||||
| return sentences | |||||
| def convert(self, data): | |||||
| pass | |||||
| class LMDataSetLoader(DataSetLoader): | |||||
| """Language Model Dataset Loader | |||||
| This loader produces data for language model training in a supervised way. | |||||
| That means it has X and Y. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LMDataSetLoader, self).__init__() | |||||
| def load(self, data_path): | |||||
| if not os.path.exists(data_path): | |||||
| raise FileNotFoundError("file {} not found.".format(data_path)) | |||||
| with open(data_path, "r", encoding="utf=8") as f: | |||||
| text = " ".join(f.readlines()) | |||||
| tokens = text.strip().split() | |||||
| data = self.sentence_cut(tokens) | |||||
| return self.convert(data) | |||||
| def sentence_cut(self, tokens, sentence_length=15): | |||||
| start_idx = 0 | |||||
| data_set = [] | |||||
| for idx in range(len(tokens) // sentence_length): | |||||
| x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||||
| y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||||
| if start_idx * idx + sentence_length + 1 >= len(tokens): | |||||
| # ad hoc | |||||
| y.extend(["<unk>"]) | |||||
| data_set.append([x, y]) | |||||
| return data_set | |||||
| def convert(self, data): | |||||
| pass | |||||
| @DataSet.set_reader('read_people_daily') | @DataSet.set_reader('read_people_daily') | ||||
| class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
| """ | """ | ||||