@@ -118,7 +118,7 @@ class LossBase(object): | |||||
if not self._checked: | if not self._checked: | ||||
for keys, val in pred_dict.items(): | for keys, val in pred_dict.items(): | ||||
if keys in target_dict.keys(): | if keys in target_dict.keys(): | ||||
duplicated.append(keys) | |||||
duplicated.append(param_map[keys]) | |||||
param_val_dict = {} | param_val_dict = {} | ||||
for keys, val in pred_dict.items(): | for keys, val in pred_dict.items(): | ||||
@@ -126,11 +126,10 @@ class LossBase(object): | |||||
for keys, val in target_dict.items(): | for keys, val in target_dict.items(): | ||||
param_val_dict.update({keys: val}) | param_val_dict.update({keys: val}) | ||||
# TODO: use the origin key to raise error | |||||
if not self._checked: | if not self._checked: | ||||
for keys in args: | for keys in args: | ||||
if param_map[keys] not in param_val_dict.keys(): | if param_map[keys] not in param_val_dict.keys(): | ||||
missing.append(keys) | |||||
missing.append(param_map[keys]) | |||||
if len(duplicated) > 0 or len(missing) > 0: | if len(duplicated) > 0 or len(missing) > 0: | ||||
raise CheckError( | raise CheckError( | ||||
@@ -300,3 +300,13 @@ class TestLoss_v2(unittest.TestCase): | |||||
b = torch.tensor([1, 0, 4]) | b = torch.tensor([1, 0, 4]) | ||||
ans = l1({"my_predict": a}, {"my_truth": b}) | ans = l1({"my_predict": a}, {"my_truth": b}) | ||||
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) | self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) | ||||
def test_check_error(self): | |||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
b = torch.tensor([1, 0, 4]) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||||
with self.assertRaises(Exception): | |||||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) |