|
|
@@ -221,8 +221,7 @@ class CrossEntropyLoss(LossBase): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pred=None, target=None, padding_idx=-100): |
|
|
|
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要 |
|
|
|
# TODO (16, 4) |
|
|
|
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际需要(16,4) |
|
|
|
super(CrossEntropyLoss, self).__init__() |
|
|
|
self._init_param_map(pred=pred, target=target) |
|
|
|
self.padding_idx = padding_idx |
|
|
|