|
|
@@ -94,6 +94,7 @@ class BaseTrainer(object): |
|
|
|
self._optimizer = None |
|
|
|
self._optimizer_proto = default_args["optimizer"] |
|
|
|
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') |
|
|
|
self._graph_summaried = False |
|
|
|
|
|
|
|
def train(self, network, train_data, dev_data=None): |
|
|
|
"""General Training Procedure |
|
|
@@ -168,6 +169,10 @@ class BaseTrainer(object): |
|
|
|
self.update() |
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) |
|
|
|
|
|
|
|
if not self._graph_summaried: |
|
|
|
self._summary_writer.add_graph(network, batch_x) |
|
|
|
self._graph_summaried = True |
|
|
|
|
|
|
|
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: |
|
|
|
end = time.time() |
|
|
|
diff = timedelta(seconds=round(end - kwargs["start"])) |
|
|
|