Browse Source

[MNT] modify utils docstring

pull/1/head
troyyyyy 1 year ago
parent
commit
467bc25319
3 changed files with 54 additions and 58 deletions
  1. +2
    -1
      .gitignore
  2. +4
    -4
      abl/reasoning/reasoner.py
  3. +48
    -53
      abl/utils/utils.py

+ 2
- 1
.gitignore View File

@@ -10,4 +10,5 @@ abl.egg-info/
examples/**/*.jpg
.idea/
build/
docs/API/generated/
docs/API/generated/
.history

+ 4
- 4
abl/reasoning/reasoner.py View File

@@ -118,7 +118,7 @@ class Reasoner:
data_example : ListData
Data example.
candidates : List[List[Any]]
Multiple compatible candidates.
Multiple possible candidates.
reasoning_results : List[Any]
Corresponding reasoning results of the candidates.

@@ -150,7 +150,7 @@ class Reasoner:
data_example : ListData
Data example.
candidates : List[List[Any]]
Multiple compatible candidates.
Multiple possible candidates.
reasoning_results : List[Any]
Corresponding reasoning results of the candidates.

@@ -162,8 +162,8 @@ class Reasoner:
if self.dist_func == "hamming":
return hamming_dist(data_example.pred_pseudo_label, candidates)
elif self.dist_func == "confidence":
candidates = [[self.label_to_idx[x] for x in c] for c in candidates]
return confidence_dist(data_example.pred_prob, candidates)
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
return confidence_dist(data_example.pred_prob, candidates_idxs)
else:
candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)


+ 48
- 53
abl/utils/utils.py View File

@@ -1,58 +1,54 @@
from itertools import chain
from typing import List, Any, Union, Tuple

import numpy as np


def flatten(nested_list):
def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[Any]:
"""
Flattens a nested list.
Flattens a nested list at the first level.

Parameters
----------
nested_list : list
A list which might contain sublists or tuples.
nested_list : List[Union[Any, List[Any], Tuple[Any, ...]]]
A list which might contain sublists or tuples at the first level.

Returns
-------
list
A flattened version of the input list.

Raises
------
TypeError
If the input object is not a list.
List[Any]
A flattened version of the input list, where only the first
level of sublists and tuples are reduced.
"""
# if not isinstance(nested_list, list):
# raise TypeError("Input must be of type list.")

if isinstance(nested_list, list) and len(nested_list) == 0:
if not isinstance(nested_list, list):
return nested_list

if not isinstance(nested_list, list) or not isinstance(nested_list[0], (list, tuple)):
return nested_list
flattened_list = []
for item in nested_list:
if isinstance(item, (list, tuple)):
flattened_list.extend(item)
else:
flattened_list.append(item)

return list(chain.from_iterable(nested_list))
return flattened_list


def reform_list(flattened_list, structured_list):
def reform_list(
flattened_list: List[Any],
structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]]
) -> List[List[Any]]:
"""
Reform the index based on structured_list structure.
Reform the list based on the structure of ``structured_list``.

Parameters
----------
flattened_list : list
A flattened list of predictions.
structured_list : list
A list containing saved predictions, which could be nested lists or tuples.
flattened_list : List[Any]
A flattened list of elements.
structured_list : List[Union[Any, List[Any], Tuple[Any, ...]]]
A list that reflects the desired structure, which may contain sublists or tuples.

Returns
-------
list
A reformed list that mimics the structure of structured_list.
List[List[Any]]
A reformed list that mimics the structure of ``structured_list``.
"""
# if not isinstance(flattened_list, list):
# raise TypeError("Input must be of type list.")

if not isinstance(structured_list[0], (list, tuple)):
return flattened_list

@@ -72,16 +68,15 @@ def hamming_dist(pred_pseudo_label, candidates):

Parameters
----------
pred_pseudo_label : list
First array to compare.
candidates : list
Second array to compare, expected to have shape (n, m)
where n is the number of rows, m is the length of pred_pseudo_label.
pred_pseudo_label : List[Any]
Pseudo-labels of an example.
candidates : List[List[Any]]
Multiple possible candidates.

Returns
-------
numpy.ndarray
Hamming distances.
np.ndarray
Hamming distances computed for each candidate.
"""
pred_pseudo_label = np.array(pred_pseudo_label)
candidates = np.array(candidates)
@@ -92,27 +87,26 @@ def hamming_dist(pred_pseudo_label, candidates):
return np.sum(pred_pseudo_label != candidates, axis=1)


def confidence_dist(pred_prob, candidates):
def confidence_dist(pred_prob, candidates_idxs):
"""
Compute the confidence distance between prediction probabilities and candidates.

Parameters
----------
pred_prob : list of numpy.ndarray
pred_prob : List[np.ndarray]
Prediction probability distributions, each element is an ndarray
representing the probability distribution of a particular prediction.
candidates : list of list of int
Index of candidate labels, each element is a list of indexes being considered
as a candidate correction.
candidates_idxs : List[List[Any]]
Multiple possible candidates' indices.

Returns
-------
numpy.ndarray
np.ndarray
Confidence distances computed for each candidate.
"""
pred_prob = np.clip(pred_prob, 1e-9, 1)
_, cols = np.indices((len(candidates), len(candidates[0])))
return 1 - np.prod(pred_prob[cols, candidates], axis=1)
_, cols = np.indices((len(candidates_idxs), len(candidates_idxs[0])))
return 1 - np.prod(pred_prob[cols, candidates_idxs], axis=1)


def to_hashable(x):
@@ -121,12 +115,12 @@ def to_hashable(x):

Parameters
----------
x : list or other type
x : Union[List[Any], Any]
A potentially nested list to convert to a tuple.

Returns
-------
tuple or other type
Union[Tuple[Any, ...], Any]
The input converted to a tuple if it was a list,
otherwise the original input.
"""
@@ -141,12 +135,12 @@ def restore_from_hashable(x):

Parameters
----------
x : tuple or other type
x : Union[Tuple[Any, ...], Any]
A potentially nested tuple to convert to a list.

Returns
-------
list or other type
Union[List[Any], Any]
The input converted to a list if it was a tuple,
otherwise the original input.
"""
@@ -156,14 +150,15 @@ def restore_from_hashable(x):

def tab_data_to_tuple(X, y, reasoning_result = 0):
'''
Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result.
Convert a tabular data to a tuple by adding a dimension to each element of
X and y. The tuple contains three elements: data, label, and reasoning result.
If X is None, return None.
Parameters
----------
X : list or other type
X : Union[List[Any], Any]
The data.
y : list or other type
y : Union[List[Any], Any]
The label.
reasoning_result : Any, optional
The reasoning result, by default 0.


Loading…
Cancel
Save