You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conftest.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import numpy as np
  2. import platform
  3. import pytest
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from ablkit.data.structures import ListData
  8. from ablkit.learning import BasicNN
  9. from ablkit.reasoning import GroundKB, KBBase, PrologKB, Reasoner
  10. class LeNet5(nn.Module):
  11. def __init__(self, num_classes=10, image_size=(28, 28)):
  12. super(LeNet5, self).__init__()
  13. self.conv1 = nn.Sequential(
  14. nn.Conv2d(1, 6, 3, padding=1),
  15. nn.ReLU(),
  16. nn.MaxPool2d(kernel_size=2, stride=2),
  17. )
  18. self.conv2 = nn.Sequential(
  19. nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)
  20. )
  21. self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU())
  22. feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2
  23. num_features = 16 * feature_map_size[0] * feature_map_size[1]
  24. self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU())
  25. self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU())
  26. self.fc3 = nn.Linear(84, num_classes)
  27. def forward(self, x):
  28. x = self.conv1(x)
  29. x = self.conv2(x)
  30. x = self.conv3(x)
  31. x = torch.flatten(x, 1)
  32. x = self.fc1(x)
  33. x = self.fc2(x)
  34. x = self.fc3(x)
  35. return x
  36. # Fixture for BasicNN instance
  37. @pytest.fixture
  38. def basic_nn_instance():
  39. model = LeNet5()
  40. loss_fn = nn.CrossEntropyLoss()
  41. optimizer = optim.Adam(model.parameters())
  42. return BasicNN(model, loss_fn, optimizer)
  43. # Fixture for base_model instance
  44. @pytest.fixture
  45. def base_model_instance():
  46. model = LeNet5()
  47. loss_fn = nn.CrossEntropyLoss()
  48. optimizer = optim.Adam(model.parameters())
  49. return BasicNN(model, loss_fn, optimizer)
  50. # Fixture for ListData instance
  51. @pytest.fixture
  52. def list_data_instance():
  53. data_examples = ListData()
  54. data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
  55. data_examples.Y = [1, 2, 3]
  56. data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
  57. return data_examples
  58. @pytest.fixture
  59. def data_examples_add():
  60. # favor 1 in first one
  61. prob1 = [
  62. [0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
  63. [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
  64. ]
  65. # favor 7 in first one
  66. prob2 = [
  67. [0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
  68. [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
  69. ]
  70. data_examples_add = ListData()
  71. data_examples_add.X = None
  72. data_examples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
  73. data_examples_add.pred_prob = [prob1, prob2, prob1, prob2]
  74. data_examples_add.Y = [8, 8, 17, 10]
  75. return data_examples_add
  76. @pytest.fixture
  77. def data_examples_hwf():
  78. data_examples_hwf = ListData()
  79. data_examples_hwf.X = None
  80. data_examples_hwf.pred_pseudo_label = [
  81. ["5", "+", "2"],
  82. ["5", "+", "9"],
  83. ["5", "+", "9"],
  84. ["5", "-", "8", "8", "8"],
  85. ]
  86. data_examples_hwf.pred_prob = [None, None, None, None]
  87. data_examples_hwf.Y = [3, 64, 65, 3.17]
  88. return data_examples_hwf
  89. class AddKB(KBBase):
  90. def __init__(self, pseudo_label_list=list(range(10)), use_cache=False):
  91. super().__init__(pseudo_label_list, use_cache=use_cache)
  92. def logic_forward(self, nums):
  93. return sum(nums)
  94. class AddGroundKB(GroundKB):
  95. def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
  96. super().__init__(pseudo_label_list, GKB_len_list)
  97. def logic_forward(self, nums):
  98. return sum(nums)
  99. class HwfKB(KBBase):
  100. def __init__(
  101. self,
  102. pseudo_label_list=[
  103. "1",
  104. "2",
  105. "3",
  106. "4",
  107. "5",
  108. "6",
  109. "7",
  110. "8",
  111. "9",
  112. "+",
  113. "-",
  114. "times",
  115. "div",
  116. ],
  117. max_err=1e-3,
  118. use_cache=False,
  119. ):
  120. super().__init__(pseudo_label_list, max_err, use_cache)
  121. def _valid_candidate(self, formula):
  122. if len(formula) % 2 == 0:
  123. return False
  124. for i in range(len(formula)):
  125. if i % 2 == 0 and formula[i] not in [
  126. "1",
  127. "2",
  128. "3",
  129. "4",
  130. "5",
  131. "6",
  132. "7",
  133. "8",
  134. "9",
  135. ]:
  136. return False
  137. if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
  138. return False
  139. return True
  140. def logic_forward(self, formula):
  141. if not self._valid_candidate(formula):
  142. return None
  143. mapping = {str(i): str(i) for i in range(1, 10)}
  144. mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
  145. formula = [mapping[f] for f in formula]
  146. return eval("".join(formula))
  147. class HedKB(PrologKB):
  148. def __init__(self, pseudo_label_list, pl_file):
  149. super().__init__(pseudo_label_list, pl_file)
  150. def consist_rule(self, exs, rules):
  151. rules = str(rules).replace("'", "")
  152. pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
  153. return len(list(self.prolog.query(pl_query))) != 0
  154. @pytest.fixture
  155. def kb_add():
  156. return AddKB()
  157. @pytest.fixture
  158. def kb_add_cache():
  159. return AddKB(use_cache=True)
  160. @pytest.fixture
  161. def kb_add_ground():
  162. return AddGroundKB()
  163. @pytest.fixture
  164. def kb_add_prolog():
  165. if platform.system() == "Darwin":
  166. return
  167. kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl")
  168. return kb
  169. @pytest.fixture
  170. def kb_hwf1():
  171. return HwfKB(max_err=0.1)
  172. @pytest.fixture
  173. def kb_hwf2():
  174. return HwfKB(max_err=1)
  175. @pytest.fixture
  176. def kb_hed():
  177. if platform.system() == "Darwin":
  178. return
  179. kb = HedKB(
  180. pseudo_label_list=[1, 0, "+", "="],
  181. pl_file="examples/hed/reasoning/learn_add.pl",
  182. )
  183. return kb
  184. @pytest.fixture
  185. def reasoner_instance(kb_add):
  186. return Reasoner(kb_add, "confidence")

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.