Browse Source

add docstring for BasicModel, BasicDataset, XYDataset

pull/3/head
Gao Enhao 2 years ago
parent
commit
a424b44e91
1 changed files with 339 additions and 30 deletions
  1. +339
    -30
      abl/models/basic_model.py

+ 339
- 30
abl/models/basic_model.py View File

@@ -15,21 +15,51 @@ import sys
sys.path.append("..")

import torch
from torch.utils.data import Dataset
import numpy
from torch.utils.data import Dataset, DataLoader

import os
from multiprocessing import Pool
from typing import List, Any, T, Tuple, Optional, Callable


class BasicDataset(Dataset):
def __init__(self, X, Y):
def __init__(self, X: List[Any], Y: List[Any]):
"""Initialize a basic dataset.

Parameters
----------
X : List[Any]
A list of objects representing the input data.
Y : List[Any]
A list of objects representing the output data.
"""
self.X = X
self.Y = Y

def __len__(self):
"""Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple(Any, Any):
"""Get an item from the dataset.

Parameters
----------
index : int
The index of the item to retrieve.

Returns
-------
Tuple(Any, Any)
A tuple containing the input and output data at the specified index.
"""
assert index < len(self), "index range error"

img = self.X[index]
@@ -39,17 +69,50 @@ class BasicDataset(Dataset):


class XYDataset(Dataset):
def __init__(self, X, Y, transform=None):
def __init__(self, X: List[Any], Y: List[int], transform: Callable[...] = None):
"""
Initialize the dataset used for classification task.

Parameters
----------
X : List[Any]
The input data.
Y : List[int]
The target data.
transform : callable, optional
A function/transform that takes in an object and returns a transformed version. Defaults to None.
"""
self.X = X
self.Y = torch.LongTensor(Y)

self.n_sample = len(X)
self.transform = transform

def __len__(self):
def __len__(self) -> int:
"""
Return the length of the dataset.

Returns
-------
int
The length of the dataset.
"""
return len(self.X)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]:
"""
Get the item at the given index.

Parameters
----------
index : int
The index of the item to get.

Returns
-------
Tuple[Any, torch.Tensor]
A tuple containing the object and its label.
"""
assert index < len(self), "index range error"

img = self.X[index]
@@ -70,20 +133,102 @@ class FakeRecorder:


class BasicModel:
"""
Wrap NN models into the form of an sklearn estimator

Parameters
----------
model : torch.nn.Module
The PyTorch model to be trained or used for prediction.
criterion : torch.nn.Module
The loss function used for training.
optimizer : torch.nn.Module
The optimizer used for training.
device : torch.device
The device on which the model will be trained or used for prediction.
batch_size : int, optional
The batch size used for training, by default 1.
num_epochs : int, optional
The number of epochs used for training, by default 1.
stop_loss : Optional[float], optional
The loss value at which to stop training, by default 0.01.
num_workers : int, optional
The number of workers used for loading data, by default 0.
save_interval : Optional[int], optional
The interval at which to save the model during training, by default None.
save_dir : Optional[str], optional
The directory in which to save the model during training, by default None.
transform : Callable[..., Any], optional
The transformation function used for data augmentation, by default None.
collate_fn : Callable[[List[T]], Any], optional
The function used to collate data, by default None.
recorder : Any, optional
The recorder used to record training progress, by default None.

Attributes
----------
model : torch.nn.Module
The PyTorch model to be trained or used for prediction.
batch_size : int
The batch size used for training.
num_epochs : int
The number of epochs used for training.
stop_loss : Optional[float]
The loss value at which to stop training.
num_workers : int
The number of workers used for loading data.
criterion : torch.nn.Module
The loss function used for training.
optimizer : torch.nn.Module
The optimizer used for training.
transform : Callable[..., Any]
The transformation function used for data augmentation.
device : torch.device
The device on which the model will be trained or used for prediction.
recorder : Any
The recorder used to record training progress.
save_interval : Optional[int]
The interval at which to save the model during training.
save_dir : Optional[str]
The directory in which to save the model during training.
collate_fn : Callable[[List[T]], Any]
The function used to collate data.

Methods
-------
fit(data_loader=None, X=None, y=None)
Train the model.
train_epoch(data_loader)
Train the model for one epoch.
predict(data_loader=None, X=None, print_prefix="")
Predict the class of the input data.
predict_proba(data_loader=None, X=None, print_prefix="")
Predict the probability of each class for the input data.
val(data_loader=None, X=None, y=None, print_prefix="")
Validate the model.
score(data_loader=None, X=None, y=None, print_prefix="")
Score the model.
_data_loader(X, y=None)
Load data.
save(epoch_id, save_dir="")
Save the model.
load(epoch_id, load_dir="")
Load the model.
"""
def __init__(
self,
model,
criterion,
optimizer,
device,
batch_size=1,
num_epochs=1,
stop_loss=0.01,
num_workers=0,
save_interval=None,
save_dir=None,
transform=None,
collate_fn=None,
model: torch.nn.Module,
criterion: torch.nn.Module,
optimizer: torch.nn.Module,
device: torch.device,
batch_size: int = 1,
num_epochs: int = 1,
stop_loss: Optional[float] = 0.01,
num_workers: int = 0,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
transform: Callable[...] = None,
collate_fn: Callable[[List[T]], Any] = None,
recorder=None,
):

@@ -106,7 +251,6 @@ class BasicModel:
self.save_interval = save_interval
self.save_dir = save_dir
self.collate_fn = collate_fn
pass

def _fit(self, data_loader, n_epoch, stop_loss):
recorder = self.recorder
@@ -126,12 +270,44 @@ class BasicModel:
recorder.print("Model fitted, minimal loss is ", min_loss)
return loss_value

def fit(self, data_loader=None, X=None, y=None):
def fit(
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None
) -> float:
"""
Train the model.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for training, by default None
X : List[Any], optional
The input data, by default None
y : List[int], optional
The target data, by default None

Returns
-------
float
The loss value of the trained model.
"""
if data_loader is None:
data_loader = self._data_loader(X, y)
return self._fit(data_loader, self.num_epochs, self.stop_loss)

def train_epoch(self, data_loader):
def train_epoch(self, data_loader: DataLoader):
"""
Train the model for one epoch.

Parameters
----------
data_loader : DataLoader
The data loader used for training.

Returns
-------
float
The loss value of the trained model.
"""
model = self.model
criterion = self.criterion
optimizer = self.optimizer
@@ -169,7 +345,29 @@ class BasicModel:

return torch.cat(results, axis=0)

def predict(self, data_loader=None, X=None, print_prefix=""):
def predict(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
) -> numpy.ndarray:
"""
Predict the class of the input data.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for prediction, by default None
X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
numpy.ndarray
The predicted class of the input data.
"""
recorder = self.recorder
recorder.print("Start Predict Class ", print_prefix)

@@ -177,9 +375,31 @@ class BasicModel:
data_loader = self._data_loader(X)
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

def predict_proba(self, data_loader=None, X=None, print_prefix=""):
def predict_proba(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
print_prefix: str = "",
) -> numpy.ndarray:
"""
Predict the probability of each class for the input data.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for prediction, by default None
X : List[Any], optional
The input data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
numpy.ndarray
The predicted probability of each class for the input data.
"""
recorder = self.recorder
# recorder.print('Start Predict Probability ', print_prefix)
recorder.print("Start Predict Probability ", print_prefix)

if data_loader is None:
data_loader = self._data_loader(X)
@@ -215,7 +435,32 @@ class BasicModel:

return mean_loss, accuracy

def val(self, data_loader=None, X=None, y=None, print_prefix=""):
def val(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
y: List[int] = None,
print_prefix: str = "",
) -> float:
"""
Validate the model.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for validation, by default None
X : List[Any], optional
The input data, by default None
y : List[int], optional
The target data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
float
The accuracy of the model.
"""
recorder = self.recorder
recorder.print("Start val ", print_prefix)

@@ -227,10 +472,54 @@ class BasicModel:
)
return accuracy

def score(self, data_loader=None, X=None, y=None, print_prefix=""):
def score(
self,
data_loader: DataLoader = None,
X: List[Any] = None,
y: List[int] = None,
print_prefix: str = "",
) -> float:
"""
Score the model.

Parameters
----------
data_loader : DataLoader, optional
The data loader used for scoring, by default None
X : List[Any], optional
The input data, by default None
y : List[int], optional
The target data, by default None
print_prefix : str, optional
The prefix used for printing, by default ""

Returns
-------
float
The accuracy of the model.
"""
return self.val(data_loader, X, y, print_prefix)

def _data_loader(self, X, y=None):
def _data_loader(
self,
X: List[Any],
y: List[int] = None,
) -> DataLoader:
"""
Generate data_loader for user provided data.

Parameters
----------
X : List[Any]
The input data.
y : List[int], optional
The target data, by default None

Returns
-------
DataLoader
The data loader.
"""
collate_fn = self.collate_fn
transform = self.transform

@@ -238,7 +527,7 @@ class BasicModel:
y = [0] * len(X)
dataset = XYDataset(X, y, transform=transform)
sampler = None
data_loader = torch.utils.data.DataLoader(
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
@@ -248,7 +537,17 @@ class BasicModel:
)
return data_loader

def save(self, epoch_id, save_dir):
def save(self, epoch_id: int, save_dir: str = ""):
"""
Save the model and the optimizer.

Parameters
----------
epoch_id : int
The epoch id.
save_dir : str, optional
The directory to save the model, by default ""
"""
recorder = self.recorder
if not os.path.exists(save_dir):
os.makedirs(save_dir)
@@ -259,7 +558,17 @@ class BasicModel:
save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth")
torch.save(self.optimizer.state_dict(), save_path)

def load(self, epoch_id, load_dir):
def load(self, epoch_id: int, load_dir: str = ""):
"""
Load the model and the optimizer.

Parameters
----------
epoch_id : int
The epoch id.
load_dir : str, optional
The directory to load the model, by default ""
"""
recorder = self.recorder
recorder.print("Loading model and opter")
load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth")


Loading…
Cancel
Save