|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- import numpy
- import pytest
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader, TensorDataset
-
-
- class TestBasicNN(object):
- @pytest.fixture
- def sample_data(self):
- return torch.randn(32, 1, 28, 28), torch.randint(0, 10, (32,))
-
- @pytest.fixture
- def sample_data_loader_with_label(self, sample_data):
- X, y = sample_data
- return DataLoader(TensorDataset(X, y), batch_size=4)
-
- @pytest.fixture
- def sample_data_loader_without_label(self, sample_data):
- X, _ = sample_data
- return DataLoader(
- TensorDataset(X),
- batch_size=4,
- collate_fn=lambda batch: torch.stack([item[0] for item in batch]),
- )
-
- def test_initialization(self, basic_nn_instance):
- """Test initialization of the BasicNN class"""
- assert basic_nn_instance.model is not None
- assert isinstance(basic_nn_instance.loss_fn, nn.Module)
- assert isinstance(basic_nn_instance.optimizer, optim.Optimizer)
-
- def test_training_methods(self, basic_nn_instance, sample_data, sample_data_loader_with_label):
- """Test train_epoch, fit, and score methods of the BasicNN class"""
-
- # Test train_epoch
- loss = basic_nn_instance.train_epoch(sample_data_loader_with_label)
- assert isinstance(loss, float)
-
- # Test fit with direct data
- X, y = sample_data
- ins = basic_nn_instance.fit(X=list(X), y=list(y))
- assert ins == basic_nn_instance
-
- # Test fit with DataLoader
- ins = basic_nn_instance.fit(data_loader=sample_data_loader_with_label)
- assert ins == basic_nn_instance
-
- # Test invalid fit method input
- with pytest.raises(ValueError):
- basic_nn_instance.fit(X=None, y=None, data_loader=None)
-
- # Test score with direct data
- accuracy = basic_nn_instance.score(X=list(X), y=list(y))
- assert 0 <= accuracy <= 1
-
- # Test score with DataLoader
- accuracy = basic_nn_instance.score(data_loader=sample_data_loader_with_label)
- assert 0 <= accuracy <= 1
-
- def test_prediction_methods(
- self, basic_nn_instance, sample_data, sample_data_loader_without_label
- ):
- """Test predict and predict_proba methods of the BasicNN class"""
- X, _ = sample_data
-
- # Test predict with direct data
- predictions = basic_nn_instance.predict(X=list(X))
- assert len(predictions) == len(X)
- assert numpy.isin(predictions, list(range(10))).all()
-
- # Test predict_proba with direct data
- predict_proba = basic_nn_instance.predict_proba(X=list(X))
- assert len(predict_proba) == len(X)
- assert ((0 <= predict_proba) & (predict_proba <= 1)).all()
-
- # Test predict and predict_proba with DataLoader
- for method in [basic_nn_instance.predict, basic_nn_instance.predict_proba]:
- result = method(data_loader=sample_data_loader_without_label)
- assert len(result) == len(X)
- if method == basic_nn_instance.predict:
- assert numpy.isin(result, list(range(10))).all()
- else:
- assert ((0 <= result) & (result <= 1)).all()
-
- def test_save_load(self, basic_nn_instance, tmp_path):
- """Test save and load methods of the BasicNN class"""
-
- # Test save with explicit save_path
- explicit_save_path = tmp_path / "model_explicit.pth"
- basic_nn_instance.save(epoch_id=1, save_path=str(explicit_save_path))
- assert explicit_save_path.exists(), "Model should be saved to the explicit path"
-
- # Test save without providing save_path (using save_dir)
- basic_nn_instance.save_dir = str(tmp_path)
- implicit_save_path = tmp_path / "model_checkpoint_epoch_1.pth"
- basic_nn_instance.save(epoch_id=1)
- assert implicit_save_path.exists(), "Model should be saved to the implicit path in save_dir"
-
- # Test error when save_path and save_dir are both None
- basic_nn_instance.save_dir = None
- with pytest.raises(ValueError):
- basic_nn_instance.save(epoch_id=1)
-
- # Test error on loading from a None path
- with pytest.raises(ValueError):
- basic_nn_instance.load(load_path=None)
-
- # Test loading model state
- original_state = basic_nn_instance.model.state_dict()
- basic_nn_instance.load(load_path=str(explicit_save_path))
- loaded_state = basic_nn_instance.model.state_dict()
- for key in original_state:
- assert torch.allclose(
- original_state[key], loaded_state[key]
- ), "Model state should be restored after loading"
|