You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abl_model.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :models.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/04/02
  9. # Description :
  10. #
  11. # ================================================================#
  12. from itertools import chain
  13. from typing import List, Any, Optional
  14. def get_part_data(X, i):
  15. return list(map(lambda x: x[i], X))
  16. class ABLModel:
  17. """
  18. Serialize data and provide a unified interface for different machine learning models.
  19. Parameters
  20. ----------
  21. base_model : Machine Learning Model
  22. The base model to use for training and prediction.
  23. Attributes
  24. ----------
  25. classifier_list : List[Any]
  26. A list of classifiers.
  27. Methods
  28. -------
  29. predict(X: List[List[Any]], mapping: Optional[dict]) -> dict
  30. Predict the labels and probabilities for the given data.
  31. valid(X: List[List[Any]], Y: List[Any]) -> float
  32. Calculate the accuracy score for the given data.
  33. train(X: List[List[Any]], Y: List[Any])
  34. Train the model on the given data.
  35. """
  36. def __init__(self, base_model) -> None:
  37. self.classifier_list = []
  38. self.classifier_list.append(base_model)
  39. def predict(self, X: List[List[Any]], mapping: Optional[dict] = None) -> dict:
  40. """
  41. Predict the labels and probabilities for the given data.
  42. Parameters
  43. ----------
  44. X : List[List[Any]]
  45. The data to predict on.
  46. Returns
  47. -------
  48. dict
  49. A dictionary containing the predicted labels and probabilities.
  50. """
  51. data_X, marks = self.merge_data(X)
  52. prob = self.classifier_list[0].predict_proba(X=data_X)
  53. label = prob.argmax(axis=1)
  54. if mapping is not None:
  55. label = [mapping[x] for x in label]
  56. prob = self.reshape_data(prob, marks)
  57. label = self.reshape_data(label, marks)
  58. return {"label": label, "prob": prob}
  59. def valid(self, X: List[List[Any]], Y: List[Any]) -> float:
  60. """
  61. Calculate the accuracy for the given data.
  62. Parameters
  63. ----------
  64. X : List[List[Any]]
  65. The data to calculate the accuracy on.
  66. Y : List[Any]
  67. The true labels for the given data.
  68. Returns
  69. -------
  70. float
  71. The accuracy score for the given data.
  72. """
  73. data_X, _ = self.merge_data(X)
  74. data_Y, _ = self.merge_data(Y)
  75. score = self.classifier_list[0].score(X=data_X, y=data_Y)
  76. return score
  77. def train(self, X: List[List[Any]], Y: List[Any]):
  78. """
  79. Train the model on the given data.
  80. Parameters
  81. ----------
  82. X : List[List[Any]]
  83. The data to train on.
  84. Y : List[Any]
  85. The true labels for the given data.
  86. """
  87. data_X, _ = self.merge_data(X)
  88. data_Y, _ = self.merge_data(Y)
  89. self.classifier_list[0].fit(X=data_X, y=data_Y)
  90. @staticmethod
  91. def merge_data(X):
  92. ret_mark = list(map(lambda x: len(x), X))
  93. ret_X = list(chain(*X))
  94. return ret_X, ret_mark
  95. @staticmethod
  96. def reshape_data(Y, marks):
  97. begin_mark = 0
  98. ret_Y = []
  99. for mark in marks:
  100. end_mark = begin_mark + mark
  101. ret_Y.append(list(Y[begin_mark:end_mark]))
  102. begin_mark = end_mark
  103. return ret_Y

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.