Browse Source

使用Trainer中data_iterator属性参与循环使得可以通过callback修改数据

tags/v0.5.5
yh_cc 5 years ago
parent
commit
09078d0250
1 changed files with 4 additions and 5 deletions
  1. +4
    -5
      fastNLP/core/trainer.py

+ 4
- 5
fastNLP/core/trainer.py View File

@@ -648,17 +648,16 @@ class Trainer(object):
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar
avg_loss = 0
data_iterator = self.data_iterator
self.batch_per_epoch = data_iterator.num_batches
self.batch_per_epoch = self.data_iterator.num_batches
for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping
self.callback_manager.on_epoch_begin()
for batch_x, batch_y in data_iterator:
for batch_x, batch_y in self.data_iterator:
self.step += 1
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices()
indices = self.data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x)
@@ -692,7 +691,7 @@ class Trainer(object):
self.callback_manager.on_batch_end()

if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
(self.validate_every < 0 and self.step % len(self.data_iterator) == 0)) \
and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,


Loading…
Cancel
Save