Browse Source

rename Batch to DataSetIter in enas_trainer

tags/v0.4.10
xuyige 5 years ago
parent
commit
6cf1a8562b
1 changed files with 7 additions and 7 deletions
  1. +7
    -7
      fastNLP/models/enas_trainer.py

+ 7
- 7
fastNLP/models/enas_trainer.py View File

@@ -14,7 +14,7 @@ except:
from ..core.utils import _pseudo_tqdm as tqdm

from ..core.trainer import Trainer
from ..core.batch import Batch
from ..core.batch import DataSetIter
from ..core.callback import CallbackManager, CallbackException
from ..core.dataset import DataSet
from ..core.utils import _move_dict_value_to_device
@@ -124,8 +124,8 @@ class ENASTrainer(Trainer):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for epoch in range(1, self.n_epochs + 1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
@@ -209,8 +209,8 @@ class ENASTrainer(Trainer):
total_loss = 0
train_idx = 0
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
@@ -262,8 +262,8 @@ class ENASTrainer(Trainer):
if not isinstance(entropies, np.ndarray):
entropies = entropies.data.cpu().numpy()
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = DataSetIter(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for inputs, targets in data_iterator:
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)


Loading…
Cancel
Save