|
|
@@ -352,82 +352,3 @@ def _prepare_losser(losser): |
|
|
|
return losser |
|
|
|
else: |
|
|
|
raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}") |
|
|
|
|
|
|
|
|
|
|
|
def squash(predict, truth, **kwargs): |
|
|
|
"""To reshape tensors in order to fit loss functions in PyTorch. |
|
|
|
|
|
|
|
:param predict: Tensor, model output |
|
|
|
:param truth: Tensor, truth from dataset |
|
|
|
:param kwargs: extra arguments |
|
|
|
:return predict , truth: predict & truth after processing |
|
|
|
""" |
|
|
|
return predict.view(-1, predict.size()[-1]), truth.view(-1, ) |
|
|
|
|
|
|
|
|
|
|
|
def unpad(predict, truth, **kwargs): |
|
|
|
"""To process padded sequence output to get true loss. |
|
|
|
|
|
|
|
:param predict: Tensor, [batch_size , max_len , tag_size] |
|
|
|
:param truth: Tensor, [batch_size , max_len] |
|
|
|
:param kwargs: kwargs["lens"] is a list or LongTensor, with size [batch_size]. The i-th element is true lengths of i-th sequence. |
|
|
|
|
|
|
|
:return predict , truth: predict & truth after processing |
|
|
|
""" |
|
|
|
if kwargs.get("lens") is None: |
|
|
|
return predict, truth |
|
|
|
lens = torch.LongTensor(kwargs["lens"]) |
|
|
|
lens, idx = torch.sort(lens, descending=True) |
|
|
|
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx], lens, batch_first=True).data |
|
|
|
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx], lens, batch_first=True).data |
|
|
|
return predict, truth |
|
|
|
|
|
|
|
|
|
|
|
def unpad_mask(predict, truth, **kwargs): |
|
|
|
"""To process padded sequence output to get true loss. |
|
|
|
|
|
|
|
:param predict: Tensor, [batch_size , max_len , tag_size] |
|
|
|
:param truth: Tensor, [batch_size , max_len] |
|
|
|
:param kwargs: kwargs["lens"] is a list or LongTensor, with size [batch_size]. The i-th element is true lengths of i-th sequence. |
|
|
|
|
|
|
|
:return predict , truth: predict & truth after processing |
|
|
|
""" |
|
|
|
if kwargs.get("lens") is None: |
|
|
|
return predict, truth |
|
|
|
mas = make_mask(kwargs["lens"], truth.size()[1]) |
|
|
|
return mask(predict, truth, mask=mas) |
|
|
|
|
|
|
|
|
|
|
|
def mask(predict, truth, **kwargs): |
|
|
|
"""To select specific elements from Tensor. This method calls ``squash()``. |
|
|
|
|
|
|
|
:param predict: Tensor, [batch_size , max_len , tag_size] |
|
|
|
:param truth: Tensor, [batch_size , max_len] |
|
|
|
:param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. |
|
|
|
|
|
|
|
:return predict , truth: predict & truth after processing |
|
|
|
""" |
|
|
|
if kwargs.get("mask") is None: |
|
|
|
return predict, truth |
|
|
|
mask = kwargs["mask"] |
|
|
|
|
|
|
|
predict, truth = squash(predict, truth) |
|
|
|
mask = mask.view(-1, ) |
|
|
|
|
|
|
|
predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0) |
|
|
|
truth = torch.masked_select(truth, mask) |
|
|
|
|
|
|
|
return predict, truth |
|
|
|
|
|
|
|
|
|
|
|
def make_mask(lens, tar_len): |
|
|
|
"""To generate a mask over a sequence. |
|
|
|
|
|
|
|
:param lens: list or LongTensor, [batch_size] |
|
|
|
:param tar_len: int |
|
|
|
:return mask: ByteTensor |
|
|
|
""" |
|
|
|
lens = torch.LongTensor(lens) |
|
|
|
mask = [torch.ge(lens, i + 1) for i in range(tar_len)] |
|
|
|
mask = torch.stack(mask, 1) |
|
|
|
return mask |