You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abl_model.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """
  2. This module contains the class ABLModel, which provides a unified interface for different
  3. machine learning models.
  4. Copyright (c) 2024 LAMDA. All rights reserved.
  5. """
  6. import pickle
  7. from typing import Any, Dict
  8. from ..data.structures import ListData
  9. from ..utils import reform_list
  10. class ABLModel:
  11. """
  12. Serialize data and provide a unified interface for different machine learning models.
  13. Parameters
  14. ----------
  15. base_model : Machine Learning Model
  16. The machine learning base model used for training and prediction. This model should
  17. implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for the
  18. model to also implement the ``predict_proba`` method for generating
  19. predictions on the probabilities.
  20. """
  21. def __init__(self, base_model: Any) -> None:
  22. if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")):
  23. raise NotImplementedError("The base_model should implement fit and predict methods.")
  24. self.base_model = base_model
  25. def predict(self, data_examples: ListData) -> Dict:
  26. """
  27. Predict the labels and probabilities for the given data.
  28. Parameters
  29. ----------
  30. data_examples : ListData
  31. A batch of data to predict on.
  32. Returns
  33. -------
  34. dict
  35. A dictionary containing the predicted labels and probabilities.
  36. """
  37. model = self.base_model
  38. data_X = data_examples.flatten("X")
  39. if hasattr(model, "predict_proba"):
  40. prob = model.predict_proba(X=data_X)
  41. label = prob.argmax(axis=1)
  42. prob = reform_list(prob, data_examples.X)
  43. else:
  44. prob = None
  45. label = model.predict(X=data_X)
  46. label = reform_list(label, data_examples.X)
  47. data_examples.pred_idx = label
  48. data_examples.pred_prob = prob
  49. return {"label": label, "prob": prob}
  50. def train(self, data_examples: ListData) -> float:
  51. """
  52. Train the model on the given data.
  53. Parameters
  54. ----------
  55. data_examples : ListData
  56. A batch of data to train on, which typically contains the data, ``X``, and the
  57. corresponding labels, ``abduced_idx``.
  58. Returns
  59. -------
  60. float
  61. The loss value of the trained model.
  62. """
  63. data_X = data_examples.flatten("X")
  64. data_y = data_examples.flatten("abduced_idx")
  65. return self.base_model.fit(X=data_X, y=data_y)
  66. def valid(self, data_examples: ListData) -> float:
  67. """
  68. Validate the model on the given data.
  69. Parameters
  70. ----------
  71. data_examples : ListData
  72. A batch of data to train on, which typically contains the data, ``X``,
  73. and the corresponding labels, ``abduced_idx``.
  74. Returns
  75. -------
  76. float
  77. The accuracy of the trained model.
  78. """
  79. data_X = data_examples.flatten("X")
  80. data_y = data_examples.flatten("abduced_idx")
  81. score = self.base_model.score(X=data_X, y=data_y)
  82. return score
  83. def _model_operation(self, operation: str, *args, **kwargs):
  84. model = self.base_model
  85. if hasattr(model, operation):
  86. method = getattr(model, operation)
  87. method(*args, **kwargs)
  88. else:
  89. if f"{operation}_path" not in kwargs:
  90. raise ValueError(f"'{operation}_path' should not be None")
  91. try:
  92. if operation == "save":
  93. with open(kwargs["save_path"], "wb") as file:
  94. pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)
  95. elif operation == "load":
  96. with open(kwargs["load_path"], "rb") as file:
  97. self.base_model = pickle.load(file)
  98. except (OSError, pickle.PickleError) as exc:
  99. raise NotImplementedError(
  100. f"{type(model).__name__} object doesn't have the {operation} method \
  101. and the default pickle-based {operation} method failed."
  102. ) from exc
  103. def save(self, *args, **kwargs) -> None:
  104. """
  105. Save the model to a file.
  106. This method delegates to the ``save`` method of self.base_model. The arguments passed to
  107. this method should match those expected by the ``save`` method of self.base_model.
  108. """
  109. self._model_operation("save", *args, **kwargs)
  110. def load(self, *args, **kwargs) -> None:
  111. """
  112. Load the model from a file.
  113. This method delegates to the ``load`` method of self.base_model. The arguments passed to
  114. this method should match those expected by the ``load`` method of self.base_model.
  115. """
  116. self._model_operation("load", *args, **kwargs)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.