Browse Source

[MNT] reformt docstring of in datasets

pull/1/head
Gao Enhao 1 year ago
parent
commit
e24e3f5a7a
4 changed files with 55 additions and 51 deletions
  1. +23
    -20
      abl/dataset/bridge_dataset.py
  2. +12
    -12
      abl/dataset/classification_dataset.py
  3. +10
    -10
      abl/dataset/prediction_dataset.py
  4. +10
    -9
      abl/dataset/regression_dataset.py

+ 23
- 20
abl/dataset/bridge_dataset.py View File

@@ -1,32 +1,35 @@
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple

from torch.utils.data import Dataset


class BridgeDataset(Dataset):
def __init__(self, X: List[Any], Z: List[Any], Y: List[Any]):
"""Initialize a basic dataset.
"""Dataset used in ``BaseBridge``.

Parameters
----------
X : List[Any]
A list of objects representing the input data.
Z : List[Any]
A list of objects representing the symbol.
Y : List[Any]
A list of objects representing the label.
"""
Parameters
----------
X : List[List[Any]]
A list of objects representing the input data.
gt_pseudo_label : List[List[Any]], optional
A list of objects representing the ground truth label of each element in ``X``.
Y : List[Any]
A list of objects representing the ground truth of the reasoning result of each instance in ``X``.
"""

def __init__(
self, X: List[List[Any]], gt_pseudo_label: Optional[List[List[Any]]], Y: List[Any]
):
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):
raise ValueError("Length of X and Y must be equal.")
self.X = X
self.Z = Z
self.gt_pseudo_label = gt_pseudo_label
self.Y = Y

if self.Z is None:
self.Z = [None] * len(self.X)
if self.gt_pseudo_label is None:
self.gt_pseudo_label = [None] * len(self.X)

def __len__(self):
"""Return the length of the dataset.
@@ -38,7 +41,7 @@ class BridgeDataset(Dataset):
"""
return len(self.X)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> Tuple[List, List, Any]:
"""Get an item from the dataset.

Parameters
@@ -48,14 +51,14 @@ class BridgeDataset(Dataset):

Returns
-------
Tuple[Any, Any]
Tuple[List, List, Any]
A tuple containing the input and output data at the specified index.
"""
if index >= len(self):
raise ValueError("index range error")

X = self.X[index]
Z = self.Z[index]
gt_pseudo_label = self.gt_pseudo_label[index]
Y = self.Y[index]

return (X, Z, Y)
return (X, gt_pseudo_label, Y)

+ 12
- 12
abl/dataset/classification_dataset.py View File

@@ -5,21 +5,21 @@ from torch.utils.data import Dataset


class ClassificationDataset(Dataset):
"""
Dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
Y : List[int]
The target data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
def __init__(
self, X: List[Any], Y: List[int], transform: Callable[..., Any] = None
):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
Y : List[int]
The target data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):


+ 10
- 10
abl/dataset/prediction_dataset.py View File

@@ -5,17 +5,17 @@ from torch.utils.data import Dataset


class PredictionDataset(Dataset):
"""
Dataset used for prediction.

Parameters
----------
X : List[Any]
The input data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
def __init__(self, X: List[Any], transform: Callable[..., Any] = None):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
if not isinstance(X, list):
raise ValueError("X should be of type list.")



+ 10
- 9
abl/dataset/regression_dataset.py View File

@@ -5,16 +5,17 @@ from torch.utils.data import Dataset


class RegressionDataset(Dataset):
"""
Dataset used for regression task.

Parameters
----------
X : List[Any]
A list of objects representing the input data.
Y : List[Any]
A list of objects representing the output data.
"""
def __init__(self, X: List[Any], Y: List[Any]):
"""Initialize a basic dataset.

Parameters
----------
X : List[Any]
A list of objects representing the input data.
Y : List[Any]
A list of objects representing the output data.
"""
if (not isinstance(X, list)) or (not isinstance(Y, list)):
raise ValueError("X and Y should be of type list.")
if len(X) != len(Y):


Loading…
Cancel
Save