| @@ -118,7 +118,7 @@ class LossBase(object): | |||||
| if not self._checked: | if not self._checked: | ||||
| for keys, val in pred_dict.items(): | for keys, val in pred_dict.items(): | ||||
| if keys in target_dict.keys(): | if keys in target_dict.keys(): | ||||
| duplicated.append(keys) | |||||
| duplicated.append(param_map[keys]) | |||||
| param_val_dict = {} | param_val_dict = {} | ||||
| for keys, val in pred_dict.items(): | for keys, val in pred_dict.items(): | ||||
| @@ -126,11 +126,10 @@ class LossBase(object): | |||||
| for keys, val in target_dict.items(): | for keys, val in target_dict.items(): | ||||
| param_val_dict.update({keys: val}) | param_val_dict.update({keys: val}) | ||||
| # TODO: use the origin key to raise error | |||||
| if not self._checked: | if not self._checked: | ||||
| for keys in args: | for keys in args: | ||||
| if param_map[keys] not in param_val_dict.keys(): | if param_map[keys] not in param_val_dict.keys(): | ||||
| missing.append(keys) | |||||
| missing.append(param_map[keys]) | |||||
| if len(duplicated) > 0 or len(missing) > 0: | if len(duplicated) > 0 or len(missing) > 0: | ||||
| raise CheckError( | raise CheckError( | ||||
| @@ -300,3 +300,13 @@ class TestLoss_v2(unittest.TestCase): | |||||
| b = torch.tensor([1, 0, 4]) | b = torch.tensor([1, 0, 4]) | ||||
| ans = l1({"my_predict": a}, {"my_truth": b}) | ans = l1({"my_predict": a}, {"my_truth": b}) | ||||
| self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) | self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) | ||||
| def test_check_error(self): | |||||
| l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
| a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
| b = torch.tensor([1, 0, 4]) | |||||
| with self.assertRaises(Exception): | |||||
| ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||||
| with self.assertRaises(Exception): | |||||
| ans = l1({"my_predict": a}, {"truth": b, "my": a}) | |||||