|
|
@@ -206,7 +206,11 @@ class CrossEntropyLoss(LossBase): |
|
|
|
|
|
|
|
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` |
|
|
|
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` |
|
|
|
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 |
|
|
|
:param seq_len: 句子的长度, 长度之外的token不会计算loss。 |
|
|
|
:param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes) |
|
|
|
或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 |
|
|
|
二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, |
|
|
|
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 |
|
|
|
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 |
|
|
|
传入seq_len. |
|
|
|
:param str reduction: 支持 `mean` ,`sum` 和 `none` . |
|
|
@@ -217,18 +221,21 @@ class CrossEntropyLoss(LossBase): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): |
|
|
|
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'): |
|
|
|
super(CrossEntropyLoss, self).__init__() |
|
|
|
self._init_param_map(pred=pred, target=target, seq_len=seq_len) |
|
|
|
self.padding_idx = padding_idx |
|
|
|
assert reduction in ('mean', 'sum', 'none') |
|
|
|
self.reduction = reduction |
|
|
|
self.class_in_dim = class_in_dim |
|
|
|
|
|
|
|
def get_loss(self, pred, target, seq_len=None): |
|
|
|
if pred.dim() > 2: |
|
|
|
if pred.size(1) != target.size(1): # 有可能顺序替换了 |
|
|
|
raise RuntimeError("It seems like that your prediction's shape is (batch_size, num_labels, max_len)." |
|
|
|
" It should be (batch_size, max_len, num_labels).") |
|
|
|
if self.class_in_dim == -1: |
|
|
|
if pred.size(1) != target.size(1): # 有可能顺序替换了 |
|
|
|
pred = pred.transpose(1, 2) |
|
|
|
else: |
|
|
|
pred = pred.tranpose(-1, pred) |
|
|
|
pred = pred.reshape(-1, pred.size(-1)) |
|
|
|
target = target.reshape(-1) |
|
|
|
if seq_len is not None: |
|
|
|