Browse Source

Introduce Fields concept to eliminate the use of different sub-trainers/sub-testers.

- update LabelField's to_tensor method to support int & str single label
- update preprocessor's convert_to_dataset method to support single label inputs
- introduce "task" in Trainer/Tester's data_forward, Tester's evaluate and metrics methods
- in cnn_text_classification.py, change the name of the argument of forward
- in sequence_modeling.py, change the name of the argument of forward
- minor adjustments in test codes
- text_classify.py works
tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
05af2e7544
8 changed files with 80 additions and 73 deletions
  1. +7
    -3
      fastNLP/core/field.py
  2. +9
    -1
      fastNLP/core/preprocess.py
  3. +35
    -33
      fastNLP/core/tester.py
  4. +13
    -24
      fastNLP/core/trainer.py
  5. +6
    -2
      fastNLP/models/cnn_text_classification.py
  6. +7
    -7
      fastNLP/models/sequence_modeling.py
  7. +1
    -1
      test/model/seq_labeling.py
  8. +2
    -2
      test/model/text_classify.py

+ 7
- 3
fastNLP/core/field.py View File

@@ -73,13 +73,17 @@ class LabelField(Field):
def index(self, vocab):
if self._index is None:
self._index = vocab[self.label]
else:
pass
return self._index

def to_tensor(self, padding_length):
if self._index is None:
return torch.LongTensor([self.label])
if isinstance(self.label, int):
return torch.LongTensor([self.label])
elif isinstance(self.label, str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
raise RuntimeError(
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label)))
else:
return torch.LongTensor([self._index])



+ 9
- 1
fastNLP/core/preprocess.py View File

@@ -251,6 +251,9 @@ class BasePreprocess(object):
"""
use_word_seq = False
use_label_seq = False
use_label_str = False

# construct a DataSet object and fill it with Instances
data_set = DataSet()
for example in data:
words, label = example[0], example[1]
@@ -270,14 +273,19 @@ class BasePreprocess(object):
elif isinstance(label, str):
y = LabelField(label, is_target=True)
instance.add_field("label", y)
use_label_str = True
else:
raise NotImplementedError("label is a {}".format(type(label)))

data_set.append(instance)

# convert strings to indices
if use_word_seq:
data_set.index_field("word_seq", vocab)
if use_label_seq:
data_set.index_field("label_seq", label_vocab)
if use_label_str:
data_set.index_field("label", label_vocab)

return data_set




+ 35
- 33
fastNLP/core/tester.py View File

@@ -122,15 +122,23 @@ class BaseTester(object):
:param truth: Tensor
:return eval_results: can be anything. It will be stored in self.eval_history
"""
batch_size, max_len = predict.size(0), predict.size(1)
if "label_seq" in truth:
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
loss = self._model.loss(predict, truth) / batch_size

if self._task == "seq_label":
return self._seq_label_evaluate(predict, truth)
elif self._task == "text_classify":
return self._text_classify_evaluate(predict, truth)
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

def _seq_label_evaluate(self, predict, truth):
batch_size, max_len = predict.size(0), predict.size(1)
loss = self._model.loss(predict, truth) / batch_size
prediction = self._model.prediction(predict)
# pad prediction to equal length
for pred in prediction:
@@ -143,6 +151,10 @@ class BaseTester(object):
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [float(loss), float(accuracy)]

def _text_classify_evaluate(self, y_logit, y_true):
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]

@property
def metrics(self):
"""Compute and return metrics.
@@ -151,10 +163,28 @@ class BaseTester(object):

:return : variable number of outputs
"""
if self._task == "seq_label":
return self._seq_label_metrics
elif self._task == "text_classify":
return self._text_classify_metrics
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

@property
def _seq_label_metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy

@property
def _text_classify_metrics(self):
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

def show_metrics(self):
"""Customize evaluation outputs in Trainer.
Called by Trainer to print evaluation results on dev set during training.
@@ -176,13 +206,7 @@ class BaseTester(object):


class SeqLabelTester(BaseTester):
"""Tester for sequence labeling.

"""
def __init__(self, **test_args):
"""
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
test_args.update({"task": "seq_label"})
print(
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.")
@@ -190,30 +214,8 @@ class SeqLabelTester(BaseTester):


class ClassificationTester(BaseTester):
"""Tester for classification."""

def __init__(self, **test_args):
"""
:param test_args: a dict-like object that has __getitem__ method.
can be accessed by "test_args["key_str"]"
"""
test_args.update({"task": "seq_label"})
print(
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.")
super(ClassificationTester, self).__init__(**test_args)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def evaluate(self, y_logit, y_true):
"""Return y_pred and y_true."""
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]

def metrics(self):
"""Compute accuracy."""
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

+ 13
- 24
fastNLP/core/trainer.py View File

@@ -218,10 +218,18 @@ class BaseTrainer(object):
self._optimizer.step()

def data_forward(self, network, x):
y = network(**x)
if self._task == "seq_label":
y = network(x["word_seq"], x["word_seq_origin_len"])
elif self._task == "text_classify":
y = network(x["word_seq"])
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

if not self._graph_summaried:
if self._task == "seq_label":
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False)
elif self._task == "text_classify":
self._summary_writer.add_graph(network, x["word_seq"], verbose=False)
self._graph_summaried = True
return y

@@ -246,6 +254,7 @@ class BaseTrainer(object):
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
truth = truth.view((-1,))
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
return self._loss_func(predict, truth)
@@ -315,30 +324,10 @@ class ClassificationTrainer(BaseTrainer):
"""Trainer for text classification."""

def __init__(self, **train_args):
train_args.update({"task": "text_classify"})
print(
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.")
super(ClassificationTrainer, self).__init__(**train_args)

self.iterator = None
self.loss_func = None
self.optimizer = None
self.best_accuracy = 0

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def get_acc(self, y_logit, y_true):
"""Compute accuracy."""
y_pred = torch.argmax(y_logit, dim=-1)
return int(torch.sum(y_true == y_pred)) / len(y_true)

def best_eval_result(self, validator):
_, _, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False

def _create_validator(self, valid_args):
return ClassificationTester(**valid_args)

+ 6
- 2
fastNLP/models/cnn_text_classification.py View File

@@ -35,8 +35,12 @@ class CNNText(torch.nn.Module):
self.dropout = nn.Dropout(drop_prob)
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)

def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
def forward(self, word_seq):
"""
:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return x: torch.LongTensor, [batch_size, num_classes]
"""
x = self.embed(word_seq) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]


+ 7
- 7
fastNLP/models/sequence_modeling.py View File

@@ -104,17 +104,17 @@ class AdvSeqLabel(SeqLabeling):

self.Crf = decoder.CRF.ConditionalRandomField(num_classes)

def forward(self, x, seq_len):
def forward(self, word_seq, word_seq_origin_len):
"""
:param x: LongTensor, [batch_size, mex_len]
:param seq_len: list of int.
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: list of int.
:return y: [batch_size, mex_len, tag_size]
"""
self.mask = self.make_mask(x, seq_len)
self.mask = self.make_mask(word_seq, word_seq_origin_len)

batch_size = x.size(0)
max_len = x.size(1)
x = self.Embedding(x)
batch_size = word_seq.size(0)
max_len = word_seq.size(1)
x = self.Embedding(word_seq)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction]


+ 1
- 1
test/model/seq_labeling.py View File

@@ -121,7 +121,7 @@ def train_and_test():

# Tester
tester = SeqLabelTester(save_output=False,
save_loss=False,
save_loss=True,
save_best_dev=False,
batch_size=4,
use_cuda=False,


+ 2
- 2
test/model/text_classify.py View File

@@ -19,9 +19,9 @@ from fastNLP.core.loss import Loss

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt",
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt",
help="path to the training data")
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model")

args = parser.parse_args()


Loading…
Cancel
Save