|
@@ -4,6 +4,8 @@ import time |
|
|
from datetime import timedelta |
|
|
from datetime import timedelta |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
|
|
|
import tensorboardX |
|
|
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
|
|
|
from fastNLP.core.action import Action |
|
|
from fastNLP.core.action import Action |
|
|
from fastNLP.core.action import RandomSampler, Batchifier |
|
|
from fastNLP.core.action import RandomSampler, Batchifier |
|
@@ -86,6 +88,8 @@ class BaseTrainer(object): |
|
|
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None |
|
|
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None |
|
|
self._optimizer = None |
|
|
self._optimizer = None |
|
|
self._optimizer_proto = default_args["optimizer"] |
|
|
self._optimizer_proto = default_args["optimizer"] |
|
|
|
|
|
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') |
|
|
|
|
|
self._graph_summaried = False |
|
|
|
|
|
|
|
|
def train(self, network, train_data, dev_data=None): |
|
|
def train(self, network, train_data, dev_data=None): |
|
|
"""General Training Procedure |
|
|
"""General Training Procedure |
|
@@ -160,6 +164,11 @@ class BaseTrainer(object): |
|
|
loss = self.get_loss(prediction, batch_y) |
|
|
loss = self.get_loss(prediction, batch_y) |
|
|
self.grad_backward(loss) |
|
|
self.grad_backward(loss) |
|
|
self.update() |
|
|
self.update() |
|
|
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) |
|
|
|
|
|
|
|
|
|
|
|
if not self._graph_summaried: |
|
|
|
|
|
self._summary_writer.add_graph(network, batch_x) |
|
|
|
|
|
self._graph_summaried = True |
|
|
|
|
|
|
|
|
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: |
|
|
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: |
|
|
end = time.time() |
|
|
end = time.time() |
|
|