|
|
@@ -265,7 +265,7 @@ class Trainer(object): |
|
|
|
|
|
|
|
# edit prediction |
|
|
|
self.callback_manager.on_loss_begin(batch_y, prediction) |
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
loss = self._compute_loss(prediction, batch_y).mean() |
|
|
|
avg_loss += loss.item() |
|
|
|
loss = loss/self.update_every |
|
|
|
|
|
|
|