diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 9b32babb..92f2f364 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -352,82 +352,3 @@ def _prepare_losser(losser): return losser else: raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}") - - -def squash(predict, truth, **kwargs): - """To reshape tensors in order to fit loss functions in PyTorch. - - :param predict: Tensor, model output - :param truth: Tensor, truth from dataset - :param kwargs: extra arguments - :return predict , truth: predict & truth after processing - """ - return predict.view(-1, predict.size()[-1]), truth.view(-1, ) - - -def unpad(predict, truth, **kwargs): - """To process padded sequence output to get true loss. - - :param predict: Tensor, [batch_size , max_len , tag_size] - :param truth: Tensor, [batch_size , max_len] - :param kwargs: kwargs["lens"] is a list or LongTensor, with size [batch_size]. The i-th element is true lengths of i-th sequence. - - :return predict , truth: predict & truth after processing - """ - if kwargs.get("lens") is None: - return predict, truth - lens = torch.LongTensor(kwargs["lens"]) - lens, idx = torch.sort(lens, descending=True) - predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx], lens, batch_first=True).data - truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx], lens, batch_first=True).data - return predict, truth - - -def unpad_mask(predict, truth, **kwargs): - """To process padded sequence output to get true loss. - - :param predict: Tensor, [batch_size , max_len , tag_size] - :param truth: Tensor, [batch_size , max_len] - :param kwargs: kwargs["lens"] is a list or LongTensor, with size [batch_size]. The i-th element is true lengths of i-th sequence. - - :return predict , truth: predict & truth after processing - """ - if kwargs.get("lens") is None: - return predict, truth - mas = make_mask(kwargs["lens"], truth.size()[1]) - return mask(predict, truth, mask=mas) - - -def mask(predict, truth, **kwargs): - """To select specific elements from Tensor. This method calls ``squash()``. - - :param predict: Tensor, [batch_size , max_len , tag_size] - :param truth: Tensor, [batch_size , max_len] - :param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. - - :return predict , truth: predict & truth after processing - """ - if kwargs.get("mask") is None: - return predict, truth - mask = kwargs["mask"] - - predict, truth = squash(predict, truth) - mask = mask.view(-1, ) - - predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0) - truth = torch.masked_select(truth, mask) - - return predict, truth - - -def make_mask(lens, tar_len): - """To generate a mask over a sequence. - - :param lens: list or LongTensor, [batch_size] - :param tar_len: int - :return mask: ByteTensor - """ - lens = torch.LongTensor(lens) - mask = [torch.ge(lens, i + 1) for i in range(tar_len)] - mask = torch.stack(mask, 1) - return mask diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 8db54615..9ba8159f 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -4,7 +4,6 @@ import torch import torch.nn.functional as F import fastNLP as loss -from fastNLP.core.losses import squash, unpad class TestLoss(unittest.TestCase): @@ -73,15 +72,3 @@ class TestLosserError(unittest.TestCase): with self.assertRaises(Exception): ans = l1({"my_predict": a}, {"truth": b, "my": a}) - - -class TestLossUtils(unittest.TestCase): - def test_squash(self): - a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) - self.assertEqual(tuple(a.size()), (3, 5)) - self.assertEqual(tuple(b.size()), (15,)) - - def test_unpad(self): - a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) - self.assertEqual(tuple(a.size()), (5, 8, 3)) - self.assertEqual(tuple(b.size()), (5, 8))