@@ -1,5 +1,5 @@ | |||||
from .batch import Batch | from .batch import Batch | ||||
from .dataset import DataSet | |||||
# from .dataset import DataSet | |||||
from .fieldarray import FieldArray | from .fieldarray import FieldArray | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
@@ -8,4 +8,6 @@ from .optimizer import Optimizer, SGD, Adam | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | ||||
from .tester import Tester | from .tester import Tester | ||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .vocabulary import Vocabulary | |||||
from .vocabulary import Vocabulary | |||||
from ..io.dataset_loader import DataSet | |||||
@@ -5,8 +5,7 @@ import numpy as np | |||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
_READERS = {} | |||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
class DataSet(object): | class DataSet(object): | ||||
@@ -98,6 +97,24 @@ class DataSet(object): | |||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
def __getattr__(self, item): | |||||
if item == "field_arrays": | |||||
raise AttributeError | |||||
# TODO dataset.x | |||||
if item in self.field_arrays: | |||||
return self.field_arrays[item] | |||||
try: | |||||
reader = DataLoaderRegister.get_reader(item) | |||||
return reader | |||||
except AttributeError: | |||||
raise | |||||
def __setstate__(self, state): | |||||
self.__dict__ = state | |||||
def __getstate__(self): | |||||
return self.__dict__ | |||||
def __len__(self): | def __len__(self): | ||||
"""Fetch the length of the dataset. | """Fetch the length of the dataset. | ||||
@@ -226,16 +243,6 @@ class DataSet(object): | |||||
""" | """ | ||||
return [name for name, field in self.field_arrays.items() if field.is_target] | return [name for name, field in self.field_arrays.items() if field.is_target] | ||||
@classmethod | |||||
def set_reader(cls, method_name): | |||||
assert isinstance(method_name, str) | |||||
def wrapper(read_cls): | |||||
_READERS[method_name] = read_cls | |||||
return read_cls | |||||
return wrapper | |||||
def apply(self, func, new_field_name=None, **kwargs): | def apply(self, func, new_field_name=None, **kwargs): | ||||
"""Apply a function to every instance of the DataSet. | """Apply a function to every instance of the DataSet. | ||||
@@ -347,6 +354,9 @@ class DataSet(object): | |||||
_dict[header].append(content) | _dict[header].append(content) | ||||
return cls(_dict) | return cls(_dict) | ||||
# def read_pos(self): | |||||
# return DataLoaderRegister.get_reader('read_pos') | |||||
def save(self, path): | def save(self, path): | ||||
"""Save the DataSet object as pickle. | """Save the DataSet object as pickle. | ||||
@@ -85,8 +85,8 @@ class Trainer(object): | |||||
if metric_key is not None: | if metric_key is not None: | ||||
self.increase_better = False if metric_key[0] == "-" else True | self.increase_better = False if metric_key[0] == "-" else True | ||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | ||||
else: | |||||
self.metric_key = None | |||||
elif metrics is not None: | |||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
@@ -147,7 +147,7 @@ class Trainer(object): | |||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||||
print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
if self.save_path is None: | if self.save_path is None: | ||||
class psudoSW: | class psudoSW: | ||||
@@ -260,7 +260,7 @@ class Trainer(object): | |||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | ||||
global_step=self.step) | global_step=self.step) | ||||
if self.save_path is not None and self._better_eval_result(res): | if self.save_path is not None and self._better_eval_result(res): | ||||
metric_key = self.metric_key if self.metric_key is not None else "None" | |||||
metric_key = self.metric_key if self.metric_key is not None else "" | |||||
self._save_model(self.model, | self._save_model(self.model, | ||||
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | "best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | ||||
return res | return res | ||||
@@ -29,3 +29,39 @@ class BaseLoader(object): | |||||
with open(cache_path, 'wb') as f: | with open(cache_path, 'wb') as f: | ||||
pickle.dump(obj, f) | pickle.dump(obj, f) | ||||
return obj | return obj | ||||
class ToyLoader0(BaseLoader): | |||||
""" | |||||
For CharLM | |||||
""" | |||||
def __init__(self, data_path): | |||||
super(ToyLoader0, self).__init__(data_path) | |||||
def load(self): | |||||
with open(self.data_path, 'r') as f: | |||||
corpus = f.read().lower() | |||||
import re | |||||
corpus = re.sub(r"<unk>", "unk", corpus) | |||||
return corpus.split() | |||||
class DataLoaderRegister: | |||||
""""register for data sets""" | |||||
_readers = {} | |||||
@classmethod | |||||
def set_reader(cls, reader_cls, read_fn_name): | |||||
# def wrapper(reader_cls): | |||||
if read_fn_name in cls._readers: | |||||
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) | |||||
if hasattr(reader_cls, 'load'): | |||||
cls._readers[read_fn_name] = reader_cls().load | |||||
return reader_cls | |||||
@classmethod | |||||
def get_reader(cls, read_fn_name): | |||||
if read_fn_name in cls._readers: | |||||
return cls._readers[read_fn_name] | |||||
raise AttributeError('no read function: {}'.format(read_fn_name)) |
@@ -2,7 +2,7 @@ import os | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.io.base_loader import BaseLoader | |||||
from fastNLP.io.base_loader import DataLoaderRegister | |||||
def convert_seq_dataset(data): | def convert_seq_dataset(data): | ||||
@@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data): | |||||
return dataset | return dataset | ||||
class DataSetLoader(BaseLoader): | |||||
class DataSetLoader: | |||||
""""loader for data sets""" | """"loader for data sets""" | ||||
def __init__(self): | |||||
super(DataSetLoader, self).__init__() | |||||
def load(self, path): | def load(self, path): | ||||
""" load data in `path` into a dataset | """ load data in `path` into a dataset | ||||
""" | """ | ||||
@@ -104,9 +101,9 @@ class RawDataSetLoader(DataSetLoader): | |||||
def convert(self, data): | def convert(self, data): | ||||
return convert_seq_dataset(data) | return convert_seq_dataset(data) | ||||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||||
@DataSet.set_reader('read_pos') | |||||
class POSDataSetLoader(DataSetLoader): | class POSDataSetLoader(DataSetLoader): | ||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
@@ -174,9 +171,9 @@ class POSDataSetLoader(DataSetLoader): | |||||
"""Convert lists of strings into Instances with Fields. | """Convert lists of strings into Instances with Fields. | ||||
""" | """ | ||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||||
@DataSet.set_reader('read_tokenize') | |||||
class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
""" | """ | ||||
Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
@@ -236,7 +233,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
@DataSet.set_reader('read_class') | |||||
class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
@@ -275,6 +271,83 @@ class ClassDataSetLoader(DataSetLoader): | |||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
class ConllLoader(DataSetLoader): | |||||
"""loader for conll format files""" | |||||
def __init__(self): | |||||
""" | |||||
:param str data_path: the path to the conll data set | |||||
""" | |||||
super(ConllLoader, self).__init__() | |||||
def load(self, data_path): | |||||
""" | |||||
:return: list lines: all lines in a conll file | |||||
""" | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
@staticmethod | |||||
def parse(lines): | |||||
""" | |||||
:param list lines:a list containing all lines in a conll file. | |||||
:return: a 3D list | |||||
""" | |||||
sentences = list() | |||||
tokens = list() | |||||
for line in lines: | |||||
if line[0] == "#": | |||||
# skip the comments | |||||
continue | |||||
if line == "\n": | |||||
sentences.append(tokens) | |||||
tokens = [] | |||||
continue | |||||
tokens.append(line.split()) | |||||
return sentences | |||||
def convert(self, data): | |||||
pass | |||||
class LMDataSetLoader(DataSetLoader): | |||||
"""Language Model Dataset Loader | |||||
This loader produces data for language model training in a supervised way. | |||||
That means it has X and Y. | |||||
""" | |||||
def __init__(self): | |||||
super(LMDataSetLoader, self).__init__() | |||||
def load(self, data_path): | |||||
if not os.path.exists(data_path): | |||||
raise FileNotFoundError("file {} not found.".format(data_path)) | |||||
with open(data_path, "r", encoding="utf=8") as f: | |||||
text = " ".join(f.readlines()) | |||||
tokens = text.strip().split() | |||||
data = self.sentence_cut(tokens) | |||||
return self.convert(data) | |||||
def sentence_cut(self, tokens, sentence_length=15): | |||||
start_idx = 0 | |||||
data_set = [] | |||||
for idx in range(len(tokens) // sentence_length): | |||||
x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||||
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||||
if start_idx * idx + sentence_length + 1 >= len(tokens): | |||||
# ad hoc | |||||
y.extend(["<unk>"]) | |||||
data_set.append([x, y]) | |||||
return data_set | |||||
def convert(self, data): | |||||
pass | |||||
@DataSet.set_reader('read_people_daily') | @DataSet.set_reader('read_people_daily') | ||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||