Browse Source

add detailed test for ABLModel and BasicNN

pull/1/head
Gao Enhao 1 year ago
parent
commit
f94335c3c9
2 changed files with 127 additions and 36 deletions
  1. +35
    -3
      tests/test_abl_model.py
  2. +92
    -33
      tests/test_basic_nn.py

+ 35
- 3
tests/test_abl_model.py View File

@@ -1,12 +1,34 @@
import numpy as np
import pytest

from abl.learning import ABLModel
from unittest.mock import Mock, create_autospec


class TestABLModel(object):
def test_ablmodel_initialization(self, base_model_instance):
"""Test the initialization of the ABLModel class."""
model = ABLModel(base_model_instance)
def test_ablmodel_initialization(self):
"""Test the initialization method of the ABLModel class."""

invalid_base_model = Mock(spec=[])

with pytest.raises(NotImplementedError):
ABLModel(invalid_base_model)

fit = Mock(return_value=1.0)
predict = Mock(return_value=np.array(1.0))

invalid_base_model = Mock(spec=fit)
with pytest.raises(NotImplementedError):
ABLModel(invalid_base_model)

invalid_base_model = Mock(spec=predict)
with pytest.raises(NotImplementedError):
ABLModel(invalid_base_model)

base_model = Mock(spec=["fit", "predict"])
base_model.fit = fit
base_model.predict = predict
model = ABLModel(base_model)
assert hasattr(model, "base_model"), "The model should have a 'base_model' attribute."

def test_ablmodel_predict(self, base_model_instance, list_data_instance):
@@ -14,6 +36,16 @@ class TestABLModel(object):
model = ABLModel(base_model_instance)
predictions = model.predict(list_data_instance)
assert isinstance(predictions, dict), "Predictions should be returned as a dictionary."
assert isinstance(predictions["label"], list)
assert isinstance(predictions["prob"], list)

basic_nn_mock = create_autospec(base_model_instance, spec_set=True)
delattr(basic_nn_mock, "predict_proba")
model = ABLModel(basic_nn_mock)
predictions = model.predict(list_data_instance)
assert isinstance(predictions, dict), "Predictions should be returned as a dictionary."
assert isinstance(predictions["label"], list)
assert predictions["prob"] is None

def test_ablmodel_train(self, base_model_instance, list_data_instance):
"""Test the train method of the ABLModel class."""


+ 92
- 33
tests/test_basic_nn.py View File

@@ -1,4 +1,5 @@
import numpy
import pytest
import torch
import torch.nn as nn
import torch.optim as optim
@@ -6,53 +7,111 @@ from torch.utils.data import DataLoader, TensorDataset


class TestBasicNN(object):
# Test initialization
@pytest.fixture
def sample_data(self):
return torch.randn(32, 1, 28, 28), torch.randint(0, 10, (32,))

@pytest.fixture
def sample_data_loader_with_label(self, sample_data):
X, y = sample_data
return DataLoader(TensorDataset(X, y), batch_size=4)

@pytest.fixture
def sample_data_loader_without_label(self, sample_data):
X, _ = sample_data
return DataLoader(
TensorDataset(X),
batch_size=4,
collate_fn=lambda batch: torch.stack([item[0] for item in batch]),
)

def test_initialization(self, basic_nn_instance):
"""Test initialization of the BasicNN class"""
assert basic_nn_instance.model is not None
assert isinstance(basic_nn_instance.criterion, nn.Module)
assert isinstance(basic_nn_instance.optimizer, optim.Optimizer)

# Test training epoch
def test_train_epoch(self, basic_nn_instance):
X = torch.randn(32, 1, 28, 28)
y = torch.randint(0, 10, (32,))
data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
loss = basic_nn_instance.train_epoch(data_loader)
def test_training_methods(self, basic_nn_instance, sample_data, sample_data_loader_with_label):
"""Test train_epoch, fit, and score methods of the BasicNN class"""

# Test train_epoch
loss = basic_nn_instance.train_epoch(sample_data_loader_with_label)
assert isinstance(loss, float)

# Test fit method
def test_fit(self, basic_nn_instance):
X = torch.randn(32, 1, 28, 28)
y = torch.randint(0, 10, (32,))
data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
loss = basic_nn_instance.fit(data_loader)
# Test fit with direct data
X, y = sample_data
loss = basic_nn_instance.fit(X=list(X), y=list(y))
assert isinstance(loss, float)

# Test predict method
def test_predict(self, basic_nn_instance):
X = list(torch.randn(32, 1, 28, 28))
predictions = basic_nn_instance.predict(X=X)
# Test fit with DataLoader
loss = basic_nn_instance.fit(data_loader=sample_data_loader_with_label)
assert isinstance(loss, float)

# Test invalid fit method input
with pytest.raises(ValueError):
basic_nn_instance.fit(X=None, y=None, data_loader=None)

# Test score with direct data
accuracy = basic_nn_instance.score(X=list(X), y=list(y))
assert 0 <= accuracy <= 1

# Test score with DataLoader
accuracy = basic_nn_instance.score(data_loader=sample_data_loader_with_label)
assert 0 <= accuracy <= 1

def test_prediction_methods(
self, basic_nn_instance, sample_data, sample_data_loader_without_label
):
"""Test predict and predict_proba methods of the BasicNN class"""
X, _ = sample_data

# Test predict with direct data
predictions = basic_nn_instance.predict(X=list(X))
assert len(predictions) == len(X)
assert numpy.isin(predictions, list(range(10))).all()

# Test predict_proba method
def test_predict_proba(self, basic_nn_instance):
X = list(torch.randn(32, 1, 28, 28))
predict_proba = basic_nn_instance.predict_proba(X=X)
# Test predict_proba with direct data
predict_proba = basic_nn_instance.predict_proba(X=list(X))
assert len(predict_proba) == len(X)
assert ((0 <= predict_proba) & (predict_proba <= 1)).all()

# Test score method
def test_score(self, basic_nn_instance):
X = torch.randn(32, 1, 28, 28)
y = torch.randint(0, 10, (32,))
data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
accuracy = basic_nn_instance.score(data_loader)
assert 0 <= accuracy <= 1
# Test predict and predict_proba with DataLoader
for method in [basic_nn_instance.predict, basic_nn_instance.predict_proba]:
result = method(data_loader=sample_data_loader_without_label)
assert len(result) == len(X)
if method == basic_nn_instance.predict:
assert numpy.isin(result, list(range(10))).all()
else:
assert ((0 <= result) & (result <= 1)).all()

# Test save and load methods
def test_save_load(self, basic_nn_instance, tmp_path):
model_path = tmp_path / "model.pth"
basic_nn_instance.save(epoch_id=1, save_path=str(model_path))
assert model_path.exists()
basic_nn_instance.load(load_path=str(model_path))
"""Test save and load methods of the BasicNN class"""

# Test save with explicit save_path
explicit_save_path = tmp_path / "model_explicit.pth"
basic_nn_instance.save(epoch_id=1, save_path=str(explicit_save_path))
assert explicit_save_path.exists(), "Model should be saved to the explicit path"

# Test save without providing save_path (using save_dir)
basic_nn_instance.save_dir = str(tmp_path)
implicit_save_path = tmp_path / "model_checkpoint_epoch_1.pth"
basic_nn_instance.save(epoch_id=1)
assert implicit_save_path.exists(), "Model should be saved to the implicit path in save_dir"

# Test error when save_path and save_dir are both None
basic_nn_instance.save_dir = None
with pytest.raises(ValueError):
basic_nn_instance.save(epoch_id=1)

# Test error on loading from a None path
with pytest.raises(ValueError):
basic_nn_instance.load(load_path=None)

# Test loading model state
original_state = basic_nn_instance.model.state_dict()
basic_nn_instance.load(load_path=str(explicit_save_path))
loaded_state = basic_nn_instance.model.state_dict()
for key in original_state:
assert torch.allclose(
original_state[key], loaded_state[key]
), "Model state should be restored after loading"

Loading…
Cancel
Save