You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bridge.py 1.1 kB

3 months ago
3 months ago
123456789101112131415161718192021222324
  1. import numpy as np
  2. from typing import List, Any
  3. from ablkit.data import ListData
  4. from ablkit.bridge import SimpleBridge
  5. class BDDBridge(SimpleBridge):
  6. def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  7. pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ]
  8. pred_pseudo_label = []
  9. for sub_list in pred_idx:
  10. sub_list = sub_list.squeeze() # 1 x nc -> nc
  11. pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list])
  12. data_examples.pred_pseudo_label = pred_pseudo_label
  13. return data_examples.pred_pseudo_label
  14. def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
  15. abduced_pseudo_label = data_examples.abduced_pseudo_label
  16. abduced_idx = []
  17. for sub_list in abduced_pseudo_label:
  18. sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list])
  19. abduced_idx.append(sub_list)
  20. data_examples.abduced_idx = abduced_idx
  21. return data_examples.abduced_idx

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.