You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kb.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import openml
  2. from z3 import If, Implies, Int, Not, Solver, Sum, sat # noqa: F401
  3. from ablkit.reasoning import KBBase
  4. class ZooKB(KBBase):
  5. def __init__(self):
  6. super().__init__(pseudo_label_list=list(range(7)), use_cache=False)
  7. self.solver = Solver()
  8. # Load information of Zoo dataset
  9. dataset = openml.datasets.get_dataset(
  10. dataset_id=62,
  11. download_data=False,
  12. download_qualities=False,
  13. download_features_meta_data=False,
  14. )
  15. X, y, categorical_indicator, attribute_names = dataset.get_data(
  16. target=dataset.default_target_attribute
  17. )
  18. self.attribute_names = attribute_names
  19. self.target_names = y.cat.categories.tolist()
  20. # print("Attribute names are: ", self.attribute_names)
  21. # print("Target names are: ", self.target_names)
  22. # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] # noqa: E501
  23. # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] # noqa: E501
  24. # Define variables
  25. for name in self.attribute_names + self.target_names:
  26. exec(
  27. f"globals()['{name}'] = Int('{name}')"
  28. ) # or use dict to create var and modify rules
  29. # Define rules
  30. rules = [
  31. Implies(milk == 1, mammal == 1),
  32. Implies(mammal == 1, milk == 1),
  33. Implies(mammal == 1, backbone == 1),
  34. Implies(mammal == 1, breathes == 1),
  35. Implies(feathers == 1, bird == 1),
  36. Implies(bird == 1, feathers == 1),
  37. Implies(bird == 1, eggs == 1),
  38. Implies(bird == 1, backbone == 1),
  39. Implies(bird == 1, breathes == 1),
  40. Implies(bird == 1, legs == 2),
  41. Implies(bird == 1, tail == 1),
  42. Implies(reptile == 1, backbone == 1),
  43. Implies(reptile == 1, breathes == 1),
  44. Implies(reptile == 1, tail == 1),
  45. Implies(fish == 1, aquatic == 1),
  46. Implies(fish == 1, toothed == 1),
  47. Implies(fish == 1, backbone == 1),
  48. Implies(fish == 1, Not(breathes == 1)),
  49. Implies(fish == 1, fins == 1),
  50. Implies(fish == 1, legs == 0),
  51. Implies(fish == 1, tail == 1),
  52. Implies(amphibian == 1, eggs == 1),
  53. Implies(amphibian == 1, aquatic == 1),
  54. Implies(amphibian == 1, backbone == 1),
  55. Implies(amphibian == 1, breathes == 1),
  56. Implies(amphibian == 1, legs == 4),
  57. Implies(insect == 1, eggs == 1),
  58. Implies(insect == 1, Not(backbone == 1)),
  59. Implies(insect == 1, legs == 6),
  60. Implies(invertebrate == 1, Not(backbone == 1)),
  61. ]
  62. # Define weights and sum of violated weights
  63. self.weights = {rule: 1 for rule in rules}
  64. self.total_violation_weight = Sum(
  65. [If(Not(rule), self.weights[rule], 0) for rule in self.weights]
  66. )
  67. def logic_forward(self, pseudo_label, data_point):
  68. attribute_names, target_names = self.attribute_names, self.target_names
  69. solver = self.solver
  70. total_violation_weight = self.total_violation_weight
  71. pseudo_label, data_point = pseudo_label[0], data_point[0]
  72. self.solver.reset()
  73. for name, value in zip(attribute_names, data_point):
  74. solver.add(eval(f"{name} == {value}"))
  75. for cate, name in zip(self.pseudo_label_list, target_names):
  76. value = 1 if (cate == pseudo_label) else 0
  77. solver.add(eval(f"{name} == {value}"))
  78. if solver.check() == sat:
  79. model = solver.model()
  80. total_weight = model.evaluate(total_violation_weight)
  81. return total_weight.as_long()
  82. else:
  83. # No solution found
  84. return 1e10

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.