diff --git a/fastNLP/core/loss.py b/fastNLP/core/loss.py index 16b5eac2..ce388989 100644 --- a/fastNLP/core/loss.py +++ b/fastNLP/core/loss.py @@ -1,58 +1,197 @@ import torch +def squash(predict , truth , **kwargs): + '''To reshape tensors in order to fit Loss functions in pytorch + + :param predict : Tensor, model output + :param truth : Tensor, truth from dataset + :param **kwargs : extract arguments + + :return predict , truth: predict & truth after processing + ''' + return predict.view(-1 , predict.size()[-1]) , truth.view(-1,) + +def unpad(predict , truth , **kwargs): + '''To process padded sequence output to get true loss + Using pack_padded_sequence() method + This method contains squash() + + :param predict : Tensor, [batch_size , max_len , tag_size] + :param truth : Tensor, [batch_size , max_len] + :param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist + arg["lens"] : list or LongTensor, [batch_size] + the i-th element is true lengths of i-th sequence + + :return predict , truth: predict & truth after processing + ''' + if kwargs.get("lens") is None: + return predict , truth + lens = torch.LongTensor(kwargs["lens"]) + lens , idx = torch.sort(lens , descending = True) + predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data + truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data + return predict , truth + +def unpad_mask(predict , truth , **kwargs): + '''To process padded sequence output to get true loss + Using mask() method + This method contains squash() + + :param predict : Tensor, [batch_size , max_len , tag_size] + :param truth : Tensor, [batch_size , max_len] + :param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist + arg["lens"] : list or LongTensor, [batch_size] + the i-th element is true lengths of i-th sequence + + :return predict , truth: predict & truth after processing + ''' + if kwargs.get("lens") is None: + return predict , truth + mas = make_mask(kwargs["lens"] , truth.size()[1]) + return mask(predict , truth , mask = mas) + +def mask(predict , truth , **kwargs): + '''To select specific elements from Tensor + This method contains squash() + + :param predict : Tensor, [batch_size , max_len , tag_size] + :param truth : Tensor, [batch_size , max_len] + :param **kwargs : extract arguments, kwargs["mask"] is expected to be exsist + arg["mask"] : ByteTensor, [batch_size , max_len] + the mask Tensor , the position that is 1 will be selected + + :return predict , truth: predict & truth after processing + ''' + if kwargs.get("mask") is None: + return predict , truth + mask = kwargs["mask"] + + predict , truth = squash(predict , truth) + mask = mask.view(-1,) + + predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0) + truth = torch.masked_select(truth , mask) + + return predict , truth + +def make_mask(lens , tar_len): + '''to generate a mask that select [:lens[i]] for i-th element + embezzle from fastNLP.models.sequence_modeling.seq_mask + + :param lens : list or LongTensor, [batch_size] + :param tar_len : int + + :return mask : ByteTensor + ''' + lens = torch.LongTensor(lens) + mask = [torch.ge(lens, i + 1) for i in range(tar_len)] + mask = torch.stack(mask, 1) + return mask + +#map string to function. Just for more elegant using +method_dict = { + "squash" : squash, + "unpad" : unpad, + "unpad_mask" : unpad_mask, + "mask" : mask, +} + +loss_function_name = { + "L1Loss".lower() : torch.nn.L1Loss, + "BCELoss".lower() : torch.nn.BCELoss, + "MSELoss".lower() : torch.nn.MSELoss, + "NLLLoss".lower() : torch.nn.NLLLoss, + "KLDivLoss".lower() : torch.nn.KLDivLoss, + "NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss" + "SmoothL1Loss".lower() : torch.nn.SmoothL1Loss, + "SoftMarginLoss".lower() : torch.nn.SoftMarginLoss, + "PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss, + "MultiMarginLoss".lower() : torch.nn.MultiMarginLoss, + "CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss, + "BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss, + "MarginRankingLoss".lower() : torch.nn.MarginRankingLoss, + "TripletMarginLoss".lower() : torch.nn.TripletMarginLoss, + "HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss, + "HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss, + "CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss, + "MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss, + "MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss, +} class Loss(object): - """Loss function of the algorithm, - either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support) - - """ - - def __init__(self, args): - """ - - :param args: None or str, the name of a loss function. - - """ - if args is None: - # this is useful when Trainer.__init__ performs type check - self._loss = None - elif isinstance(args, str): - self._loss = self._borrow_from_pytorch(args) - else: - raise NotImplementedError - - def get(self): - """ - - :return self._loss: the loss function - """ - return self._loss - - @staticmethod - def _borrow_from_pytorch(loss_name): - """Given a name of a loss function, return it from PyTorch. - - :param loss_name: str, the name of a loss function - - - cross_entropy: combines log softmax and nll loss in a single function. - - nll: negative log likelihood - - :return loss: a PyTorch loss - """ - - class InnerCrossEntropy: - """A simple wrapper to guarantee input shapes.""" - - def __init__(self): - self.f = torch.nn.CrossEntropyLoss() - - def __call__(self, predict, truth): - truth = truth.view(-1, ) - return self.f(predict, truth) - - if loss_name == "cross_entropy": - return InnerCrossEntropy() - elif loss_name == 'nll': - return torch.nn.NLLLoss() - else: - raise NotImplementedError + '''a Loss object is a callable object represents loss functions + ''' + + def __init__(self , loss_name , pre_pro = [squash], **kwargs): + ''' + + :param loss_name: str or None , the name of loss function + :param pre_pro : list of function or str, methods to reform parameters before calculating loss + the strings will be auto translated to pre-defined functions + :param **kwargs: kwargs for torch loss function + + pre_pro funcsions should have three arguments: predict, truth, **arg + predict and truth is the necessary parameters in loss function + arg is the extra parameters passed-in when calling loss function + pre_pro functions should return two objects, respectively predict and truth that after processed + + ''' + + if loss_name is None: + # this is useful when Trainer.__init__ performs type check + self._loss = None + else: + if not isinstance(loss_name, str): + raise NotImplementedError + else: + self._loss = self._get_loss(loss_name , **kwargs) + + self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro] + + def add_pre_pro(self , func): + '''add a pre_pro function + + :param func: a function or str, methods to reform parameters before calculating loss + the strings will be auto translated to pre-defined functions + ''' + if not callable(func): + func = method_dict.get(func) + if func is None: + return + self.pre_pro.append(func) + + @staticmethod + def _get_loss(loss_name , **kwargs): + '''Get loss function from torch + + :param loss_name: str, the name of loss function + :param **kwargs: kwargs for torch loss function + :return: A callable loss function object + ''' + loss_name = loss_name.strip().lower() + loss_name = "".join(loss_name.split("_")) + + if len(loss_name) < 4 or loss_name[-4 : ] != "loss": + loss_name += "loss" + return loss_function_name[loss_name](**kwargs) + + def get(self): + '''This method exists just for make some existing codes run error-freely + ''' + return self + + def __call__(self , predict , truth , **kwargs): + '''call a loss function + predict and truth will be processed by pre_pro methods in order of addition + + :param predict : Tensor, model output + :param truth : Tensor, truth from dataset + :param **kwargs : extra arguments, pass to pre_pro functions + for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens + ''' + for f in self.pre_pro: + if f is None: + continue + predict , truth = f(predict , truth , **kwargs) + + return self._loss(predict , truth) diff --git a/requirements.txt b/requirements.txt index 954dd741..a775c8ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.14.2 -torch==0.4.0 +torch>=0.4.0 torchvision>=0.1.8 tensorboardX diff --git a/test/core/test_loss.py b/test/core/test_loss.py new file mode 100644 index 00000000..d6b43fc1 --- /dev/null +++ b/test/core/test_loss.py @@ -0,0 +1,300 @@ +import os +import unittest + +from fastNLP.core.dataset import DataSet +from fastNLP.core.metrics import SeqLabelEvaluator +from fastNLP.core.field import TextField, LabelField +from fastNLP.core.instance import Instance + +from fastNLP.core.optimizer import Optimizer +from fastNLP.core.trainer import SeqLabelTrainer +from fastNLP.models.sequence_modeling import SeqLabeling + +import fastNLP.core.loss as loss +import math +import torch as tc +import pdb + +class TestLoss(unittest.TestCase): + + def test_case_1(self): + #验证nllloss的原理 + + print (".----------------------------------") + + loss_func = loss.Loss("nll") + + #pdb.set_trace() + + y = tc.Tensor( + [ + [.3,.4,.3], + [.5,.3,.2], + [.3,.6,.1], + ] + ) + + gy = tc.LongTensor( + [ + 0, + 1, + 2, + ] + ) + + + y = tc.log(y) + los = loss_func(y , gy) + + r = -math.log(.3) - math.log(.3) - math.log(.1) + r /= 3 + print ("loss = %f" % (los)) + print ("r = %f" % (r)) + + def test_case_2(self): + #验证squash()的正确性 + print ("----------------------------------") + + log = math.log + + loss_func = loss.Loss("nll") + + #pdb.set_trace() + + y = tc.Tensor( + [ + [[.3,.4,.3],[.3,.4,.3],], + [[.5,.3,.2],[.1,.2,.7],], + [[.3,.6,.1],[.2,.1,.7],], + ] + ) + + gy = tc.LongTensor( + [ + [0,2], + [1,2], + [2,1], + ] + ) + + + #pdb.set_trace() + + y = tc.log(y) + los = loss_func(y , gy) + + r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) + r /= 6 + print ("loss = %f" % (los)) + print ("r = %f" % (r)) + + def test_case_3(self): + #验证pack_padded_sequence()的正确性 + print ("----------------------------------") + + log = math.log + + loss_func = loss.Loss("nll") + + #pdb.set_trace() + + y = tc.Tensor( + [ + [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],], + [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],], + [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],], + ] + ) + + gy = tc.LongTensor( + [ + [0,2,1,], + [1,2,0,], + [2,0,0,], + ] + ) + + lens = [3,2,1] + + #pdb.set_trace() + + y = tc.log(y) + + yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data + gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data + los = loss_func(yy , gyy) + print ("loss = %f" % (los)) + + + r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) + r /= 6 + print ("r = %f" % (r)) + + def test_case_4(self): + #验证unpad()的正确性 + print ("----------------------------------") + + log = math.log + + #pdb.set_trace() + + y = tc.Tensor( + [ + [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], + [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], + [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], + ] + ) + + gy = tc.LongTensor( + [ + [0,2,1,2,], + [1,2,0,0,], + [2,0,0,0,], + ] + ) + + lens = [4,2,1] + + #pdb.set_trace() + + y = tc.log(y) + + loss_func = loss.Loss("nll" , pre_pro = ["unpad"]) + los = loss_func(y , gy , lens = lens) + print ("loss = %f" % (los)) + + + r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) + r /= 7 + print ("r = %f" % (r)) + + def test_case_5(self): + #验证mask()和make_mask()的正确性 + print ("----------------------------------") + + log = math.log + + #pdb.set_trace() + + y = tc.Tensor( + [ + [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], + [[.5,.4,.1],[.3,.2,.5],[.4,.5,.1,],[.6,.1,.3,],], + [[.3,.6,.1],[.3,.2,.5],[.0,.0,.0,],[.0,.0,.0,],], + ] + ) + + gy = tc.LongTensor( + [ + [1,2,0,0,], + [0,2,1,2,], + [2,1,0,0,], + ] + ) + + mask = tc.ByteTensor( + [ + [1,1,0,0,], + [1,1,1,1,], + [1,1,0,0,], + ] + ) + + y = tc.log(y) + + lens = [2,4,2] + + loss_func = loss.Loss("nll" , pre_pro = ["mask"]) + los = loss_func(y , gy , mask = mask) + print ("loss = %f" % (los)) + + los2 = loss_func(y , gy , mask = loss.make_mask(lens,gy.size()[-1])) + print ("loss2 = %f" % (los2)) + + + r = -log(.3) -log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2) + r /= 8 + print ("r = %f" % (r)) + + def test_case_6(self): + #验证unpad_mask()的正确性 + print ("----------------------------------") + + log = math.log + + #pdb.set_trace() + + y = tc.Tensor( + [ + [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], + [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], + [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], + ] + ) + + gy = tc.LongTensor( + [ + [0,2,1,2,], + [1,2,0,0,], + [2,0,0,0,], + ] + ) + + lens = [4,2,1] + + #pdb.set_trace() + + y = tc.log(y) + + loss_func = loss.Loss("nll" , pre_pro = ["unpad_mask"]) + los = loss_func(y , gy , lens = lens) + print ("loss = %f" % (los)) + + + r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1) + r /= 7 + print ("r = %f" % (r)) + + def test_case_7(self): + #验证一些其他东西 + print ("----------------------------------") + + log = math.log + + #pdb.set_trace() + + y = tc.Tensor( + [ + [[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],], + [[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],], + [[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],], + ] + ) + + gy = tc.LongTensor( + [ + [0,2,1,2,], + [1,2,0,0,], + [2,0,0,0,], + ] + ) + + lens = [4,2,1] + + #pdb.set_trace() + + y = tc.log(y) + + loss_func = loss.Loss("nll" , pre_pro = [] , weight = tc.Tensor([1,1,0])) + loss_func.add_pre_pro("unpad_mask") + los = loss_func(y , gy , lens = lens) + print ("loss = %f" % (los)) + + + r = - log(.3) - log(.5) - log(.3) + r /= 3 + print ("r = %f" % (r)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file