Browse Source

[FIX] fix typo

pull/3/head
troyyyyy 1 year ago
parent
commit
2d83433388
1 changed files with 11 additions and 10 deletions
  1. +11
    -10
      abl/reasoning/reasoner.py

+ 11
- 10
abl/reasoning/reasoner.py View File

@@ -20,14 +20,13 @@ class ReasonerBase:
dist_func : str, optional
The distance function to be used when determining the cost list between each
candidate and the given prediction. Valid options include: `"hamming"` (default)
| `"confidence"`. Any other options will raise a `NotImplementedError`.
For detailed explanations of these options, refer to `_get_cost_list`.
mapping : dictt, optional
A mapping from label to index. If not provided, a default
order-based mapping is created.
| `"confidence"`. Any other options will raise a `NotImplementedError`. For
detailed explanations of these options, refer to `_get_cost_list`.
mapping : dict, optional
A mapping from label to index. If not provided, a default order-based mapping is
created.
use_zoopt : bool, optional
Whether to use the Zoopt library during abductive reasoning.
Default is False.
Whether to use the Zoopt library during abductive reasoning. Default to False.
"""
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False):
@@ -42,12 +41,14 @@ class ReasonerBase:
label: index for index, label in enumerate(self.kb.pseudo_label_list)
}
else:
if not isinstance(mapping, dict):
raise TypeError("mapping should be dict")
self.mapping = mapping

def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates):
"""
Due to the nondeterminism of abductive reasoning, there could be multiple candidates
satisfying the knowledge base. If this happens, return one candidate that has the
satisfying the knowledge base. When this happens, return one candidate that has the
minimum cost. If no candidates are provided, an empty list is returned.
Parameters
@@ -58,7 +59,7 @@ class ReasonerBase:
Predicted probabilities of the prediction (Each sublist contains the probability
values of all pseudo labels).
candidates : List[List[Any]]
Several candidate abduction results.
Multiple candidate abduction results.
"""
if len(candidates) == 0:
return []
@@ -86,7 +87,7 @@ class ReasonerBase:
Predicted probabilities of the prediction (Each sublist contains the probability
values of all pseudo labels). Used when distance function is "confidence".
candidates : List[List[Any]]
Several candidate abduction results.
Multiple candidate abduction results.
"""
if self.dist_func == "hamming":
return hamming_dist(pred_pseudo_label, candidates)


Loading…
Cancel
Save