|
|
@@ -15,21 +15,51 @@ import sys |
|
|
|
sys.path.append("..") |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch.utils.data import Dataset |
|
|
|
import numpy |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
import os |
|
|
|
from multiprocessing import Pool |
|
|
|
from typing import List, Any, T, Tuple, Optional, Callable |
|
|
|
|
|
|
|
|
|
|
|
class BasicDataset(Dataset): |
|
|
|
def __init__(self, X, Y): |
|
|
|
def __init__(self, X: List[Any], Y: List[Any]): |
|
|
|
"""Initialize a basic dataset. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : List[Any] |
|
|
|
A list of objects representing the input data. |
|
|
|
Y : List[Any] |
|
|
|
A list of objects representing the output data. |
|
|
|
""" |
|
|
|
self.X = X |
|
|
|
self.Y = Y |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
"""Return the length of the dataset. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
int |
|
|
|
The length of the dataset. |
|
|
|
""" |
|
|
|
return len(self.X) |
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
def __getitem__(self, index: int) -> Tuple(Any, Any): |
|
|
|
"""Get an item from the dataset. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
index : int |
|
|
|
The index of the item to retrieve. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Tuple(Any, Any) |
|
|
|
A tuple containing the input and output data at the specified index. |
|
|
|
""" |
|
|
|
assert index < len(self), "index range error" |
|
|
|
|
|
|
|
img = self.X[index] |
|
|
@@ -39,17 +69,50 @@ class BasicDataset(Dataset): |
|
|
|
|
|
|
|
|
|
|
|
class XYDataset(Dataset): |
|
|
|
def __init__(self, X, Y, transform=None): |
|
|
|
def __init__(self, X: List[Any], Y: List[int], transform: Callable[...] = None): |
|
|
|
""" |
|
|
|
Initialize the dataset used for classification task. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : List[Any] |
|
|
|
The input data. |
|
|
|
Y : List[int] |
|
|
|
The target data. |
|
|
|
transform : callable, optional |
|
|
|
A function/transform that takes in an object and returns a transformed version. Defaults to None. |
|
|
|
""" |
|
|
|
self.X = X |
|
|
|
self.Y = torch.LongTensor(Y) |
|
|
|
|
|
|
|
self.n_sample = len(X) |
|
|
|
self.transform = transform |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
def __len__(self) -> int: |
|
|
|
""" |
|
|
|
Return the length of the dataset. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
int |
|
|
|
The length of the dataset. |
|
|
|
""" |
|
|
|
return len(self.X) |
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, torch.Tensor]: |
|
|
|
""" |
|
|
|
Get the item at the given index. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
index : int |
|
|
|
The index of the item to get. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
Tuple[Any, torch.Tensor] |
|
|
|
A tuple containing the object and its label. |
|
|
|
""" |
|
|
|
assert index < len(self), "index range error" |
|
|
|
|
|
|
|
img = self.X[index] |
|
|
@@ -70,20 +133,102 @@ class FakeRecorder: |
|
|
|
|
|
|
|
|
|
|
|
class BasicModel: |
|
|
|
""" |
|
|
|
Wrap NN models into the form of an sklearn estimator |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
model : torch.nn.Module |
|
|
|
The PyTorch model to be trained or used for prediction. |
|
|
|
criterion : torch.nn.Module |
|
|
|
The loss function used for training. |
|
|
|
optimizer : torch.nn.Module |
|
|
|
The optimizer used for training. |
|
|
|
device : torch.device |
|
|
|
The device on which the model will be trained or used for prediction. |
|
|
|
batch_size : int, optional |
|
|
|
The batch size used for training, by default 1. |
|
|
|
num_epochs : int, optional |
|
|
|
The number of epochs used for training, by default 1. |
|
|
|
stop_loss : Optional[float], optional |
|
|
|
The loss value at which to stop training, by default 0.01. |
|
|
|
num_workers : int, optional |
|
|
|
The number of workers used for loading data, by default 0. |
|
|
|
save_interval : Optional[int], optional |
|
|
|
The interval at which to save the model during training, by default None. |
|
|
|
save_dir : Optional[str], optional |
|
|
|
The directory in which to save the model during training, by default None. |
|
|
|
transform : Callable[..., Any], optional |
|
|
|
The transformation function used for data augmentation, by default None. |
|
|
|
collate_fn : Callable[[List[T]], Any], optional |
|
|
|
The function used to collate data, by default None. |
|
|
|
recorder : Any, optional |
|
|
|
The recorder used to record training progress, by default None. |
|
|
|
|
|
|
|
Attributes |
|
|
|
---------- |
|
|
|
model : torch.nn.Module |
|
|
|
The PyTorch model to be trained or used for prediction. |
|
|
|
batch_size : int |
|
|
|
The batch size used for training. |
|
|
|
num_epochs : int |
|
|
|
The number of epochs used for training. |
|
|
|
stop_loss : Optional[float] |
|
|
|
The loss value at which to stop training. |
|
|
|
num_workers : int |
|
|
|
The number of workers used for loading data. |
|
|
|
criterion : torch.nn.Module |
|
|
|
The loss function used for training. |
|
|
|
optimizer : torch.nn.Module |
|
|
|
The optimizer used for training. |
|
|
|
transform : Callable[..., Any] |
|
|
|
The transformation function used for data augmentation. |
|
|
|
device : torch.device |
|
|
|
The device on which the model will be trained or used for prediction. |
|
|
|
recorder : Any |
|
|
|
The recorder used to record training progress. |
|
|
|
save_interval : Optional[int] |
|
|
|
The interval at which to save the model during training. |
|
|
|
save_dir : Optional[str] |
|
|
|
The directory in which to save the model during training. |
|
|
|
collate_fn : Callable[[List[T]], Any] |
|
|
|
The function used to collate data. |
|
|
|
|
|
|
|
Methods |
|
|
|
------- |
|
|
|
fit(data_loader=None, X=None, y=None) |
|
|
|
Train the model. |
|
|
|
train_epoch(data_loader) |
|
|
|
Train the model for one epoch. |
|
|
|
predict(data_loader=None, X=None, print_prefix="") |
|
|
|
Predict the class of the input data. |
|
|
|
predict_proba(data_loader=None, X=None, print_prefix="") |
|
|
|
Predict the probability of each class for the input data. |
|
|
|
val(data_loader=None, X=None, y=None, print_prefix="") |
|
|
|
Validate the model. |
|
|
|
score(data_loader=None, X=None, y=None, print_prefix="") |
|
|
|
Score the model. |
|
|
|
_data_loader(X, y=None) |
|
|
|
Load data. |
|
|
|
save(epoch_id, save_dir="") |
|
|
|
Save the model. |
|
|
|
load(epoch_id, load_dir="") |
|
|
|
Load the model. |
|
|
|
""" |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
model, |
|
|
|
criterion, |
|
|
|
optimizer, |
|
|
|
device, |
|
|
|
batch_size=1, |
|
|
|
num_epochs=1, |
|
|
|
stop_loss=0.01, |
|
|
|
num_workers=0, |
|
|
|
save_interval=None, |
|
|
|
save_dir=None, |
|
|
|
transform=None, |
|
|
|
collate_fn=None, |
|
|
|
model: torch.nn.Module, |
|
|
|
criterion: torch.nn.Module, |
|
|
|
optimizer: torch.nn.Module, |
|
|
|
device: torch.device, |
|
|
|
batch_size: int = 1, |
|
|
|
num_epochs: int = 1, |
|
|
|
stop_loss: Optional[float] = 0.01, |
|
|
|
num_workers: int = 0, |
|
|
|
save_interval: Optional[int] = None, |
|
|
|
save_dir: Optional[str] = None, |
|
|
|
transform: Callable[...] = None, |
|
|
|
collate_fn: Callable[[List[T]], Any] = None, |
|
|
|
recorder=None, |
|
|
|
): |
|
|
|
|
|
|
@@ -106,7 +251,6 @@ class BasicModel: |
|
|
|
self.save_interval = save_interval |
|
|
|
self.save_dir = save_dir |
|
|
|
self.collate_fn = collate_fn |
|
|
|
pass |
|
|
|
|
|
|
|
def _fit(self, data_loader, n_epoch, stop_loss): |
|
|
|
recorder = self.recorder |
|
|
@@ -126,12 +270,44 @@ class BasicModel: |
|
|
|
recorder.print("Model fitted, minimal loss is ", min_loss) |
|
|
|
return loss_value |
|
|
|
|
|
|
|
def fit(self, data_loader=None, X=None, y=None): |
|
|
|
def fit( |
|
|
|
self, data_loader: DataLoader = None, X: List[Any] = None, y: List[int] = None |
|
|
|
) -> float: |
|
|
|
""" |
|
|
|
Train the model. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for training, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The loss value of the trained model. |
|
|
|
""" |
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X, y) |
|
|
|
return self._fit(data_loader, self.num_epochs, self.stop_loss) |
|
|
|
|
|
|
|
def train_epoch(self, data_loader): |
|
|
|
def train_epoch(self, data_loader: DataLoader): |
|
|
|
""" |
|
|
|
Train the model for one epoch. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader |
|
|
|
The data loader used for training. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The loss value of the trained model. |
|
|
|
""" |
|
|
|
model = self.model |
|
|
|
criterion = self.criterion |
|
|
|
optimizer = self.optimizer |
|
|
@@ -169,7 +345,29 @@ class BasicModel: |
|
|
|
|
|
|
|
return torch.cat(results, axis=0) |
|
|
|
|
|
|
|
def predict(self, data_loader=None, X=None, print_prefix=""): |
|
|
|
def predict( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
) -> numpy.ndarray: |
|
|
|
""" |
|
|
|
Predict the class of the input data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for prediction, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
numpy.ndarray |
|
|
|
The predicted class of the input data. |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("Start Predict Class ", print_prefix) |
|
|
|
|
|
|
@@ -177,9 +375,31 @@ class BasicModel: |
|
|
|
data_loader = self._data_loader(X) |
|
|
|
return self._predict(data_loader).argmax(axis=1).cpu().numpy() |
|
|
|
|
|
|
|
def predict_proba(self, data_loader=None, X=None, print_prefix=""): |
|
|
|
def predict_proba( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
) -> numpy.ndarray: |
|
|
|
""" |
|
|
|
Predict the probability of each class for the input data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for prediction, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
numpy.ndarray |
|
|
|
The predicted probability of each class for the input data. |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
# recorder.print('Start Predict Probability ', print_prefix) |
|
|
|
recorder.print("Start Predict Probability ", print_prefix) |
|
|
|
|
|
|
|
if data_loader is None: |
|
|
|
data_loader = self._data_loader(X) |
|
|
@@ -215,7 +435,32 @@ class BasicModel: |
|
|
|
|
|
|
|
return mean_loss, accuracy |
|
|
|
|
|
|
|
def val(self, data_loader=None, X=None, y=None, print_prefix=""): |
|
|
|
def val( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
y: List[int] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
) -> float: |
|
|
|
""" |
|
|
|
Validate the model. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for validation, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The accuracy of the model. |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("Start val ", print_prefix) |
|
|
|
|
|
|
@@ -227,10 +472,54 @@ class BasicModel: |
|
|
|
) |
|
|
|
return accuracy |
|
|
|
|
|
|
|
def score(self, data_loader=None, X=None, y=None, print_prefix=""): |
|
|
|
def score( |
|
|
|
self, |
|
|
|
data_loader: DataLoader = None, |
|
|
|
X: List[Any] = None, |
|
|
|
y: List[int] = None, |
|
|
|
print_prefix: str = "", |
|
|
|
) -> float: |
|
|
|
""" |
|
|
|
Score the model. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
data_loader : DataLoader, optional |
|
|
|
The data loader used for scoring, by default None |
|
|
|
X : List[Any], optional |
|
|
|
The input data, by default None |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
print_prefix : str, optional |
|
|
|
The prefix used for printing, by default "" |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
float |
|
|
|
The accuracy of the model. |
|
|
|
""" |
|
|
|
return self.val(data_loader, X, y, print_prefix) |
|
|
|
|
|
|
|
def _data_loader(self, X, y=None): |
|
|
|
def _data_loader( |
|
|
|
self, |
|
|
|
X: List[Any], |
|
|
|
y: List[int] = None, |
|
|
|
) -> DataLoader: |
|
|
|
""" |
|
|
|
Generate data_loader for user provided data. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : List[Any] |
|
|
|
The input data. |
|
|
|
y : List[int], optional |
|
|
|
The target data, by default None |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
DataLoader |
|
|
|
The data loader. |
|
|
|
""" |
|
|
|
collate_fn = self.collate_fn |
|
|
|
transform = self.transform |
|
|
|
|
|
|
@@ -238,7 +527,7 @@ class BasicModel: |
|
|
|
y = [0] * len(X) |
|
|
|
dataset = XYDataset(X, y, transform=transform) |
|
|
|
sampler = None |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
|
|
data_loader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_size=self.batch_size, |
|
|
|
shuffle=False, |
|
|
@@ -248,7 +537,17 @@ class BasicModel: |
|
|
|
) |
|
|
|
return data_loader |
|
|
|
|
|
|
|
def save(self, epoch_id, save_dir): |
|
|
|
def save(self, epoch_id: int, save_dir: str = ""): |
|
|
|
""" |
|
|
|
Save the model and the optimizer. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
epoch_id : int |
|
|
|
The epoch id. |
|
|
|
save_dir : str, optional |
|
|
|
The directory to save the model, by default "" |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
if not os.path.exists(save_dir): |
|
|
|
os.makedirs(save_dir) |
|
|
@@ -259,7 +558,17 @@ class BasicModel: |
|
|
|
save_path = os.path.join(save_dir, str(epoch_id) + "_opt.pth") |
|
|
|
torch.save(self.optimizer.state_dict(), save_path) |
|
|
|
|
|
|
|
def load(self, epoch_id, load_dir): |
|
|
|
def load(self, epoch_id: int, load_dir: str = ""): |
|
|
|
""" |
|
|
|
Load the model and the optimizer. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
epoch_id : int |
|
|
|
The epoch id. |
|
|
|
load_dir : str, optional |
|
|
|
The directory to load the model, by default "" |
|
|
|
""" |
|
|
|
recorder = self.recorder |
|
|
|
recorder.print("Loading model and opter") |
|
|
|
load_path = os.path.join(load_dir, str(epoch_id) + "_net.pth") |
|
|
|