Browse Source

优化loss在missing和duplicate时报错的信息:返回loss初始化约定接受的key

tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
9acdb54fc8
2 changed files with 12 additions and 3 deletions
  1. +2
    -3
      fastNLP/core/losses.py
  2. +10
    -0
      test/core/test_loss.py

+ 2
- 3
fastNLP/core/losses.py View File

@@ -118,7 +118,7 @@ class LossBase(object):
if not self._checked:
for keys, val in pred_dict.items():
if keys in target_dict.keys():
duplicated.append(keys)
duplicated.append(param_map[keys])

param_val_dict = {}
for keys, val in pred_dict.items():
@@ -126,11 +126,10 @@ class LossBase(object):
for keys, val in target_dict.items():
param_val_dict.update({keys: val})

# TODO: use the origin key to raise error
if not self._checked:
for keys in args:
if param_map[keys] not in param_val_dict.keys():
missing.append(keys)
missing.append(param_map[keys])

if len(duplicated) > 0 or len(missing) > 0:
raise CheckError(


+ 10
- 0
test/core/test_loss.py View File

@@ -300,3 +300,13 @@ class TestLoss_v2(unittest.TestCase):
b = torch.tensor([1, 0, 4])
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))

def test_check_error(self):
l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
b = torch.tensor([1, 0, 4])
with self.assertRaises(Exception):
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b})

with self.assertRaises(Exception):
ans = l1({"my_predict": a}, {"truth": b, "my": a})

Loading…
Cancel
Save