from unittest.mock import Mock, create_autospec import numpy as np import pytest from ablkit.learning import ABLModel class TestABLModel(object): def test_ablmodel_initialization(self): """Test the initialization method of the ABLModel class.""" invalid_base_model = Mock(spec=[]) with pytest.raises(NotImplementedError): ABLModel(invalid_base_model) invalid_base_model = Mock(spec=["fit"]) invalid_base_model.fit.return_value = 1.0 with pytest.raises(NotImplementedError): ABLModel(invalid_base_model) invalid_base_model = Mock(spec=["predict"]) invalid_base_model.predict.return_value = np.array(1.0) with pytest.raises(NotImplementedError): ABLModel(invalid_base_model) base_model = Mock(spec=["fit", "predict"]) base_model.fit.return_value = 1.0 base_model.predict.return_value = np.array(1.0) model = ABLModel(base_model) assert hasattr(model, "base_model"), "The model should have a 'base_model' attribute." def test_ablmodel_predict(self, base_model_instance, list_data_instance): """Test the predict method of the ABLModel class.""" model = ABLModel(base_model_instance) predictions = model.predict(list_data_instance) assert isinstance(predictions, dict), "Predictions should be returned as a dictionary." assert isinstance(predictions["label"], list) assert isinstance(predictions["prob"], list) basic_nn_mock = create_autospec(base_model_instance, spec_set=True) delattr(basic_nn_mock, "predict_proba") model = ABLModel(basic_nn_mock) predictions = model.predict(list_data_instance) assert isinstance(predictions, dict), "Predictions should be returned as a dictionary." assert isinstance(predictions["label"], list) assert predictions["prob"] is None def test_ablmodel_train(self, base_model_instance, list_data_instance): """Test the train method of the ABLModel class.""" model = ABLModel(base_model_instance) list_data_instance.abduced_idx = [[1, 2], [3, 4], [5, 6]] ins = model.train(list_data_instance) assert ins == model.base_model, "Training should return the base model instance." def test_ablmodel_save_load(self, base_model_instance, tmp_path): """Test the save method of the ABLModel class.""" model = ABLModel(base_model_instance) model_path = tmp_path / "model.pth" model.save(save_path=str(model_path)) assert model_path.exists() model.load(load_path=str(model_path)) assert isinstance(model.base_model, type(base_model_instance)) def test_ablmodel_invalid_operation(self, base_model_instance): """Test invalid operation handling in the ABLModel class.""" model = ABLModel(base_model_instance) with pytest.raises(ValueError): model._model_operation("invalid_operation", save_path=None) def test_ablmodel_operation_without_path(self, base_model_instance): """Test operation without providing a path in the ABLModel class.""" model = ABLModel(base_model_instance) with pytest.raises(ValueError): model.save() # No path provided