|
|
@@ -9,13 +9,11 @@ |
|
|
|
# Description : |
|
|
|
# |
|
|
|
# ================================================================# |
|
|
|
from typing import List, Any, Optional |
|
|
|
|
|
|
|
import pickle |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
from ..structures import ListData |
|
|
|
from ..utils import flatten, reform_idx |
|
|
|
|
|
|
|
from ..utils import reform_idx |
|
|
|
|
|
|
|
|
|
|
|
class ABLModel: |
|
|
@@ -34,7 +32,7 @@ class ABLModel: |
|
|
|
|
|
|
|
Methods |
|
|
|
------- |
|
|
|
predict(X: List[List[Any]], mapping: Optional[dict] = None) -> dict |
|
|
|
predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict |
|
|
|
Predict the labels and probabilities for the given data. |
|
|
|
valid(X: List[List[Any]], Y: List[Any]) -> float |
|
|
|
Calculate the accuracy score for the given data. |
|
|
@@ -46,10 +44,7 @@ class ABLModel: |
|
|
|
Load the model from a file. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, base_model) -> None: |
|
|
|
self.classifier_list = [] |
|
|
|
self.classifier_list.append(base_model) |
|
|
|
|
|
|
|
def __init__(self, base_model: Any) -> None: |
|
|
|
if not ( |
|
|
|
hasattr(base_model, "fit") |
|
|
|
and hasattr(base_model, "predict") |
|
|
@@ -59,7 +54,9 @@ class ABLModel: |
|
|
|
"base_model should have fit, predict and score methods." |
|
|
|
) |
|
|
|
|
|
|
|
def predict(self, data_samples: ListData, mapping: Optional[dict] = None) -> dict: |
|
|
|
self.base_model = base_model |
|
|
|
|
|
|
|
def predict(self, data_samples: ListData) -> Dict: |
|
|
|
""" |
|
|
|
Predict the labels and probabilities for the given data. |
|
|
|
|
|
|
@@ -67,32 +64,27 @@ class ABLModel: |
|
|
|
---------- |
|
|
|
X : List[List[Any]] |
|
|
|
The data to predict on. |
|
|
|
mapping : Optional[dict], optional |
|
|
|
A mapping dictionary to map labels to their original values, by default None. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
dict |
|
|
|
A dictionary containing the predicted labels and probabilities. |
|
|
|
""" |
|
|
|
model = self.classifier_list[0] |
|
|
|
data_X = flatten(data_samples["X"]) |
|
|
|
model = self.base_model |
|
|
|
data_X = data_samples.flatten("X") |
|
|
|
if hasattr(model, "predict_proba"): |
|
|
|
prob = model.predict_proba(X=data_X) |
|
|
|
label = prob.argmax(axis=1) |
|
|
|
prob = reform_idx(prob, data_samples["X"]) |
|
|
|
else: |
|
|
|
prob = None |
|
|
|
prob = [None] * len(data_samples) |
|
|
|
label = model.predict(X=data_X) |
|
|
|
|
|
|
|
if mapping is not None: |
|
|
|
label = [mapping[y] for y in label] |
|
|
|
|
|
|
|
label = reform_idx(label, data_samples["X"]) |
|
|
|
|
|
|
|
return {"label": label, "prob": prob} |
|
|
|
|
|
|
|
def valid(self, X: List[List[Any]], Y: List[Any]) -> float: |
|
|
|
def valid(self, data_samples: ListData) -> float: |
|
|
|
""" |
|
|
|
Calculate the accuracy for the given data. |
|
|
|
|
|
|
@@ -108,9 +100,9 @@ class ABLModel: |
|
|
|
float |
|
|
|
The accuracy score for the given data. |
|
|
|
""" |
|
|
|
data_X = flatten(X) |
|
|
|
data_Y = flatten(Y) |
|
|
|
score = self.classifier_list[0].score(X=data_X, y=data_Y) |
|
|
|
data_X = data_samples.flatten("X") |
|
|
|
data_y = data_samples.flatten("gt_idx") |
|
|
|
score = self.base_model.score(X=data_X, y=data_y) |
|
|
|
return score |
|
|
|
|
|
|
|
def train(self, data_samples: ListData) -> float: |
|
|
@@ -129,12 +121,12 @@ class ABLModel: |
|
|
|
float |
|
|
|
The loss value of the trained model. |
|
|
|
""" |
|
|
|
data_X = flatten(data_samples["X"]) |
|
|
|
data_y = flatten(data_samples["abduced_idx"]) |
|
|
|
return self.classifier_list[0].fit(X=data_X, y=data_y) |
|
|
|
data_X = data_samples.flatten("X") |
|
|
|
data_y = data_samples.flatten("abduced_idx") |
|
|
|
return self.base_model.fit(X=data_X, y=data_y) |
|
|
|
|
|
|
|
def _model_operation(self, operation: str, *args, **kwargs): |
|
|
|
model = self.classifier_list[0] |
|
|
|
model = self.base_model |
|
|
|
if hasattr(model, operation): |
|
|
|
method = getattr(model, operation) |
|
|
|
method(*args, **kwargs) |
|
|
@@ -143,11 +135,11 @@ class ABLModel: |
|
|
|
if not f"{operation}_path" in kwargs.keys(): |
|
|
|
raise ValueError(f"'{operation}_path' should not be None") |
|
|
|
if operation == "save": |
|
|
|
with open(kwargs["save_path"], 'wb') as file: |
|
|
|
with open(kwargs["save_path"], "wb") as file: |
|
|
|
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
elif operation == "load": |
|
|
|
with open(kwargs["load_path"], 'rb') as file: |
|
|
|
self.classifier_list[0] = pickle.load(file) |
|
|
|
with open(kwargs["load_path"], "rb") as file: |
|
|
|
self.base_model = pickle.load(file) |
|
|
|
except: |
|
|
|
raise NotImplementedError( |
|
|
|
f"{type(model).__name__} object doesn't have the {operation} method" |
|
|
|