@@ -1,32 +1,35 @@ | |||
from typing import Any, List, Tuple | |||
from typing import Any, List, Optional, Tuple | |||
from torch.utils.data import Dataset | |||
class BridgeDataset(Dataset): | |||
def __init__(self, X: List[Any], Z: List[Any], Y: List[Any]): | |||
"""Initialize a basic dataset. | |||
"""Dataset used in ``BaseBridge``. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
A list of objects representing the input data. | |||
Z : List[Any] | |||
A list of objects representing the symbol. | |||
Y : List[Any] | |||
A list of objects representing the label. | |||
""" | |||
Parameters | |||
---------- | |||
X : List[List[Any]] | |||
A list of objects representing the input data. | |||
gt_pseudo_label : List[List[Any]], optional | |||
A list of objects representing the ground truth label of each element in ``X``. | |||
Y : List[Any] | |||
A list of objects representing the ground truth of the reasoning result of each instance in ``X``. | |||
""" | |||
def __init__( | |||
self, X: List[List[Any]], gt_pseudo_label: Optional[List[List[Any]]], Y: List[Any] | |||
): | |||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
raise ValueError("X and Y should be of type list.") | |||
if len(X) != len(Y): | |||
raise ValueError("Length of X and Y must be equal.") | |||
self.X = X | |||
self.Z = Z | |||
self.gt_pseudo_label = gt_pseudo_label | |||
self.Y = Y | |||
if self.Z is None: | |||
self.Z = [None] * len(self.X) | |||
if self.gt_pseudo_label is None: | |||
self.gt_pseudo_label = [None] * len(self.X) | |||
def __len__(self): | |||
"""Return the length of the dataset. | |||
@@ -38,7 +41,7 @@ class BridgeDataset(Dataset): | |||
""" | |||
return len(self.X) | |||
def __getitem__(self, index: int) -> Tuple[Any, Any]: | |||
def __getitem__(self, index: int) -> Tuple[List, List, Any]: | |||
"""Get an item from the dataset. | |||
Parameters | |||
@@ -48,14 +51,14 @@ class BridgeDataset(Dataset): | |||
Returns | |||
------- | |||
Tuple[Any, Any] | |||
Tuple[List, List, Any] | |||
A tuple containing the input and output data at the specified index. | |||
""" | |||
if index >= len(self): | |||
raise ValueError("index range error") | |||
X = self.X[index] | |||
Z = self.Z[index] | |||
gt_pseudo_label = self.gt_pseudo_label[index] | |||
Y = self.Y[index] | |||
return (X, Z, Y) | |||
return (X, gt_pseudo_label, Y) |
@@ -5,21 +5,21 @@ from torch.utils.data import Dataset | |||
class ClassificationDataset(Dataset): | |||
""" | |||
Dataset used for classification task. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
The input data. | |||
Y : List[int] | |||
The target data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
""" | |||
def __init__( | |||
self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None | |||
): | |||
""" | |||
Initialize the dataset used for classification task. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
The input data. | |||
Y : List[int] | |||
The target data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
""" | |||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
raise ValueError("X and Y should be of type list.") | |||
if len(X) != len(Y): | |||
@@ -5,17 +5,17 @@ from torch.utils.data import Dataset | |||
class PredictionDataset(Dataset): | |||
""" | |||
Dataset used for prediction. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
The input data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
""" | |||
def __init__(self, X: List[Any], transform: Callable[..., Any] = None): | |||
""" | |||
Initialize the dataset used for classification task. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
The input data. | |||
transform : Callable[..., Any], optional | |||
A function/transform that takes in an object and returns a transformed version. Defaults to None. | |||
""" | |||
if not isinstance(X, list): | |||
raise ValueError("X should be of type list.") | |||
@@ -5,16 +5,17 @@ from torch.utils.data import Dataset | |||
class RegressionDataset(Dataset): | |||
""" | |||
Dataset used for regression task. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
A list of objects representing the input data. | |||
Y : List[Any] | |||
A list of objects representing the output data. | |||
""" | |||
def __init__(self, X: List[Any], Y: List[Any]): | |||
"""Initialize a basic dataset. | |||
Parameters | |||
---------- | |||
X : List[Any] | |||
A list of objects representing the input data. | |||
Y : List[Any] | |||
A list of objects representing the output data. | |||
""" | |||
if (not isinstance(X, list)) or (not isinstance(Y, list)): | |||
raise ValueError("X and Y should be of type list.") | |||
if len(X) != len(Y): | |||