|
|
@@ -1,4 +1,5 @@ |
|
|
|
import numpy |
|
|
|
import pytest |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.optim as optim |
|
|
@@ -6,53 +7,111 @@ from torch.utils.data import DataLoader, TensorDataset |
|
|
|
|
|
|
|
|
|
|
|
class TestBasicNN(object): |
|
|
|
# Test initialization |
|
|
|
@pytest.fixture |
|
|
|
def sample_data(self): |
|
|
|
return torch.randn(32, 1, 28, 28), torch.randint(0, 10, (32,)) |
|
|
|
|
|
|
|
@pytest.fixture |
|
|
|
def sample_data_loader_with_label(self, sample_data): |
|
|
|
X, y = sample_data |
|
|
|
return DataLoader(TensorDataset(X, y), batch_size=4) |
|
|
|
|
|
|
|
@pytest.fixture |
|
|
|
def sample_data_loader_without_label(self, sample_data): |
|
|
|
X, _ = sample_data |
|
|
|
return DataLoader( |
|
|
|
TensorDataset(X), |
|
|
|
batch_size=4, |
|
|
|
collate_fn=lambda batch: torch.stack([item[0] for item in batch]), |
|
|
|
) |
|
|
|
|
|
|
|
def test_initialization(self, basic_nn_instance): |
|
|
|
"""Test initialization of the BasicNN class""" |
|
|
|
assert basic_nn_instance.model is not None |
|
|
|
assert isinstance(basic_nn_instance.criterion, nn.Module) |
|
|
|
assert isinstance(basic_nn_instance.optimizer, optim.Optimizer) |
|
|
|
|
|
|
|
# Test training epoch |
|
|
|
def test_train_epoch(self, basic_nn_instance): |
|
|
|
X = torch.randn(32, 1, 28, 28) |
|
|
|
y = torch.randint(0, 10, (32,)) |
|
|
|
data_loader = DataLoader(TensorDataset(X, y), batch_size=4) |
|
|
|
loss = basic_nn_instance.train_epoch(data_loader) |
|
|
|
def test_training_methods(self, basic_nn_instance, sample_data, sample_data_loader_with_label): |
|
|
|
"""Test train_epoch, fit, and score methods of the BasicNN class""" |
|
|
|
|
|
|
|
# Test train_epoch |
|
|
|
loss = basic_nn_instance.train_epoch(sample_data_loader_with_label) |
|
|
|
assert isinstance(loss, float) |
|
|
|
|
|
|
|
# Test fit method |
|
|
|
def test_fit(self, basic_nn_instance): |
|
|
|
X = torch.randn(32, 1, 28, 28) |
|
|
|
y = torch.randint(0, 10, (32,)) |
|
|
|
data_loader = DataLoader(TensorDataset(X, y), batch_size=4) |
|
|
|
loss = basic_nn_instance.fit(data_loader) |
|
|
|
# Test fit with direct data |
|
|
|
X, y = sample_data |
|
|
|
loss = basic_nn_instance.fit(X=list(X), y=list(y)) |
|
|
|
assert isinstance(loss, float) |
|
|
|
|
|
|
|
# Test predict method |
|
|
|
def test_predict(self, basic_nn_instance): |
|
|
|
X = list(torch.randn(32, 1, 28, 28)) |
|
|
|
predictions = basic_nn_instance.predict(X=X) |
|
|
|
# Test fit with DataLoader |
|
|
|
loss = basic_nn_instance.fit(data_loader=sample_data_loader_with_label) |
|
|
|
assert isinstance(loss, float) |
|
|
|
|
|
|
|
# Test invalid fit method input |
|
|
|
with pytest.raises(ValueError): |
|
|
|
basic_nn_instance.fit(X=None, y=None, data_loader=None) |
|
|
|
|
|
|
|
# Test score with direct data |
|
|
|
accuracy = basic_nn_instance.score(X=list(X), y=list(y)) |
|
|
|
assert 0 <= accuracy <= 1 |
|
|
|
|
|
|
|
# Test score with DataLoader |
|
|
|
accuracy = basic_nn_instance.score(data_loader=sample_data_loader_with_label) |
|
|
|
assert 0 <= accuracy <= 1 |
|
|
|
|
|
|
|
def test_prediction_methods( |
|
|
|
self, basic_nn_instance, sample_data, sample_data_loader_without_label |
|
|
|
): |
|
|
|
"""Test predict and predict_proba methods of the BasicNN class""" |
|
|
|
X, _ = sample_data |
|
|
|
|
|
|
|
# Test predict with direct data |
|
|
|
predictions = basic_nn_instance.predict(X=list(X)) |
|
|
|
assert len(predictions) == len(X) |
|
|
|
assert numpy.isin(predictions, list(range(10))).all() |
|
|
|
|
|
|
|
# Test predict_proba method |
|
|
|
def test_predict_proba(self, basic_nn_instance): |
|
|
|
X = list(torch.randn(32, 1, 28, 28)) |
|
|
|
predict_proba = basic_nn_instance.predict_proba(X=X) |
|
|
|
# Test predict_proba with direct data |
|
|
|
predict_proba = basic_nn_instance.predict_proba(X=list(X)) |
|
|
|
assert len(predict_proba) == len(X) |
|
|
|
assert ((0 <= predict_proba) & (predict_proba <= 1)).all() |
|
|
|
|
|
|
|
# Test score method |
|
|
|
def test_score(self, basic_nn_instance): |
|
|
|
X = torch.randn(32, 1, 28, 28) |
|
|
|
y = torch.randint(0, 10, (32,)) |
|
|
|
data_loader = DataLoader(TensorDataset(X, y), batch_size=4) |
|
|
|
accuracy = basic_nn_instance.score(data_loader) |
|
|
|
assert 0 <= accuracy <= 1 |
|
|
|
# Test predict and predict_proba with DataLoader |
|
|
|
for method in [basic_nn_instance.predict, basic_nn_instance.predict_proba]: |
|
|
|
result = method(data_loader=sample_data_loader_without_label) |
|
|
|
assert len(result) == len(X) |
|
|
|
if method == basic_nn_instance.predict: |
|
|
|
assert numpy.isin(result, list(range(10))).all() |
|
|
|
else: |
|
|
|
assert ((0 <= result) & (result <= 1)).all() |
|
|
|
|
|
|
|
# Test save and load methods |
|
|
|
def test_save_load(self, basic_nn_instance, tmp_path): |
|
|
|
model_path = tmp_path / "model.pth" |
|
|
|
basic_nn_instance.save(epoch_id=1, save_path=str(model_path)) |
|
|
|
assert model_path.exists() |
|
|
|
basic_nn_instance.load(load_path=str(model_path)) |
|
|
|
"""Test save and load methods of the BasicNN class""" |
|
|
|
|
|
|
|
# Test save with explicit save_path |
|
|
|
explicit_save_path = tmp_path / "model_explicit.pth" |
|
|
|
basic_nn_instance.save(epoch_id=1, save_path=str(explicit_save_path)) |
|
|
|
assert explicit_save_path.exists(), "Model should be saved to the explicit path" |
|
|
|
|
|
|
|
# Test save without providing save_path (using save_dir) |
|
|
|
basic_nn_instance.save_dir = str(tmp_path) |
|
|
|
implicit_save_path = tmp_path / "model_checkpoint_epoch_1.pth" |
|
|
|
basic_nn_instance.save(epoch_id=1) |
|
|
|
assert implicit_save_path.exists(), "Model should be saved to the implicit path in save_dir" |
|
|
|
|
|
|
|
# Test error when save_path and save_dir are both None |
|
|
|
basic_nn_instance.save_dir = None |
|
|
|
with pytest.raises(ValueError): |
|
|
|
basic_nn_instance.save(epoch_id=1) |
|
|
|
|
|
|
|
# Test error on loading from a None path |
|
|
|
with pytest.raises(ValueError): |
|
|
|
basic_nn_instance.load(load_path=None) |
|
|
|
|
|
|
|
# Test loading model state |
|
|
|
original_state = basic_nn_instance.model.state_dict() |
|
|
|
basic_nn_instance.load(load_path=str(explicit_save_path)) |
|
|
|
loaded_state = basic_nn_instance.model.state_dict() |
|
|
|
for key in original_state: |
|
|
|
assert torch.allclose( |
|
|
|
original_state[key], loaded_state[key] |
|
|
|
), "Model state should be restored after loading" |