|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- from typing import Any, List, Mapping, Optional
-
- import numpy as np
-
- from ..structures import ListData
- from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist
- from .base_kb import BaseKB
- from .search_engine import BFS, BaseSearchEngine
-
-
- class ReasonerBase:
- def __init__(
- self,
- kb: BaseKB,
- dist_func: str = "confidence",
- mapping: Optional[Mapping] = None,
- search_engine: Optional[BaseSearchEngine] = None,
- use_cache: bool = False,
- cache_file: Optional[str] = None,
- cache_size: Optional[int] = 4096,
- ):
- """
- Base class for all reasoner in the ABL system.
-
- Parameters
- ----------
- kb : BaseKB
- The knowledge base to be used for reasoning.
- dist_func : str, optional
- The distance function to be used. Can be "hamming" or "confidence". Default is "confidence".
- mapping : dict, optional
- A mapping of indices to labels. If None, a default mapping is generated.
- use_zoopt : bool, optional
- Whether to use the Zoopt library for optimization. Default is False.
-
- Raises
- ------
- NotImplementedError
- If the specified distance function is neither "hamming" nor "confidence".
- """
-
- if not isinstance(kb, BaseKB):
- raise ValueError("The kb should be of type BaseKB.")
- self.kb = kb
-
- if dist_func not in ["hamming", "confidence"]:
- raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.")
- self.dist_func = dist_func
-
- if mapping is None:
- self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
- else:
- if not isinstance(mapping, dict):
- raise ValueError("mapping must be of type dict")
-
- for key, value in mapping.items():
- if not isinstance(key, int):
- raise ValueError("All keys in the mapping must be integers")
-
- if value not in self.kb.pseudo_label_list:
- raise ValueError("All values in the mapping must be in the pseudo_label_list")
-
- self.mapping = mapping
- self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))
-
- if search_engine is None:
- self.search_engine = BFS()
- else:
- if not isinstance(search_engine, BaseSearchEngine):
- raise ValueError("The search_engine should be of type BaseSearchEngine.")
- else:
- self.search_engine = search_engine
-
- self.use_cache = use_cache
- self.cache_file = cache_file
- if self.use_cache:
- if not hasattr(self, "get_key"):
- raise NotImplementedError("If use_cache is True, get_key should be implemented.")
- key_func = self.get_key
- else:
- key_func = lambda x: x
- self.cache = Cache[ListData, List[List[Any]]](
- func=self.abduce_candidates,
- cache=self.use_cache,
- cache_file=self.cache_file,
- key_func=key_func,
- max_size=cache_size,
- )
-
- def abduce(
- self,
- data_sample: ListData,
- max_revision: int = -1,
- require_more_revision: int = 0,
- ):
- """
- Perform revision by abduction on the given data.
-
- Parameters
- ----------
- pred_prob : list
- List of probabilities for predicted results.
- pred_pseudo_label : list
- List of predicted pseudo labels.
- y : any
- Ground truth for the predicted results.
- max_revision : int or float, optional
- Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
- If -1, any revisions are allowed. Defaults to -1.
- require_more_revision : int, optional
- Number of additional revisions to require. Defaults to 0.
-
- Returns
- -------
- list
- The abduced revisions.
- """
- symbol_num = data_sample.elements_num("pred_pseudo_label")
- max_revision_num = calculate_revision_num(max_revision, symbol_num)
- data_sample.set_metainfo(dict(symbol_num=symbol_num))
-
- candidates = self.cache.get(data_sample, max_revision_num, require_more_revision)
- candidate = self.select_one_candidate(data_sample, candidates)
- return candidate
-
- def abduce_candidates(
- self,
- data_sample: ListData,
- max_revision_num: int = -1,
- require_more_revision: int = 0,
- ):
- """
- Perform revision by abduction on the given data.
-
- Parameters
- ----------
- pred_prob : list
- List of probabilities for predicted results.
- pred_pseudo_label : list
- List of predicted pseudo labels.
- y : any
- Ground truth for the predicted results.
- max_revision : int or float, optional
- Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
- If -1, any revisions are allowed. Defaults to -1.
- require_more_revision : int, optional
- Number of additional revisions to require. Defaults to 0.
-
- Returns
- -------
- list
- The abduced revisions.
- """
-
- if hasattr(self.kb, "abduce_candidates"):
- candidates = self.kb.abduce_candidates(
- data_sample, max_revision_num, require_more_revision
- )
- elif hasattr(self.kb, "revise_at_idx"):
- candidates = []
- gen = self.search_engine.generator(
- data_sample,
- max_revision_num=max_revision_num,
- require_more_revision=require_more_revision,
- )
- send_signal = True
- for revision_idx in gen:
- candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
- if len(candidates) > 0 and send_signal:
- try:
- revision_idx = gen.send("success")
- candidates.extend(self.kb.revise_at_idx(data_sample, revision_idx))
- send_signal = False
- except StopIteration:
- break
- else:
- raise NotImplementedError(
- "The kb should either implement abduce_candidates or revise_at_idx."
- )
- return candidates
-
- def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]):
- """
- Get one candidate. If multiple candidates exist, return the one with minimum cost.
-
- Parameters
- ----------
- pred_pseudo_label : list
- The pseudo label to be used for selecting a candidate.
- pred_prob : list
- Probabilities of the predictions.
- candidates : list
- List of candidate abduction result.
-
- Returns
- -------
- list
- The chosen candidate based on minimum cost.
- If no candidates, an empty list is returned.
- """
- if len(candidates) == 0:
- return []
- elif len(candidates) == 1:
- return candidates[0]
- else:
- cost_array = self._get_dist_list(data_sample, candidates)
- candidate = candidates[np.argmin(cost_array)]
- return candidate
-
- def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]):
- """
- Get the list of costs between each pseudo label and candidate.
-
- Parameters
- ----------
- pred_pseudo_label : list
- The pseudo label to be used for computing costs of candidates.
- pred_prob : list
- Probabilities of the predictions. Used when distance function is "confidence".
- candidates : list
- List of candidate abduction result.
-
- Returns
- -------
- numpy.ndarray
- Array of computed costs for each candidate.
- """
- if self.dist_func == "hamming":
- return hamming_dist(data_sample["pred_pseudo_label"][0], candidates)
-
- elif self.dist_func == "confidence":
- candidates = [[self.remapping[x] for x in c] for c in candidates]
- return confidence_dist(data_sample["pred_prob"][0], candidates)
-
- def batch_abduce(
- self,
- data_samples: ListData,
- max_revision: int = -1,
- require_more_revision: int = 0,
- ):
- """
- Perform abduction on the given data in batches.
-
- Parameters
- ----------
- pred_prob : list
- List of probabilities for predicted results.
- pred_pseudo_label : list
- List of predicted pseudo labels.
- Y : list
- List of ground truths for the predicted results.
- max_revision : int or float, optional
- Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
- If -1, use all revisions. Defaults to -1.
- require_more_revision : int, optional
- Number of additional revisions to require. Defaults to 0.
-
- Returns
- -------
- list
- The abduced revisions in batches.
- """
- abduced_pseudo_label = [
- self.abduce(
- data_sample,
- max_revision=max_revision,
- require_more_revision=require_more_revision,
- )
- for data_sample in data_samples
- ]
- data_samples.abduced_pseudo_label = abduced_pseudo_label
- return abduced_pseudo_label
|