Browse Source

LossInForward update

tags/v0.2.0^2
yh 6 years ago
parent
commit
3daa889bb0
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      fastNLP/core/losses.py

+ 5
- 4
fastNLP/core/losses.py View File

@@ -76,14 +76,15 @@ class NewLoss(LossBase):
class LossInForward(LossBase): class LossInForward(LossBase):
def __init__(self, loss_key='loss'): def __init__(self, loss_key='loss'):
super().__init__() super().__init__()

self.loss_key = loss_key self.loss_key = loss_key


def get_loss(self, *args, **kwargs):
pass
def get_loss(self, **kwargs):
if self.loss_key not in kwargs:
pass


def __call__(self, output_dict, predict_dict): def __call__(self, output_dict, predict_dict):
pass

return self.get_loss(**output_dict)




def _prepare_losser(losser): def _prepare_losser(losser):


Loading…
Cancel
Save