You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abl_model.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :models.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/04/02
  9. # Description :
  10. #
  11. # ================================================================#
  12. import pickle
  13. from typing import Any, Dict
  14. from ..structures import ListData
  15. from ..utils import reform_idx
  16. class ABLModel:
  17. """
  18. Serialize data and provide a unified interface for different machine learning models.
  19. Parameters
  20. ----------
  21. base_model : Machine Learning Model
  22. The base model to use for training and prediction.
  23. Attributes
  24. ----------
  25. classifier_list : List[Any]
  26. A list of classifiers.
  27. Methods
  28. -------
  29. predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict
  30. Predict the labels and probabilities for the given data.
  31. valid(X: List[List[Any]], Y: List[Any]) -> float
  32. Calculate the accuracy score for the given data.
  33. train(X: List[List[Any]], Y: List[Any]) -> float
  34. Train the model on the given data.
  35. save(*args, **kwargs) -> None
  36. Save the model to a file.
  37. load(*args, **kwargs) -> None
  38. Load the model from a file.
  39. """
  40. def __init__(self, base_model: Any) -> None:
  41. if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")):
  42. raise NotImplementedError("The base_model should implement fit and predict methods.")
  43. self.base_model = base_model
  44. def predict(self, data_samples: ListData) -> Dict:
  45. """
  46. Predict the labels and probabilities for the given data.
  47. Parameters
  48. ----------
  49. X : List[List[Any]]
  50. The data to predict on.
  51. Returns
  52. -------
  53. dict
  54. A dictionary containing the predicted labels and probabilities.
  55. """
  56. model = self.base_model
  57. data_X = data_samples.flatten("X")
  58. if hasattr(model, "predict_proba"):
  59. prob = model.predict_proba(X=data_X)
  60. label = prob.argmax(axis=1)
  61. prob = reform_idx(prob, data_samples["X"])
  62. else:
  63. prob = [None] * len(data_samples)
  64. label = model.predict(X=data_X)
  65. label = reform_idx(label, data_samples["X"])
  66. return {"label": label, "prob": prob}
  67. def train(self, data_samples: ListData) -> float:
  68. """
  69. Train the model on the given data.
  70. Parameters
  71. ----------
  72. X : List[List[Any]]
  73. The data to train on.
  74. Y : List[Any]
  75. The true labels for the given data.
  76. Returns
  77. -------
  78. float
  79. The loss value of the trained model.
  80. """
  81. data_X = data_samples.flatten("X")
  82. data_y = data_samples.flatten("abduced_idx")
  83. return self.base_model.fit(X=data_X, y=data_y)
  84. def _model_operation(self, operation: str, *args, **kwargs):
  85. model = self.base_model
  86. if hasattr(model, operation):
  87. method = getattr(model, operation)
  88. method(*args, **kwargs)
  89. else:
  90. if not f"{operation}_path" in kwargs.keys():
  91. raise ValueError(f"'{operation}_path' should not be None")
  92. else:
  93. try:
  94. if operation == "save":
  95. with open(kwargs["save_path"], "wb") as file:
  96. pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
  97. elif operation == "load":
  98. with open(kwargs["load_path"], "rb") as file:
  99. self.base_model = pickle.load(file)
  100. except:
  101. raise NotImplementedError(
  102. f"{type(model).__name__} object doesn't have the {operation} method and the default pickle-based {operation} method failed."
  103. )
  104. def save(self, *args, **kwargs) -> None:
  105. """
  106. Save the model to a file.
  107. """
  108. self._model_operation("save", *args, **kwargs)
  109. def load(self, *args, **kwargs) -> None:
  110. """
  111. Load the model from a file.
  112. """
  113. self._model_operation("load", *args, **kwargs)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.