You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ground_kb.py 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Hashable, List
  3. from abl.structures import ListData
  4. from .base_kb import BaseKB
  5. class GroundKB(BaseKB, ABC):
  6. def __init__(self, pseudo_label_list: List) -> None:
  7. super().__init__(pseudo_label_list)
  8. self.GKB = self.construct_base()
  9. @abstractmethod
  10. def construct_base(self) -> dict:
  11. pass
  12. @abstractmethod
  13. def get_key(self, data_sample: ListData) -> Hashable:
  14. pass
  15. def key2candidates(self, key: Hashable) -> List[List[Any]]:
  16. return self.GKB[key]
  17. def filter_candidates(
  18. self,
  19. data_sample: ListData,
  20. candidates: List[List[Any]],
  21. max_revision_num: int,
  22. require_more_revision: int = 0,
  23. ) -> List[List[Any]]:
  24. return candidates
  25. def abduce_candidates(
  26. self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
  27. ):
  28. return self._abduce_by_GKB(
  29. data_sample=data_sample,
  30. max_revision_num=max_revision_num,
  31. require_more_revision=require_more_revision,
  32. )
  33. def _abduce_by_GKB(
  34. self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0
  35. ):
  36. candidates = self.key2candidates(self.get_key(data_sample))
  37. return self.filter_candidates(
  38. data_sample=data_sample,
  39. max_revision_num=max_revision_num,
  40. require_more_revision=require_more_revision,
  41. candidates=candidates,
  42. )
  43. # TODO: When the output is excessively long, use ellipses as a substitute.
  44. def __repr__(self):
  45. return (
  46. f"<{self.__class__.__name__}(\n"
  47. f" pseudo_label_list: {self.pseudo_label_list!r}\n"
  48. f" GKB: {self.GKB!r}\n"
  49. f") at {hex(id(self))}>"
  50. )

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.