|
|
@@ -20,14 +20,13 @@ class ReasonerBase: |
|
|
|
dist_func : str, optional |
|
|
|
The distance function to be used when determining the cost list between each |
|
|
|
candidate and the given prediction. Valid options include: `"hamming"` (default) |
|
|
|
| `"confidence"`. Any other options will raise a `NotImplementedError`. |
|
|
|
For detailed explanations of these options, refer to `_get_cost_list`. |
|
|
|
mapping : dictt, optional |
|
|
|
A mapping from label to index. If not provided, a default |
|
|
|
order-based mapping is created. |
|
|
|
| `"confidence"`. Any other options will raise a `NotImplementedError`. For |
|
|
|
detailed explanations of these options, refer to `_get_cost_list`. |
|
|
|
mapping : dict, optional |
|
|
|
A mapping from label to index. If not provided, a default order-based mapping is |
|
|
|
created. |
|
|
|
use_zoopt : bool, optional |
|
|
|
Whether to use the Zoopt library during abductive reasoning. |
|
|
|
Default is False. |
|
|
|
Whether to use the Zoopt library during abductive reasoning. Default to False. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): |
|
|
@@ -42,12 +41,14 @@ class ReasonerBase: |
|
|
|
label: index for index, label in enumerate(self.kb.pseudo_label_list) |
|
|
|
} |
|
|
|
else: |
|
|
|
if not isinstance(mapping, dict): |
|
|
|
raise TypeError("mapping should be dict") |
|
|
|
self.mapping = mapping |
|
|
|
|
|
|
|
def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): |
|
|
|
""" |
|
|
|
Due to the nondeterminism of abductive reasoning, there could be multiple candidates |
|
|
|
satisfying the knowledge base. If this happens, return one candidate that has the |
|
|
|
satisfying the knowledge base. When this happens, return one candidate that has the |
|
|
|
minimum cost. If no candidates are provided, an empty list is returned. |
|
|
|
|
|
|
|
Parameters |
|
|
@@ -58,7 +59,7 @@ class ReasonerBase: |
|
|
|
Predicted probabilities of the prediction (Each sublist contains the probability |
|
|
|
values of all pseudo labels). |
|
|
|
candidates : List[List[Any]] |
|
|
|
Several candidate abduction results. |
|
|
|
Multiple candidate abduction results. |
|
|
|
""" |
|
|
|
if len(candidates) == 0: |
|
|
|
return [] |
|
|
@@ -86,7 +87,7 @@ class ReasonerBase: |
|
|
|
Predicted probabilities of the prediction (Each sublist contains the probability |
|
|
|
values of all pseudo labels). Used when distance function is "confidence". |
|
|
|
candidates : List[List[Any]] |
|
|
|
Several candidate abduction results. |
|
|
|
Multiple candidate abduction results. |
|
|
|
""" |
|
|
|
if self.dist_func == "hamming": |
|
|
|
return hamming_dist(pred_pseudo_label, candidates) |
|
|
|