Browse Source

[bugfix]修复了trainer在update_every大于1时loss的显示错误

tags/v0.5.5
Yige Xu 5 years ago
parent
commit
81c71aceb8
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      fastNLP/core/trainer.py

+ 1
- 1
fastNLP/core/trainer.py View File

@@ -666,8 +666,8 @@ class Trainer(object):
# edit prediction # edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction) self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y).mean() loss = self._compute_loss(prediction, batch_y).mean()
avg_loss += loss.item()
loss = loss / self.update_every loss = loss / self.update_every
avg_loss += loss.item()


# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss) self.callback_manager.on_backward_begin(loss)


Loading…
Cancel
Save