Browse Source

CrossEntropyLoss增加class_in_dim选项控制target的维度

tags/v0.4.10
yh 6 years ago
parent
commit
89142d9dc5
1 changed files with 12 additions and 5 deletions
  1. +12
    -5
      fastNLP/core/losses.py

+ 12
- 5
fastNLP/core/losses.py View File

@@ -206,7 +206,11 @@ class CrossEntropyLoss(LossBase):
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。
:param seq_len: 句子的长度, 长度之外的token不会计算loss。
:param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes)
或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第
二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等,
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
传入seq_len.
:param str reduction: 支持 `mean` ,`sum` 和 `none` .
@@ -217,18 +221,21 @@ class CrossEntropyLoss(LossBase):
"""
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'):
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.padding_idx = padding_idx
assert reduction in ('mean', 'sum', 'none')
self.reduction = reduction
self.class_in_dim = class_in_dim
def get_loss(self, pred, target, seq_len=None):
if pred.dim() > 2:
if pred.size(1) != target.size(1): # 有可能顺序替换了
raise RuntimeError("It seems like that your prediction's shape is (batch_size, num_labels, max_len)."
" It should be (batch_size, max_len, num_labels).")
if self.class_in_dim == -1:
if pred.size(1) != target.size(1): # 有可能顺序替换了
pred = pred.transpose(1, 2)
else:
pred = pred.tranpose(-1, pred)
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)
if seq_len is not None:


Loading…
Cancel
Save