|
|
@@ -14,7 +14,7 @@ except: |
|
|
|
from ..core.utils import _pseudo_tqdm as tqdm |
|
|
|
|
|
|
|
from ..core.trainer import Trainer |
|
|
|
from ..core.batch import Batch |
|
|
|
from ..core.batch import DataSetIter |
|
|
|
from ..core.callback import CallbackManager, CallbackException |
|
|
|
from ..core.dataset import DataSet |
|
|
|
from ..core.utils import _move_dict_value_to_device |
|
|
@@ -124,8 +124,8 @@ class ENASTrainer(Trainer): |
|
|
|
len(self.train_data) % self.batch_size != 0)) * self.n_epochs |
|
|
|
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
|
avg_loss = 0 |
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, |
|
|
|
prefetch=self.prefetch) |
|
|
|
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, |
|
|
|
prefetch=self.prefetch) |
|
|
|
for epoch in range(1, self.n_epochs + 1): |
|
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
|
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) |
|
|
@@ -209,8 +209,8 @@ class ENASTrainer(Trainer): |
|
|
|
total_loss = 0 |
|
|
|
train_idx = 0 |
|
|
|
avg_loss = 0 |
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, |
|
|
|
prefetch=self.prefetch) |
|
|
|
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, |
|
|
|
prefetch=self.prefetch) |
|
|
|
|
|
|
|
for batch_x, batch_y in data_iterator: |
|
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) |
|
|
@@ -262,8 +262,8 @@ class ENASTrainer(Trainer): |
|
|
|
if not isinstance(entropies, np.ndarray): |
|
|
|
entropies = entropies.data.cpu().numpy() |
|
|
|
|
|
|
|
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, |
|
|
|
prefetch=self.prefetch) |
|
|
|
data_iterator = DataSetIter(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, |
|
|
|
prefetch=self.prefetch) |
|
|
|
|
|
|
|
for inputs, targets in data_iterator: |
|
|
|
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) |
|
|
|