You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conftest.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import numpy as np
  2. import pytest
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from abl.data.structures import ListData
  7. from abl.learning import BasicNN
  8. from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner
  9. class LeNet5(nn.Module):
  10. def __init__(self, num_classes=10, image_size=(28, 28)):
  11. super(LeNet5, self).__init__()
  12. self.conv1 = nn.Sequential(
  13. nn.Conv2d(1, 6, 3, padding=1),
  14. nn.ReLU(),
  15. nn.MaxPool2d(kernel_size=2, stride=2),
  16. )
  17. self.conv2 = nn.Sequential(
  18. nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)
  19. )
  20. self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU())
  21. feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2
  22. num_features = 16 * feature_map_size[0] * feature_map_size[1]
  23. self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU())
  24. self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU())
  25. self.fc3 = nn.Linear(84, num_classes)
  26. def forward(self, x):
  27. x = self.conv1(x)
  28. x = self.conv2(x)
  29. x = self.conv3(x)
  30. x = torch.flatten(x, 1)
  31. x = self.fc1(x)
  32. x = self.fc2(x)
  33. x = self.fc3(x)
  34. return x
  35. # Fixture for BasicNN instance
  36. @pytest.fixture
  37. def basic_nn_instance():
  38. model = LeNet5()
  39. loss_fn = nn.CrossEntropyLoss()
  40. optimizer = optim.Adam(model.parameters())
  41. return BasicNN(model, loss_fn, optimizer)
  42. # Fixture for base_model instance
  43. @pytest.fixture
  44. def base_model_instance():
  45. model = LeNet5()
  46. loss_fn = nn.CrossEntropyLoss()
  47. optimizer = optim.Adam(model.parameters())
  48. return BasicNN(model, loss_fn, optimizer)
  49. # Fixture for ListData instance
  50. @pytest.fixture
  51. def list_data_instance():
  52. data_examples = ListData()
  53. data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
  54. data_examples.Y = [1, 2, 3]
  55. data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
  56. return data_examples
  57. @pytest.fixture
  58. def data_examples_add():
  59. # favor 1 in first one
  60. prob1 = [
  61. [0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
  62. [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
  63. ]
  64. # favor 7 in first one
  65. prob2 = [
  66. [0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
  67. [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
  68. ]
  69. data_examples_add = ListData()
  70. data_examples_add.X = None
  71. data_examples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
  72. data_examples_add.pred_prob = [prob1, prob2, prob1, prob2]
  73. data_examples_add.Y = [8, 8, 17, 10]
  74. return data_examples_add
  75. @pytest.fixture
  76. def data_examples_hwf():
  77. data_examples_hwf = ListData()
  78. data_examples_hwf.X = None
  79. data_examples_hwf.pred_pseudo_label = [
  80. ["5", "+", "2"],
  81. ["5", "+", "9"],
  82. ["5", "+", "9"],
  83. ["5", "-", "8", "8", "8"],
  84. ]
  85. data_examples_hwf.pred_prob = [None, None, None, None]
  86. data_examples_hwf.Y = [3, 64, 65, 3.17]
  87. return data_examples_hwf
  88. class AddKB(KBBase):
  89. def __init__(self, pseudo_label_list=list(range(10)), use_cache=False):
  90. super().__init__(pseudo_label_list, use_cache=use_cache)
  91. def logic_forward(self, nums):
  92. return sum(nums)
  93. class AddGroundKB(GroundKB):
  94. def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
  95. super().__init__(pseudo_label_list, GKB_len_list)
  96. def logic_forward(self, nums):
  97. return sum(nums)
  98. class HwfKB(KBBase):
  99. def __init__(
  100. self,
  101. pseudo_label_list=[
  102. "1",
  103. "2",
  104. "3",
  105. "4",
  106. "5",
  107. "6",
  108. "7",
  109. "8",
  110. "9",
  111. "+",
  112. "-",
  113. "times",
  114. "div",
  115. ],
  116. max_err=1e-3,
  117. use_cache=False,
  118. ):
  119. super().__init__(pseudo_label_list, max_err, use_cache)
  120. def _valid_candidate(self, formula):
  121. if len(formula) % 2 == 0:
  122. return False
  123. for i in range(len(formula)):
  124. if i % 2 == 0 and formula[i] not in [
  125. "1",
  126. "2",
  127. "3",
  128. "4",
  129. "5",
  130. "6",
  131. "7",
  132. "8",
  133. "9",
  134. ]:
  135. return False
  136. if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
  137. return False
  138. return True
  139. def logic_forward(self, formula):
  140. if not self._valid_candidate(formula):
  141. return None
  142. mapping = {str(i): str(i) for i in range(1, 10)}
  143. mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
  144. formula = [mapping[f] for f in formula]
  145. return eval("".join(formula))
  146. class HedKB(PrologKB):
  147. def __init__(self, pseudo_label_list, pl_file):
  148. super().__init__(pseudo_label_list, pl_file)
  149. def consist_rule(self, exs, rules):
  150. rules = str(rules).replace("'", "")
  151. pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
  152. return len(list(self.prolog.query(pl_query))) != 0
  153. @pytest.fixture
  154. def kb_add():
  155. return AddKB()
  156. @pytest.fixture
  157. def kb_add_cache():
  158. return AddKB(use_cache=True)
  159. @pytest.fixture
  160. def kb_add_ground():
  161. return AddGroundKB()
  162. @pytest.fixture
  163. def kb_add_prolog():
  164. kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl")
  165. return kb
  166. @pytest.fixture
  167. def kb_hwf1():
  168. return HwfKB(max_err=0.1)
  169. @pytest.fixture
  170. def kb_hwf2():
  171. return HwfKB(max_err=1)
  172. @pytest.fixture
  173. def kb_hed():
  174. kb = HedKB(
  175. pseudo_label_list=[1, 0, "+", "="],
  176. pl_file="examples/hed/reasoning/learn_add.pl",
  177. )
  178. return kb
  179. @pytest.fixture
  180. def reasoner_instance(kb_add):
  181. return Reasoner(kb_add, "confidence")

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.