|
|
@@ -209,7 +209,7 @@ class LossFunc(LossBase): |
|
|
|
class CrossEntropyLoss(LossBase): |
|
|
|
r""" |
|
|
|
交叉熵损失函数 |
|
|
|
|
|
|
|
|
|
|
|
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` |
|
|
|
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` |
|
|
|
:param seq_len: 句子的长度, 长度之外的token不会计算loss。 |
|
|
@@ -219,24 +219,27 @@ class CrossEntropyLoss(LossBase): |
|
|
|
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 |
|
|
|
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 |
|
|
|
传入seq_len. |
|
|
|
:param label_smoothing: label smoothing的eps系数,将标签编码为形如[1-eps, eps/(C-1), ..., eps/(C-1)]的向量(而不是oneone-hot向量),其中C是 |
|
|
|
标签类数,如果设置为None表示不进行label smoothing. |
|
|
|
:param str reduction: 支持 `mean` ,`sum` 和 `none` . |
|
|
|
|
|
|
|
Example:: |
|
|
|
|
|
|
|
loss = CrossEntropyLoss(pred='pred', target='label', padding_idx=0) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'): |
|
|
|
|
|
|
|
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, label_smoothing=None, reduction='mean'): |
|
|
|
super(CrossEntropyLoss, self).__init__() |
|
|
|
self._init_param_map(pred=pred, target=target, seq_len=seq_len) |
|
|
|
self.padding_idx = padding_idx |
|
|
|
assert reduction in ('mean', 'sum', 'none') |
|
|
|
self.reduction = reduction |
|
|
|
self.class_in_dim = class_in_dim |
|
|
|
|
|
|
|
self.label_smoothing = label_smoothing |
|
|
|
|
|
|
|
def get_loss(self, pred, target, seq_len=None): |
|
|
|
if seq_len is not None and target.dim()>1: |
|
|
|
if seq_len is not None and target.dim() > 1: |
|
|
|
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) |
|
|
|
target = target.masked_fill(mask, self.padding_idx) |
|
|
|
|
|
|
@@ -249,6 +252,20 @@ class CrossEntropyLoss(LossBase): |
|
|
|
pred = pred.reshape(-1, pred.size(-1)) |
|
|
|
target = target.reshape(-1) |
|
|
|
|
|
|
|
if self.label_smoothing is not None: |
|
|
|
n_class = pred.size(1) |
|
|
|
one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1) |
|
|
|
one_hot = one_hot * (1 - self.label_smoothing) + (1 - one_hot) * self.label_smoothing / (n_class - 1) |
|
|
|
print(one_hot) |
|
|
|
print(pred) |
|
|
|
log_prb = F.log_softmax(pred, dim=1) |
|
|
|
print(log_prb.shape) |
|
|
|
|
|
|
|
non_pad_mask = target.ne(self.padding_idx) |
|
|
|
print(non_pad_mask.shape) |
|
|
|
loss = -(one_hot * log_prb).sum(dim=1) |
|
|
|
return getattr(loss.masked_select(non_pad_mask), self.reduction)() |
|
|
|
|
|
|
|
return F.cross_entropy(input=pred, target=target, |
|
|
|
ignore_index=self.padding_idx, reduction=self.reduction) |
|
|
|
|
|
|
|