Browse Source

Merge pull request #60 from KuNyaa/master

add tensorboardX for loss visualization
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
49ad966c5f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 0 deletions
  1. +6
    -0
      README.md
  2. +9
    -0
      fastNLP/core/trainer.py
  3. +1
    -0
      requirements.txt

+ 6
- 0
README.md View File

@@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
- numpy>=1.14.2
- torch==0.4.0
- torchvision>=0.1.8
- tensorboardX


## Resources
@@ -47,6 +48,11 @@ conda install pytorch torchvision -c pytorch
pip3 install torch torchvision
```

### TensorboardX Installation

```shell
pip3 install tensorboardX
```

## Project Structure



+ 9
- 0
fastNLP/core/trainer.py View File

@@ -4,6 +4,8 @@ import time
from datetime import timedelta

import torch
import tensorboardX
from tensorboardX import SummaryWriter

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
@@ -86,6 +88,8 @@ class BaseTrainer(object):
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None
self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False

def train(self, network, train_data, dev_data=None):
"""General Training Procedure
@@ -160,6 +164,11 @@ class BaseTrainer(object):
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)

if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True

if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time()


+ 1
- 0
requirements.txt View File

@@ -1,3 +1,4 @@
numpy>=1.14.2
torch==0.4.0
torchvision>=0.1.8
tensorboardX

Loading…
Cancel
Save