|
- from itertools import chain
-
- import numpy as np
-
-
- def flatten(nested_list):
- """
- Flattens a nested list.
-
- Parameters
- ----------
- nested_list : list
- A list which might contain sublists or tuples.
-
- Returns
- -------
- list
- A flattened version of the input list.
-
- Raises
- ------
- TypeError
- If the input object is not a list.
- """
- # if not isinstance(nested_list, list):
- # raise TypeError("Input must be of type list.")
-
- if not isinstance(nested_list, list) or not isinstance(nested_list[0], (list, tuple)):
- return nested_list
-
- return list(chain.from_iterable(nested_list))
-
-
- def reform_list(flattened_list, structured_list):
- """
- Reform the index based on structured_list structure.
-
- Parameters
- ----------
- flattened_list : list
- A flattened list of predictions.
- structured_list : list
- A list containing saved predictions, which could be nested lists or tuples.
-
- Returns
- -------
- list
- A reformed list that mimics the structure of structured_list.
- """
- # if not isinstance(flattened_list, list):
- # raise TypeError("Input must be of type list.")
-
- if not isinstance(structured_list[0], (list, tuple)):
- return flattened_list
-
- reformed_list = []
- idx_start = 0
- for elem in structured_list:
- idx_end = idx_start + len(elem)
- reformed_list.append(flattened_list[idx_start:idx_end])
- idx_start = idx_end
-
- return reformed_list
-
-
- def hamming_dist(pred_pseudo_label, candidates):
- """
- Compute the Hamming distance between two arrays.
-
- Parameters
- ----------
- pred_pseudo_label : list
- First array to compare.
- candidates : list
- Second array to compare, expected to have shape (n, m)
- where n is the number of rows, m is the length of pred_pseudo_label.
-
- Returns
- -------
- numpy.ndarray
- Hamming distances.
- """
- pred_pseudo_label = np.array(pred_pseudo_label)
- candidates = np.array(candidates)
-
- # Ensuring that pred_pseudo_label is broadcastable to the shape of candidates
- pred_pseudo_label = np.expand_dims(pred_pseudo_label, 0)
-
- return np.sum(pred_pseudo_label != candidates, axis=1)
-
-
- def confidence_dist(pred_prob, candidates):
- """
- Compute the confidence distance between prediction probabilities and candidates.
-
- Parameters
- ----------
- pred_prob : list of numpy.ndarray
- Prediction probability distributions, each element is an ndarray
- representing the probability distribution of a particular prediction.
- candidates : list of list of int
- Index of candidate labels, each element is a list of indexes being considered
- as a candidate correction.
-
- Returns
- -------
- numpy.ndarray
- Confidence distances computed for each candidate.
- """
- pred_prob = np.clip(pred_prob, 1e-9, 1)
- _, cols = np.indices((len(candidates), len(candidates[0])))
- return 1 - np.prod(pred_prob[cols, candidates], axis=1)
-
-
- def block_sample(X, Z, Y, sample_num, seg_idx):
- """
- Extract a block of samples from lists X, Z, and Y.
-
- Parameters
- ----------
- X, Z, Y : list
- Input lists from which to extract the samples.
- sample_num : int
- The number of samples per block.
- seg_idx : int
- The block index to extract.
-
- Returns
- -------
- tuple of lists
- The extracted block samples from X, Z, and Y.
-
- Example
- -------
- >>> X = [1, 2, 3, 4, 5, 6]
- >>> Z = ['a', 'b', 'c', 'd', 'e', 'f']
- >>> Y = [10, 11, 12, 13, 14, 15]
- >>> block_sample(X, Z, Y, 2, 1)
- ([3, 4], ['c', 'd'], [12, 13])
- """
- start_idx = sample_num * seg_idx
- end_idx = sample_num * (seg_idx + 1)
-
- return (data[start_idx:end_idx] for data in (X, Z, Y))
-
-
- def to_hashable(x):
- """
- Convert a nested list to a nested tuple so it is hashable.
-
- Parameters
- ----------
- x : list or other type
- A potentially nested list to convert to a tuple.
-
- Returns
- -------
- tuple or other type
- The input converted to a tuple if it was a list,
- otherwise the original input.
- """
- if isinstance(x, list):
- return tuple(to_hashable(item) for item in x)
- return x
-
-
- def restore_from_hashable(x):
- """
- Convert a nested tuple back to a nested list.
-
- Parameters
- ----------
- x : tuple or other type
- A potentially nested tuple to convert to a list.
-
- Returns
- -------
- list or other type
- The input converted to a list if it was a tuple,
- otherwise the original input.
- """
- if isinstance(x, tuple):
- return [restore_from_hashable(item) for item in x]
- return x
-
-
- def calculate_revision_num(parameter, total_length):
- """
- Convert a float parameter to an integer, based on a total length.
-
- Parameters
- ----------
- parameter : int or float
- The parameter to convert. If float, it should be between 0 and 1.
- If int, it should be non-negative. If -1, it will be replaced with total_length.
- total_length : int
- The total length to calculate the parameter from if it's a fraction.
-
- Returns
- -------
- int
- The calculated parameter.
-
- Raises
- ------
- TypeError
- If parameter is not an int or a float.
- ValueError
- If parameter is a float not in [0, 1] or an int below 0.
- """
- if not isinstance(parameter, (int, float)):
- raise TypeError("Parameter must be of type int or float.")
-
- if parameter == -1:
- return total_length
- elif isinstance(parameter, float):
- if not (0 <= parameter <= 1):
- raise ValueError("If parameter is a float, it must be between 0 and 1.")
- return round(total_length * parameter)
- else:
- if parameter < 0:
- raise ValueError("If parameter is an int, it must be non-negative.")
- return parameter
|