Browse Source

update trainer, tester, example model

tags/v0.2.0
yunfan 6 years ago
parent
commit
44e098e285
3 changed files with 42 additions and 20 deletions
  1. +9
    -7
      fastNLP/core/tester.py
  2. +19
    -12
      fastNLP/core/trainer.py
  3. +14
    -1
      fastNLP/models/cnn_text_classification.py

+ 9
- 7
fastNLP/core/tester.py View File

@@ -1,10 +1,11 @@
import itertools
from collections import defaultdict from collections import defaultdict


import torch import torch


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.utils import _build_args


class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """
@@ -40,7 +41,12 @@ class Tester(object):
output[k].append(v) output[k].append(v)
for k, v in batch_y.items(): for k, v in batch_y.items():
truths[k].append(v) truths[k].append(v)
eval_results = self.evaluate(**output, **truths)
for k, v in output.items():
output[k] = itertools.chain(*v)
for k, v in truths.items():
truths[k] = itertools.chain(*v)
args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args)
print("[tester] {}".format(self.print_eval_results(eval_results))) print("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False) self.mode(network, is_test=False)
self.metrics = eval_results self.metrics = eval_results
@@ -60,14 +66,10 @@ class Tester(object):


def data_forward(self, network, x): def data_forward(self, network, x):
"""A forward pass of the model. """ """A forward pass of the model. """
x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
return y return y


def evaluate(self, **kwargs):
"""Compute evaluation metrics.
"""
return self._evaluator(**kwargs)

def print_eval_results(self, results): def print_eval_results(self, results):
"""Override this method to support more print formats. """Override this method to support more print formats.




+ 19
- 12
fastNLP/core/trainer.py View File

@@ -21,9 +21,8 @@ class Trainer(object):


""" """
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(),
**kwargs): **kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


@@ -36,9 +35,16 @@ class Trainer(object):
self.save_path = str(save_path) self.save_path = str(save_path)
self.print_every = int(print_every) self.print_every = int(print_every)


self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get()
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
self.evaluator = evaluator
model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
self.loss_func = self.model.get_loss
if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())

assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name)
self.evaluator = self.model.evaluate


if self.dev_data is not None: if self.dev_data is not None:
valid_args = {"batch_size": self.batch_size, "save_path": self.save_path, valid_args = {"batch_size": self.batch_size, "save_path": self.save_path,
@@ -48,7 +54,10 @@ class Trainer(object):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)


self._summary_writer = SummaryWriter(os.path.join(self.save_path, 'tensorboard_logs'))
self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs')
if os.path.exists(self.tensorboard_path):
os.rmdir(self.tensorboard_path)
self._summary_writer = SummaryWriter(self.tensorboard_path)
self._graph_summaried = False self._graph_summaried = False
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
@@ -138,6 +147,7 @@ class Trainer(object):
self.optimizer.step() self.optimizer.step()


def data_forward(self, network, x): def data_forward(self, network, x):
x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not self._graph_summaried: if not self._graph_summaried:
# self._summary_writer.add_graph(network, x, verbose=False) # self._summary_writer.add_graph(network, x, verbose=False)
@@ -161,12 +171,9 @@ class Trainer(object):
:param truth: ground truth label vector :param truth: ground truth label vector
:return: a scalar :return: a scalar
""" """
if isinstance(predict, dict) and isinstance(truth, dict):
return self.loss_func(**predict, **truth)
if len(truth) > 1:
raise NotImplementedError("Not ready to handle multi-labels.")
truth = list(truth.values())[0] if len(truth) > 0 else None
return self.loss_func(predict, truth)
assert isinstance(predict, dict) and isinstance(truth, dict)
args = _build_args(self.loss_func, **predict, **truth)
return self.loss_func(**args)


def save_model(self, model, model_name, only_param=False): def save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name) model_name = os.path.join(self.save_path, model_name)


+ 14
- 1
fastNLP/models/cnn_text_classification.py View File

@@ -46,5 +46,18 @@ class CNNText(torch.nn.Module):
x = self.fc(x) # [N,C] -> [N, N_class] x = self.fc(x) # [N,C] -> [N, N_class]
return {'output':x} return {'output':x}


def loss(self, output, label_seq):
def predict(self, word_seq):
output = self(word_seq)
_, predict = output.max(dim=1)
return {'predict': predict}

def get_loss(self, output, label_seq):
return self._loss(output, label_seq) return self._loss(output, label_seq)

def evaluate(self, predict, label_seq):
predict, label_seq = torch.stack(predict, dim=0), torch.stack(label_seq, dim=0)
predict, label_seq = predict.squeeze(), label_seq.squeeze()
correct = (predict == label_seq).long().sum().item()
total = label_seq.size(0)
return 1.0 * correct / total


Loading…
Cancel
Save