Browse Source

add graph summary in _train_step

tags/v0.1.0
KuNya 6 years ago
parent
commit
68b63fb071
1 changed files with 5 additions and 0 deletions
  1. +5
    -0
      fastNLP/core/trainer.py

+ 5
- 0
fastNLP/core/trainer.py View File

@@ -94,6 +94,7 @@ class BaseTrainer(object):
self._optimizer = None
self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False

def train(self, network, train_data, dev_data=None):
"""General Training Procedure
@@ -168,6 +169,10 @@ class BaseTrainer(object):
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)

if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True

if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))


Loading…
Cancel
Save