Browse Source

add tensorboardX for loss visualization

tags/v0.1.0
KuNya 6 years ago
parent
commit
baf17892a7
3 changed files with 11 additions and 0 deletions
  1. +6
    -0
      README.md
  2. +4
    -0
      fastNLP/core/trainer.py
  3. +1
    -0
      requirements.txt

+ 6
- 0
README.md View File

@@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
- numpy>=1.14.2
- torch==0.4.0
- torchvision>=0.1.8
- tensorboardX


## Resources
@@ -45,6 +46,11 @@ conda install pytorch torchvision -c pytorch
pip3 install torch torchvision
```

### TensorboardX Installation

```shell
pip3 install tensorboardX
```

## Project Structure



+ 4
- 0
fastNLP/core/trainer.py View File

@@ -5,6 +5,8 @@ import time
from datetime import timedelta

import torch
import tensorboardX
from tensorboardX import SummaryWriter

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
@@ -91,6 +93,7 @@ class BaseTrainer(object):
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None
self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')

def train(self, network, train_data, dev_data=None):
"""General Training Procedure
@@ -163,6 +166,7 @@ class BaseTrainer(object):
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)

if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time()


+ 1
- 0
requirements.txt View File

@@ -1,3 +1,4 @@
numpy>=1.14.2
torch==0.4.0
torchvision>=0.1.8
tensorboardX

Loading…
Cancel
Save