From 09078d0250e18048d91d62eb77f2e58b53ebd6e8 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 19 Dec 2019 23:15:14 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8Trainer=E4=B8=ADdata=5Fiterat?= =?UTF-8?q?or=E5=B1=9E=E6=80=A7=E5=8F=82=E4=B8=8E=E5=BE=AA=E7=8E=AF?= =?UTF-8?q?=E4=BD=BF=E5=BE=97=E5=8F=AF=E4=BB=A5=E9=80=9A=E8=BF=87callback?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 721478f7..a0b93d9a 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -648,17 +648,16 @@ class Trainer(object): with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: self.pbar = pbar avg_loss = 0 - data_iterator = self.data_iterator - self.batch_per_epoch = data_iterator.num_batches + self.batch_per_epoch = self.data_iterator.num_batches for epoch in range(1, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping self.callback_manager.on_epoch_begin() - for batch_x, batch_y in data_iterator: + for batch_x, batch_y in self.data_iterator: self.step += 1 _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) - indices = data_iterator.get_batch_indices() + indices = self.data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.on_batch_begin(batch_x, batch_y, indices) prediction = self._data_forward(self.model, batch_x) @@ -692,7 +691,7 @@ class Trainer(object): self.callback_manager.on_batch_end() if ((self.validate_every > 0 and self.step % self.validate_every == 0) or - (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ + (self.validate_every < 0 and self.step % len(self.data_iterator) == 0)) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,