|
|
@@ -666,8 +666,8 @@ class Trainer(object): |
|
|
|
# edit prediction |
|
|
|
self.callback_manager.on_loss_begin(batch_y, prediction) |
|
|
|
loss = self._compute_loss(prediction, batch_y).mean() |
|
|
|
avg_loss += loss.item() |
|
|
|
loss = loss / self.update_every |
|
|
|
avg_loss += loss.item() |
|
|
|
|
|
|
|
# Is loss NaN or inf? requires_grad = False |
|
|
|
self.callback_manager.on_backward_begin(loss) |
|
|
|