Browse Source

[ENH] Change dist_func to four parameters

pull/1/head
Tony-HYX 1 year ago
parent
commit
a96cdfd6c2
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      tests/test_reasoning.py

+ 3
- 3
tests/test_reasoning.py View File

@@ -101,7 +101,7 @@ class TestReaonser(object):
excinfo.value excinfo.value
) )
def random_dist(self, data_sample, candidates, reasoning_results):
def random_dist(self, data_sample, candidates, candidate_idxs, reasoning_results):
cost_list = [np.random.rand() for _ in candidates] cost_list = [np.random.rand() for _ in candidates]
return cost_list return cost_list
@@ -113,14 +113,14 @@ class TestReaonser(object):
cost_list = np.array([np.random.rand() for _ in candidates]) cost_list = np.array([np.random.rand() for _ in candidates])
return cost_list return cost_list
def invalid_dist2(self, data_sample, candidates, reasoning_results):
def invalid_dist2(self, data_sample, candidates, candidate_idxs, reasoning_results):
cost_list = np.array([np.random.rand() for _ in candidates]) cost_list = np.array([np.random.rand() for _ in candidates])
return np.append(cost_list, np.random.rand()) return np.append(cost_list, np.random.rand())
def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add): def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add):
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
Reasoner(kb_add, self.invalid_dist1) Reasoner(kb_add, self.invalid_dist1)
assert 'User-defined dist_func must have exactly three parameters' in str(
assert 'User-defined dist_func must have exactly four parameters' in str(
excinfo.value excinfo.value
) )
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:


Loading…
Cancel
Save