Browse Source

rename Batch to DataSetIter in enas_trainer

tags/v0.4.10
xuyige 6 years ago
parent
commit
6cf1a8562b
1 changed files with 7 additions and 7 deletions
  1. +7
    -7
      fastNLP/models/enas_trainer.py

+ 7
- 7
fastNLP/models/enas_trainer.py View File

@@ -14,7 +14,7 @@ except:
from ..core.utils import _pseudo_tqdm as tqdm from ..core.utils import _pseudo_tqdm as tqdm


from ..core.trainer import Trainer from ..core.trainer import Trainer
from ..core.batch import Batch
from ..core.batch import DataSetIter
from ..core.callback import CallbackManager, CallbackException from ..core.callback import CallbackManager, CallbackException
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.utils import _move_dict_value_to_device from ..core.utils import _move_dict_value_to_device
@@ -124,8 +124,8 @@ class ENASTrainer(Trainer):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for epoch in range(1, self.n_epochs + 1): for epoch in range(1, self.n_epochs + 1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
@@ -209,8 +209,8 @@ class ENASTrainer(Trainer):
total_loss = 0 total_loss = 0
train_idx = 0 train_idx = 0
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
@@ -262,8 +262,8 @@ class ENASTrainer(Trainer):
if not isinstance(entropies, np.ndarray): if not isinstance(entropies, np.ndarray):
entropies = entropies.data.cpu().numpy() entropies = entropies.data.cpu().numpy()
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = DataSetIter(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
for inputs, targets in data_iterator: for inputs, targets in data_iterator:
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)


Loading…
Cancel
Save