|
@@ -648,17 +648,16 @@ class Trainer(object): |
|
|
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
self.pbar = pbar |
|
|
self.pbar = pbar |
|
|
avg_loss = 0 |
|
|
avg_loss = 0 |
|
|
data_iterator = self.data_iterator |
|
|
|
|
|
self.batch_per_epoch = data_iterator.num_batches |
|
|
|
|
|
|
|
|
self.batch_per_epoch = self.data_iterator.num_batches |
|
|
for epoch in range(1, self.n_epochs + 1): |
|
|
for epoch in range(1, self.n_epochs + 1): |
|
|
self.epoch = epoch |
|
|
self.epoch = epoch |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
# early stopping |
|
|
# early stopping |
|
|
self.callback_manager.on_epoch_begin() |
|
|
self.callback_manager.on_epoch_begin() |
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
|
|
|
|
|
for batch_x, batch_y in self.data_iterator: |
|
|
self.step += 1 |
|
|
self.step += 1 |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
indices = data_iterator.get_batch_indices() |
|
|
|
|
|
|
|
|
indices = self.data_iterator.get_batch_indices() |
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
@@ -692,7 +691,7 @@ class Trainer(object): |
|
|
self.callback_manager.on_batch_end() |
|
|
self.callback_manager.on_batch_end() |
|
|
|
|
|
|
|
|
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or |
|
|
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or |
|
|
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ |
|
|
|
|
|
|
|
|
(self.validate_every < 0 and self.step % len(self.data_iterator) == 0)) \ |
|
|
and self.dev_data is not None: |
|
|
and self.dev_data is not None: |
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, |
|
|
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, |
|
|