diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index aa1ffb89..9306f9f9 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -76,14 +76,15 @@ class NewLoss(LossBase): class LossInForward(LossBase): def __init__(self, loss_key='loss'): super().__init__() - self.loss_key = loss_key - def get_loss(self, *args, **kwargs): - pass + def get_loss(self, **kwargs): + if self.loss_key not in kwargs: + pass def __call__(self, output_dict, predict_dict): - pass + + return self.get_loss(**output_dict) def _prepare_losser(losser):