diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2a8d85da..b45dd148 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -265,7 +265,7 @@ class Trainer(object): # edit prediction self.callback_manager.on_loss_begin(batch_y, prediction) - loss = self._compute_loss(prediction, batch_y) + loss = self._compute_loss(prediction, batch_y).mean() avg_loss += loss.item() loss = loss/self.update_every