From e32c099444a85ae81807a3b75b57107364c93146 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 17 Nov 2023 17:05:31 +0800 Subject: [PATCH] [ENH] add test for learning --- tests/conftest.py | 34 +++++++++++++++++++ tests/test_abl_model.py | 44 ++++++++++++++++++++++++ tests/test_basic_nn.py | 51 ++++++++++++++++++++++++++++ tests/test_models.py | 75 ----------------------------------------- 4 files changed, 129 insertions(+), 75 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_abl_model.py create mode 100644 tests/test_basic_nn.py delete mode 100644 tests/test_models.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fac5523 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,34 @@ +import pytest +import torch +import torch.nn as nn +import torch.optim as optim + +from abl.learning import BasicNN +from abl.structures import ListData +from examples.models.nn import LeNet5 + + +# Fixture for BasicNN instance +@pytest.fixture +def basic_nn_instance(): + model = LeNet5() + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters()) + return BasicNN(model, criterion, optimizer) + +# Fixture for base_model instance +@pytest.fixture +def base_model_instance(): + model = LeNet5() + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters()) + return BasicNN(model, criterion, optimizer) + +# Fixture for ListData instance +@pytest.fixture +def list_data_instance(): + data_samples = ListData() + data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)] + data_samples.Y = [1, 2, 3] + data_samples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] + return data_samples diff --git a/tests/test_abl_model.py b/tests/test_abl_model.py new file mode 100644 index 0000000..5a2d50c --- /dev/null +++ b/tests/test_abl_model.py @@ -0,0 +1,44 @@ +import pytest + +from abl.learning import ABLModel + + +class TestABLModel(object): + def test_ablmodel_initialization(self, base_model_instance): + """Test the initialization of the ABLModel class.""" + model = ABLModel(base_model_instance) + assert hasattr(model, "base_model"), "The model should have a 'base_model' attribute." + + def test_ablmodel_predict(self, base_model_instance, list_data_instance): + """Test the predict method of the ABLModel class.""" + model = ABLModel(base_model_instance) + predictions = model.predict(list_data_instance) + assert isinstance(predictions, dict), "Predictions should be returned as a dictionary." + + def test_ablmodel_train(self, base_model_instance, list_data_instance): + """Test the train method of the ABLModel class.""" + model = ABLModel(base_model_instance) + list_data_instance.abduced_idx = [[1, 2], [3, 4], [5, 6]] + loss = model.train(list_data_instance) + assert isinstance(loss, float), "Training should return a float value indicating the loss." + + def test_ablmodel_save_load(self, base_model_instance, tmp_path): + """Test the save method of the ABLModel class.""" + model = ABLModel(base_model_instance) + model_path = tmp_path / "model.pth" + model.save(save_path=str(model_path)) + assert model_path.exists() + model.load(load_path=str(model_path)) + assert isinstance(model.base_model, type(base_model_instance)) + + def test_ablmodel_invalid_operation(self, base_model_instance): + """Test invalid operation handling in the ABLModel class.""" + model = ABLModel(base_model_instance) + with pytest.raises(ValueError): + model._model_operation("invalid_operation", save_path=None) + + def test_ablmodel_operation_without_path(self, base_model_instance): + """Test operation without providing a path in the ABLModel class.""" + model = ABLModel(base_model_instance) + with pytest.raises(ValueError): + model.save() # No path provided diff --git a/tests/test_basic_nn.py b/tests/test_basic_nn.py new file mode 100644 index 0000000..f344700 --- /dev/null +++ b/tests/test_basic_nn.py @@ -0,0 +1,51 @@ +import numpy +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + + +class TestBasicNN(object): + # Test initialization + def test_initialization(self, basic_nn_instance): + assert basic_nn_instance.model is not None + assert isinstance(basic_nn_instance.criterion, nn.Module) + assert isinstance(basic_nn_instance.optimizer, optim.Optimizer) + + # Test training epoch + def test_train_epoch(self, basic_nn_instance): + X = torch.randn(32, 1, 28, 28) + y = torch.randint(0, 10, (32,)) + data_loader = DataLoader(TensorDataset(X, y), batch_size=4) + loss = basic_nn_instance.train_epoch(data_loader) + assert isinstance(loss, float) + + # Test fit method + def test_fit(self, basic_nn_instance): + X = torch.randn(32, 1, 28, 28) + y = torch.randint(0, 10, (32,)) + data_loader = DataLoader(TensorDataset(X, y), batch_size=4) + loss = basic_nn_instance.fit(data_loader) + assert isinstance(loss, float) + + # Test predict method + def test_predict(self, basic_nn_instance): + X = list(torch.randn(32, 1, 28, 28)) + predictions = basic_nn_instance.predict(X=X) + assert len(predictions) == len(X) + assert numpy.isin(predictions, list(range(10))).all() + + # Test score method + def test_score(self, basic_nn_instance): + X = torch.randn(32, 1, 28, 28) + y = torch.randint(0, 10, (32,)) + data_loader = DataLoader(TensorDataset(X, y), batch_size=4) + accuracy = basic_nn_instance.score(data_loader) + assert 0 <= accuracy <= 1 + + # Test save and load methods + def test_save_load(self, basic_nn_instance, tmp_path): + model_path = tmp_path / "model.pth" + basic_nn_instance.save(epoch_id=1, save_path=str(model_path)) + assert model_path.exists() + basic_nn_instance.load(load_path=str(model_path)) diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index d3bf3e5..0000000 --- a/tests/test_models.py +++ /dev/null @@ -1,75 +0,0 @@ -import sys - -sys.path.insert(0, sys.path[0] + "/../") - -import os -import pytest -import torch -import torch.nn as nn -import numpy as np - -from examples.models.nn import LeNet5, SymbolNet -from abl.models.basic_model import BasicModel - - -class TestBasicModel(object): - @pytest.mark.parametrize("num_classes", [4, 10]) - @pytest.mark.parametrize("image_size", [(28, 28, 1), (45, 45, 1)]) - @pytest.mark.parametrize("cls", [LeNet5, SymbolNet]) - @pytest.mark.parametrize("criterion", [nn.CrossEntropyLoss]) - @pytest.mark.parametrize("optimizer", [torch.optim.RMSprop]) - @pytest.mark.parametrize("device", [torch.device("cpu")]) - def test_models(self, num_classes, image_size, cls, criterion, optimizer, device): - cls = cls(num_classes=num_classes, image_size=image_size) - criterion = criterion() - optimizer = optimizer(cls.parameters(), lr=0.001) - - self.num_classes = num_classes - self.image_size = image_size - self.model = BasicModel(cls, criterion, optimizer, device) - - self.data_X = [ - np.random.rand(image_size[2], image_size[0], image_size[1]).astype( - np.float32 - ) - for i in range(5) - ] - self.data_y = np.random.randint(0, num_classes, (5,)) - - self._test_fit() - self._test_predict() - self._test_predict_proba() - self._test_score() - self._test_save() - self._test_load() - - def _test_fit(self): - self.model.fit(X=self.data_X, y=self.data_y) - - def _test_predict(self): - predict_result = self.model.predict(X=self.data_X) - assert predict_result.dtype == int - assert predict_result.shape == (5,) - assert (0 <= predict_result).all() and (predict_result < self.num_classes).all() - - def _test_predict_proba(self): - predict_result = self.model.predict_proba(X=self.data_X) - assert predict_result.dtype == np.float32 - assert predict_result.shape == (5, self.num_classes) - assert (0 <= predict_result).all() and (predict_result <= 1).all() - - def _test_score(self): - accuracy = self.model.score(X=self.data_X, y=self.data_y) - assert type(accuracy) == float - assert 0 <= accuracy <= 1 - - def _test_save(self): - self.model.save(1, "results/test_models") - assert os.path.exists("results/test_models/1_net.pth") - assert os.path.exists("results/test_models/1_opt.pth") - os.remove("results/test_models/1_net.pth") - os.remove("results/test_models/1_opt.pth") - - def _test_load(self): - self.model.save(1, "results/test_models") - self.model.load(1, "results/test_models")