- update LabelField's to_tensor method to support int & str single label - update preprocessor's convert_to_dataset method to support single label inputs - introduce "task" in Trainer/Tester's data_forward, Tester's evaluate and metrics methods - in cnn_text_classification.py, change the name of the argument of forward - in sequence_modeling.py, change the name of the argument of forward - minor adjustments in test codes - text_classify.py workstags/v0.1.0
| @@ -73,13 +73,17 @@ class LabelField(Field): | |||||
| def index(self, vocab): | def index(self, vocab): | ||||
| if self._index is None: | if self._index is None: | ||||
| self._index = vocab[self.label] | self._index = vocab[self.label] | ||||
| else: | |||||
| pass | |||||
| return self._index | return self._index | ||||
| def to_tensor(self, padding_length): | def to_tensor(self, padding_length): | ||||
| if self._index is None: | if self._index is None: | ||||
| return torch.LongTensor([self.label]) | |||||
| if isinstance(self.label, int): | |||||
| return torch.LongTensor([self.label]) | |||||
| elif isinstance(self.label, str): | |||||
| raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
| else: | |||||
| raise RuntimeError( | |||||
| "Not support type for LabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
| else: | else: | ||||
| return torch.LongTensor([self._index]) | return torch.LongTensor([self._index]) | ||||
| @@ -251,6 +251,9 @@ class BasePreprocess(object): | |||||
| """ | """ | ||||
| use_word_seq = False | use_word_seq = False | ||||
| use_label_seq = False | use_label_seq = False | ||||
| use_label_str = False | |||||
| # construct a DataSet object and fill it with Instances | |||||
| data_set = DataSet() | data_set = DataSet() | ||||
| for example in data: | for example in data: | ||||
| words, label = example[0], example[1] | words, label = example[0], example[1] | ||||
| @@ -270,14 +273,19 @@ class BasePreprocess(object): | |||||
| elif isinstance(label, str): | elif isinstance(label, str): | ||||
| y = LabelField(label, is_target=True) | y = LabelField(label, is_target=True) | ||||
| instance.add_field("label", y) | instance.add_field("label", y) | ||||
| use_label_str = True | |||||
| else: | else: | ||||
| raise NotImplementedError("label is a {}".format(type(label))) | raise NotImplementedError("label is a {}".format(type(label))) | ||||
| data_set.append(instance) | data_set.append(instance) | ||||
| # convert strings to indices | |||||
| if use_word_seq: | if use_word_seq: | ||||
| data_set.index_field("word_seq", vocab) | data_set.index_field("word_seq", vocab) | ||||
| if use_label_seq: | if use_label_seq: | ||||
| data_set.index_field("label_seq", label_vocab) | data_set.index_field("label_seq", label_vocab) | ||||
| if use_label_str: | |||||
| data_set.index_field("label", label_vocab) | |||||
| return data_set | return data_set | ||||
| @@ -122,15 +122,23 @@ class BaseTester(object): | |||||
| :param truth: Tensor | :param truth: Tensor | ||||
| :return eval_results: can be anything. It will be stored in self.eval_history | :return eval_results: can be anything. It will be stored in self.eval_history | ||||
| """ | """ | ||||
| batch_size, max_len = predict.size(0), predict.size(1) | |||||
| if "label_seq" in truth: | if "label_seq" in truth: | ||||
| truth = truth["label_seq"] | truth = truth["label_seq"] | ||||
| elif "label" in truth: | elif "label" in truth: | ||||
| truth = truth["label"] | truth = truth["label"] | ||||
| else: | else: | ||||
| raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | ||||
| loss = self._model.loss(predict, truth) / batch_size | |||||
| if self._task == "seq_label": | |||||
| return self._seq_label_evaluate(predict, truth) | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_evaluate(predict, truth) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| def _seq_label_evaluate(self, predict, truth): | |||||
| batch_size, max_len = predict.size(0), predict.size(1) | |||||
| loss = self._model.loss(predict, truth) / batch_size | |||||
| prediction = self._model.prediction(predict) | prediction = self._model.prediction(predict) | ||||
| # pad prediction to equal length | # pad prediction to equal length | ||||
| for pred in prediction: | for pred in prediction: | ||||
| @@ -143,6 +151,10 @@ class BaseTester(object): | |||||
| accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | ||||
| return [float(loss), float(accuracy)] | return [float(loss), float(accuracy)] | ||||
| def _text_classify_evaluate(self, y_logit, y_true): | |||||
| y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||||
| return [y_prob, y_true] | |||||
| @property | @property | ||||
| def metrics(self): | def metrics(self): | ||||
| """Compute and return metrics. | """Compute and return metrics. | ||||
| @@ -151,10 +163,28 @@ class BaseTester(object): | |||||
| :return : variable number of outputs | :return : variable number of outputs | ||||
| """ | """ | ||||
| if self._task == "seq_label": | |||||
| return self._seq_label_metrics | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_metrics | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| @property | |||||
| def _seq_label_metrics(self): | |||||
| batch_loss = np.mean([x[0] for x in self.eval_history]) | batch_loss = np.mean([x[0] for x in self.eval_history]) | ||||
| batch_accuracy = np.mean([x[1] for x in self.eval_history]) | batch_accuracy = np.mean([x[1] for x in self.eval_history]) | ||||
| return batch_loss, batch_accuracy | return batch_loss, batch_accuracy | ||||
| @property | |||||
| def _text_classify_metrics(self): | |||||
| y_prob, y_true = zip(*self.eval_history) | |||||
| y_prob = torch.cat(y_prob, dim=0) | |||||
| y_pred = torch.argmax(y_prob, dim=-1) | |||||
| y_true = torch.cat(y_true, dim=0) | |||||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||||
| def show_metrics(self): | def show_metrics(self): | ||||
| """Customize evaluation outputs in Trainer. | """Customize evaluation outputs in Trainer. | ||||
| Called by Trainer to print evaluation results on dev set during training. | Called by Trainer to print evaluation results on dev set during training. | ||||
| @@ -176,13 +206,7 @@ class BaseTester(object): | |||||
| class SeqLabelTester(BaseTester): | class SeqLabelTester(BaseTester): | ||||
| """Tester for sequence labeling. | |||||
| """ | |||||
| def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
| """ | |||||
| :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
| """ | |||||
| test_args.update({"task": "seq_label"}) | test_args.update({"task": "seq_label"}) | ||||
| print( | print( | ||||
| "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") | "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") | ||||
| @@ -190,30 +214,8 @@ class SeqLabelTester(BaseTester): | |||||
| class ClassificationTester(BaseTester): | class ClassificationTester(BaseTester): | ||||
| """Tester for classification.""" | |||||
| def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
| """ | |||||
| :param test_args: a dict-like object that has __getitem__ method. | |||||
| can be accessed by "test_args["key_str"]" | |||||
| """ | |||||
| test_args.update({"task": "seq_label"}) | |||||
| print( | |||||
| "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.") | |||||
| super(ClassificationTester, self).__init__(**test_args) | super(ClassificationTester, self).__init__(**test_args) | ||||
| def data_forward(self, network, x): | |||||
| """Forward through network.""" | |||||
| logits = network(x) | |||||
| return logits | |||||
| def evaluate(self, y_logit, y_true): | |||||
| """Return y_pred and y_true.""" | |||||
| y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||||
| return [y_prob, y_true] | |||||
| def metrics(self): | |||||
| """Compute accuracy.""" | |||||
| y_prob, y_true = zip(*self.eval_history) | |||||
| y_prob = torch.cat(y_prob, dim=0) | |||||
| y_pred = torch.argmax(y_prob, dim=-1) | |||||
| y_true = torch.cat(y_true, dim=0) | |||||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||||
| @@ -218,10 +218,18 @@ class BaseTrainer(object): | |||||
| self._optimizer.step() | self._optimizer.step() | ||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| y = network(**x) | |||||
| if self._task == "seq_label": | |||||
| y = network(x["word_seq"], x["word_seq_origin_len"]) | |||||
| elif self._task == "text_classify": | |||||
| y = network(x["word_seq"]) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| if not self._graph_summaried: | if not self._graph_summaried: | ||||
| if self._task == "seq_label": | if self._task == "seq_label": | ||||
| self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | ||||
| elif self._task == "text_classify": | |||||
| self._summary_writer.add_graph(network, x["word_seq"], verbose=False) | |||||
| self._graph_summaried = True | self._graph_summaried = True | ||||
| return y | return y | ||||
| @@ -246,6 +254,7 @@ class BaseTrainer(object): | |||||
| truth = truth["label_seq"] | truth = truth["label_seq"] | ||||
| elif "label" in truth: | elif "label" in truth: | ||||
| truth = truth["label"] | truth = truth["label"] | ||||
| truth = truth.view((-1,)) | |||||
| else: | else: | ||||
| raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | ||||
| return self._loss_func(predict, truth) | return self._loss_func(predict, truth) | ||||
| @@ -315,30 +324,10 @@ class ClassificationTrainer(BaseTrainer): | |||||
| """Trainer for text classification.""" | """Trainer for text classification.""" | ||||
| def __init__(self, **train_args): | def __init__(self, **train_args): | ||||
| train_args.update({"task": "text_classify"}) | |||||
| print( | |||||
| "[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.") | |||||
| super(ClassificationTrainer, self).__init__(**train_args) | super(ClassificationTrainer, self).__init__(**train_args) | ||||
| self.iterator = None | |||||
| self.loss_func = None | |||||
| self.optimizer = None | |||||
| self.best_accuracy = 0 | |||||
| def data_forward(self, network, x): | |||||
| """Forward through network.""" | |||||
| logits = network(x) | |||||
| return logits | |||||
| def get_acc(self, y_logit, y_true): | |||||
| """Compute accuracy.""" | |||||
| y_pred = torch.argmax(y_logit, dim=-1) | |||||
| return int(torch.sum(y_true == y_pred)) / len(y_true) | |||||
| def best_eval_result(self, validator): | |||||
| _, _, accuracy = validator.metrics() | |||||
| if accuracy > self.best_accuracy: | |||||
| self.best_accuracy = accuracy | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| return ClassificationTester(**valid_args) | return ClassificationTester(**valid_args) | ||||
| @@ -35,8 +35,12 @@ class CNNText(torch.nn.Module): | |||||
| self.dropout = nn.Dropout(drop_prob) | self.dropout = nn.Dropout(drop_prob) | ||||
| self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | ||||
| def forward(self, x): | |||||
| x = self.embed(x) # [N,L] -> [N,L,C] | |||||
| def forward(self, word_seq): | |||||
| """ | |||||
| :param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
| :return x: torch.LongTensor, [batch_size, num_classes] | |||||
| """ | |||||
| x = self.embed(word_seq) # [N,L] -> [N,L,C] | |||||
| x = self.conv_pool(x) # [N,L,C] -> [N,C] | x = self.conv_pool(x) # [N,L,C] -> [N,C] | ||||
| x = self.dropout(x) | x = self.dropout(x) | ||||
| x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
| @@ -104,17 +104,17 @@ class AdvSeqLabel(SeqLabeling): | |||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
| def forward(self, x, seq_len): | |||||
| def forward(self, word_seq, word_seq_origin_len): | |||||
| """ | """ | ||||
| :param x: LongTensor, [batch_size, mex_len] | |||||
| :param seq_len: list of int. | |||||
| :param word_seq: LongTensor, [batch_size, mex_len] | |||||
| :param word_seq_origin_len: list of int. | |||||
| :return y: [batch_size, mex_len, tag_size] | :return y: [batch_size, mex_len, tag_size] | ||||
| """ | """ | ||||
| self.mask = self.make_mask(x, seq_len) | |||||
| self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||||
| batch_size = x.size(0) | |||||
| max_len = x.size(1) | |||||
| x = self.Embedding(x) | |||||
| batch_size = word_seq.size(0) | |||||
| max_len = word_seq.size(1) | |||||
| x = self.Embedding(word_seq) | |||||
| # [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
| x = self.Rnn(x) | x = self.Rnn(x) | ||||
| # [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
| @@ -121,7 +121,7 @@ def train_and_test(): | |||||
| # Tester | # Tester | ||||
| tester = SeqLabelTester(save_output=False, | tester = SeqLabelTester(save_output=False, | ||||
| save_loss=False, | |||||
| save_loss=True, | |||||
| save_best_dev=False, | save_best_dev=False, | ||||
| batch_size=4, | batch_size=4, | ||||
| use_cuda=False, | use_cuda=False, | ||||
| @@ -19,9 +19,9 @@ from fastNLP.core.loss import Loss | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | ||||
| parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt", | |||||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt", | |||||
| help="path to the training data") | help="path to the training data") | ||||
| parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||