|
@@ -45,14 +45,8 @@ class ABLModel: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, base_model: Any) -> None: |
|
|
def __init__(self, base_model: Any) -> None: |
|
|
if not ( |
|
|
|
|
|
hasattr(base_model, "fit") |
|
|
|
|
|
and hasattr(base_model, "predict") |
|
|
|
|
|
and hasattr(base_model, "score") |
|
|
|
|
|
): |
|
|
|
|
|
raise NotImplementedError( |
|
|
|
|
|
"base_model should have fit, predict and score methods." |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")): |
|
|
|
|
|
raise NotImplementedError("The base_model should implement fit and predict methods.") |
|
|
|
|
|
|
|
|
self.base_model = base_model |
|
|
self.base_model = base_model |
|
|
|
|
|
|
|
@@ -84,27 +78,6 @@ class ABLModel: |
|
|
|
|
|
|
|
|
return {"label": label, "prob": prob} |
|
|
return {"label": label, "prob": prob} |
|
|
|
|
|
|
|
|
def valid(self, data_samples: ListData) -> float: |
|
|
|
|
|
""" |
|
|
|
|
|
Calculate the accuracy for the given data. |
|
|
|
|
|
|
|
|
|
|
|
Parameters |
|
|
|
|
|
---------- |
|
|
|
|
|
X : List[List[Any]] |
|
|
|
|
|
The data to calculate the accuracy on. |
|
|
|
|
|
Y : List[Any] |
|
|
|
|
|
The true labels for the given data. |
|
|
|
|
|
|
|
|
|
|
|
Returns |
|
|
|
|
|
------- |
|
|
|
|
|
float |
|
|
|
|
|
The accuracy score for the given data. |
|
|
|
|
|
""" |
|
|
|
|
|
data_X = data_samples.flatten("X") |
|
|
|
|
|
data_y = data_samples.flatten("gt_idx") |
|
|
|
|
|
score = self.base_model.score(X=data_X, y=data_y) |
|
|
|
|
|
return score |
|
|
|
|
|
|
|
|
|
|
|
def train(self, data_samples: ListData) -> float: |
|
|
def train(self, data_samples: ListData) -> float: |
|
|
""" |
|
|
""" |
|
|
Train the model on the given data. |
|
|
Train the model on the given data. |
|
@@ -131,19 +104,20 @@ class ABLModel: |
|
|
method = getattr(model, operation) |
|
|
method = getattr(model, operation) |
|
|
method(*args, **kwargs) |
|
|
method(*args, **kwargs) |
|
|
else: |
|
|
else: |
|
|
try: |
|
|
|
|
|
if not f"{operation}_path" in kwargs.keys(): |
|
|
|
|
|
raise ValueError(f"'{operation}_path' should not be None") |
|
|
|
|
|
if operation == "save": |
|
|
|
|
|
with open(kwargs["save_path"], "wb") as file: |
|
|
|
|
|
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
elif operation == "load": |
|
|
|
|
|
with open(kwargs["load_path"], "rb") as file: |
|
|
|
|
|
self.base_model = pickle.load(file) |
|
|
|
|
|
except: |
|
|
|
|
|
raise NotImplementedError( |
|
|
|
|
|
f"{type(model).__name__} object doesn't have the {operation} method" |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
if not f"{operation}_path" in kwargs.keys(): |
|
|
|
|
|
raise ValueError(f"'{operation}_path' should not be None") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
|
|
|
if operation == "save": |
|
|
|
|
|
with open(kwargs["save_path"], "wb") as file: |
|
|
|
|
|
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
elif operation == "load": |
|
|
|
|
|
with open(kwargs["load_path"], "rb") as file: |
|
|
|
|
|
self.base_model = pickle.load(file) |
|
|
|
|
|
except: |
|
|
|
|
|
raise NotImplementedError( |
|
|
|
|
|
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed." |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def save(self, *args, **kwargs) -> None: |
|
|
def save(self, *args, **kwargs) -> None: |
|
|
""" |
|
|
""" |
|
|