Browse Source

[ENH] add abstract data interface to abl_model

ab_data
Gao Enhao 1 year ago
parent
commit
5111f03f7d
1 changed files with 21 additions and 29 deletions
  1. +21
    -29
      abl/learning/abl_model.py

+ 21
- 29
abl/learning/abl_model.py View File

@@ -9,13 +9,11 @@
# Description :
#
# ================================================================#
from typing import List, Any, Optional

import pickle
from typing import Any, Dict

from ..structures import ListData
from ..utils import flatten, reform_idx

from ..utils import reform_idx


class ABLModel:
@@ -34,7 +32,7 @@ class ABLModel:

Methods
-------
predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict
predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict
Predict the labels and probabilities for the given data.
valid(X: List[List[Any]], Y: List[Any]) -> float
Calculate the accuracy score for the given data.
@@ -46,10 +44,7 @@ class ABLModel:
Load the model from a file.
"""

def __init__(self, base_model) -> None:
self.classifier_list = []
self.classifier_list.append(base_model)

def __init__(self, base_model: Any) -> None:
if not (
hasattr(base_model, "fit")
and hasattr(base_model, "predict")
@@ -59,7 +54,9 @@ class ABLModel:
"base_model should have fit, predict and score methods."
)

def predict(self, data_samples: ListData, mapping: Optional[dict] = None) -> dict:
self.base_model = base_model

def predict(self, data_samples: ListData) -> Dict:
"""
Predict the labels and probabilities for the given data.

@@ -67,32 +64,27 @@ class ABLModel:
----------
X : List[List[Any]]
The data to predict on.
mapping : Optional[dict], optional
A mapping dictionary to map labels to their original values, by default None.

Returns
-------
dict
A dictionary containing the predicted labels and probabilities.
"""
model = self.classifier_list[0]
data_X = flatten(data_samples["X"])
model = self.base_model
data_X = data_samples.flatten("X")
if hasattr(model, "predict_proba"):
prob = model.predict_proba(X=data_X)
label = prob.argmax(axis=1)
prob = reform_idx(prob, data_samples["X"])
else:
prob = None
prob = [None] * len(data_samples)
label = model.predict(X=data_X)

if mapping is not None:
label = [mapping[y] for y in label]

label = reform_idx(label, data_samples["X"])

return {"label": label, "prob": prob}

def valid(self, X: List[List[Any]], Y: List[Any]) -> float:
def valid(self, data_samples: ListData) -> float:
"""
Calculate the accuracy for the given data.

@@ -108,9 +100,9 @@ class ABLModel:
float
The accuracy score for the given data.
"""
data_X = flatten(X)
data_Y = flatten(Y)
score = self.classifier_list[0].score(X=data_X, y=data_Y)
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("gt_idx")
score = self.base_model.score(X=data_X, y=data_y)
return score

def train(self, data_samples: ListData) -> float:
@@ -129,12 +121,12 @@ class ABLModel:
float
The loss value of the trained model.
"""
data_X = flatten(data_samples["X"])
data_y = flatten(data_samples["abduced_idx"])
return self.classifier_list[0].fit(X=data_X, y=data_y)
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("abduced_idx")
return self.base_model.fit(X=data_X, y=data_y)

def _model_operation(self, operation: str, *args, **kwargs):
model = self.classifier_list[0]
model = self.base_model
if hasattr(model, operation):
method = getattr(model, operation)
method(*args, **kwargs)
@@ -143,11 +135,11 @@ class ABLModel:
if not f"{operation}_path" in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
if operation == "save":
with open(kwargs["save_path"], 'wb') as file:
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], 'rb') as file:
self.classifier_list[0] = pickle.load(file)
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method"


Loading…
Cancel
Save