You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reasoning.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import pytest
  2. from abl.reasoning import PrologKB, Reasoner
  3. class TestKBBase(object):
  4. def test_init(self, kb_add):
  5. assert kb_add.pseudo_label_list == list(range(10))
  6. def test_init_cache(self, kb_add_cache):
  7. assert kb_add_cache.pseudo_label_list == list(range(10))
  8. assert kb_add_cache.use_cache is True
  9. def test_logic_forward(self, kb_add):
  10. result = kb_add.logic_forward([1, 2])
  11. assert result == 3
  12. def test_revise_at_idx(self, kb_add):
  13. result = kb_add.revise_at_idx([0, 2], 2, [])
  14. assert result == [[0, 2]]
  15. result = kb_add.revise_at_idx([1, 2], 2, [])
  16. assert result == []
  17. result = kb_add.revise_at_idx([1, 2], 2, [0, 1])
  18. assert result == [[0, 2], [1, 1], [2, 0]]
  19. def test_abduce_candidates(self, kb_add):
  20. result = kb_add.abduce_candidates([0, 1], 1, max_revision_num=2, require_more_revision=0)
  21. assert result == [[0, 1]]
  22. result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2, require_more_revision=0)
  23. assert result == [[1, 0]]
  24. class TestGroundKB(object):
  25. def test_init(self, kb_add_ground):
  26. assert kb_add_ground.pseudo_label_list == list(range(10))
  27. assert kb_add_ground.GKB_len_list == [2]
  28. assert kb_add_ground.GKB
  29. def test_logic_forward_ground(self, kb_add_ground):
  30. result = kb_add_ground.logic_forward([1, 2])
  31. assert result == 3
  32. def test_abduce_candidates_ground(self, kb_add_ground):
  33. result = kb_add_ground.abduce_candidates(
  34. [1, 2], 1, max_revision_num=2, require_more_revision=0
  35. )
  36. assert result == [(1, 0)]
  37. class TestPrologKB(object):
  38. def test_init_pl1(self, kb_add_prolog):
  39. assert kb_add_prolog.pseudo_label_list == list(range(10))
  40. assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl"
  41. def test_init_pl2(self, kb_hed):
  42. assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
  43. assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl"
  44. def test_prolog_file_not_exist(self):
  45. pseudo_label_list = [1, 2]
  46. non_existing_file = "path/to/non_existing_file.pl"
  47. with pytest.raises(FileNotFoundError) as excinfo:
  48. PrologKB(pseudo_label_list=pseudo_label_list, pl_file=non_existing_file)
  49. assert non_existing_file in str(excinfo.value)
  50. def test_logic_forward_pl1(self, kb_add_prolog):
  51. result = kb_add_prolog.logic_forward([1, 2])
  52. assert result == 3
  53. def test_logic_forward_pl2(self, kb_hed):
  54. consist_exs = [
  55. [1, 1, "+", 0, "=", 1, 1],
  56. [1, "+", 1, "=", 1, 0],
  57. [0, "+", 0, "=", 0],
  58. ]
  59. inconsist_exs = [
  60. [1, 1, "+", 0, "=", 1, 1],
  61. [1, "+", 1, "=", 1, 0],
  62. [0, "+", 0, "=", 0],
  63. [0, "+", 0, "=", 1],
  64. ]
  65. assert kb_hed.logic_forward(consist_exs) is True
  66. assert kb_hed.logic_forward(inconsist_exs) is False
  67. def test_revise_at_idx(self, kb_add_prolog):
  68. result = kb_add_prolog.revise_at_idx([1, 2], 2, [0])
  69. assert result == [[0, 2]]
  70. class TestReaonser(object):
  71. def test_reasoner_init(self, reasoner_instance):
  72. assert reasoner_instance.dist_func == "confidence"
  73. def test_invalid_dist_funce(kb_add):
  74. with pytest.raises(NotImplementedError) as excinfo:
  75. Reasoner(kb_add, "invalid_dist_func")
  76. assert 'Valid options for dist_func include "hamming" and "confidence"' in str(
  77. excinfo.value
  78. )
  79. class test_batch_abduce(object):
  80. def test_batch_abduce_add(self, kb_add, data_samples_add):
  81. reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0)
  82. reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1)
  83. reasoner3 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=0)
  84. reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1)
  85. assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
  86. assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
  87. assert reasoner3.batch_abduce(data_samples_add) == [
  88. [1, 7],
  89. [7, 1],
  90. [8, 9],
  91. [1, 9],
  92. ]
  93. assert reasoner4.batch_abduce(data_samples_add) == [
  94. [1, 7],
  95. [7, 1],
  96. [8, 9],
  97. [7, 3],
  98. ]
  99. def test_batch_abduce_ground(self, kb_add_ground, data_samples_add):
  100. reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0)
  101. reasoner2 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=1)
  102. reasoner3 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=0)
  103. reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1)
  104. assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
  105. assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
  106. assert reasoner3.batch_abduce(data_samples_add) == [
  107. (1, 7),
  108. (7, 1),
  109. (8, 9),
  110. (1, 9),
  111. ]
  112. assert reasoner4.batch_abduce(data_samples_add) == [
  113. (1, 7),
  114. (7, 1),
  115. (8, 9),
  116. (7, 3),
  117. ]
  118. def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add):
  119. reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
  120. reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1)
  121. reasoner3 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=0)
  122. reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1)
  123. assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
  124. assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
  125. assert reasoner3.batch_abduce(data_samples_add) == [
  126. [1, 7],
  127. [7, 1],
  128. [8, 9],
  129. [1, 9],
  130. ]
  131. assert reasoner4.batch_abduce(data_samples_add) == [
  132. [1, 7],
  133. [7, 1],
  134. [8, 9],
  135. [7, 3],
  136. ]
  137. def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add):
  138. reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
  139. reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)
  140. assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
  141. assert reasoner2.batch_abduce(data_samples_add) == [
  142. [1, 7],
  143. [7, 1],
  144. [8, 9],
  145. [7, 3],
  146. ]
  147. def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf):
  148. reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0)
  149. reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0)
  150. reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0)
  151. res = reasoner1.batch_abduce(data_samples_hwf)
  152. assert res == [
  153. ["1", "+", "2"],
  154. ["8", "times", "8"],
  155. [],
  156. ["4", "-", "6", "div", "8"],
  157. ]
  158. res = reasoner2.batch_abduce(data_samples_hwf)
  159. assert res == [["1", "+", "2"], [], [], []]
  160. res = reasoner3.batch_abduce(data_samples_hwf)
  161. assert res == [
  162. ["1", "+", "2"],
  163. ["8", "times", "8"],
  164. [],
  165. ["4", "-", "6", "div", "8"],
  166. ]
  167. def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf):
  168. reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0)
  169. reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0)
  170. reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0)
  171. res = reasoner1.batch_abduce(data_samples_hwf)
  172. assert res == [
  173. ["1", "+", "2"],
  174. ["7", "times", "9"],
  175. ["8", "times", "8"],
  176. ["5", "-", "8", "div", "8"],
  177. ]
  178. res = reasoner2.batch_abduce(data_samples_hwf)
  179. assert res == [
  180. ["1", "+", "2"],
  181. ["7", "times", "9"],
  182. [],
  183. ["5", "-", "8", "div", "8"],
  184. ]
  185. res = reasoner3.batch_abduce(data_samples_hwf)
  186. assert res == [
  187. ["1", "+", "2"],
  188. ["7", "times", "9"],
  189. ["8", "times", "8"],
  190. ["5", "-", "8", "div", "8"],
  191. ]

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.