Browse Source

[MNT] resolve comments in abl_model.py

ab_data
Gao Enhao 1 year ago
parent
commit
b18e42b0e0
1 changed files with 16 additions and 42 deletions
  1. +16
    -42
      abl/learning/abl_model.py

+ 16
- 42
abl/learning/abl_model.py View File

@@ -45,14 +45,8 @@ class ABLModel:
""" """


def __init__(self, base_model: Any) -> None: def __init__(self, base_model: Any) -> None:
if not (
hasattr(base_model, "fit")
and hasattr(base_model, "predict")
and hasattr(base_model, "score")
):
raise NotImplementedError(
"base_model should have fit, predict and score methods."
)
if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")):
raise NotImplementedError("The base_model should implement fit and predict methods.")


self.base_model = base_model self.base_model = base_model


@@ -84,27 +78,6 @@ class ABLModel:


return {"label": label, "prob": prob} return {"label": label, "prob": prob}


def valid(self, data_samples: ListData) -> float:
"""
Calculate the accuracy for the given data.

Parameters
----------
X : List[List[Any]]
The data to calculate the accuracy on.
Y : List[Any]
The true labels for the given data.

Returns
-------
float
The accuracy score for the given data.
"""
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("gt_idx")
score = self.base_model.score(X=data_X, y=data_y)
return score

def train(self, data_samples: ListData) -> float: def train(self, data_samples: ListData) -> float:
""" """
Train the model on the given data. Train the model on the given data.
@@ -131,19 +104,20 @@ class ABLModel:
method = getattr(model, operation) method = getattr(model, operation)
method(*args, **kwargs) method(*args, **kwargs)
else: else:
try:
if not f"{operation}_path" in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method"
)
if not f"{operation}_path" in kwargs.keys():
raise ValueError(f"'{operation}_path' should not be None")
else:
try:
if operation == "save":
with open(kwargs["save_path"], "wb") as file:
pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
elif operation == "load":
with open(kwargs["load_path"], "rb") as file:
self.base_model = pickle.load(file)
except:
raise NotImplementedError(
f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed."
)


def save(self, *args, **kwargs) -> None: def save(self, *args, **kwargs) -> None:
""" """


Loading…
Cancel
Save