|
- import numpy as np
- from typing import List, Any
- from ablkit.data import ListData
- from ablkit.bridge import SimpleBridge
-
-
- class BDDBridge(SimpleBridge):
- def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
- pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ]
- pred_pseudo_label = []
- for sub_list in pred_idx:
- sub_list = sub_list.squeeze() # 1 x nc -> nc
- pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list])
- data_examples.pred_pseudo_label = pred_pseudo_label
- return data_examples.pred_pseudo_label
-
- def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
- abduced_pseudo_label = data_examples.abduced_pseudo_label
- abduced_idx = []
- for sub_list in abduced_pseudo_label:
- sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list])
- abduced_idx.append(sub_list)
- data_examples.abduced_idx = abduced_idx
- return data_examples.abduced_idx
|