You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_basic_nn.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import numpy
  2. import pytest
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.utils.data import DataLoader, TensorDataset
  7. class TestBasicNN(object):
  8. @pytest.fixture
  9. def sample_data(self):
  10. return torch.randn(32, 1, 28, 28), torch.randint(0, 10, (32,))
  11. @pytest.fixture
  12. def sample_data_loader_with_label(self, sample_data):
  13. X, y = sample_data
  14. return DataLoader(TensorDataset(X, y), batch_size=4)
  15. @pytest.fixture
  16. def sample_data_loader_without_label(self, sample_data):
  17. X, _ = sample_data
  18. return DataLoader(
  19. TensorDataset(X),
  20. batch_size=4,
  21. collate_fn=lambda batch: torch.stack([item[0] for item in batch]),
  22. )
  23. def test_initialization(self, basic_nn_instance):
  24. """Test initialization of the BasicNN class"""
  25. assert basic_nn_instance.model is not None
  26. assert isinstance(basic_nn_instance.loss_fn, nn.Module)
  27. assert isinstance(basic_nn_instance.optimizer, optim.Optimizer)
  28. def test_training_methods(self, basic_nn_instance, sample_data, sample_data_loader_with_label):
  29. """Test train_epoch, fit, and score methods of the BasicNN class"""
  30. # Test train_epoch
  31. loss = basic_nn_instance.train_epoch(sample_data_loader_with_label)
  32. assert isinstance(loss, float)
  33. # Test fit with direct data
  34. X, y = sample_data
  35. ins = basic_nn_instance.fit(X=list(X), y=list(y))
  36. assert ins == basic_nn_instance
  37. # Test fit with DataLoader
  38. ins = basic_nn_instance.fit(data_loader=sample_data_loader_with_label)
  39. assert ins == basic_nn_instance
  40. # Test invalid fit method input
  41. with pytest.raises(ValueError):
  42. basic_nn_instance.fit(X=None, y=None, data_loader=None)
  43. # Test score with direct data
  44. accuracy = basic_nn_instance.score(X=list(X), y=list(y))
  45. assert 0 <= accuracy <= 1
  46. # Test score with DataLoader
  47. accuracy = basic_nn_instance.score(data_loader=sample_data_loader_with_label)
  48. assert 0 <= accuracy <= 1
  49. def test_prediction_methods(
  50. self, basic_nn_instance, sample_data, sample_data_loader_without_label
  51. ):
  52. """Test predict and predict_proba methods of the BasicNN class"""
  53. X, _ = sample_data
  54. # Test predict with direct data
  55. predictions = basic_nn_instance.predict(X=list(X))
  56. assert len(predictions) == len(X)
  57. assert numpy.isin(predictions, list(range(10))).all()
  58. # Test predict_proba with direct data
  59. predict_proba = basic_nn_instance.predict_proba(X=list(X))
  60. assert len(predict_proba) == len(X)
  61. assert ((0 <= predict_proba) & (predict_proba <= 1)).all()
  62. # Test predict and predict_proba with DataLoader
  63. for method in [basic_nn_instance.predict, basic_nn_instance.predict_proba]:
  64. result = method(data_loader=sample_data_loader_without_label)
  65. assert len(result) == len(X)
  66. if method == basic_nn_instance.predict:
  67. assert numpy.isin(result, list(range(10))).all()
  68. else:
  69. assert ((0 <= result) & (result <= 1)).all()
  70. def test_save_load(self, basic_nn_instance, tmp_path):
  71. """Test save and load methods of the BasicNN class"""
  72. # Test save with explicit save_path
  73. explicit_save_path = tmp_path / "model_explicit.pth"
  74. basic_nn_instance.save(epoch_id=1, save_path=str(explicit_save_path))
  75. assert explicit_save_path.exists(), "Model should be saved to the explicit path"
  76. # Test save without providing save_path (using save_dir)
  77. basic_nn_instance.save_dir = str(tmp_path)
  78. implicit_save_path = tmp_path / "model_checkpoint_epoch_1.pth"
  79. basic_nn_instance.save(epoch_id=1)
  80. assert implicit_save_path.exists(), "Model should be saved to the implicit path in save_dir"
  81. # Test error when save_path and save_dir are both None
  82. basic_nn_instance.save_dir = None
  83. with pytest.raises(ValueError):
  84. basic_nn_instance.save(epoch_id=1)
  85. # Test error on loading from a None path
  86. with pytest.raises(ValueError):
  87. basic_nn_instance.load(load_path=None)
  88. # Test loading model state
  89. original_state = basic_nn_instance.model.state_dict()
  90. basic_nn_instance.load(load_path=str(explicit_save_path))
  91. loaded_state = basic_nn_instance.model.state_dict()
  92. for key in original_state:
  93. assert torch.allclose(
  94. original_state[key], loaded_state[key]
  95. ), "Model state should be restored after loading"

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.