diff --git a/README.md b/README.md index 8ebd9d30..84d658fd 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa - numpy>=1.14.2 - torch==0.4.0 - torchvision>=0.1.8 +- tensorboardX ## Resources @@ -47,6 +48,11 @@ conda install pytorch torchvision -c pytorch pip3 install torch torchvision ``` +### TensorboardX Installation + +```shell +pip3 install tensorboardX +``` ## Project Structure diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 4714131e..523a1763 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -4,6 +4,8 @@ import time from datetime import timedelta import torch +import tensorboardX +from tensorboardX import SummaryWriter from fastNLP.core.action import Action from fastNLP.core.action import RandomSampler, Batchifier @@ -86,6 +88,8 @@ class BaseTrainer(object): self._loss_func = default_args["loss"].get() # return a pytorch loss function or None self._optimizer = None self._optimizer_proto = default_args["optimizer"] + self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') + self._graph_summaried = False def train(self, network, train_data, dev_data=None): """General Training Procedure @@ -160,6 +164,11 @@ class BaseTrainer(object): loss = self.get_loss(prediction, batch_y) self.grad_backward(loss) self.update() + self._summary_writer.add_scalar("loss", loss.item(), global_step=step) + + if not self._graph_summaried: + self._summary_writer.add_graph(network, batch_x) + self._graph_summaried = True if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: end = time.time() diff --git a/requirements.txt b/requirements.txt index d961dd92..954dd741 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.14.2 torch==0.4.0 torchvision>=0.1.8 +tensorboardX