From 05af2e754431bccf356dbe8eb2a04344e89031a8 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 15 Sep 2018 19:23:10 +0800 Subject: [PATCH] Introduce Fields concept to eliminate the use of different sub-trainers/sub-testers. - update LabelField's to_tensor method to support int & str single label - update preprocessor's convert_to_dataset method to support single label inputs - introduce "task" in Trainer/Tester's data_forward, Tester's evaluate and metrics methods - in cnn_text_classification.py, change the name of the argument of forward - in sequence_modeling.py, change the name of the argument of forward - minor adjustments in test codes - text_classify.py works --- fastNLP/core/field.py | 10 +++- fastNLP/core/preprocess.py | 10 +++- fastNLP/core/tester.py | 68 ++++++++++++----------- fastNLP/core/trainer.py | 37 +++++------- fastNLP/models/cnn_text_classification.py | 8 ++- fastNLP/models/sequence_modeling.py | 14 ++--- test/model/seq_labeling.py | 2 +- test/model/text_classify.py | 4 +- 8 files changed, 80 insertions(+), 73 deletions(-) diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index eb2bc78e..1efa759a 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -73,13 +73,17 @@ class LabelField(Field): def index(self, vocab): if self._index is None: self._index = vocab[self.label] - else: - pass return self._index def to_tensor(self, padding_length): if self._index is None: - return torch.LongTensor([self.label]) + if isinstance(self.label, int): + return torch.LongTensor([self.label]) + elif isinstance(self.label, str): + raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) + else: + raise RuntimeError( + "Not support type for LabelField. Expect str or int, got {}.".format(type(self.label))) else: return torch.LongTensor([self._index]) diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index b7c33f3b..cbd194d1 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -251,6 +251,9 @@ class BasePreprocess(object): """ use_word_seq = False use_label_seq = False + use_label_str = False + + # construct a DataSet object and fill it with Instances data_set = DataSet() for example in data: words, label = example[0], example[1] @@ -270,14 +273,19 @@ class BasePreprocess(object): elif isinstance(label, str): y = LabelField(label, is_target=True) instance.add_field("label", y) + use_label_str = True else: raise NotImplementedError("label is a {}".format(type(label))) - data_set.append(instance) + + # convert strings to indices if use_word_seq: data_set.index_field("word_seq", vocab) if use_label_seq: data_set.index_field("label_seq", label_vocab) + if use_label_str: + data_set.index_field("label", label_vocab) + return data_set diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index cfbc918e..f23ab704 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -122,15 +122,23 @@ class BaseTester(object): :param truth: Tensor :return eval_results: can be anything. It will be stored in self.eval_history """ - batch_size, max_len = predict.size(0), predict.size(1) if "label_seq" in truth: truth = truth["label_seq"] elif "label" in truth: truth = truth["label"] else: raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) - loss = self._model.loss(predict, truth) / batch_size + if self._task == "seq_label": + return self._seq_label_evaluate(predict, truth) + elif self._task == "text_classify": + return self._text_classify_evaluate(predict, truth) + else: + raise NotImplementedError("Unknown task type {}.".format(self._task)) + + def _seq_label_evaluate(self, predict, truth): + batch_size, max_len = predict.size(0), predict.size(1) + loss = self._model.loss(predict, truth) / batch_size prediction = self._model.prediction(predict) # pad prediction to equal length for pred in prediction: @@ -143,6 +151,10 @@ class BaseTester(object): accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] return [float(loss), float(accuracy)] + def _text_classify_evaluate(self, y_logit, y_true): + y_prob = torch.nn.functional.softmax(y_logit, dim=-1) + return [y_prob, y_true] + @property def metrics(self): """Compute and return metrics. @@ -151,10 +163,28 @@ class BaseTester(object): :return : variable number of outputs """ + if self._task == "seq_label": + return self._seq_label_metrics + elif self._task == "text_classify": + return self._text_classify_metrics + else: + raise NotImplementedError("Unknown task type {}.".format(self._task)) + + @property + def _seq_label_metrics(self): batch_loss = np.mean([x[0] for x in self.eval_history]) batch_accuracy = np.mean([x[1] for x in self.eval_history]) return batch_loss, batch_accuracy + @property + def _text_classify_metrics(self): + y_prob, y_true = zip(*self.eval_history) + y_prob = torch.cat(y_prob, dim=0) + y_pred = torch.argmax(y_prob, dim=-1) + y_true = torch.cat(y_true, dim=0) + acc = float(torch.sum(y_pred == y_true)) / len(y_true) + return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc + def show_metrics(self): """Customize evaluation outputs in Trainer. Called by Trainer to print evaluation results on dev set during training. @@ -176,13 +206,7 @@ class BaseTester(object): class SeqLabelTester(BaseTester): - """Tester for sequence labeling. - - """ def __init__(self, **test_args): - """ - :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" - """ test_args.update({"task": "seq_label"}) print( "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") @@ -190,30 +214,8 @@ class SeqLabelTester(BaseTester): class ClassificationTester(BaseTester): - """Tester for classification.""" - def __init__(self, **test_args): - """ - :param test_args: a dict-like object that has __getitem__ method. - can be accessed by "test_args["key_str"]" - """ + test_args.update({"task": "seq_label"}) + print( + "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.") super(ClassificationTester, self).__init__(**test_args) - - def data_forward(self, network, x): - """Forward through network.""" - logits = network(x) - return logits - - def evaluate(self, y_logit, y_true): - """Return y_pred and y_true.""" - y_prob = torch.nn.functional.softmax(y_logit, dim=-1) - return [y_prob, y_true] - - def metrics(self): - """Compute accuracy.""" - y_prob, y_true = zip(*self.eval_history) - y_prob = torch.cat(y_prob, dim=0) - y_pred = torch.argmax(y_prob, dim=-1) - y_true = torch.cat(y_true, dim=0) - acc = float(torch.sum(y_pred == y_true)) / len(y_true) - return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index f4c3e8c1..1405f156 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -218,10 +218,18 @@ class BaseTrainer(object): self._optimizer.step() def data_forward(self, network, x): - y = network(**x) + if self._task == "seq_label": + y = network(x["word_seq"], x["word_seq_origin_len"]) + elif self._task == "text_classify": + y = network(x["word_seq"]) + else: + raise NotImplementedError("Unknown task type {}.".format(self._task)) + if not self._graph_summaried: if self._task == "seq_label": self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) + elif self._task == "text_classify": + self._summary_writer.add_graph(network, x["word_seq"], verbose=False) self._graph_summaried = True return y @@ -246,6 +254,7 @@ class BaseTrainer(object): truth = truth["label_seq"] elif "label" in truth: truth = truth["label"] + truth = truth.view((-1,)) else: raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) return self._loss_func(predict, truth) @@ -315,30 +324,10 @@ class ClassificationTrainer(BaseTrainer): """Trainer for text classification.""" def __init__(self, **train_args): + train_args.update({"task": "text_classify"}) + print( + "[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.") super(ClassificationTrainer, self).__init__(**train_args) - self.iterator = None - self.loss_func = None - self.optimizer = None - self.best_accuracy = 0 - - def data_forward(self, network, x): - """Forward through network.""" - logits = network(x) - return logits - - def get_acc(self, y_logit, y_true): - """Compute accuracy.""" - y_pred = torch.argmax(y_logit, dim=-1) - return int(torch.sum(y_true == y_pred)) / len(y_true) - - def best_eval_result(self, validator): - _, _, accuracy = validator.metrics() - if accuracy > self.best_accuracy: - self.best_accuracy = accuracy - return True - else: - return False - def _create_validator(self, valid_args): return ClassificationTester(**valid_args) diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index fc7388a5..15a65221 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -35,8 +35,12 @@ class CNNText(torch.nn.Module): self.dropout = nn.Dropout(drop_prob) self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) - def forward(self, x): - x = self.embed(x) # [N,L] -> [N,L,C] + def forward(self, word_seq): + """ + :param word_seq: torch.LongTensor, [batch_size, seq_len] + :return x: torch.LongTensor, [batch_size, num_classes] + """ + x = self.embed(word_seq) # [N,L] -> [N,L,C] x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index bed3f0a6..8d194947 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -104,17 +104,17 @@ class AdvSeqLabel(SeqLabeling): self.Crf = decoder.CRF.ConditionalRandomField(num_classes) - def forward(self, x, seq_len): + def forward(self, word_seq, word_seq_origin_len): """ - :param x: LongTensor, [batch_size, mex_len] - :param seq_len: list of int. + :param word_seq: LongTensor, [batch_size, mex_len] + :param word_seq_origin_len: list of int. :return y: [batch_size, mex_len, tag_size] """ - self.mask = self.make_mask(x, seq_len) + self.mask = self.make_mask(word_seq, word_seq_origin_len) - batch_size = x.size(0) - max_len = x.size(1) - x = self.Embedding(x) + batch_size = word_seq.size(0) + max_len = word_seq.size(1) + x = self.Embedding(word_seq) # [batch_size, max_len, word_emb_dim] x = self.Rnn(x) # [batch_size, max_len, hidden_size * direction] diff --git a/test/model/seq_labeling.py b/test/model/seq_labeling.py index dcfa8bb4..d7750b17 100644 --- a/test/model/seq_labeling.py +++ b/test/model/seq_labeling.py @@ -121,7 +121,7 @@ def train_and_test(): # Tester tester = SeqLabelTester(save_output=False, - save_loss=False, + save_loss=True, save_best_dev=False, batch_size=4, use_cuda=False, diff --git a/test/model/text_classify.py b/test/model/text_classify.py index dd20505f..381a768e 100644 --- a/test/model/text_classify.py +++ b/test/model/text_classify.py @@ -19,9 +19,9 @@ from fastNLP.core.loss import Loss parser = argparse.ArgumentParser() parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") -parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt", +parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt", help="path to the training data") -parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") +parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") args = parser.parse_args()