| @@ -76,14 +76,15 @@ class NewLoss(LossBase): | |||||
| class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
| def __init__(self, loss_key='loss'): | def __init__(self, loss_key='loss'): | ||||
| super().__init__() | super().__init__() | ||||
| self.loss_key = loss_key | self.loss_key = loss_key | ||||
| def get_loss(self, *args, **kwargs): | |||||
| pass | |||||
| def get_loss(self, **kwargs): | |||||
| if self.loss_key not in kwargs: | |||||
| pass | |||||
| def __call__(self, output_dict, predict_dict): | def __call__(self, output_dict, predict_dict): | ||||
| pass | |||||
| return self.get_loss(**output_dict) | |||||
| def _prepare_losser(losser): | def _prepare_losser(losser): | ||||