From 89142d9dc5ad34b98a1d8d0db47bed4bab562fd9 Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 17 Aug 2019 11:36:39 +0800 Subject: [PATCH] =?UTF-8?q?CrossEntropyLoss=E5=A2=9E=E5=8A=A0class=5Fin=5F?= =?UTF-8?q?dim=E9=80=89=E9=A1=B9=E6=8E=A7=E5=88=B6target=E7=9A=84=E7=BB=B4?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 05e5b440..d5549cec 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -206,7 +206,11 @@ class CrossEntropyLoss(LossBase): :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` - :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 + :param seq_len: 句子的长度, 长度之外的token不会计算loss。 + :param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes) + 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 + 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, + 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 传入seq_len. :param str reduction: 支持 `mean` ,`sum` 和 `none` . @@ -217,18 +221,21 @@ class CrossEntropyLoss(LossBase): """ - def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): + def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'): super(CrossEntropyLoss, self).__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.padding_idx = padding_idx assert reduction in ('mean', 'sum', 'none') self.reduction = reduction + self.class_in_dim = class_in_dim def get_loss(self, pred, target, seq_len=None): if pred.dim() > 2: - if pred.size(1) != target.size(1): # 有可能顺序替换了 - raise RuntimeError("It seems like that your prediction's shape is (batch_size, num_labels, max_len)." - " It should be (batch_size, max_len, num_labels).") + if self.class_in_dim == -1: + if pred.size(1) != target.size(1): # 有可能顺序替换了 + pred = pred.transpose(1, 2) + else: + pred = pred.tranpose(-1, pred) pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) if seq_len is not None: