|
@@ -203,7 +203,7 @@ class Trainer(object): |
|
|
self._tqdm_train() |
|
|
self._tqdm_train() |
|
|
else: |
|
|
else: |
|
|
self._print_train() |
|
|
self._print_train() |
|
|
self.callback_manager.after_train() |
|
|
|
|
|
|
|
|
self.callback_manager.after_train(self.model) |
|
|
|
|
|
|
|
|
if self.dev_data is not None: |
|
|
if self.dev_data is not None: |
|
|
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + |
|
|
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + |
|
@@ -229,24 +229,35 @@ class Trainer(object): |
|
|
self.step = 0 |
|
|
self.step = 0 |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
as_numpy=False) |
|
|
as_numpy=False) |
|
|
total_steps = data_iterator.num_batches*self.n_epochs |
|
|
|
|
|
|
|
|
total_steps = data_iterator.num_batches * self.n_epochs |
|
|
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
avg_loss = 0 |
|
|
avg_loss = 0 |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
self.callback_manager.before_epoch() |
|
|
|
|
|
|
|
|
# early stopping |
|
|
|
|
|
self.callback_manager.before_epoch(epoch, self.n_epochs) |
|
|
for batch_x, batch_y in data_iterator: |
|
|
for batch_x, batch_y in data_iterator: |
|
|
self.callback_manager.before_batch() |
|
|
|
|
|
|
|
|
indices = data_iterator.get_batch_indices() |
|
|
|
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
|
|
|
self.callback_manager.before_batch(batch_x, batch_y, indices) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
|
|
|
|
|
|
self.callback_manager.before_loss() |
|
|
|
|
|
|
|
|
# edit prediction |
|
|
|
|
|
self.callback_manager.before_loss(batch_y, prediction) |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
avg_loss += loss.item() |
|
|
avg_loss += loss.item() |
|
|
|
|
|
|
|
|
self.callback_manager.before_backward() |
|
|
|
|
|
|
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
|
|
|
self.callback_manager.before_backward(loss, self.model) |
|
|
self._grad_backward(loss) |
|
|
self._grad_backward(loss) |
|
|
|
|
|
# gradient clipping |
|
|
|
|
|
self.callback_manager.after_backward(self.model) |
|
|
|
|
|
|
|
|
self._update() |
|
|
self._update() |
|
|
|
|
|
# lr scheduler; lr_finder; one_cycle |
|
|
|
|
|
self.callback_manager.after_step() |
|
|
|
|
|
|
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
for name, param in self.model.named_parameters(): |
|
|
for name, param in self.model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
if param.requires_grad: |
|
@@ -258,23 +269,27 @@ class Trainer(object): |
|
|
avg_loss = 0 |
|
|
avg_loss = 0 |
|
|
pbar.update(self.print_every) |
|
|
pbar.update(self.print_every) |
|
|
self.step += 1 |
|
|
self.step += 1 |
|
|
|
|
|
# do nothing |
|
|
self.callback_manager.after_batch() |
|
|
self.callback_manager.after_batch() |
|
|
|
|
|
|
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0 \ |
|
|
|
|
|
|
|
|
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or |
|
|
|
|
|
(self.validate_every < 0 and self.step % self.batch_size == len(data_iterator))) \ |
|
|
and self.dev_data is not None: |
|
|
and self.dev_data is not None: |
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
self.tester._format_eval_results(eval_res) |
|
|
self.tester._format_eval_results(eval_res) |
|
|
pbar.write(eval_str) |
|
|
pbar.write(eval_str) |
|
|
if self.validate_every < 0 and self.dev_data: |
|
|
|
|
|
eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
|
|
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
|
|
|
self.tester._format_eval_results(eval_res) |
|
|
|
|
|
pbar.write(eval_str) |
|
|
|
|
|
if epoch!=self.n_epochs: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if self.validate_every < 0 and self.dev_data: |
|
|
|
|
|
# eval_res = self._do_validation(epoch=epoch, step=self.step) |
|
|
|
|
|
# eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ |
|
|
|
|
|
# self.tester._format_eval_results(eval_res) |
|
|
|
|
|
# pbar.write(eval_str) |
|
|
|
|
|
if epoch != self.n_epochs: |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, |
|
|
as_numpy=False) |
|
|
as_numpy=False) |
|
|
self.callback_manager.after_epoch() |
|
|
|
|
|
|
|
|
# lr decay; early stopping |
|
|
|
|
|
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) |
|
|
pbar.close() |
|
|
pbar.close() |
|
|
|
|
|
|
|
|
def _print_train(self): |
|
|
def _print_train(self): |
|
@@ -340,6 +355,8 @@ class Trainer(object): |
|
|
self.best_dev_perf = res |
|
|
self.best_dev_perf = res |
|
|
self.best_dev_epoch = epoch |
|
|
self.best_dev_epoch = epoch |
|
|
self.best_dev_step = step |
|
|
self.best_dev_step = step |
|
|
|
|
|
# get validation results; adjust optimizer |
|
|
|
|
|
self.callback_manager.after_valid(res, self.metric_key, self.optimizer) |
|
|
return res |
|
|
return res |
|
|
|
|
|
|
|
|
def _mode(self, model, is_test=False): |
|
|
def _mode(self, model, is_test=False): |
|
|