|
@@ -578,7 +578,10 @@ class Trainer(object): |
|
|
self.step = 0 |
|
|
self.step = 0 |
|
|
self.epoch = 0 |
|
|
self.epoch = 0 |
|
|
start = time.time() |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(self.model, nn.DataParallel): |
|
|
|
|
|
self._forward_func = self.model.module.forward |
|
|
|
|
|
else: |
|
|
|
|
|
self._forward_func = self.model.forward |
|
|
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
self.pbar = pbar |
|
|
self.pbar = pbar |
|
|
avg_loss = 0 |
|
|
avg_loss = 0 |
|
@@ -682,11 +685,11 @@ class Trainer(object): |
|
|
self.optimizer.step() |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
def _data_forward(self, network, x): |
|
|
def _data_forward(self, network, x): |
|
|
x = _build_args(network.forward, **x) |
|
|
|
|
|
|
|
|
x = _build_args(self._forward_func, **x) |
|
|
y = network(**x) |
|
|
y = network(**x) |
|
|
if not isinstance(y, dict): |
|
|
if not isinstance(y, dict): |
|
|
raise TypeError( |
|
|
raise TypeError( |
|
|
f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.") |
|
|
|
|
|
|
|
|
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") |
|
|
return y |
|
|
return y |
|
|
|
|
|
|
|
|
def _grad_backward(self, loss): |
|
|
def _grad_backward(self, loss): |
|
@@ -845,8 +848,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ |
|
|
print(info_str) |
|
|
print(info_str) |
|
|
_check_forward_error(forward_func=model.forward, dataset=dataset, |
|
|
_check_forward_error(forward_func=model.forward, dataset=dataset, |
|
|
batch_x=batch_x, check_level=check_level) |
|
|
batch_x=batch_x, check_level=check_level) |
|
|
|
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
|
|
|
|
|
|
if isinstance(model, nn.DataParallel): |
|
|
|
|
|
forward_func = model.module.forward |
|
|
|
|
|
else: |
|
|
|
|
|
forward_func = model.forward |
|
|
|
|
|
refined_batch_x = _build_args(forward_func, **batch_x) |
|
|
pred_dict = model(**refined_batch_x) |
|
|
pred_dict = model(**refined_batch_x) |
|
|
func_signature = _get_func_signature(model.forward) |
|
|
func_signature = _get_func_signature(model.forward) |
|
|
if not isinstance(pred_dict, dict): |
|
|
if not isinstance(pred_dict, dict): |
|
|