@@ -1,5 +1,5 @@ | |||
from .batch import Batch | |||
from .dataset import DataSet | |||
# from .dataset import DataSet | |||
from .fieldarray import FieldArray | |||
from .instance import Instance | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
@@ -8,4 +8,6 @@ from .optimizer import Optimizer, SGD, Adam | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | |||
from .tester import Tester | |||
from .trainer import Trainer | |||
from .vocabulary import Vocabulary | |||
from .vocabulary import Vocabulary | |||
from ..io.dataset_loader import DataSet | |||
@@ -5,8 +5,7 @@ import numpy as np | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.utils import get_func_signature | |||
_READERS = {} | |||
from fastNLP.io.base_loader import DataLoaderRegister | |||
class DataSet(object): | |||
@@ -98,6 +97,24 @@ class DataSet(object): | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __getattr__(self, item): | |||
if item == "field_arrays": | |||
raise AttributeError | |||
# TODO dataset.x | |||
if item in self.field_arrays: | |||
return self.field_arrays[item] | |||
try: | |||
reader = DataLoaderRegister.get_reader(item) | |||
return reader | |||
except AttributeError: | |||
raise | |||
def __setstate__(self, state): | |||
self.__dict__ = state | |||
def __getstate__(self): | |||
return self.__dict__ | |||
def __len__(self): | |||
"""Fetch the length of the dataset. | |||
@@ -226,16 +243,6 @@ class DataSet(object): | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||
@classmethod | |||
def set_reader(cls, method_name): | |||
assert isinstance(method_name, str) | |||
def wrapper(read_cls): | |||
_READERS[method_name] = read_cls | |||
return read_cls | |||
return wrapper | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
"""Apply a function to every instance of the DataSet. | |||
@@ -347,6 +354,9 @@ class DataSet(object): | |||
_dict[header].append(content) | |||
return cls(_dict) | |||
# def read_pos(self): | |||
# return DataLoaderRegister.get_reader('read_pos') | |||
def save(self, path): | |||
"""Save the DataSet object as pickle. | |||
@@ -85,8 +85,8 @@ class Trainer(object): | |||
if metric_key is not None: | |||
self.increase_better = False if metric_key[0] == "-" else True | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
else: | |||
self.metric_key = None | |||
elif metrics is not None: | |||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||
# prepare loss | |||
losser = _prepare_losser(loss) | |||
@@ -147,7 +147,7 @@ class Trainer(object): | |||
self._mode(self.model, is_test=False) | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||
print("training epochs started " + self.start_time, flush=True) | |||
if self.save_path is None: | |||
class psudoSW: | |||
@@ -260,7 +260,7 @@ class Trainer(object): | |||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||
global_step=self.step) | |||
if self.save_path is not None and self._better_eval_result(res): | |||
metric_key = self.metric_key if self.metric_key is not None else "None" | |||
metric_key = self.metric_key if self.metric_key is not None else "" | |||
self._save_model(self.model, | |||
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time])) | |||
return res | |||
@@ -29,3 +29,39 @@ class BaseLoader(object): | |||
with open(cache_path, 'wb') as f: | |||
pickle.dump(obj, f) | |||
return obj | |||
class ToyLoader0(BaseLoader): | |||
""" | |||
For CharLM | |||
""" | |||
def __init__(self, data_path): | |||
super(ToyLoader0, self).__init__(data_path) | |||
def load(self): | |||
with open(self.data_path, 'r') as f: | |||
corpus = f.read().lower() | |||
import re | |||
corpus = re.sub(r"<unk>", "unk", corpus) | |||
return corpus.split() | |||
class DataLoaderRegister: | |||
""""register for data sets""" | |||
_readers = {} | |||
@classmethod | |||
def set_reader(cls, reader_cls, read_fn_name): | |||
# def wrapper(reader_cls): | |||
if read_fn_name in cls._readers: | |||
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name)) | |||
if hasattr(reader_cls, 'load'): | |||
cls._readers[read_fn_name] = reader_cls().load | |||
return reader_cls | |||
@classmethod | |||
def get_reader(cls, read_fn_name): | |||
if read_fn_name in cls._readers: | |||
return cls._readers[read_fn_name] | |||
raise AttributeError('no read function: {}'.format(read_fn_name)) |
@@ -2,7 +2,7 @@ import os | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.io.base_loader import BaseLoader | |||
from fastNLP.io.base_loader import DataLoaderRegister | |||
def convert_seq_dataset(data): | |||
@@ -61,12 +61,9 @@ def convert_seq2seq_dataset(data): | |||
return dataset | |||
class DataSetLoader(BaseLoader): | |||
class DataSetLoader: | |||
""""loader for data sets""" | |||
def __init__(self): | |||
super(DataSetLoader, self).__init__() | |||
def load(self, path): | |||
""" load data in `path` into a dataset | |||
""" | |||
@@ -104,9 +101,9 @@ class RawDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq_dataset(data) | |||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||
@DataSet.set_reader('read_pos') | |||
class POSDataSetLoader(DataSetLoader): | |||
"""Dataset Loader for POS Tag datasets. | |||
@@ -174,9 +171,9 @@ class POSDataSetLoader(DataSetLoader): | |||
"""Convert lists of strings into Instances with Fields. | |||
""" | |||
return convert_seq2seq_dataset(data) | |||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||
@DataSet.set_reader('read_tokenize') | |||
class TokenizeDataSetLoader(DataSetLoader): | |||
""" | |||
Data set loader for tokenization data sets | |||
@@ -236,7 +233,6 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
return convert_seq2seq_dataset(data) | |||
@DataSet.set_reader('read_class') | |||
class ClassDataSetLoader(DataSetLoader): | |||
"""Loader for classification data sets""" | |||
@@ -275,6 +271,83 @@ class ClassDataSetLoader(DataSetLoader): | |||
return convert_seq2tag_dataset(data) | |||
class ConllLoader(DataSetLoader): | |||
"""loader for conll format files""" | |||
def __init__(self): | |||
""" | |||
:param str data_path: the path to the conll data set | |||
""" | |||
super(ConllLoader, self).__init__() | |||
def load(self, data_path): | |||
""" | |||
:return: list lines: all lines in a conll file | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
data = self.parse(lines) | |||
return self.convert(data) | |||
@staticmethod | |||
def parse(lines): | |||
""" | |||
:param list lines:a list containing all lines in a conll file. | |||
:return: a 3D list | |||
""" | |||
sentences = list() | |||
tokens = list() | |||
for line in lines: | |||
if line[0] == "#": | |||
# skip the comments | |||
continue | |||
if line == "\n": | |||
sentences.append(tokens) | |||
tokens = [] | |||
continue | |||
tokens.append(line.split()) | |||
return sentences | |||
def convert(self, data): | |||
pass | |||
class LMDataSetLoader(DataSetLoader): | |||
"""Language Model Dataset Loader | |||
This loader produces data for language model training in a supervised way. | |||
That means it has X and Y. | |||
""" | |||
def __init__(self): | |||
super(LMDataSetLoader, self).__init__() | |||
def load(self, data_path): | |||
if not os.path.exists(data_path): | |||
raise FileNotFoundError("file {} not found.".format(data_path)) | |||
with open(data_path, "r", encoding="utf=8") as f: | |||
text = " ".join(f.readlines()) | |||
tokens = text.strip().split() | |||
data = self.sentence_cut(tokens) | |||
return self.convert(data) | |||
def sentence_cut(self, tokens, sentence_length=15): | |||
start_idx = 0 | |||
data_set = [] | |||
for idx in range(len(tokens) // sentence_length): | |||
x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||
if start_idx * idx + sentence_length + 1 >= len(tokens): | |||
# ad hoc | |||
y.extend(["<unk>"]) | |||
data_set.append([x, y]) | |||
return data_set | |||
def convert(self, data): | |||
pass | |||
@DataSet.set_reader('read_people_daily') | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
""" | |||