Browse Source

update trainer, tester, example model

tags/v0.2.0
yunfan 6 years ago
parent
commit
44e098e285
3 changed files with 42 additions and 20 deletions
  1. +9
    -7
      fastNLP/core/tester.py
  2. +19
    -12
      fastNLP/core/trainer.py
  3. +14
    -1
      fastNLP/models/cnn_text_classification.py

+ 9
- 7
fastNLP/core/tester.py View File

@@ -1,10 +1,11 @@
import itertools
from collections import defaultdict

import torch

from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.utils import _build_args

class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """
@@ -40,7 +41,12 @@ class Tester(object):
output[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
eval_results = self.evaluate(**output, **truths)
for k, v in output.items():
output[k] = itertools.chain(*v)
for k, v in truths.items():
truths[k] = itertools.chain(*v)
args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args)
print("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False)
self.metrics = eval_results
@@ -60,14 +66,10 @@ class Tester(object):

def data_forward(self, network, x):
"""A forward pass of the model. """
x = _build_args(network.forward, **x)
y = network(**x)
return y

def evaluate(self, **kwargs):
"""Compute evaluation metrics.
"""
return self._evaluator(**kwargs)

def print_eval_results(self, results):
"""Override this method to support more print formats.



+ 19
- 12
fastNLP/core/trainer.py View File

@@ -21,9 +21,8 @@ class Trainer(object):

"""
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(),
**kwargs):
super(Trainer, self).__init__()

@@ -36,9 +35,16 @@ class Trainer(object):
self.save_path = str(save_path)
self.print_every = int(print_every)

self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get()
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
self.evaluator = evaluator
model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
self.loss_func = self.model.get_loss
if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())

assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name)
self.evaluator = self.model.evaluate

if self.dev_data is not None:
valid_args = {"batch_size": self.batch_size, "save_path": self.save_path,
@@ -48,7 +54,10 @@ class Trainer(object):
for k, v in kwargs.items():
setattr(self, k, v)

self._summary_writer = SummaryWriter(os.path.join(self.save_path, 'tensorboard_logs'))
self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs')
if os.path.exists(self.tensorboard_path):
os.rmdir(self.tensorboard_path)
self._summary_writer = SummaryWriter(self.tensorboard_path)
self._graph_summaried = False
self.step = 0
self.start_time = None # start timestamp
@@ -138,6 +147,7 @@ class Trainer(object):
self.optimizer.step()

def data_forward(self, network, x):
x = _build_args(network.forward, **x)
y = network(**x)
if not self._graph_summaried:
# self._summary_writer.add_graph(network, x, verbose=False)
@@ -161,12 +171,9 @@ class Trainer(object):
:param truth: ground truth label vector
:return: a scalar
"""
if isinstance(predict, dict) and isinstance(truth, dict):
return self.loss_func(**predict, **truth)
if len(truth) > 1:
raise NotImplementedError("Not ready to handle multi-labels.")
truth = list(truth.values())[0] if len(truth) > 0 else None
return self.loss_func(predict, truth)
assert isinstance(predict, dict) and isinstance(truth, dict)
args = _build_args(self.loss_func, **predict, **truth)
return self.loss_func(**args)

def save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)


+ 14
- 1
fastNLP/models/cnn_text_classification.py View File

@@ -46,5 +46,18 @@ class CNNText(torch.nn.Module):
x = self.fc(x) # [N,C] -> [N, N_class]
return {'output':x}

def loss(self, output, label_seq):
def predict(self, word_seq):
output = self(word_seq)
_, predict = output.max(dim=1)
return {'predict': predict}

def get_loss(self, output, label_seq):
return self._loss(output, label_seq)

def evaluate(self, predict, label_seq):
predict, label_seq = torch.stack(predict, dim=0), torch.stack(label_seq, dim=0)
predict, label_seq = predict.squeeze(), label_seq.squeeze()
correct = (predict == label_seq).long().sum().item()
total = label_seq.size(0)
return 1.0 * correct / total


Loading…
Cancel
Save