You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mnist_add_kb.py 687 B

1234567891011121314151617181920
  1. from typing import Any
  2. from abl.reasoning import SearchBasedKB
  3. from abl.structures import ListData
  4. class AddKB(SearchBasedKB):
  5. def __init__(self, pseudo_label_list=list(range(10)), use_cache=True, cache_size=4096):
  6. super().__init__(
  7. pseudo_label_list=pseudo_label_list, use_cache=use_cache, cache_size=cache_size
  8. )
  9. def get_key(self, data_sample: ListData):
  10. return (data_sample.to_tuple("pred_pseudo_label"), data_sample["Y"][0])
  11. def check_equal(self, data_sample: ListData, y: Any):
  12. return self.logic_forward(data_sample) == y
  13. def logic_forward(self, data_sample):
  14. return sum(data_sample["pred_pseudo_label"][0])

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.