- apply DataSet in Predictor; remove sub-predictors; add "task" argument to specify which task to predict, as how Trainer/Tester did. - remove Action class - add helper function for DataSet, to create DataSet easily - more code comments - clean up unnecessary codes - add unit tests for Batch, Predictor, Preprocessor, Trainer, Testertags/v0.1.0
| @@ -1 +0,0 @@ | |||
| @@ -4,88 +4,6 @@ import numpy as np | |||
| import torch | |||
| class Action(object): | |||
| """Operations shared by Trainer, Tester, or Inference. | |||
| This is designed for reducing replicate codes. | |||
| - make_batch: produce a min-batch of data. @staticmethod | |||
| - pad: padding method used in sequence modeling. @staticmethod | |||
| - mode: change network mode for either train or test. (for PyTorch) @staticmethod | |||
| """ | |||
| def __init__(self): | |||
| super(Action, self).__init__() | |||
| @staticmethod | |||
| def make_batch(iterator, use_cuda, output_length=True, max_len=None): | |||
| """Batch and Pad data. | |||
| :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | |||
| :param use_cuda: bool, whether to use GPU | |||
| :param output_length: bool, whether to output the original length of the sequence before padding. (default: True) | |||
| :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||
| :return : | |||
| if output_length is True, | |||
| (batch_x, seq_len): tuple of two elements | |||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| if output_length is False, | |||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| """ | |||
| for batch in iterator: | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| batch_x = Action.pad(batch_x) | |||
| # pad batch_y only if it is a 2-level list | |||
| if len(batch_y) > 0 and isinstance(batch_y[0], list): | |||
| batch_y = Action.pad(batch_y) | |||
| # convert list to tensor | |||
| batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||
| batch_y = convert_to_torch_tensor(batch_y, use_cuda) | |||
| # trim data to max_len | |||
| if max_len is not None and batch_x.size(1) > max_len: | |||
| batch_x = batch_x[:, :max_len] | |||
| if output_length: | |||
| seq_len = [len(x) for x in batch_x] | |||
| yield (batch_x, seq_len), batch_y | |||
| else: | |||
| yield batch_x, batch_y | |||
| @staticmethod | |||
| def pad(batch, fill=0): | |||
| """ Pad a mini-batch of sequence samples to maximum length of this batch. | |||
| :param batch: list of list | |||
| :param fill: word index to pad, default 0. | |||
| :return batch: a padded mini-batch | |||
| """ | |||
| max_length = max([len(x) for x in batch]) | |||
| for idx, sample in enumerate(batch): | |||
| if len(sample) < max_length: | |||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
| return batch | |||
| @staticmethod | |||
| def mode(model, is_test=False): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| :param model: a PyTorch model | |||
| :param is_test: bool, whether in test mode or not. | |||
| """ | |||
| if is_test: | |||
| model.eval() | |||
| else: | |||
| model.train() | |||
| def convert_to_torch_tensor(data_list, use_cuda): | |||
| """Convert lists into (cuda) Tensors. | |||
| @@ -224,6 +142,7 @@ class BucketBatchifier(Batchifier): | |||
| """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||
| In sampling, first random choose a bucket. Then sample data from it. | |||
| The number of buckets is decided dynamically by the variance of sentence lengths. | |||
| TODO: merge it into Batch | |||
| """ | |||
| def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): | |||
| @@ -1,8 +1,77 @@ | |||
| from collections import defaultdict | |||
| from fastNLP.core.field import TextField | |||
| from fastNLP.core.instance import Instance | |||
| def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): | |||
| if has_target is True: | |||
| if label_vocab is None: | |||
| raise RuntimeError("Must provide label vocabulary to transform labels.") | |||
| return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab) | |||
| else: | |||
| return create_unlabeled_dataset_from_lists(str_lists, word_vocab) | |||
| def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab): | |||
| """Create an DataSet instance that contains labels. | |||
| :param str_lists: list of list of strings, [num_examples, 2, *]. | |||
| :: | |||
| [ | |||
| [[word_11, word_12, ...], [label_11, label_12, ...]], | |||
| ... | |||
| ] | |||
| :param word_vocab: dict of (str: int), which means (word: index). | |||
| :param label_vocab: dict of (str: int), which means (word: index). | |||
| :return data_set: a DataSet instance. | |||
| """ | |||
| data_set = DataSet() | |||
| for example in str_lists: | |||
| word_seq, label_seq = example[0], example[1] | |||
| x = TextField(word_seq, is_target=False) | |||
| y = TextField(label_seq, is_target=True) | |||
| data_set.append(Instance(word_seq=x, label_seq=y)) | |||
| data_set.index_field("word_seq", word_vocab) | |||
| data_set.index_field("label_seq", label_vocab) | |||
| return data_set | |||
| def create_unlabeled_dataset_from_lists(str_lists, word_vocab): | |||
| """Create an DataSet instance that contains no labels. | |||
| :param str_lists: list of list of strings, [num_examples, *]. | |||
| :: | |||
| [ | |||
| [word_11, word_12, ...], | |||
| ... | |||
| ] | |||
| :param word_vocab: dict of (str: int), which means (word: index). | |||
| :return data_set: a DataSet instance. | |||
| """ | |||
| data_set = DataSet() | |||
| for word_seq in str_lists: | |||
| x = TextField(word_seq, is_target=False) | |||
| data_set.append(Instance(word_seq=x)) | |||
| data_set.index_field("word_seq", word_vocab) | |||
| return data_set | |||
| class DataSet(list): | |||
| """A DataSet object is a list of Instance objects. | |||
| """ | |||
| def __init__(self, name="", instances=None): | |||
| """ | |||
| :param name: str, the name of the dataset. (default: "") | |||
| :param instances: list of Instance objects. (default: None) | |||
| """ | |||
| list.__init__([]) | |||
| self.name = name | |||
| if instances is not None: | |||
| @@ -20,9 +20,10 @@ class Field(object): | |||
| class TextField(Field): | |||
| def __init__(self, text: list, is_target): | |||
| def __init__(self, text, is_target): | |||
| """ | |||
| :param list text: | |||
| :param text: list of strings | |||
| :param is_target: bool | |||
| """ | |||
| super(TextField, self).__init__(is_target) | |||
| self.text = text | |||
| @@ -32,7 +33,7 @@ class TextField(Field): | |||
| if self._index is None: | |||
| self._index = [vocab[c] for c in self.text] | |||
| else: | |||
| print('error') | |||
| raise RuntimeError("Replicate indexing of this field.") | |||
| return self._index | |||
| def get_length(self): | |||
| @@ -41,7 +41,7 @@ class Instance(object): | |||
| :param padding_length: dict of (str: int), which means (field name: padding_length of this field) | |||
| :return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||
| tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||
| If is_target is False for all fields, tensor_y would be an empty dict. | |||
| """ | |||
| tensor_x = {} | |||
| tensor_y = {} | |||
| @@ -1,53 +1,10 @@ | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.action import Batchifier, SequentialSampler | |||
| from fastNLP.core.action import convert_to_torch_tensor | |||
| from fastNLP.core.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||
| from fastNLP.modules import utils | |||
| def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None): | |||
| """Batch and Pad data, only for Inference. | |||
| :param iterator: An iterable object that returns a list of indices representing a mini-batch of samples. | |||
| :param use_cuda: bool, whether to use GPU | |||
| :param output_length: bool, whether to output the original length of the sequence before padding. (default: False) | |||
| :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||
| :param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None) | |||
| :return: | |||
| """ | |||
| for batch_x in iterator: | |||
| batch_x = pad(batch_x) | |||
| # convert list to tensor | |||
| batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||
| # trim data to max_len | |||
| if max_len is not None and batch_x.size(1) > max_len: | |||
| batch_x = batch_x[:, :max_len] | |||
| if min_len is not None and batch_x.size(1) < min_len: | |||
| pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x) | |||
| batch_x = torch.cat((batch_x, pad_tensor), 1) | |||
| if output_length: | |||
| seq_len = [len(x) for x in batch_x] | |||
| yield tuple([batch_x, seq_len]) | |||
| else: | |||
| yield batch_x | |||
| def pad(batch, fill=0): | |||
| """ Pad a mini-batch of sequence samples to maximum length of this batch. | |||
| :param batch: list of list | |||
| :param fill: word index to pad, default 0. | |||
| :return batch: a padded mini-batch | |||
| """ | |||
| max_length = max([len(x) for x in batch]) | |||
| for idx, sample in enumerate(batch): | |||
| if len(sample) < max_length: | |||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
| return batch | |||
| from fastNLP.core.action import SequentialSampler | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.dataset import create_dataset_from_lists | |||
| from fastNLP.core.preprocess import load_pickle | |||
| class Predictor(object): | |||
| @@ -59,11 +16,17 @@ class Predictor(object): | |||
| Currently, Predictor does not support GPU. | |||
| """ | |||
| def __init__(self, pickle_path): | |||
| def __init__(self, pickle_path, task): | |||
| """ | |||
| :param pickle_path: str, the path to the pickle files. | |||
| :param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify"). | |||
| """ | |||
| self.batch_size = 1 | |||
| self.batch_output = [] | |||
| self.iterator = None | |||
| self.pickle_path = pickle_path | |||
| self._task = task # one of ("seq_label", "text_classify") | |||
| self.index2label = load_pickle(self.pickle_path, "id2class.pkl") | |||
| self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | |||
| @@ -71,19 +34,19 @@ class Predictor(object): | |||
| """Perform inference using the trained model. | |||
| :param network: a PyTorch model (cpu) | |||
| :param data: list of list of strings | |||
| :param data: list of list of strings, [num_examples, seq_len] | |||
| :return: list of list of strings, [num_examples, tag_seq_length] | |||
| """ | |||
| # transform strings into indices | |||
| # transform strings into DataSet object | |||
| data = self.prepare_input(data) | |||
| # turn on the testing mode; clean up the history | |||
| self.mode(network, test=True) | |||
| self.batch_output.clear() | |||
| data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||
| data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||
| for batch_x in self.make_batch(data_iterator, use_cuda=False): | |||
| for batch_x, _ in data_iterator: | |||
| with torch.no_grad(): | |||
| prediction = self.data_forward(network, batch_x) | |||
| @@ -99,103 +62,61 @@ class Predictor(object): | |||
| def data_forward(self, network, x): | |||
| """Forward through network.""" | |||
| raise NotImplementedError | |||
| def make_batch(self, iterator, use_cuda): | |||
| raise NotImplementedError | |||
| y = network(**x) | |||
| if self._task == "seq_label": | |||
| y = network.prediction(y) | |||
| return y | |||
| def prepare_input(self, data): | |||
| """Transform two-level list of strings into that of index. | |||
| """Transform two-level list of strings into an DataSet object. | |||
| In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor. | |||
| :param data: | |||
| :param data: list of list of strings. | |||
| :: | |||
| [ | |||
| [word_11, word_12, ...], | |||
| [word_21, word_22, ...], | |||
| ... | |||
| ] | |||
| :return data_index: list of list of int. | |||
| :return data_set: a DataSet instance. | |||
| """ | |||
| assert isinstance(data, list) | |||
| data_index = [] | |||
| default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL] | |||
| for example in data: | |||
| data_index.append([self.word2index.get(w, default_unknown_index) for w in example]) | |||
| return data_index | |||
| return create_dataset_from_lists(data, self.word2index, has_target=False) | |||
| def prepare_output(self, data): | |||
| """Transform list of batch outputs into strings.""" | |||
| raise NotImplementedError | |||
| class SeqLabelInfer(Predictor): | |||
| """ | |||
| Inference on sequence labeling models. | |||
| """ | |||
| def __init__(self, pickle_path): | |||
| super(SeqLabelInfer, self).__init__(pickle_path) | |||
| if self._task == "seq_label": | |||
| return self._seq_label_prepare_output(data) | |||
| elif self._task == "text_classify": | |||
| return self._text_classify_prepare_output(data) | |||
| else: | |||
| raise NotImplementedError("Unknown task type {}".format(self._task)) | |||
| def data_forward(self, network, inputs): | |||
| """ | |||
| This is only for sequence labeling with CRF decoder. | |||
| :param network: a PyTorch model | |||
| :param inputs: tuple of (x, seq_len) | |||
| x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||
| after padding. | |||
| seq_len: list of int, the lengths of sequences before padding. | |||
| :return prediction: Tensor of shape [batch_size, max_len] | |||
| """ | |||
| if not isinstance(inputs[1], list) and isinstance(inputs[0], list): | |||
| raise RuntimeError("output_length must be true for sequence modeling.") | |||
| # unpack the returned value from make_batch | |||
| x, seq_len = inputs[0], inputs[1] | |||
| batch_size, max_len = x.size(0), x.size(1) | |||
| mask = utils.seq_mask(seq_len, max_len) | |||
| mask = mask.byte().view(batch_size, max_len) | |||
| y = network(x) | |||
| prediction = network.prediction(y, mask) | |||
| return torch.Tensor(prediction) | |||
| def make_batch(self, iterator, use_cuda): | |||
| return make_batch(iterator, use_cuda, output_length=True) | |||
| def prepare_output(self, batch_outputs): | |||
| """Transform list of batch outputs into strings. | |||
| :param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length]. | |||
| :return results: 2-D list of strings, shape [num_examples, tag_seq_length] | |||
| """ | |||
| def _seq_label_prepare_output(self, batch_outputs): | |||
| results = [] | |||
| for batch in batch_outputs: | |||
| for example in np.array(batch): | |||
| results.append([self.index2label[int(x)] for x in example]) | |||
| return results | |||
| class ClassificationInfer(Predictor): | |||
| """ | |||
| Inference on Classification models. | |||
| """ | |||
| def __init__(self, pickle_path): | |||
| super(ClassificationInfer, self).__init__(pickle_path) | |||
| def data_forward(self, network, x): | |||
| """Forward through network.""" | |||
| logits = network(x) | |||
| return logits | |||
| def make_batch(self, iterator, use_cuda): | |||
| return make_batch(iterator, use_cuda, output_length=False, min_len=5) | |||
| def prepare_output(self, batch_outputs): | |||
| """ | |||
| Transform list of batch outputs into strings. | |||
| :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes]. | |||
| :return results: list of strings | |||
| """ | |||
| def _text_classify_prepare_output(self, batch_outputs): | |||
| results = [] | |||
| for batch_out in batch_outputs: | |||
| idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||
| results.extend([self.index2label[i] for i in idx]) | |||
| return results | |||
| class SeqLabelInfer(Predictor): | |||
| def __init__(self, pickle_path): | |||
| print( | |||
| "[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.") | |||
| super(SeqLabelInfer, self).__init__(pickle_path, "seq_label") | |||
| class ClassificationInfer(Predictor): | |||
| def __init__(self, pickle_path): | |||
| print( | |||
| "[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.") | |||
| super(ClassificationInfer, self).__init__(pickle_path, "text_classify") | |||
| @@ -1,7 +1,6 @@ | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.action import Action | |||
| from fastNLP.core.action import RandomSampler | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.saver.logger import create_logger | |||
| @@ -79,7 +78,7 @@ class BaseTester(object): | |||
| self._model = network | |||
| # turn on the testing mode; clean up the history | |||
| self.mode(network, test=True) | |||
| self.mode(network, is_test=True) | |||
| self.eval_history.clear() | |||
| self.batch_output.clear() | |||
| @@ -102,13 +101,17 @@ class BaseTester(object): | |||
| print(self.make_eval_output(prediction, eval_results)) | |||
| step += 1 | |||
| def mode(self, model, test): | |||
| def mode(self, model, is_test=False): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| :param model: a PyTorch model | |||
| :param test: bool, whether in test mode. | |||
| :param is_test: bool, whether in test mode or not. | |||
| """ | |||
| Action.mode(model, test) | |||
| if is_test: | |||
| model.eval() | |||
| else: | |||
| model.train() | |||
| def data_forward(self, network, x): | |||
| """A forward pass of the model. """ | |||
| @@ -6,7 +6,6 @@ from datetime import timedelta | |||
| import torch | |||
| from tensorboardX import SummaryWriter | |||
| from fastNLP.core.action import Action | |||
| from fastNLP.core.action import RandomSampler | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.loss import Loss | |||
| @@ -126,7 +125,7 @@ class BaseTrainer(object): | |||
| logger.info("training epoch {}".format(epoch)) | |||
| # turn on network training mode | |||
| self.mode(network, test=False) | |||
| self.mode(network, is_test=False) | |||
| # prepare mini-batch iterator | |||
| data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||
| use_cuda=self.use_cuda) | |||
| @@ -201,8 +200,17 @@ class BaseTrainer(object): | |||
| network_copy = copy.deepcopy(network) | |||
| self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | |||
| def mode(self, network, test): | |||
| Action.mode(network, test) | |||
| def mode(self, model, is_test=False): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| :param model: a PyTorch model | |||
| :param is_test: bool, whether in test mode or not. | |||
| """ | |||
| if is_test: | |||
| model.eval() | |||
| else: | |||
| model.train() | |||
| def define_optimizer(self): | |||
| """Define framework-specific optimizer specified by the models. | |||
| @@ -284,7 +292,7 @@ class BaseTrainer(object): | |||
| :param validator: a Tester instance | |||
| :return: bool, True means current results on dev set is the best. | |||
| """ | |||
| loss, accuracy = validator.metrics() | |||
| loss, accuracy = validator.metrics | |||
| if accuracy > self._best_accuracy: | |||
| self._best_accuracy = accuracy | |||
| return True | |||
| @@ -62,6 +62,8 @@ class SeqLabeling(BaseModel): | |||
| """ | |||
| x = x.float() | |||
| y = y.long() | |||
| assert x.shape[:2] == y.shape | |||
| assert y.shape == self.mask.shape | |||
| total_loss = self.Crf(x, y, self.mask) | |||
| return torch.mean(total_loss) | |||
| @@ -1,17 +0,0 @@ | |||
| import unittest | |||
| from fastNLP.core.action import Action, Batchifier, SequentialSampler | |||
| class TestAction(unittest.TestCase): | |||
| def test_case_1(self): | |||
| x = [1, 2, 3, 4, 5, 6, 7, 8] | |||
| y = [1, 1, 1, 1, 2, 2, 2, 2] | |||
| data = [] | |||
| for i in range(len(x)): | |||
| data.append([[x[i]], [y[i]]]) | |||
| data = Batchifier(SequentialSampler(data), batch_size=2, drop_last=False) | |||
| action = Action() | |||
| for batch_x in action.make_batch(data, use_cuda=False, output_length=True, max_len=None): | |||
| print(batch_x) | |||
| @@ -0,0 +1,62 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.dataset import DataSet, create_dataset_from_lists | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| raw_texts = ["i am a cat", | |||
| "this is a test of new batch", | |||
| "ha ha", | |||
| "I am a good boy .", | |||
| "This is the most beautiful girl ." | |||
| ] | |||
| texts = [text.strip().split() for text in raw_texts] | |||
| labels = [0, 1, 0, 0, 1] | |||
| # prepare vocabulary | |||
| vocab = {} | |||
| for text in texts: | |||
| for tokens in text: | |||
| if tokens not in vocab: | |||
| vocab[tokens] = len(vocab) | |||
| class TestCase1(unittest.TestCase): | |||
| def test(self): | |||
| data = DataSet() | |||
| for text, label in zip(texts, labels): | |||
| x = TextField(text, is_target=False) | |||
| y = LabelField(label, is_target=True) | |||
| ins = Instance(text=x, label=y) | |||
| data.append(ins) | |||
| # use vocabulary to index data | |||
| data.index_field("text", vocab) | |||
| # define naive sampler for batch class | |||
| class SeqSampler: | |||
| def __call__(self, dataset): | |||
| return list(range(len(dataset))) | |||
| # use batch to iterate dataset | |||
| data_iterator = Batch(data, 2, SeqSampler(), False) | |||
| for batch_x, batch_y in data_iterator: | |||
| self.assertEqual(len(batch_x), 2) | |||
| self.assertTrue(isinstance(batch_x, dict)) | |||
| self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | |||
| self.assertTrue(isinstance(batch_y, dict)) | |||
| self.assertTrue(isinstance(batch_y["label"], torch.LongTensor)) | |||
| class TestCase2(unittest.TestCase): | |||
| def test(self): | |||
| data = DataSet() | |||
| for text in texts: | |||
| x = TextField(text, is_target=False) | |||
| ins = Instance(text=x) | |||
| data.append(ins) | |||
| data_set = create_dataset_from_lists(texts, vocab, has_target=False) | |||
| self.assertTrue(type(data) == type(data_set)) | |||
| @@ -0,0 +1,51 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.core.preprocess import save_pickle | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| class TestPredictor(unittest.TestCase): | |||
| def test_seq_label(self): | |||
| model_args = { | |||
| "vocab_size": 10, | |||
| "word_emb_dim": 100, | |||
| "rnn_hidden_units": 100, | |||
| "num_classes": 5 | |||
| } | |||
| infer_data = [ | |||
| ['a', 'b', 'c', 'd', 'e'], | |||
| ['a', '@', 'c', 'd', 'e'], | |||
| ['a', 'b', '#', 'd', 'e'], | |||
| ['a', 'b', 'c', '?', 'e'], | |||
| ['a', 'b', 'c', 'd', '$'], | |||
| ['!', 'b', 'c', 'd', 'e'] | |||
| ] | |||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| os.system("mkdir save") | |||
| save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") | |||
| save_pickle(vocab, "./save/", "word2id.pkl") | |||
| model = SeqLabeling(model_args) | |||
| predictor = Predictor("./save/", task="seq_label") | |||
| results = predictor.predict(network=model, data=infer_data) | |||
| self.assertTrue(isinstance(results, list)) | |||
| self.assertGreater(len(results), 0) | |||
| for res in results: | |||
| self.assertTrue(isinstance(res, list)) | |||
| self.assertEqual(len(res), 5) | |||
| self.assertTrue(isinstance(res[0], str)) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| class TestPredictor2(unittest.TestCase): | |||
| def test_text_classify(self): | |||
| # TODO | |||
| pass | |||
| @@ -1,24 +1,25 @@ | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.preprocess import SeqLabelPreprocess | |||
| data = [ | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| ] | |||
| class TestSeqLabelPreprocess(unittest.TestCase): | |||
| def test_case_1(self): | |||
| data = [ | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||
| ] | |||
| class TestCase1(unittest.TestCase): | |||
| def test(self): | |||
| if os.path.exists("./save"): | |||
| for root, dirs, files in os.walk("./save", topdown=False): | |||
| for name in files: | |||
| @@ -27,17 +28,45 @@ class TestSeqLabelPreprocess(unittest.TestCase): | |||
| os.rmdir(os.path.join(root, name)) | |||
| result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, | |||
| pickle_path="./save") | |||
| result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, | |||
| pickle_path="./save") | |||
| self.assertEqual(len(result), 2) | |||
| self.assertEqual(type(result[0]), DataSet) | |||
| self.assertEqual(type(result[1]), DataSet) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| class TestCase2(unittest.TestCase): | |||
| def test(self): | |||
| if os.path.exists("./save"): | |||
| for root, dirs, files in os.walk("./save", topdown=False): | |||
| for name in files: | |||
| os.remove(os.path.join(root, name)) | |||
| for name in dirs: | |||
| os.rmdir(os.path.join(root, name)) | |||
| result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, | |||
| pickle_path="./save", train_dev_split=0.4, | |||
| cross_val=True) | |||
| result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, | |||
| pickle_path="./save", train_dev_split=0.4, | |||
| cross_val=True) | |||
| cross_val=False) | |||
| self.assertEqual(len(result), 3) | |||
| self.assertEqual(type(result[0]), DataSet) | |||
| self.assertEqual(type(result[1]), DataSet) | |||
| self.assertEqual(type(result[2]), DataSet) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| class TestCase3(unittest.TestCase): | |||
| def test(self): | |||
| num_folds = 2 | |||
| result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data, | |||
| pickle_path="./save", train_dev_split=0.4, | |||
| cross_val=True, n_fold=num_folds) | |||
| self.assertEqual(len(result), 2) | |||
| self.assertEqual(len(result[0]), num_folds) | |||
| self.assertEqual(len(result[1]), num_folds) | |||
| for data_set in result[0] + result[1]: | |||
| self.assertEqual(type(data_set), DataSet) | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| @@ -1,37 +1,55 @@ | |||
| from fastNLP.core.preprocess import SeqLabelPreprocess | |||
| import os | |||
| import unittest | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import TextField | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| data_name = "pku_training.utf8" | |||
| pickle_path = "data_for_tests" | |||
| def foo(): | |||
| loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8") | |||
| train_data = loader.load_pku() | |||
| train_args = ConfigSection() | |||
| ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args}) | |||
| # Preprocessor | |||
| p = SeqLabelPreprocess() | |||
| train_data = p.run(train_data) | |||
| train_args["vocab_size"] = p.vocab_size | |||
| train_args["num_classes"] = p.num_classes | |||
| model = SeqLabeling(train_args) | |||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
| "save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | |||
| "use_cuda": True} | |||
| validator = SeqLabelTester(**valid_args) | |||
| print("start validation.") | |||
| validator.test(model, train_data) | |||
| print(validator.show_metrics()) | |||
| if __name__ == "__main__": | |||
| foo() | |||
| class TestTester(unittest.TestCase): | |||
| def test_case_1(self): | |||
| model_args = { | |||
| "vocab_size": 10, | |||
| "word_emb_dim": 100, | |||
| "rnn_hidden_units": 100, | |||
| "num_classes": 5 | |||
| } | |||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
| "save_loss": True, "batch_size": 2, "pickle_path": "./save/", | |||
| "use_cuda": False, "print_every_step": 1} | |||
| train_data = [ | |||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| ] | |||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||
| data_set = DataSet() | |||
| for example in train_data: | |||
| text, label = example[0], example[1] | |||
| x = TextField(text, False) | |||
| y = TextField(label, is_target=True) | |||
| ins = Instance(word_seq=x, label_seq=y) | |||
| data_set.append(ins) | |||
| data_set.index_field("word_seq", vocab) | |||
| data_set.index_field("label_seq", label_vocab) | |||
| model = SeqLabeling(model_args) | |||
| tester = SeqLabelTester(**valid_args) | |||
| tester.test(network=model, dev_data=data_set) | |||
| # If this can run, everything is OK. | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| @@ -1,33 +1,54 @@ | |||
| import os | |||
| import torch.nn as nn | |||
| import unittest | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import TextField | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.loss import Loss | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| class TestTrainer(unittest.TestCase): | |||
| def test_case_1(self): | |||
| args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | |||
| args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/", | |||
| "save_best_dev": True, "model_name": "default_model_name.pkl", | |||
| "loss": Loss(None), | |||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||
| "vocab_size": 20, | |||
| "vocab_size": 10, | |||
| "word_emb_dim": 100, | |||
| "rnn_hidden_units": 100, | |||
| "num_classes": 3 | |||
| "num_classes": 5 | |||
| } | |||
| trainer = SeqLabelTrainer() | |||
| trainer = SeqLabelTrainer(**args) | |||
| train_data = [ | |||
| [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]], | |||
| [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]], | |||
| [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]], | |||
| [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]], | |||
| [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]], | |||
| [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]], | |||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||
| ] | |||
| dev_data = train_data | |||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||
| data_set = DataSet() | |||
| for example in train_data: | |||
| text, label = example[0], example[1] | |||
| x = TextField(text, False) | |||
| y = TextField(label, is_target=True) | |||
| ins = Instance(word_seq=x, label_seq=y) | |||
| data_set.append(ins) | |||
| data_set.index_field("word_seq", vocab) | |||
| data_set.index_field("label_seq", label_vocab) | |||
| model = SeqLabeling(args) | |||
| trainer.train(network=model, train_data=train_data, dev_data=dev_data) | |||
| trainer.train(network=model, train_data=data_set, dev_data=data_set) | |||
| # If this can run, everything is OK. | |||
| os.system("rm -rf save") | |||
| print("pickle path deleted") | |||
| @@ -1,8 +0,0 @@ | |||
| def test_charlm(): | |||
| pass | |||
| if __name__ == "__main__": | |||
| test_charlm() | |||
| @@ -0,0 +1,96 @@ | |||
| import argparse | |||
| import os | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.preprocess import SeqLabelPreprocess | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.dataset_loader import POSDatasetLoader | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", | |||
| help="path to the training data") | |||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||
| parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||
| parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", | |||
| help="data used for inference") | |||
| args = parser.parse_args() | |||
| pickle_path = args.save | |||
| model_name = args.model_name | |||
| config_dir = args.config | |||
| data_path = args.train | |||
| data_infer_path = args.infer | |||
| def test_training(): | |||
| # Config Loader | |||
| trainer_args = ConfigSection() | |||
| model_args = ConfigSection() | |||
| ConfigLoader("config.cfg").load_config(config_dir, { | |||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||
| # Data Loader | |||
| pos_loader = POSDatasetLoader(data_path) | |||
| train_data = pos_loader.load_lines() | |||
| # Preprocessor | |||
| p = SeqLabelPreprocess() | |||
| data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||
| model_args["vocab_size"] = p.vocab_size | |||
| model_args["num_classes"] = p.num_classes | |||
| trainer = SeqLabelTrainer( | |||
| epochs=trainer_args["epochs"], | |||
| batch_size=trainer_args["batch_size"], | |||
| validate=False, | |||
| use_cuda=False, | |||
| pickle_path=pickle_path, | |||
| save_best_dev=trainer_args["save_best_dev"], | |||
| model_name=model_name, | |||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||
| ) | |||
| # Model | |||
| model = SeqLabeling(model_args) | |||
| # Start training | |||
| trainer.train(model, data_train, data_dev) | |||
| # Saver | |||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||
| saver.save_pytorch(model) | |||
| del model, trainer, pos_loader | |||
| # Define the same model | |||
| model = SeqLabeling(model_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||
| # Load test configuration | |||
| tester_args = ConfigSection() | |||
| ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||
| # Tester | |||
| tester = SeqLabelTester(save_output=False, | |||
| save_loss=True, | |||
| save_best_dev=False, | |||
| batch_size=4, | |||
| use_cuda=False, | |||
| pickle_path=pickle_path, | |||
| model_name="seq_label_in_test.pkl", | |||
| print_every_step=1 | |||
| ) | |||
| # Start testing with validation data | |||
| tester.test(model, data_dev) | |||
| loss, accuracy = tester.metrics | |||
| assert 0 < accuracy < 1 | |||