# coding: utf-8 # ================================================================# # Copyright (C) 2020 Freecss All rights reserved. # # File Name :models.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2020/04/02 # Description : # # ================================================================# import pickle from typing import Any, Dict from ..structures import ListData from ..utils import reform_idx class ABLModel: """ Serialize data and provide a unified interface for different machine learning models. Parameters ---------- base_model : Machine Learning Model The base model to use for training and prediction. Attributes ---------- classifier_list : List[Any] A list of classifiers. Methods ------- predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict Predict the labels and probabilities for the given data. valid(X: List[List[Any]], Y: List[Any]) -> float Calculate the accuracy score for the given data. train(X: List[List[Any]], Y: List[Any]) -> float Train the model on the given data. save(*args, **kwargs) -> None Save the model to a file. load(*args, **kwargs) -> None Load the model from a file. """ def __init__(self, base_model: Any) -> None: if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): raise NotImplementedError("The base_model should implement fit and predict methods.") self.base_model = base_model def predict(self, data_samples: ListData) -> Dict: """ Predict the labels and probabilities for the given data. Parameters ---------- X : List[List[Any]] The data to predict on. Returns ------- dict A dictionary containing the predicted labels and probabilities. """ model = self.base_model data_X = data_samples.flatten("X") if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) prob = reform_idx(prob, data_samples["X"]) else: prob = None label = model.predict(X=data_X) label = reform_idx(label, data_samples["X"]) data_samples.pred_idx = label if prob is not None: data_samples.pred_prob = prob return {"label": label, "prob": prob} def train(self, data_samples: ListData) -> float: """ Train the model on the given data. Parameters ---------- X : List[List[Any]] The data to train on. Y : List[Any] The true labels for the given data. Returns ------- float The loss value of the trained model. """ data_X = data_samples.flatten("X") data_y = data_samples.flatten("abduced_idx") return self.base_model.fit(X=data_X, y=data_y) def _model_operation(self, operation: str, *args, **kwargs): model = self.base_model if hasattr(model, operation): method = getattr(model, operation) method(*args, **kwargs) else: if not f"{operation}_path" in kwargs.keys(): raise ValueError(f"'{operation}_path' should not be None") else: try: if operation == "save": with open(kwargs["save_path"], "wb") as file: pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) elif operation == "load": with open(kwargs["load_path"], "rb") as file: self.base_model = pickle.load(file) except: raise NotImplementedError( f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." ) def save(self, *args, **kwargs) -> None: """ Save the model to a file. """ self._model_operation("save", *args, **kwargs) def load(self, *args, **kwargs) -> None: """ Load the model from a file. """ self._model_operation("load", *args, **kwargs)