You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reasoning.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import pytest
  2. import numpy as np
  3. from abl.reasoning import PrologKB, Reasoner
  4. class TestKBBase(object):
  5. def test_init(self, kb_add):
  6. assert kb_add.pseudo_label_list == list(range(10))
  7. def test_init_cache(self, kb_add_cache):
  8. assert kb_add_cache.pseudo_label_list == list(range(10))
  9. assert kb_add_cache.use_cache is True
  10. def test_logic_forward(self, kb_add):
  11. result = kb_add.logic_forward([1, 2])
  12. assert result == 3
  13. with pytest.raises(TypeError):
  14. kb_add.logic_forward([1, 2], [0.1, -0.2, 0.2, -0.3])
  15. def test_revise_at_idx(self, kb_add):
  16. result = kb_add.revise_at_idx([0, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
  17. assert result == ([[0, 2]], [2])
  18. result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
  19. assert result == ([], [])
  20. result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0, 1])
  21. assert result == ([[0, 2], [1, 1], [2, 0]], [2, 2, 2])
  22. def test_abduce_candidates(self, kb_add):
  23. result = kb_add.abduce_candidates([0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
  24. assert result == ([[0, 1]], [1])
  25. result = kb_add.abduce_candidates([1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
  26. assert result == ([[1, 0]], [1])
  27. class TestGroundKB(object):
  28. def test_init(self, kb_add_ground):
  29. assert kb_add_ground.pseudo_label_list == list(range(10))
  30. assert kb_add_ground.GKB_len_list == [2]
  31. assert kb_add_ground.GKB
  32. def test_logic_forward_ground(self, kb_add_ground):
  33. result = kb_add_ground.logic_forward([1, 2])
  34. assert result == 3
  35. def test_abduce_candidates_ground(self, kb_add_ground):
  36. result = kb_add_ground.abduce_candidates(
  37. [1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
  38. )
  39. assert result == ([(1, 0)], [1])
  40. class TestPrologKB(object):
  41. def test_init_pl1(self, kb_add_prolog):
  42. assert kb_add_prolog.pseudo_label_list == list(range(10))
  43. assert kb_add_prolog.pl_file == "examples/mnist_add/add.pl"
  44. def test_init_pl2(self, kb_hed):
  45. assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
  46. assert kb_hed.pl_file == "examples/hed/reasoning/learn_add.pl"
  47. def test_prolog_file_not_exist(self):
  48. pseudo_label_list = [1, 2]
  49. non_existing_file = "path/to/non_existing_file.pl"
  50. with pytest.raises(FileNotFoundError) as excinfo:
  51. PrologKB(pseudo_label_list=pseudo_label_list, pl_file=non_existing_file)
  52. assert non_existing_file in str(excinfo.value)
  53. def test_logic_forward_pl1(self, kb_add_prolog):
  54. result = kb_add_prolog.logic_forward([1, 2])
  55. assert result == 3
  56. def test_logic_forward_pl2(self, kb_hed):
  57. consist_exs = [
  58. [1, 1, "+", 0, "=", 1, 1],
  59. [1, "+", 1, "=", 1, 0],
  60. [0, "+", 0, "=", 0],
  61. ]
  62. inconsist_exs = [
  63. [1, 1, "+", 0, "=", 1, 1],
  64. [1, "+", 1, "=", 1, 0],
  65. [0, "+", 0, "=", 0],
  66. [0, "+", 0, "=", 1],
  67. ]
  68. assert kb_hed.logic_forward(consist_exs) is True
  69. assert kb_hed.logic_forward(inconsist_exs) is False
  70. def test_revise_at_idx(self, kb_add_prolog):
  71. result = kb_add_prolog.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0])
  72. assert result == ([[0, 2]], [2])
  73. class TestReaonser(object):
  74. def test_reasoner_init(self, reasoner_instance):
  75. assert reasoner_instance.dist_func == "confidence"
  76. def test_invalid_predefined_dist_func(self, kb_add):
  77. with pytest.raises(NotImplementedError) as excinfo:
  78. Reasoner(kb_add, "invalid_dist_func")
  79. assert 'Valid options for predefined dist_func include "hamming" and "confidence"' in str(
  80. excinfo.value
  81. )
  82. def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results):
  83. cost_list = [np.random.rand() for _ in candidates]
  84. return cost_list
  85. def test_user_defined_dist_func(self, kb_add):
  86. reasoner = Reasoner(kb_add, self.random_dist)
  87. assert reasoner.dist_func == self.random_dist
  88. def invalid_dist1(self, candidates):
  89. cost_list = np.array([np.random.rand() for _ in candidates])
  90. return cost_list
  91. def invalid_dist2(self, data_example, candidates, candidate_idxs, reasoning_results):
  92. cost_list = np.array([np.random.rand() for _ in candidates])
  93. return np.append(cost_list, np.random.rand())
  94. def test_invalid_user_defined_dist_func(self, kb_add, data_examples_add):
  95. with pytest.raises(ValueError) as excinfo:
  96. Reasoner(kb_add, self.invalid_dist1)
  97. assert 'User-defined dist_func must have exactly four parameters' in str(
  98. excinfo.value
  99. )
  100. with pytest.raises(ValueError) as excinfo:
  101. reasoner = Reasoner(kb_add, self.invalid_dist2)
  102. reasoner.batch_abduce(data_examples_add)
  103. assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str(
  104. excinfo.value
  105. )
  106. class TestBatchAbduce(object):
  107. def test_batch_abduce_add(self, kb_add, data_examples_add):
  108. reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0)
  109. reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1)
  110. reasoner3 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=0)
  111. reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1)
  112. assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  113. assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  114. assert reasoner3.batch_abduce(data_examples_add) == [
  115. [1, 7],
  116. [7, 1],
  117. [8, 9],
  118. [1, 9],
  119. ]
  120. assert reasoner4.batch_abduce(data_examples_add) == [
  121. [1, 7],
  122. [7, 1],
  123. [8, 9],
  124. [7, 3],
  125. ]
  126. def test_batch_abduce_ground(self, kb_add_ground, data_examples_add):
  127. reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0)
  128. reasoner2 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=1)
  129. reasoner3 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=0)
  130. reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1)
  131. assert reasoner1.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)]
  132. assert reasoner2.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)]
  133. assert reasoner3.batch_abduce(data_examples_add) == [
  134. (1, 7),
  135. (7, 1),
  136. (8, 9),
  137. (1, 9),
  138. ]
  139. assert reasoner4.batch_abduce(data_examples_add) == [
  140. (1, 7),
  141. (7, 1),
  142. (8, 9),
  143. (7, 3),
  144. ]
  145. def test_batch_abduce_prolog(self, kb_add_prolog, data_examples_add):
  146. reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
  147. reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1)
  148. reasoner3 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=0)
  149. reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1)
  150. assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  151. assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  152. assert reasoner3.batch_abduce(data_examples_add) == [
  153. [1, 7],
  154. [7, 1],
  155. [8, 9],
  156. [1, 9],
  157. ]
  158. assert reasoner4.batch_abduce(data_examples_add) == [
  159. [1, 7],
  160. [7, 1],
  161. [8, 9],
  162. [7, 3],
  163. ]
  164. def test_batch_abduce_zoopt(self, kb_add_prolog, data_examples_add):
  165. reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
  166. reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)
  167. assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  168. assert reasoner2.batch_abduce(data_examples_add) == [
  169. [1, 7],
  170. [7, 1],
  171. [8, 9],
  172. [7, 3],
  173. ]
  174. def test_batch_abduce_hwf1(self, kb_hwf1, data_examples_hwf):
  175. reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0)
  176. reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0)
  177. reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0)
  178. res = reasoner1.batch_abduce(data_examples_hwf)
  179. assert res == [
  180. ["1", "+", "2"],
  181. ["8", "times", "8"],
  182. [],
  183. ["4", "-", "6", "div", "8"],
  184. ]
  185. res = reasoner2.batch_abduce(data_examples_hwf)
  186. assert res == [["1", "+", "2"], [], [], []]
  187. res = reasoner3.batch_abduce(data_examples_hwf)
  188. assert res == [
  189. ["1", "+", "2"],
  190. ["8", "times", "8"],
  191. [],
  192. ["4", "-", "6", "div", "8"],
  193. ]
  194. def test_batch_abduce_hwf2(self, kb_hwf2, data_examples_hwf):
  195. reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0)
  196. reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0)
  197. reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0)
  198. res = reasoner1.batch_abduce(data_examples_hwf)
  199. assert res == [
  200. ["1", "+", "2"],
  201. ["7", "times", "9"],
  202. ["8", "times", "8"],
  203. ["5", "-", "8", "div", "8"],
  204. ]
  205. res = reasoner2.batch_abduce(data_examples_hwf)
  206. assert res == [
  207. ["1", "+", "2"],
  208. ["7", "times", "9"],
  209. [],
  210. ["5", "-", "8", "div", "8"],
  211. ]
  212. res = reasoner3.batch_abduce(data_examples_hwf)
  213. assert res == [
  214. ["1", "+", "2"],
  215. ["7", "times", "9"],
  216. ["8", "times", "8"],
  217. ["5", "-", "8", "div", "8"],
  218. ]

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.