Browse Source

Merge pull request #60 from KuNyaa/master

add tensorboardX for loss visualization
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
49ad966c5f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 0 deletions
  1. +6
    -0
      README.md
  2. +9
    -0
      fastNLP/core/trainer.py
  3. +1
    -0
      requirements.txt

+ 6
- 0
README.md View File

@@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
- numpy>=1.14.2 - numpy>=1.14.2
- torch==0.4.0 - torch==0.4.0
- torchvision>=0.1.8 - torchvision>=0.1.8
- tensorboardX




## Resources ## Resources
@@ -47,6 +48,11 @@ conda install pytorch torchvision -c pytorch
pip3 install torch torchvision pip3 install torch torchvision
``` ```


### TensorboardX Installation

```shell
pip3 install tensorboardX
```


## Project Structure ## Project Structure




+ 9
- 0
fastNLP/core/trainer.py View File

@@ -4,6 +4,8 @@ import time
from datetime import timedelta from datetime import timedelta


import torch import torch
import tensorboardX
from tensorboardX import SummaryWriter


from fastNLP.core.action import Action from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier from fastNLP.core.action import RandomSampler, Batchifier
@@ -86,6 +88,8 @@ class BaseTrainer(object):
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None self._optimizer = None
self._optimizer_proto = default_args["optimizer"] self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False


def train(self, network, train_data, dev_data=None): def train(self, network, train_data, dev_data=None):
"""General Training Procedure """General Training Procedure
@@ -160,6 +164,11 @@ class BaseTrainer(object):
loss = self.get_loss(prediction, batch_y) loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss) self.grad_backward(loss)
self.update() self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)

if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True


if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time() end = time.time()


+ 1
- 0
requirements.txt View File

@@ -1,3 +1,4 @@
numpy>=1.14.2 numpy>=1.14.2
torch==0.4.0 torch==0.4.0
torchvision>=0.1.8 torchvision>=0.1.8
tensorboardX

Loading…
Cancel
Save