diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 721478f7..a0b93d9a 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -648,17 +648,16 @@ class Trainer(object): with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: self.pbar = pbar avg_loss = 0 - data_iterator = self.data_iterator - self.batch_per_epoch = data_iterator.num_batches + self.batch_per_epoch = self.data_iterator.num_batches for epoch in range(1, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping self.callback_manager.on_epoch_begin() - for batch_x, batch_y in data_iterator: + for batch_x, batch_y in self.data_iterator: self.step += 1 _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) - indices = data_iterator.get_batch_indices() + indices = self.data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.on_batch_begin(batch_x, batch_y, indices) prediction = self._data_forward(self.model, batch_x) @@ -692,7 +691,7 @@ class Trainer(object): self.callback_manager.on_batch_end() if ((self.validate_every > 0 and self.step % self.validate_every == 0) or - (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ + (self.validate_every < 0 and self.step % len(self.data_iterator) == 0)) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,