You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_abl_model.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from unittest.mock import Mock, create_autospec
  2. import numpy as np
  3. import pytest
  4. from ablkit.learning import ABLModel
  5. class TestABLModel(object):
  6. def test_ablmodel_initialization(self):
  7. """Test the initialization method of the ABLModel class."""
  8. invalid_base_model = Mock(spec=[])
  9. with pytest.raises(NotImplementedError):
  10. ABLModel(invalid_base_model)
  11. invalid_base_model = Mock(spec=["fit"])
  12. invalid_base_model.fit.return_value = 1.0
  13. with pytest.raises(NotImplementedError):
  14. ABLModel(invalid_base_model)
  15. invalid_base_model = Mock(spec=["predict"])
  16. invalid_base_model.predict.return_value = np.array(1.0)
  17. with pytest.raises(NotImplementedError):
  18. ABLModel(invalid_base_model)
  19. base_model = Mock(spec=["fit", "predict"])
  20. base_model.fit.return_value = 1.0
  21. base_model.predict.return_value = np.array(1.0)
  22. model = ABLModel(base_model)
  23. assert hasattr(model, "base_model"), "The model should have a 'base_model' attribute."
  24. def test_ablmodel_predict(self, base_model_instance, list_data_instance):
  25. """Test the predict method of the ABLModel class."""
  26. model = ABLModel(base_model_instance)
  27. predictions = model.predict(list_data_instance)
  28. assert isinstance(predictions, dict), "Predictions should be returned as a dictionary."
  29. assert isinstance(predictions["label"], list)
  30. assert isinstance(predictions["prob"], list)
  31. basic_nn_mock = create_autospec(base_model_instance, spec_set=True)
  32. delattr(basic_nn_mock, "predict_proba")
  33. model = ABLModel(basic_nn_mock)
  34. predictions = model.predict(list_data_instance)
  35. assert isinstance(predictions, dict), "Predictions should be returned as a dictionary."
  36. assert isinstance(predictions["label"], list)
  37. assert predictions["prob"] is None
  38. def test_ablmodel_train(self, base_model_instance, list_data_instance):
  39. """Test the train method of the ABLModel class."""
  40. model = ABLModel(base_model_instance)
  41. list_data_instance.abduced_idx = [[1, 2], [3, 4], [5, 6]]
  42. ins = model.train(list_data_instance)
  43. assert ins == model.base_model, "Training should return the base model instance."
  44. def test_ablmodel_save_load(self, base_model_instance, tmp_path):
  45. """Test the save method of the ABLModel class."""
  46. model = ABLModel(base_model_instance)
  47. model_path = tmp_path / "model.pth"
  48. model.save(save_path=str(model_path))
  49. assert model_path.exists()
  50. model.load(load_path=str(model_path))
  51. assert isinstance(model.base_model, type(base_model_instance))
  52. def test_ablmodel_invalid_operation(self, base_model_instance):
  53. """Test invalid operation handling in the ABLModel class."""
  54. model = ABLModel(base_model_instance)
  55. with pytest.raises(ValueError):
  56. model._model_operation("invalid_operation", save_path=None)
  57. def test_ablmodel_operation_without_path(self, base_model_instance):
  58. """Test operation without providing a path in the ABLModel class."""
  59. model = ABLModel(base_model_instance)
  60. with pytest.raises(ValueError):
  61. model.save() # No path provided

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.