From 5111f03f7df30028499bfad3c2c7c361506d872c Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sat, 11 Nov 2023 00:35:39 +0800 Subject: [PATCH] [ENH] add abstract data interface to abl_model --- abl/learning/abl_model.py | 50 ++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index ae853bc..1a720ec 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -9,13 +9,11 @@ # Description : # # ================================================================# -from typing import List, Any, Optional - import pickle +from typing import Any, Dict from ..structures import ListData -from ..utils import flatten, reform_idx - +from ..utils import reform_idx class ABLModel: @@ -34,7 +32,7 @@ class ABLModel: Methods ------- - predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict + predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict Predict the labels and probabilities for the given data. valid(X: List[List[Any]], Y: List[Any]) -> float Calculate the accuracy score for the given data. @@ -46,10 +44,7 @@ class ABLModel: Load the model from a file. """ - def __init__(self, base_model) -> None: - self.classifier_list = [] - self.classifier_list.append(base_model) - + def __init__(self, base_model: Any) -> None: if not ( hasattr(base_model, "fit") and hasattr(base_model, "predict") @@ -59,7 +54,9 @@ class ABLModel: "base_model should have fit, predict and score methods." ) - def predict(self, data_samples: ListData, mapping: Optional[dict] = None) -> dict: + self.base_model = base_model + + def predict(self, data_samples: ListData) -> Dict: """ Predict the labels and probabilities for the given data. @@ -67,32 +64,27 @@ class ABLModel: ---------- X : List[List[Any]] The data to predict on. - mapping : Optional[dict], optional - A mapping dictionary to map labels to their original values, by default None. Returns ------- dict A dictionary containing the predicted labels and probabilities. """ - model = self.classifier_list[0] - data_X = flatten(data_samples["X"]) + model = self.base_model + data_X = data_samples.flatten("X") if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) label = prob.argmax(axis=1) prob = reform_idx(prob, data_samples["X"]) else: - prob = None + prob = [None] * len(data_samples) label = model.predict(X=data_X) - if mapping is not None: - label = [mapping[y] for y in label] - label = reform_idx(label, data_samples["X"]) return {"label": label, "prob": prob} - def valid(self, X: List[List[Any]], Y: List[Any]) -> float: + def valid(self, data_samples: ListData) -> float: """ Calculate the accuracy for the given data. @@ -108,9 +100,9 @@ class ABLModel: float The accuracy score for the given data. """ - data_X = flatten(X) - data_Y = flatten(Y) - score = self.classifier_list[0].score(X=data_X, y=data_Y) + data_X = data_samples.flatten("X") + data_y = data_samples.flatten("gt_idx") + score = self.base_model.score(X=data_X, y=data_y) return score def train(self, data_samples: ListData) -> float: @@ -129,12 +121,12 @@ class ABLModel: float The loss value of the trained model. """ - data_X = flatten(data_samples["X"]) - data_y = flatten(data_samples["abduced_idx"]) - return self.classifier_list[0].fit(X=data_X, y=data_y) + data_X = data_samples.flatten("X") + data_y = data_samples.flatten("abduced_idx") + return self.base_model.fit(X=data_X, y=data_y) def _model_operation(self, operation: str, *args, **kwargs): - model = self.classifier_list[0] + model = self.base_model if hasattr(model, operation): method = getattr(model, operation) method(*args, **kwargs) @@ -143,11 +135,11 @@ class ABLModel: if not f"{operation}_path" in kwargs.keys(): raise ValueError(f"'{operation}_path' should not be None") if operation == "save": - with open(kwargs["save_path"], 'wb') as file: + with open(kwargs["save_path"], "wb") as file: pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) elif operation == "load": - with open(kwargs["load_path"], 'rb') as file: - self.classifier_list[0] = pickle.load(file) + with open(kwargs["load_path"], "rb") as file: + self.base_model = pickle.load(file) except: raise NotImplementedError( f"{type(model).__name__} object doesn't have the {operation} method"