|
@@ -10,6 +10,7 @@ from torch import nn |
|
|
from tqdm.autonotebook import tqdm |
|
|
from tqdm.autonotebook import tqdm |
|
|
|
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.batch import Batch |
|
|
|
|
|
from fastNLP.core.callback import CallbackManager |
|
|
from fastNLP.core.dataset import DataSet |
|
|
from fastNLP.core.dataset import DataSet |
|
|
from fastNLP.core.losses import _prepare_losser |
|
|
from fastNLP.core.losses import _prepare_losser |
|
|
from fastNLP.core.metrics import _prepare_metrics |
|
|
from fastNLP.core.metrics import _prepare_metrics |
|
@@ -29,7 +30,8 @@ from fastNLP.core.utils import get_func_signature |
|
|
class Trainer(object): |
|
|
class Trainer(object): |
|
|
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), |
|
|
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), |
|
|
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): |
|
|
|
|
|
|
|
|
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, |
|
|
|
|
|
callbacks=None): |
|
|
""" |
|
|
""" |
|
|
:param DataSet train_data: the training data |
|
|
:param DataSet train_data: the training data |
|
|
:param torch.nn.modules.module model: a PyTorch model |
|
|
:param torch.nn.modules.module model: a PyTorch model |
|
@@ -109,6 +111,7 @@ class Trainer(object): |
|
|
self.validate_every = int(validate_every) |
|
|
self.validate_every = int(validate_every) |
|
|
self.best_metric_indicator = None |
|
|
self.best_metric_indicator = None |
|
|
self.sampler = sampler |
|
|
self.sampler = sampler |
|
|
|
|
|
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) |
|
|
|
|
|
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
@@ -194,10 +197,14 @@ class Trainer(object): |
|
|
else: |
|
|
else: |
|
|
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) |
|
|
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) |
|
|
self._summary_writer = SummaryWriter(path) |
|
|
self._summary_writer = SummaryWriter(path) |
|
|
|
|
|
|
|
|
|
|
|
self.callback_manager.before_train() |
|
|
if self.use_tqdm: |
|
|
if self.use_tqdm: |
|
|
self._tqdm_train() |
|
|
self._tqdm_train() |
|
|
else: |
|
|
else: |
|
|
self._print_train() |
|
|
self._print_train() |
|
|
|
|
|
self.callback_manager.after_train() |
|
|
|
|
|
|
|
|
if self.dev_data is not None: |
|
|
if self.dev_data is not None: |
|
|
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + |
|
|
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + |
|
|
self.tester._format_eval_results(self.best_dev_perf),) |
|
|
self.tester._format_eval_results(self.best_dev_perf),) |
|
@@ -227,11 +234,17 @@ class Trainer(object): |
|
|
avg_loss = 0 |
|
|
avg_loss = 0 |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
|
|
|
self.callback_manager.before_epoch() |
|
|
for batch_x, batch_y in data_iterator: |
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
|
|
self.callback_manager.before_batch() |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
|
|
|
|
|
|
|
|
|
self.callback_manager.before_loss() |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
avg_loss += loss.item() |
|
|
avg_loss += loss.item() |
|
|
|
|
|
|
|
|
|
|
|
self.callback_manager.before_backward() |
|
|
self._grad_backward(loss) |
|
|
self._grad_backward(loss) |
|
|
self._update() |
|
|
self._update() |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
@@ -245,6 +258,8 @@ class Trainer(object): |
|
|
avg_loss = 0 |
|
|
avg_loss = 0 |
|
|
pbar.update(self.print_every) |
|
|
pbar.update(self.print_every) |
|
|
self.step += 1 |
|
|
self.step += 1 |
|
|
|
|
|
self.callback_manager.after_batch() |
|
|
|
|
|
|
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0 \ |
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0 \ |
|
|
and self.dev_data is not None: |
|
|
and self.dev_data is not None: |
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
@@ -259,23 +274,31 @@ class Trainer(object): |
|
|
if epoch!=self.n_epochs: |
|
|
if epoch!=self.n_epochs: |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
as_numpy=False) |
|
|
as_numpy=False) |
|
|
|
|
|
self.callback_manager.after_epoch() |
|
|
pbar.close() |
|
|
pbar.close() |
|
|
|
|
|
|
|
|
def _print_train(self): |
|
|
def _print_train(self): |
|
|
epoch = 1 |
|
|
epoch = 1 |
|
|
start = time.time() |
|
|
start = time.time() |
|
|
while epoch <= self.n_epochs: |
|
|
while epoch <= self.n_epochs: |
|
|
|
|
|
self.callback_manager.before_epoch() |
|
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
as_numpy=False) |
|
|
as_numpy=False) |
|
|
|
|
|
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
|
|
self.callback_manager.before_batch() |
|
|
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 |
|
|
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
|
|
|
|
|
|
|
|
|
self.callback_manager.before_loss() |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
|
|
|
|
|
|
|
|
self.callback_manager.before_backward() |
|
|
self._grad_backward(loss) |
|
|
self._grad_backward(loss) |
|
|
self._update() |
|
|
self._update() |
|
|
|
|
|
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
for name, param in self.model.named_parameters(): |
|
|
for name, param in self.model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
if param.requires_grad: |
|
@@ -294,11 +317,13 @@ class Trainer(object): |
|
|
self._do_validation(epoch=epoch, step=self.step) |
|
|
self._do_validation(epoch=epoch, step=self.step) |
|
|
|
|
|
|
|
|
self.step += 1 |
|
|
self.step += 1 |
|
|
|
|
|
self.callback_manager.after_batch() |
|
|
|
|
|
|
|
|
# validate_every override validation at end of epochs |
|
|
# validate_every override validation at end of epochs |
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
self._do_validation(epoch=epoch, step=self.step) |
|
|
self._do_validation(epoch=epoch, step=self.step) |
|
|
epoch += 1 |
|
|
epoch += 1 |
|
|
|
|
|
self.callback_manager.after_epoch() |
|
|
|
|
|
|
|
|
def _do_validation(self, epoch, step): |
|
|
def _do_validation(self, epoch, step): |
|
|
res = self.tester.test() |
|
|
res = self.tester.test() |
|
|