You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bdd_model.py 780 B

3 months ago
3 months ago
12345678910111213141516171819202122232425
  1. from typing import Dict
  2. import numpy as np
  3. from ablkit.data import ListData
  4. from ablkit.learning import ABLModel
  5. from ablkit.utils import reform_list
  6. class BDDABLModel(ABLModel):
  7. def predict(self, data_examples: ListData) -> Dict:
  8. model = self.base_model
  9. data_X = data_examples.flatten("X")
  10. if hasattr(model, "predict_proba"):
  11. prob = model.predict_proba(X=data_X)
  12. label = np.where(prob > 0.5, 1, 0).astype(int)
  13. prob = reform_list(prob, data_examples.X)
  14. else:
  15. prob = None
  16. label = model.predict(X=data_X)
  17. label = reform_list(label, data_examples.X)
  18. data_examples.pred_idx = label
  19. data_examples.pred_prob = prob
  20. return {"label": label, "prob": prob}

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.