Browse Source

add graph summary in _train_step

tags/v0.1.0
KuNya 7 years ago
parent
commit
68b63fb071
1 changed files with 5 additions and 0 deletions
  1. +5
    -0
      fastNLP/core/trainer.py

+ 5
- 0
fastNLP/core/trainer.py View File

@@ -94,6 +94,7 @@ class BaseTrainer(object):
self._optimizer = None self._optimizer = None
self._optimizer_proto = default_args["optimizer"] self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False


def train(self, network, train_data, dev_data=None): def train(self, network, train_data, dev_data=None):
"""General Training Procedure """General Training Procedure
@@ -168,6 +169,10 @@ class BaseTrainer(object):
self.update() self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) self._summary_writer.add_scalar("loss", loss.item(), global_step=step)


if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True

if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"])) diff = timedelta(seconds=round(end - kwargs["start"]))


Loading…
Cancel
Save