You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reasoning.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import numpy as np
  2. import platform
  3. import pytest
  4. from ablkit.reasoning import PrologKB, Reasoner
  5. class TestKBBase(object):
  6. def test_init(self, kb_add):
  7. assert kb_add.pseudo_label_list == list(range(10))
  8. def test_init_cache(self, kb_add_cache):
  9. assert kb_add_cache.pseudo_label_list == list(range(10))
  10. assert kb_add_cache.use_cache is True
  11. def test_logic_forward(self, kb_add):
  12. result = kb_add.logic_forward([1, 2])
  13. assert result == 3
  14. with pytest.raises(TypeError):
  15. kb_add.logic_forward([1, 2], [0.1, -0.2, 0.2, -0.3])
  16. def test_revise_at_idx(self, kb_add):
  17. result = kb_add.revise_at_idx([0, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
  18. assert result == ([[0, 2]], [2])
  19. result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [])
  20. assert result == ([], [])
  21. result = kb_add.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0, 1])
  22. assert result == ([[0, 2], [1, 1], [2, 0]], [2, 2, 2])
  23. def test_abduce_candidates(self, kb_add):
  24. result = kb_add.abduce_candidates(
  25. [0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
  26. )
  27. assert result == ([[0, 1]], [1])
  28. result = kb_add.abduce_candidates(
  29. [1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
  30. )
  31. assert result == ([[1, 0]], [1])
  32. class TestGroundKB(object):
  33. def test_init(self, kb_add_ground):
  34. assert kb_add_ground.pseudo_label_list == list(range(10))
  35. assert kb_add_ground.GKB_len_list == [2]
  36. assert kb_add_ground.GKB
  37. def test_logic_forward_ground(self, kb_add_ground):
  38. result = kb_add_ground.logic_forward([1, 2])
  39. assert result == 3
  40. def test_abduce_candidates_ground(self, kb_add_ground):
  41. result = kb_add_ground.abduce_candidates(
  42. [1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
  43. )
  44. assert result == ([(1, 0)], [1])
  45. class TestPrologKB(object):
  46. def test_init_pl1(self, kb_add_prolog):
  47. if platform.system() == "Darwin":
  48. return
  49. assert kb_add_prolog.pseudo_label_list == list(range(10))
  50. assert kb_add_prolog.pl_file == "examples/mnist_add/add.pl"
  51. def test_init_pl2(self, kb_hed):
  52. if platform.system() == "Darwin":
  53. return
  54. assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
  55. assert kb_hed.pl_file == "examples/hed/reasoning/learn_add.pl"
  56. def test_prolog_file_not_exist(self):
  57. if platform.system() == "Darwin":
  58. return
  59. pseudo_label_list = [1, 2]
  60. non_existing_file = "path/to/non_existing_file.pl"
  61. with pytest.raises(FileNotFoundError) as excinfo:
  62. PrologKB(pseudo_label_list=pseudo_label_list, pl_file=non_existing_file)
  63. assert non_existing_file in str(excinfo.value)
  64. def test_logic_forward_pl1(self, kb_add_prolog):
  65. if platform.system() == "Darwin":
  66. return
  67. result = kb_add_prolog.logic_forward([1, 2])
  68. assert result == 3
  69. def test_logic_forward_pl2(self, kb_hed):
  70. if platform.system() == "Darwin":
  71. return
  72. consist_exs = [
  73. [1, 1, "+", 0, "=", 1, 1],
  74. [1, "+", 1, "=", 1, 0],
  75. [0, "+", 0, "=", 0],
  76. ]
  77. inconsist_exs = [
  78. [1, 1, "+", 0, "=", 1, 1],
  79. [1, "+", 1, "=", 1, 0],
  80. [0, "+", 0, "=", 0],
  81. [0, "+", 0, "=", 1],
  82. ]
  83. assert kb_hed.logic_forward(consist_exs) is True
  84. assert kb_hed.logic_forward(inconsist_exs) is False
  85. def test_revise_at_idx(self, kb_add_prolog):
  86. if platform.system() == "Darwin":
  87. return
  88. result = kb_add_prolog.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0])
  89. assert result == ([[0, 2]], [2])
  90. class TestReaonser(object):
  91. def test_reasoner_init(self, reasoner_instance):
  92. assert reasoner_instance.dist_func == "confidence"
  93. def test_invalid_predefined_dist_func(self, kb_add):
  94. with pytest.raises(NotImplementedError) as excinfo:
  95. Reasoner(kb_add, "invalid_dist_func")
  96. assert (
  97. 'Valid options for predefined dist_func include "hamming", "confidence" '
  98. + 'and "avg_confidence"'
  99. in str(excinfo.value)
  100. )
  101. def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results):
  102. cost_list = [np.random.rand() for _ in candidates]
  103. return cost_list
  104. def test_user_defined_dist_func(self, kb_add):
  105. reasoner = Reasoner(kb_add, self.random_dist)
  106. assert reasoner.dist_func == self.random_dist
  107. def invalid_dist1(self, candidates):
  108. cost_list = np.array([np.random.rand() for _ in candidates])
  109. return cost_list
  110. def invalid_dist2(self, data_example, candidates, candidate_idxs, reasoning_results):
  111. cost_list = np.array([np.random.rand() for _ in candidates])
  112. return np.append(cost_list, np.random.rand())
  113. def test_invalid_user_defined_dist_func(self, kb_add, data_examples_add):
  114. with pytest.raises(ValueError) as excinfo:
  115. Reasoner(kb_add, self.invalid_dist1)
  116. assert "User-defined dist_func must have exactly four parameters" in str(excinfo.value)
  117. with pytest.raises(ValueError) as excinfo:
  118. reasoner = Reasoner(kb_add, self.invalid_dist2)
  119. reasoner.batch_abduce(data_examples_add)
  120. assert (
  121. "The length of the array returned by dist_func must be "
  122. + "equal to the number of candidates"
  123. in str(excinfo.value)
  124. )
  125. class TestBatchAbduce(object):
  126. def test_batch_abduce_add(self, kb_add, data_examples_add):
  127. reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0)
  128. reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1)
  129. reasoner3 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=0)
  130. reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1)
  131. assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  132. assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  133. assert reasoner3.batch_abduce(data_examples_add) == [
  134. [1, 7],
  135. [7, 1],
  136. [8, 9],
  137. [1, 9],
  138. ]
  139. assert reasoner4.batch_abduce(data_examples_add) == [
  140. [1, 7],
  141. [7, 1],
  142. [8, 9],
  143. [7, 3],
  144. ]
  145. def test_batch_abduce_ground(self, kb_add_ground, data_examples_add):
  146. reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0)
  147. reasoner2 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=1)
  148. reasoner3 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=0)
  149. reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1)
  150. assert reasoner1.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)]
  151. assert reasoner2.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)]
  152. assert reasoner3.batch_abduce(data_examples_add) == [
  153. (1, 7),
  154. (7, 1),
  155. (8, 9),
  156. (1, 9),
  157. ]
  158. assert reasoner4.batch_abduce(data_examples_add) == [
  159. (1, 7),
  160. (7, 1),
  161. (8, 9),
  162. (7, 3),
  163. ]
  164. def test_batch_abduce_prolog(self, kb_add_prolog, data_examples_add):
  165. if platform.system() == "Darwin":
  166. return
  167. reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
  168. reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1)
  169. reasoner3 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=0)
  170. reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1)
  171. assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  172. assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  173. assert reasoner3.batch_abduce(data_examples_add) == [
  174. [1, 7],
  175. [7, 1],
  176. [8, 9],
  177. [1, 9],
  178. ]
  179. assert reasoner4.batch_abduce(data_examples_add) == [
  180. [1, 7],
  181. [7, 1],
  182. [8, 9],
  183. [7, 3],
  184. ]
  185. def test_batch_abduce_zoopt(self, kb_add_prolog, data_examples_add):
  186. if platform.system() == "Darwin":
  187. return
  188. reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
  189. reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)
  190. assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
  191. assert reasoner2.batch_abduce(data_examples_add) == [
  192. [1, 7],
  193. [7, 1],
  194. [8, 9],
  195. [7, 3],
  196. ]
  197. def test_batch_abduce_hwf1(self, kb_hwf1, data_examples_hwf):
  198. reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0)
  199. reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0)
  200. reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0)
  201. res = reasoner1.batch_abduce(data_examples_hwf)
  202. assert res == [
  203. ["1", "+", "2"],
  204. ["8", "times", "8"],
  205. [],
  206. ["4", "-", "6", "div", "8"],
  207. ]
  208. res = reasoner2.batch_abduce(data_examples_hwf)
  209. assert res == [["1", "+", "2"], [], [], []]
  210. res = reasoner3.batch_abduce(data_examples_hwf)
  211. assert res == [
  212. ["1", "+", "2"],
  213. ["8", "times", "8"],
  214. [],
  215. ["4", "-", "6", "div", "8"],
  216. ]
  217. def test_batch_abduce_hwf2(self, kb_hwf2, data_examples_hwf):
  218. reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0)
  219. reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0)
  220. reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0)
  221. res = reasoner1.batch_abduce(data_examples_hwf)
  222. assert res == [
  223. ["1", "+", "2"],
  224. ["7", "times", "9"],
  225. ["8", "times", "8"],
  226. ["5", "-", "8", "div", "8"],
  227. ]
  228. res = reasoner2.batch_abduce(data_examples_hwf)
  229. assert res == [
  230. ["1", "+", "2"],
  231. ["7", "times", "9"],
  232. [],
  233. ["5", "-", "8", "div", "8"],
  234. ]
  235. res = reasoner3.batch_abduce(data_examples_hwf)
  236. assert res == [
  237. ["1", "+", "2"],
  238. ["7", "times", "9"],
  239. ["8", "times", "8"],
  240. ["5", "-", "8", "div", "8"],
  241. ]

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.