diff --git a/fastNLP/core/action.py b/fastNLP/core/action.py index ef595cbb..d36d350d 100644 --- a/fastNLP/core/action.py +++ b/fastNLP/core/action.py @@ -168,19 +168,7 @@ class BaseSampler(object): """ - def __init__(self, data_set): - """ - - :param data_set: multi-level list, of shape [num_example, *] - - """ - self.data_set_length = len(data_set) - self.data = data_set - - def __len__(self): - return self.data_set_length - - def __iter__(self): + def __call__(self, *args, **kwargs): raise NotImplementedError @@ -189,16 +177,8 @@ class SequentialSampler(BaseSampler): """ - def __init__(self, data_set): - """ - - :param data_set: multi-level list - - """ - super(SequentialSampler, self).__init__(data_set) - - def __iter__(self): - return iter(self.data) + def __call__(self, data_set): + return list(range(len(data_set))) class RandomSampler(BaseSampler): @@ -206,17 +186,9 @@ class RandomSampler(BaseSampler): """ - def __init__(self, data_set): - """ + def __call__(self, data_set): + return list(np.random.permutation(len(data_set))) - :param data_set: multi-level list - - """ - super(RandomSampler, self).__init__(data_set) - self.order = np.random.permutation(self.data_set_length) - - def __iter__(self): - return iter((self.data[idx] for idx in self.order)) class Batchifier(object): diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py new file mode 100644 index 00000000..0a5e9712 --- /dev/null +++ b/fastNLP/core/batch.py @@ -0,0 +1,126 @@ +from collections import defaultdict + +import torch + +from fastNLP.core.dataset import DataSet +from fastNLP.core.field import TextField, LabelField +from fastNLP.core.instance import Instance + + +class Batch(object): + """Batch is an iterable object which iterates over mini-batches. + + :: + for batch_x, batch_y in Batch(data_set): + + """ + + def __init__(self, dataset, batch_size, sampler, use_cuda): + self.dataset = dataset + self.batch_size = batch_size + self.sampler = sampler + self.use_cuda = use_cuda + self.idx_list = None + self.curidx = 0 + + def __iter__(self): + self.idx_list = self.sampler(self.dataset) + self.curidx = 0 + self.lengths = self.dataset.get_length() + return self + + def __next__(self): + """ + + :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) + batch_x also contains an item (str: list of int) about origin lengths, + which means ("field_name_origin_len": origin lengths). + E.g. + :: + {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) + + batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) + All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. + The names of fields are defined in preprocessor's convert_to_dataset method. + + """ + if self.curidx >= len(self.idx_list): + raise StopIteration + else: + endidx = min(self.curidx + self.batch_size, len(self.idx_list)) + padding_length = {field_name: max(field_length[self.curidx: endidx]) + for field_name, field_length in self.lengths.items()} + origin_lengths = {field_name: field_length[self.curidx: endidx] + for field_name, field_length in self.lengths.items()} + + batch_x, batch_y = defaultdict(list), defaultdict(list) + for idx in range(self.curidx, endidx): + x, y = self.dataset.to_tensor(idx, padding_length) + for name, tensor in x.items(): + batch_x[name].append(tensor) + for name, tensor in y.items(): + batch_y[name].append(tensor) + + batch_origin_length = {} + # combine instances into a batch + for batch in (batch_x, batch_y): + for name, tensor_list in batch.items(): + if self.use_cuda: + batch[name] = torch.stack(tensor_list, dim=0).cuda() + else: + batch[name] = torch.stack(tensor_list, dim=0) + + # add origin lengths in batch_x + for name, tensor in batch_x.items(): + if self.use_cuda: + batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda() + else: + batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]) + batch_x.update(batch_origin_length) + + self.curidx += endidx + return batch_x, batch_y + + +if __name__ == "__main__": + """simple running example + """ + texts = ["i am a cat", + "this is a test of new batch", + "haha" + ] + labels = [0, 1, 0] + + # prepare vocabulary + vocab = {} + for text in texts: + for tokens in text.split(): + if tokens not in vocab: + vocab[tokens] = len(vocab) + print("vocabulary: ", vocab) + + # prepare input dataset + data = DataSet() + for text, label in zip(texts, labels): + x = TextField(text.split(), False) + y = LabelField(label, is_target=True) + ins = Instance(text=x, label=y) + data.append(ins) + + # use vocabulary to index data + data.index_field("text", vocab) + + + # define naive sampler for batch class + class SeqSampler: + def __call__(self, dataset): + return list(range(len(dataset))) + + + # use batch to iterate dataset + data_iterator = Batch(data, 2, SeqSampler(), False) + for epoch in range(1): + for batch_x, batch_y in data_iterator: + print(batch_x) + print(batch_y) + # do stuff diff --git a/fastNLP/data/dataset.py b/fastNLP/core/dataset.py similarity index 55% rename from fastNLP/data/dataset.py rename to fastNLP/core/dataset.py index ffe75494..5f749795 100644 --- a/fastNLP/data/dataset.py +++ b/fastNLP/core/dataset.py @@ -7,23 +7,36 @@ class DataSet(list): self.name = name if instances is not None: self.extend(instances) - + def index_all(self, vocab): for ins in self: ins.index_all(vocab) - + def index_field(self, field_name, vocab): for ins in self: ins.index_field(field_name, vocab) def to_tensor(self, idx: int, padding_length: dict): + """Convert an instance in a dataset to tensor. + + :param idx: int, the index of the instance in the dataset. + :param padding_length: int + :return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) + tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) + + """ ins = self[idx] return ins.to_tensor(padding_length) - + def get_length(self): + """Fetch lengths of all fields in all instances in a dataset. + + :return lengths: dict of (str: list). The str is the field name. + The list contains lengths of this field in all instances. + + """ lengths = defaultdict(list) for ins in self: for field_name, field_length in ins.get_length().items(): lengths[field_name].append(field_length) return lengths - diff --git a/fastNLP/data/field.py b/fastNLP/core/field.py similarity index 65% rename from fastNLP/data/field.py rename to fastNLP/core/field.py index ada90857..eb2bc78e 100644 --- a/fastNLP/data/field.py +++ b/fastNLP/core/field.py @@ -1,18 +1,23 @@ import torch + class Field(object): + """A field defines a data type. + + """ + def __init__(self, is_target: bool): self.is_target = is_target def index(self, vocab): - pass - + raise NotImplementedError + def get_length(self): - pass + raise NotImplementedError def to_tensor(self, padding_length): - pass - + raise NotImplementedError + class TextField(Field): def __init__(self, text: list, is_target): @@ -31,25 +36,38 @@ class TextField(Field): return self._index def get_length(self): + """Fetch the length of the text field. + + :return length: int, the length of the text. + + """ return len(self.text) def to_tensor(self, padding_length: int): + """Convert text field to tensor. + + :param padding_length: int + :return tensor: torch.LongTensor, of shape [padding_length, ] + """ pads = [] if self._index is None: - print('error') + raise RuntimeError("Indexing not done before to_tensor in TextField.") if padding_length > self.get_length(): - pads = [0 for i in range(padding_length - self.get_length())] - # (length, ) + pads = [0] * (padding_length - self.get_length()) return torch.LongTensor(self._index + pads) - + class LabelField(Field): def __init__(self, label, is_target=True): super(LabelField, self).__init__(is_target) self.label = label self._index = None - + def get_length(self): + """Fetch the length of the label field. + + :return length: int, the length of the label, always 1. + """ return 1 def index(self, vocab): @@ -58,13 +76,13 @@ class LabelField(Field): else: pass return self._index - + def to_tensor(self, padding_length): if self._index is None: return torch.LongTensor([self.label]) else: return torch.LongTensor([self._index]) + if __name__ == "__main__": - tf = TextField("test the code".split()) - + tf = TextField("test the code".split(), is_target=False) diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py new file mode 100644 index 00000000..3322e576 --- /dev/null +++ b/fastNLP/core/instance.py @@ -0,0 +1,53 @@ +class Instance(object): + """An instance which consists of Fields is an example in the DataSet. + + """ + + def __init__(self, **fields): + self.fields = fields + self.has_index = False + self.indexes = {} + + def add_field(self, field_name, field): + self.fields[field_name] = field + + def get_length(self): + """Fetch the length of all fields in the instance. + + :return length: dict of (str: int), which means (field name: field length). + + """ + length = {name: field.get_length() for name, field in self.fields.items()} + return length + + def index_field(self, field_name, vocab): + """use `vocab` to index certain field + """ + self.indexes[field_name] = self.fields[field_name].index(vocab) + + def index_all(self, vocab): + """use `vocab` to index all fields + """ + if self.has_index: + print("error") + return self.indexes + indexes = {name: field.index(vocab) for name, field in self.fields.items()} + self.indexes = indexes + return indexes + + def to_tensor(self, padding_length: dict): + """Convert instance to tensor. + + :param padding_length: dict of (str: int), which means (field name: padding_length of this field) + :return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) + tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) + + """ + tensor_x = {} + tensor_y = {} + for name, field in self.fields.items(): + if field.is_target: + tensor_y[name] = field.to_tensor(padding_length[name]) + else: + tensor_x[name] = field.to_tensor(padding_length[name]) + return tensor_x, tensor_y diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index f8142c36..b7c33f3b 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -3,6 +3,10 @@ import os import numpy as np +from fastNLP.core.dataset import DataSet +from fastNLP.core.field import TextField, LabelField +from fastNLP.core.instance import Instance + DEFAULT_PADDING_LABEL = '' # dict index = 0 DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 DEFAULT_RESERVED_LABEL = ['', @@ -84,7 +88,7 @@ class BasePreprocess(object): return len(self.label2index) def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): - """Main preprocessing pipeline. + """Main pre-processing pipeline. :param train_dev_data: three-level list, with either single label or multiple labels in a sample. :param test_data: three-level list, with either single label or multiple labels in a sample. (optional) @@ -92,7 +96,9 @@ class BasePreprocess(object): :param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. :param cross_val: bool, whether to do cross validation. :param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. - :return results: a tuple of datasets after preprocessing. + :return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset. + If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset + is a list of DataSet objects; Otherwise, each dataset is a DataSet object. """ if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): @@ -111,68 +117,87 @@ class BasePreprocess(object): index2label = self.build_reverse_dict(self.label2index) save_pickle(index2label, pickle_path, "id2class.pkl") - data_train = [] - data_dev = [] + train_set = [] + dev_set = [] if not cross_val: if not pickle_exist(pickle_path, "data_train.pkl"): - data_train.extend(self.to_index(train_dev_data)) if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): - split = int(len(data_train) * train_dev_split) - data_dev = data_train[: split] - data_train = data_train[split:] - save_pickle(data_dev, pickle_path, "data_dev.pkl") + split = int(len(train_dev_data) * train_dev_split) + data_dev = train_dev_data[: split] + data_train = train_dev_data[split:] + train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) + dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) + + save_pickle(dev_set, pickle_path, "data_dev.pkl") print("{} of the training data is split for validation. ".format(train_dev_split)) - save_pickle(data_train, pickle_path, "data_train.pkl") + else: + train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) + save_pickle(train_set, pickle_path, "data_train.pkl") else: - data_train = load_pickle(pickle_path, "data_train.pkl") + train_set = load_pickle(pickle_path, "data_train.pkl") if pickle_exist(pickle_path, "data_dev.pkl"): - data_dev = load_pickle(pickle_path, "data_dev.pkl") + dev_set = load_pickle(pickle_path, "data_dev.pkl") else: # cross_val is True if not pickle_exist(pickle_path, "data_train_0.pkl"): # cross validation - data_idx = self.to_index(train_dev_data) - data_cv = self.cv_split(data_idx, n_fold) + data_cv = self.cv_split(train_dev_data, n_fold) for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): + data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) + data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) save_pickle( data_train_cv, pickle_path, "data_train_{}.pkl".format(i)) save_pickle( data_dev_cv, pickle_path, "data_dev_{}.pkl".format(i)) - data_train.append(data_train_cv) - data_dev.append(data_dev_cv) + train_set.append(data_train_cv) + dev_set.append(data_dev_cv) print("{}-fold cross validation.".format(n_fold)) else: for i in range(n_fold): data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) - data_train.append(data_train_cv) - data_dev.append(data_dev_cv) + train_set.append(data_train_cv) + dev_set.append(data_dev_cv) # prepare test data if provided - data_test = [] + test_set = [] if test_data is not None: if not pickle_exist(pickle_path, "data_test.pkl"): - data_test = self.to_index(test_data) - save_pickle(data_test, pickle_path, "data_test.pkl") + test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) + save_pickle(test_set, pickle_path, "data_test.pkl") # return preprocessed results - results = [data_train] + results = [train_set] if cross_val or train_dev_split > 0: - results.append(data_dev) + results.append(dev_set) if test_data: - results.append(data_test) + results.append(test_set) if len(results) == 1: return results[0] else: return tuple(results) def build_dict(self, data): - raise NotImplementedError + label2index = DEFAULT_WORD_TO_INDEX.copy() + word2index = DEFAULT_WORD_TO_INDEX.copy() + for example in data: + for word in example[0]: + if word not in word2index: + word2index[word] = len(word2index) + label = example[1] + if isinstance(label, str): + # label is a string + if label not in label2index: + label2index[label] = len(label2index) + elif isinstance(label, list): + # label is a list of strings + for single_label in label: + if single_label not in label2index: + label2index[single_label] = len(label2index) + return word2index, label2index - def to_index(self, data): - raise NotImplementedError def build_reverse_dict(self, word_dict): id2word = {word_dict[w]: w for w in word_dict} @@ -186,11 +211,23 @@ class BasePreprocess(object): return data_train, data_dev def cv_split(self, data, n_fold): - """Split data for cross validation.""" + """Split data for cross validation. + + :param data: list of string + :param n_fold: int + :return data_cv: + + :: + [ + (data_train, data_dev), # 1st fold + (data_train, data_dev), # 2nd fold + ... + ] + + """ data_copy = data.copy() np.random.shuffle(data_copy) fold_size = round(len(data_copy) / n_fold) - data_cv = [] for i in range(n_fold - 1): start = i * fold_size @@ -202,154 +239,62 @@ class BasePreprocess(object): data_dev = data_copy[start:] data_train = data_copy[:start] data_cv.append((data_train, data_dev)) - return data_cv + def convert_to_dataset(self, data, vocab, label_vocab): + """Convert list of indices into a DataSet object. -class SeqLabelPreprocess(BasePreprocess): - """Preprocess pipeline, including building mapping from words to index, from index to words, - from labels/classes to index, from index to labels/classes. - data of three-level list which have multiple labels in each sample. - :: - - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - - """ - - def __init__(self): - super(SeqLabelPreprocess, self).__init__() - - def build_dict(self, data): - """Add new words with indices into self.word_dict, new labels with indices into self.label_dict. - - :param data: three-level list - :: - - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - - :return word2index: dict of {str, int} - label2index: dict of {str, int} + :param data: list. Entries are strings. + :param vocab: a dict, mapping string (token) to index (int). + :param label_vocab: a dict, mapping string (label) to index (int). + :return data_set: a DataSet object """ - # In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. - label2index = DEFAULT_WORD_TO_INDEX.copy() - word2index = DEFAULT_WORD_TO_INDEX.copy() + use_word_seq = False + use_label_seq = False + data_set = DataSet() for example in data: - for word, label in zip(example[0], example[1]): - if word not in word2index: - word2index[word] = len(word2index) - if label not in label2index: - label2index[label] = len(label2index) - return word2index, label2index + words, label = example[0], example[1] + instance = Instance() - def to_index(self, data): - """Convert word strings and label strings into indices. + if isinstance(words, list): + x = TextField(words, is_target=False) + instance.add_field("word_seq", x) + use_word_seq = True + else: + raise NotImplementedError("words is a {}".format(type(words))) + + if isinstance(label, list): + y = TextField(label, is_target=True) + instance.add_field("label_seq", y) + use_label_seq = True + elif isinstance(label, str): + y = LabelField(label, is_target=True) + instance.add_field("label", y) + else: + raise NotImplementedError("label is a {}".format(type(label))) - :param data: three-level list - :: + data_set.append(instance) + if use_word_seq: + data_set.index_field("word_seq", vocab) + if use_label_seq: + data_set.index_field("label_seq", label_vocab) + return data_set - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - :return data_index: the same shape as data, but each string is replaced by its corresponding index - """ - data_index = [] - for example in data: - word_list = [] - label_list = [] - for word, label in zip(example[0], example[1]): - word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) - label_list.append(self.label2index.get(label, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) - data_index.append([word_list, label_list]) - return data_index +class SeqLabelPreprocess(BasePreprocess): + def __init__(self): + super(SeqLabelPreprocess, self).__init__() class ClassPreprocess(BasePreprocess): - """ Preprocess pipeline for classification datasets. - Preprocess pipeline, including building mapping from words to index, from index to words, - from labels/classes to index, from index to labels/classes. - design for data of three-level list which has a single label in each sample. - :: - - [ - [ [word_11, word_12, ...], label_1 ], - [ [word_21, word_22, ...], label_2 ], - ... - ] - - """ - def __init__(self): super(ClassPreprocess, self).__init__() - def build_dict(self, data): - """Build vocabulary.""" - # build vocabulary from scratch if nothing exists - word2index = DEFAULT_WORD_TO_INDEX.copy() - label2index = DEFAULT_WORD_TO_INDEX.copy() - - # collect every word and label - for sent, label in data: - if len(sent) <= 1: - continue - - if label not in label2index: - label2index[label] = len(label2index) - - for word in sent: - if word not in word2index: - word2index[word] = len(word2index) - return word2index, label2index - - def to_index(self, data): - """Convert word strings and label strings into indices. - - :param data: three-level list - :: - - [ - [ [word_11, word_12, ...], label_1 ], - [ [word_21, word_22, ...], label_2 ], - ... - ] - - :return data_index: the same shape as data, but each string is replaced by its corresponding index - """ - data_index = [] - for example in data: - word_list = [] - # example[0] is the word list, example[1] is the single label - for word in example[0]: - word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) - label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) - data_index.append([word_list, label_index]) - return data_index - - -def infer_preprocess(pickle_path, data): - """Preprocess over inference data. Transform three-level list of strings into that of index. - :: - - [ - [word_11, word_12, ...], - [word_21, word_22, ...], - ... - ] - - """ - word2index = load_pickle(pickle_path, "word2id.pkl") - data_index = [] - for example in data: - data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) - return data_index +if __name__ == "__main__": + p = BasePreprocess() + train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], + [["You", "are", "pretty", "."], "1"] + ] + training_set = p.run(train_dev_data) + print(training_set) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index bcb6ba8c..cfbc918e 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -2,8 +2,8 @@ import numpy as np import torch from fastNLP.core.action import Action -from fastNLP.core.action import RandomSampler, Batchifier -from fastNLP.modules import utils +from fastNLP.core.action import RandomSampler +from fastNLP.core.batch import Batch from fastNLP.saver.logger import create_logger logger = create_logger(__name__, "./train_test.log") @@ -35,16 +35,16 @@ class BaseTester(object): """ "required_args" is the collection of arguments that users must pass to Trainer explicitly. This is used to warn users of essential settings in the training. - Obviously, "required_args" is the subset of "default_args". - The value in "default_args" to the keys in "required_args" is simply for type check. + Specially, "required_args" does not have default value, so they have nothing to do with "default_args". """ - # add required arguments here - required_args = {} + required_args = {"task" # one of ("seq_label", "text_classify") + } for req_key in required_args: if req_key not in kwargs: logger.error("Tester lacks argument {}".format(req_key)) raise ValueError("Tester lacks argument {}".format(req_key)) + self._task = kwargs["task"] for key in default_args: if key in kwargs: @@ -83,10 +83,10 @@ class BaseTester(object): self.eval_history.clear() self.batch_output.clear() - iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=False)) + data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) step = 0 - for batch_x, batch_y in self.make_batch(iterator): + for batch_x, batch_y in data_iterator: with torch.no_grad(): prediction = self.data_forward(network, batch_x) eval_results = self.evaluate(prediction, batch_y) @@ -112,7 +112,8 @@ class BaseTester(object): def data_forward(self, network, x): """A forward pass of the model. """ - raise NotImplementedError + y = network(**x) + return y def evaluate(self, predict, truth): """Compute evaluation metrics. @@ -121,7 +122,26 @@ class BaseTester(object): :param truth: Tensor :return eval_results: can be anything. It will be stored in self.eval_history """ - raise NotImplementedError + batch_size, max_len = predict.size(0), predict.size(1) + if "label_seq" in truth: + truth = truth["label_seq"] + elif "label" in truth: + truth = truth["label"] + else: + raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) + loss = self._model.loss(predict, truth) / batch_size + + prediction = self._model.prediction(predict) + # pad prediction to equal length + for pred in prediction: + if len(pred) < max_len: + pred += [0] * (max_len - len(pred)) + results = torch.Tensor(prediction).view(-1, ) + + # make sure "results" is in the same device as "truth" + results = results.to(truth) + accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] + return [float(loss), float(accuracy)] @property def metrics(self): @@ -131,7 +151,9 @@ class BaseTester(object): :return : variable number of outputs """ - raise NotImplementedError + batch_loss = np.mean([x[0] for x in self.eval_history]) + batch_accuracy = np.mean([x[1] for x in self.eval_history]) + return batch_loss, batch_accuracy def show_metrics(self): """Customize evaluation outputs in Trainer. @@ -140,10 +162,8 @@ class BaseTester(object): :return print_str: str """ - raise NotImplementedError - - def make_batch(self, iterator): - raise NotImplementedError + loss, accuracy = self.metrics + return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) def make_eval_output(self, predictions, eval_results): """Customize Tester outputs. @@ -152,78 +172,21 @@ class BaseTester(object): :param eval_results: Tensor :return: str, to be printed. """ - raise NotImplementedError + return self.show_metrics() + class SeqLabelTester(BaseTester): """Tester for sequence labeling. """ - def __init__(self, **test_args): """ :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" """ + test_args.update({"task": "seq_label"}) + print( + "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") super(SeqLabelTester, self).__init__(**test_args) - self.max_len = None - self.mask = None - self.seq_len = None - - def data_forward(self, network, inputs): - """This is only for sequence labeling with CRF decoder. - - :param network: a PyTorch model - :param inputs: tuple of (x, seq_len) - x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch - after padding. - seq_len: list of int, the lengths of sequences before padding. - :return y: Tensor of shape [batch_size, max_len] - """ - if not isinstance(inputs, tuple): - raise RuntimeError("output_length must be true for sequence modeling.") - # unpack the returned value from make_batch - x, seq_len = inputs[0], inputs[1] - batch_size, max_len = x.size(0), x.size(1) - mask = utils.seq_mask(seq_len, max_len) - mask = mask.byte().view(batch_size, max_len) - if torch.cuda.is_available() and self.use_cuda: - mask = mask.cuda() - self.mask = mask - self.seq_len = seq_len - y = network(x) - return y - - def evaluate(self, predict, truth): - """Compute metrics (or loss). - - :param predict: Tensor, [batch_size, max_len, tag_size] - :param truth: Tensor, [batch_size, max_len] - :return: - """ - batch_size, max_len = predict.size(0), predict.size(1) - loss = self._model.loss(predict, truth, self.mask) / batch_size - - prediction = self._model.prediction(predict, self.mask) - results = torch.Tensor(prediction).view(-1, ) - # make sure "results" is in the same device as "truth" - results = results.to(truth) - accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] - return [float(loss), float(accuracy)] - - def metrics(self): - batch_loss = np.mean([x[0] for x in self.eval_history]) - batch_accuracy = np.mean([x[1] for x in self.eval_history]) - return batch_loss, batch_accuracy - - def show_metrics(self): - """This is called by Trainer to print evaluation on dev set. - - :return print_str: str - """ - loss, accuracy = self.metrics() - return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) - - def make_batch(self, iterator): - return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) class ClassificationTester(BaseTester): @@ -236,9 +199,6 @@ class ClassificationTester(BaseTester): """ super(ClassificationTester, self).__init__(**test_args) - def make_batch(self, iterator, max_len=None): - return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) - def data_forward(self, network, x): """Forward through network.""" logits = network(x) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 523a1763..f4c3e8c1 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -4,15 +4,14 @@ import time from datetime import timedelta import torch -import tensorboardX from tensorboardX import SummaryWriter from fastNLP.core.action import Action -from fastNLP.core.action import RandomSampler, Batchifier +from fastNLP.core.action import RandomSampler +from fastNLP.core.batch import Batch from fastNLP.core.loss import Loss from fastNLP.core.optimizer import Optimizer from fastNLP.core.tester import SeqLabelTester, ClassificationTester -from fastNLP.modules import utils from fastNLP.saver.logger import create_logger from fastNLP.saver.model_saver import ModelSaver @@ -50,16 +49,16 @@ class BaseTrainer(object): """ "required_args" is the collection of arguments that users must pass to Trainer explicitly. This is used to warn users of essential settings in the training. - Obviously, "required_args" is the subset of "default_args". - The value in "default_args" to the keys in "required_args" is simply for type check. + Specially, "required_args" does not have default value, so they have nothing to do with "default_args". """ - # add required arguments here - required_args = {} + required_args = {"task" # one of ("seq_label", "text_classify") + } for req_key in required_args: if req_key not in kwargs: logger.error("Trainer lacks argument {}".format(req_key)) raise ValueError("Trainer lacks argument {}".format(req_key)) + self._task = kwargs["task"] for key in default_args: if key in kwargs: @@ -90,13 +89,14 @@ class BaseTrainer(object): self._optimizer_proto = default_args["optimizer"] self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') self._graph_summaried = False + self._best_accuracy = 0.0 def train(self, network, train_data, dev_data=None): """General Training Procedure :param network: a model - :param train_data: three-level list, the training set. - :param dev_data: three-level list, the validation data (optional) + :param train_data: a DataSet instance, the training data + :param dev_data: a DataSet instance, the validation data (optional) """ # transfer model to gpu if available if torch.cuda.is_available() and self.use_cuda: @@ -128,7 +128,8 @@ class BaseTrainer(object): # turn on network training mode self.mode(network, test=False) # prepare mini-batch iterator - data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) + data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), + use_cuda=self.use_cuda) logger.info("prepared data iterator") # one forward and backward pass @@ -157,7 +158,7 @@ class BaseTrainer(object): - epoch: int, """ step = 0 - for batch_x, batch_y in self.make_batch(data_iterator): + for batch_x, batch_y in data_iterator: prediction = self.data_forward(network, batch_x) @@ -166,10 +167,6 @@ class BaseTrainer(object): self.update() self._summary_writer.add_scalar("loss", loss.item(), global_step=step) - if not self._graph_summaried: - self._summary_writer.add_graph(network, batch_x) - self._graph_summaried = True - if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: end = time.time() diff = timedelta(seconds=round(end - kwargs["start"])) @@ -204,9 +201,6 @@ class BaseTrainer(object): network_copy = copy.deepcopy(network) self.train(network_copy, train_data_cv[i], dev_data_cv[i]) - def make_batch(self, iterator): - raise NotImplementedError - def mode(self, network, test): Action.mode(network, test) @@ -224,7 +218,12 @@ class BaseTrainer(object): self._optimizer.step() def data_forward(self, network, x): - raise NotImplementedError + y = network(**x) + if not self._graph_summaried: + if self._task == "seq_label": + self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) + self._graph_summaried = True + return y def grad_backward(self, loss): """Compute gradient with link rules. @@ -243,6 +242,12 @@ class BaseTrainer(object): :param truth: ground truth label vector :return: a scalar """ + if "label_seq" in truth: + truth = truth["label_seq"] + elif "label" in truth: + truth = truth["label"] + else: + raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) return self._loss_func(predict, truth) def define_loss(self): @@ -270,7 +275,12 @@ class BaseTrainer(object): :param validator: a Tester instance :return: bool, True means current results on dev set is the best. """ - raise NotImplementedError + loss, accuracy = validator.metrics() + if accuracy > self._best_accuracy: + self._best_accuracy = accuracy + return True + else: + return False def save_model(self, network, model_name): """Save this model with such a name. @@ -291,55 +301,11 @@ class SeqLabelTrainer(BaseTrainer): """Trainer for Sequence Labeling """ - def __init__(self, **kwargs): + kwargs.update({"task": "seq_label"}) + print( + "[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.") super(SeqLabelTrainer, self).__init__(**kwargs) - # self.vocab_size = kwargs["vocab_size"] - # self.num_classes = kwargs["num_classes"] - self.max_len = None - self.mask = None - self.best_accuracy = 0.0 - - def data_forward(self, network, inputs): - if not isinstance(inputs, tuple): - raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) - # unpack the returned value from make_batch - x, seq_len = inputs[0], inputs[1] - - batch_size, max_len = x.size(0), x.size(1) - mask = utils.seq_mask(seq_len, max_len) - mask = mask.byte().view(batch_size, max_len) - - if torch.cuda.is_available() and self.use_cuda: - mask = mask.cuda() - self.mask = mask - - y = network(x) - return y - - def get_loss(self, predict, truth): - """Compute loss given prediction and ground truth. - - :param predict: prediction label vector, [batch_size, max_len, tag_size] - :param truth: ground truth label vector, [batch_size, max_len] - :return loss: a scalar - """ - batch_size, max_len = predict.size(0), predict.size(1) - assert truth.shape == (batch_size, max_len) - - loss = self._model.loss(predict, truth, self.mask) - return loss - - def best_eval_result(self, validator): - loss, accuracy = validator.metrics() - if accuracy > self.best_accuracy: - self.best_accuracy = accuracy - return True - else: - return False - - def make_batch(self, iterator): - return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) def _create_validator(self, valid_args): return SeqLabelTester(**valid_args) @@ -361,9 +327,6 @@ class ClassificationTrainer(BaseTrainer): logits = network(x) return logits - def make_batch(self, iterator): - return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) - def get_acc(self, y_logit, y_true): """Compute accuracy.""" y_pred = torch.argmax(y_logit, dim=-1) diff --git a/fastNLP/data/batch.py b/fastNLP/data/batch.py deleted file mode 100644 index ef5f7d46..00000000 --- a/fastNLP/data/batch.py +++ /dev/null @@ -1,86 +0,0 @@ -from collections import defaultdict -import torch - -class Batch(object): - def __init__(self, dataset, sampler, batch_size): - self.dataset = dataset - self.sampler = sampler - self.batch_size = batch_size - - self.idx_list = None - self.curidx = 0 - - def __iter__(self): - self.idx_list = self.sampler(self.dataset) - self.curidx = 0 - self.lengths = self.dataset.get_length() - return self - - def __next__(self): - if self.curidx >= len(self.idx_list): - raise StopIteration - else: - endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - padding_length = {field_name : max(field_length[self.curidx: endidx]) - for field_name, field_length in self.lengths.items()} - - batch_x, batch_y = defaultdict(list), defaultdict(list) - for idx in range(self.curidx, endidx): - x, y = self.dataset.to_tensor(idx, padding_length) - for name, tensor in x.items(): - batch_x[name].append(tensor) - for name, tensor in y.items(): - batch_y[name].append(tensor) - - for batch in (batch_x, batch_y): - for name, tensor_list in batch.items(): - print(name, " ", tensor_list) - batch[name] = torch.stack(tensor_list, dim=0) - self.curidx += endidx - return batch_x, batch_y - - -if __name__ == "__main__": - """simple running example - """ - from field import TextField, LabelField - from instance import Instance - from dataset import DataSet - - texts = ["i am a cat", - "this is a test of new batch", - "haha" - ] - labels = [0, 1, 0] - - # prepare vocabulary - vocab = {} - for text in texts: - for tokens in text.split(): - if tokens not in vocab: - vocab[tokens] = len(vocab) - - # prepare input dataset - data = DataSet() - for text, label in zip(texts, labels): - x = TextField(text.split(), False) - y = LabelField(label, is_target=True) - ins = Instance(text=x, label=y) - data.append(ins) - - # use vocabulary to index data - data.index_field("text", vocab) - - # define naive sampler for batch class - class SeqSampler: - def __call__(self, dataset): - return list(range(len(dataset))) - - # use bacth to iterate dataset - batcher = Batch(data, SeqSampler(), 2) - for epoch in range(3): - for batch_x, batch_y in batcher: - print(batch_x) - print(batch_y) - # do stuff - diff --git a/fastNLP/data/instance.py b/fastNLP/data/instance.py deleted file mode 100644 index 4b78dfc3..00000000 --- a/fastNLP/data/instance.py +++ /dev/null @@ -1,38 +0,0 @@ -class Instance(object): - def __init__(self, **fields): - self.fields = fields - self.has_index = False - self.indexes = {} - - def add_field(self, field_name, field): - self.fields[field_name] = field - - def get_length(self): - length = {name : field.get_length() for name, field in self.fields.items()} - return length - - def index_field(self, field_name, vocab): - """use `vocab` to index certain field - """ - self.indexes[field_name] = self.fields[field_name].index(vocab) - - def index_all(self, vocab): - """use `vocab` to index all fields - """ - if self.has_index: - print("error") - return self.indexes - indexes = {name : field.index(vocab) for name, field in self.fields.items()} - self.indexes = indexes - return indexes - - def to_tensor(self, padding_length: dict): - tensorX = {} - tensorY = {} - for name, field in self.fields.items(): - if field.is_target: - tensorY[name] = field.to_tensor(padding_length[name]) - else: - tensorX[name] = field.to_tensor(padding_length[name]) - - return tensorX, tensorY diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index 5addc73e..bed3f0a6 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -4,6 +4,20 @@ from fastNLP.models.base_model import BaseModel from fastNLP.modules import decoder, encoder +def seq_mask(seq_len, max_len): + """Create a mask for the sequences. + + :param seq_len: list or torch.LongTensor + :param max_len: int + :return mask: torch.LongTensor + """ + if isinstance(seq_len, list): + seq_len = torch.LongTensor(seq_len) + mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] + mask = torch.stack(mask, 1) + return mask + + class SeqLabeling(BaseModel): """ PyTorch Network for sequence labeling @@ -20,13 +34,17 @@ class SeqLabeling(BaseModel): self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) self.Linear = encoder.linear.Linear(hidden_dim, num_classes) self.Crf = decoder.CRF.ConditionalRandomField(num_classes) + self.mask = None - def forward(self, x): + def forward(self, word_seq, word_seq_origin_len): """ - :param x: LongTensor, [batch_size, mex_len] + :param word_seq: LongTensor, [batch_size, mex_len] + :param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. :return y: [batch_size, mex_len, tag_size] """ - x = self.Embedding(x) + self.mask = self.make_mask(word_seq, word_seq_origin_len) + + x = self.Embedding(word_seq) # [batch_size, max_len, word_emb_dim] x = self.Rnn(x) # [batch_size, max_len, hidden_size * direction] @@ -34,27 +52,32 @@ class SeqLabeling(BaseModel): # [batch_size, max_len, num_classes] return x - def loss(self, x, y, mask): + def loss(self, x, y): """ Negative log likelihood loss. :param x: Tensor, [batch_size, max_len, tag_size] :param y: Tensor, [batch_size, max_len] - :param mask: ByteTensor, [batch_size, ,max_len] :return loss: a scalar Tensor """ x = x.float() y = y.long() - total_loss = self.Crf(x, y, mask) + total_loss = self.Crf(x, y, self.mask) return torch.mean(total_loss) - def prediction(self, x, mask): + def make_mask(self, x, seq_len): + batch_size, max_len = x.size(0), x.size(1) + mask = seq_mask(seq_len, max_len) + mask = mask.byte().view(batch_size, max_len) + mask = mask.to(x) + return mask + + def prediction(self, x): """ :param x: FloatTensor, [batch_size, max_len, tag_size] - :param mask: ByteTensor, [batch_size, max_len] :return prediction: list of [decode path(list)] """ - tag_seq = self.Crf.viterbi_decode(x, mask) + tag_seq = self.Crf.viterbi_decode(x, self.mask) return tag_seq @@ -81,11 +104,14 @@ class AdvSeqLabel(SeqLabeling): self.Crf = decoder.CRF.ConditionalRandomField(num_classes) - def forward(self, x): + def forward(self, x, seq_len): """ :param x: LongTensor, [batch_size, mex_len] + :param seq_len: list of int. :return y: [batch_size, mex_len, tag_size] """ + self.mask = self.make_mask(x, seq_len) + batch_size = x.size(0) max_len = x.size(1) x = self.Embedding(x) diff --git a/test/model/seq_labeling.py b/test/model/seq_labeling.py index 0f7a072b..dcfa8bb4 100644 --- a/test/model/seq_labeling.py +++ b/test/model/seq_labeling.py @@ -15,11 +15,11 @@ from fastNLP.core.optimizer import Optimizer parser = argparse.ArgumentParser() parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") -parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt", +parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", help="path to the training data") -parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") +parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") -parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt", +parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", help="data used for inference") args = parser.parse_args() @@ -86,7 +86,7 @@ def train_and_test(): trainer = SeqLabelTrainer( epochs=trainer_args["epochs"], batch_size=trainer_args["batch_size"], - validate=trainer_args["validate"], + validate=False, use_cuda=trainer_args["use_cuda"], pickle_path=pickle_path, save_best_dev=trainer_args["save_best_dev"], @@ -139,5 +139,5 @@ def train_and_test(): if __name__ == "__main__": - # train_and_test() - infer() + train_and_test() + # infer() diff --git a/test/model/text_classify.py b/test/model/text_classify.py index 6ff3c059..dd20505f 100644 --- a/test/model/text_classify.py +++ b/test/model/text_classify.py @@ -115,4 +115,4 @@ def train(): if __name__ == "__main__": train() - infer() + # infer()