From 831aa855e7bdda1152898c6a2d1e43056d0ff538 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sun, 12 Nov 2023 22:49:32 +0800 Subject: [PATCH] [FIX] fix bugs and run mnist and hwf successfully --- abl/reasoning/__init__.py | 2 +- abl/reasoning/base_kb.py | 9 +- abl/reasoning/reasoner.py | 134 +++++++++++++-------- abl/reasoning/search_engine/__init__.py | 1 - abl/reasoning/search_engine/bfs.py | 2 +- examples/mnist_add/mnist_add_example.ipynb | 8 +- examples/mnist_add/mnist_add_kb.py | 7 +- 7 files changed, 94 insertions(+), 69 deletions(-) diff --git a/abl/reasoning/__init__.py b/abl/reasoning/__init__.py index def9522..231741e 100644 --- a/abl/reasoning/__init__.py +++ b/abl/reasoning/__init__.py @@ -3,4 +3,4 @@ from .ground_kb import GroundKB from .prolog_based_kb import PrologBasedKB from .reasoner import ReasonerBase from .search_based_kb import SearchBasedKB -from .search_engine import BFS, BaseSearchEngine, Zoopt +from .search_engine import BFS, BaseSearchEngine diff --git a/abl/reasoning/base_kb.py b/abl/reasoning/base_kb.py index 848d641..b240aa8 100644 --- a/abl/reasoning/base_kb.py +++ b/abl/reasoning/base_kb.py @@ -1,17 +1,10 @@ -from abc import ABC, abstractmethod - -from ..structures import ListData +from abc import ABC class BaseKB(ABC): def __init__(self, pseudo_label_list) -> None: self.pseudo_label_list = pseudo_label_list - @abstractmethod - def abduce_candidates(self, data_sample: ListData): - """Placeholder for abduction of the knowledge base.""" - pass - # TODO: When the output is excessively long, use ellipses as a substitute. def __repr__(self): return ( diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index e3bc60a..7fdebb0 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -3,8 +3,7 @@ from typing import Any, List, Mapping, Optional import numpy as np from ..structures import ListData -from ..utils import (Cache, calculate_revision_num, confidence_dist, - hamming_dist) +from ..utils import Cache, calculate_revision_num, confidence_dist, hamming_dist from .base_kb import BaseKB from .search_engine import BFS, BaseSearchEngine @@ -81,70 +80,53 @@ class ReasonerBase: else: key_func = lambda x: x self.cache = Cache[ListData, List[List[Any]]]( - func=self.abduce, + func=self.abduce_candidates, cache=self.use_cache, cache_file=self.cache_file, key_func=key_func, max_size=cache_size, ) - def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]): + def abduce( + self, + data_sample: ListData, + max_revision: int = -1, + require_more_revision: int = 0, + ): """ - Get the list of costs between each pseudo label and candidate. + Perform revision by abduction on the given data. Parameters ---------- - pred_pseudo_label : list - The pseudo label to be used for computing costs of candidates. pred_prob : list - Probabilities of the predictions. Used when distance function is "confidence". - candidates : list - List of candidate abduction result. - - Returns - ------- - numpy.ndarray - Array of computed costs for each candidate. - """ - if self.dist_func == "hamming": - return hamming_dist(data_sample["pred_pseudo_label"][0], candidates) - - elif self.dist_func == "confidence": - candidates = [[self.remapping[x] for x in c] for c in candidates] - return confidence_dist(data_sample["pred_prob"][0], candidates) - - def select(self, data_sample: ListData, candidates: List[List[Any]]): - """ - Get one candidate. If multiple candidates exist, return the one with minimum cost. - - Parameters - ---------- + List of probabilities for predicted results. pred_pseudo_label : list - The pseudo label to be used for selecting a candidate. - pred_prob : list - Probabilities of the predictions. - candidates : list - List of candidate abduction result. + List of predicted pseudo labels. + y : any + Ground truth for the predicted results. + max_revision : int or float, optional + Maximum number of revisions to use. If float, represents the fraction of total revisions to use. + If -1, any revisions are allowed. Defaults to -1. + require_more_revision : int, optional + Number of additional revisions to require. Defaults to 0. Returns ------- list - The chosen candidate based on minimum cost. - If no candidates, an empty list is returned. + The abduced revisions. """ - if len(candidates) == 0: - return [] - elif len(candidates) == 1: - return candidates[0] - else: - cost_array = self._get_dist_list(data_sample, candidates) - candidate = candidates[np.argmin(cost_array)] - return candidate + symbol_num = data_sample.elements_num("pred_pseudo_label") + max_revision_num = calculate_revision_num(max_revision, symbol_num) + data_sample.set_metainfo(dict(symbol_num=symbol_num)) - def abduce( + candidates = self.cache.get(data_sample, max_revision_num, require_more_revision) + candidate = self.select_one_candidate(data_sample, candidates) + return candidate + + def abduce_candidates( self, data_sample: ListData, - max_revision: int = -1, + max_revision_num: int = -1, require_more_revision: int = 0, ): """ @@ -169,9 +151,6 @@ class ReasonerBase: list The abduced revisions. """ - symbol_num = data_sample.elements_num("pred_pseudo_label") - max_revision_num = calculate_revision_num(max_revision, symbol_num) - data_sample.set_metainfo(dict(symbol_num=symbol_num)) if hasattr(self.kb, "abduce_candidates"): candidates = self.kb.abduce_candidates( @@ -198,9 +177,60 @@ class ReasonerBase: raise NotImplementedError( "The kb should either implement abduce_candidates or revise_at_idx." ) + return candidates - candidate = self.select(data_sample, candidates) - return candidate + def select_one_candidate(self, data_sample: ListData, candidates: List[List[Any]]): + """ + Get one candidate. If multiple candidates exist, return the one with minimum cost. + + Parameters + ---------- + pred_pseudo_label : list + The pseudo label to be used for selecting a candidate. + pred_prob : list + Probabilities of the predictions. + candidates : list + List of candidate abduction result. + + Returns + ------- + list + The chosen candidate based on minimum cost. + If no candidates, an empty list is returned. + """ + if len(candidates) == 0: + return [] + elif len(candidates) == 1: + return candidates[0] + else: + cost_array = self._get_dist_list(data_sample, candidates) + candidate = candidates[np.argmin(cost_array)] + return candidate + + def _get_dist_list(self, data_sample: ListData, candidates: List[List[Any]]): + """ + Get the list of costs between each pseudo label and candidate. + + Parameters + ---------- + pred_pseudo_label : list + The pseudo label to be used for computing costs of candidates. + pred_prob : list + Probabilities of the predictions. Used when distance function is "confidence". + candidates : list + List of candidate abduction result. + + Returns + ------- + numpy.ndarray + Array of computed costs for each candidate. + """ + if self.dist_func == "hamming": + return hamming_dist(data_sample["pred_pseudo_label"][0], candidates) + + elif self.dist_func == "confidence": + candidates = [[self.remapping[x] for x in c] for c in candidates] + return confidence_dist(data_sample["pred_prob"][0], candidates) def batch_abduce( self, @@ -231,7 +261,7 @@ class ReasonerBase: The abduced revisions in batches. """ abduced_pseudo_label = [ - self.cache.get( + self.abduce( data_sample, max_revision=max_revision, require_more_revision=require_more_revision, diff --git a/abl/reasoning/search_engine/__init__.py b/abl/reasoning/search_engine/__init__.py index 45f5442..c7fae26 100644 --- a/abl/reasoning/search_engine/__init__.py +++ b/abl/reasoning/search_engine/__init__.py @@ -1,3 +1,2 @@ from .base_search_engine import BaseSearchEngine from .bfs import BFS -from .zoopt import Zoopt diff --git a/abl/reasoning/search_engine/bfs.py b/abl/reasoning/search_engine/bfs.py index 104470a..4596c99 100644 --- a/abl/reasoning/search_engine/bfs.py +++ b/abl/reasoning/search_engine/bfs.py @@ -12,7 +12,7 @@ class BFS(BaseSearchEngine): pass def generator( - data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 + self, data_sample: ListData, max_revision_num: int, require_more_revision: int = 0 ) -> Union[List, Tuple, numpy.ndarray]: symbol_num = data_sample["symbol_num"] max_revision_num = min(max_revision_num, symbol_num) diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 99cab22..1045fc5 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -51,7 +51,13 @@ "source": [ "# Initialize knowledge base and abducer\n", "kb = AddKB()\n", - "abducer = ReasonerBase(kb, dist_func=\"confidence\")" + "\n", + "# If use cache, get_key should be implemented in the abducer\n", + "class AddAbducer(ReasonerBase):\n", + " def get_key(self, data_sample):\n", + " return (data_sample.to_tuple(\"pred_pseudo_label\"), data_sample[\"Y\"][0])\n", + "\n", + "abducer = AddAbducer(kb, dist_func=\"confidence\", use_cache=True)" ] }, { diff --git a/examples/mnist_add/mnist_add_kb.py b/examples/mnist_add/mnist_add_kb.py index 8a29927..21808ca 100644 --- a/examples/mnist_add/mnist_add_kb.py +++ b/examples/mnist_add/mnist_add_kb.py @@ -5,14 +5,11 @@ from abl.structures import ListData class AddKB(SearchBasedKB): - def __init__(self, pseudo_label_list=list(range(10)), use_cache=True, cache_size=4096): + def __init__(self, pseudo_label_list=list(range(10))): super().__init__( - pseudo_label_list=pseudo_label_list, use_cache=use_cache, cache_size=cache_size + pseudo_label_list=pseudo_label_list ) - def get_key(self, data_sample: ListData): - return (data_sample.to_tuple("pred_pseudo_label"), data_sample["Y"][0]) - def check_equal(self, data_sample: ListData, y: Any): return self.logic_forward(data_sample) == y