|
|
@@ -5,6 +5,8 @@ import time |
|
|
|
from datetime import timedelta |
|
|
|
|
|
|
|
import torch |
|
|
|
import tensorboardX |
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
|
|
from fastNLP.core.action import Action |
|
|
|
from fastNLP.core.action import RandomSampler, Batchifier |
|
|
@@ -91,6 +93,7 @@ class BaseTrainer(object): |
|
|
|
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None |
|
|
|
self._optimizer = None |
|
|
|
self._optimizer_proto = default_args["optimizer"] |
|
|
|
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') |
|
|
|
|
|
|
|
def train(self, network, train_data, dev_data=None): |
|
|
|
"""General Training Procedure |
|
|
@@ -163,6 +166,7 @@ class BaseTrainer(object): |
|
|
|
loss = self.get_loss(prediction, batch_y) |
|
|
|
self.grad_backward(loss) |
|
|
|
self.update() |
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) |
|
|
|
|
|
|
|
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: |
|
|
|
end = time.time() |
|
|
|