@@ -12,6 +12,29 @@ from .base_bridge import BaseBridge
class SimpleBridge(BaseBridge):
"""
A basic implementation for bridging machine learning and reasoning parts.
This class implements the typical pipeline of Abductive learning, which involves
the following five steps:
- Predict class probabilities and indices for the given data samples.
- Map indices into pseudo labels.
- Revise pseudo labels based on abdutive reasoning.
- Map the revised pseudo labels to indices.
- Train the model.
Parameters
----------
model : ABLModel
The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training.
reasoner : Reasoner
The reasoning part wrapped in ``Reasoner``, which is used for pseudo label revision.
metric_list : List[BaseMetric]
A list of metrics used for evaluating the model's performance.
"""
def __init__(
self,
model: ABLModel,
@@ -22,14 +45,54 @@ class SimpleBridge(BaseBridge):
self.metric_list = metric_list
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
"""
Predict class indices and probabilities (if ``predict_proba`` is implemented in
``self.model.base_model``) on the given data samples.
Parameters
----------
data_samples : ListData
Data samples on which predictions are to be made.
Returns
-------
Tuple[List[ndarray], List[ndarray]]
A tuple containing lists of predicted indices and probabilities.
"""
self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.pred_prob
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""
Revise predicted pseudo labels of the given data samples using abduction.
Parameters
----------
data_samples : ListData
Data samples containing predicted pseudo labels.
Returns
-------
List[List[Any]]
A list of abduced pseudo labels for the given data samples.
"""
self.reasoner.batch_abduce(data_samples)
return data_samples.abduced_pseudo_label
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""
Map indices of data samples into pseudo labels.
Parameters
----------
data_samples : ListData
Data samples containing the indices.
Returns
-------
List[List[Any]]
A list of pseudo labels converted from indices.
"""
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
@@ -37,6 +100,19 @@ class SimpleBridge(BaseBridge):
return data_samples.pred_pseudo_label
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""
Map pseudo labels of data samples into indices.
Parameters
----------
data_samples : ListData
Data samples containing pseudo labels.
Returns
-------
List[List[Any]]
A list of indices converted from pseudo labels.
"""
abduced_idx = [
[self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
@@ -49,6 +125,21 @@ class SimpleBridge(BaseBridge):
prefix: str,
data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> ListData:
"""
Transform data in the form of (X, gt_pseudo_label, Y) into ListData.
Parameters
----------
prefix : str
A prefix indicating the type of data processing (e.g., 'train', 'test').
data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Data to be preprocessed. Can be ListData or a tuple of lists.
Returns
-------
ListData
The preprocessed ListData object.
"""
if isinstance(data, ListData):
data_samples = data
if not (
@@ -69,6 +160,21 @@ class SimpleBridge(BaseBridge):
def concat_data_samples(
self, unlabel_data_samples: ListData, label_data_samples: Optional[ListData]
) -> ListData:
"""
Concatenate unlabeled and labeled data samples. ``abduced_pseudo_label`` of unlabeled data samples and ``gt_pseudo_label`` of labeled data samples will be used to train the model.
Parameters
----------
unlabel_data_samples : ListData
Unlabeled data samples to concatenate.
label_data_samples : Optional[ListData]
Labeled data samples to concatenate, if available.
Returns
-------
ListData
Concatenated data samples.
"""
if label_data_samples is None:
return unlabel_data_samples
@@ -89,11 +195,38 @@ class SimpleBridge(BaseBridge):
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
] = None,
loops: int = 50,
segment_size: Union[int, float] = - 1,
segment_size: Union[int, float] = 1.0 ,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
):
"""
A typical training pipeline of Abuductive Learning.
Parameters
----------
train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Training data.
label_data : Optional[Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]]
Data with ``gt_pseudo_label`` that can be used to train the model, by
default None.
val_data : Optional[Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]]
Validation data, by default None.
loops : int
Machine Learning part and Reasoning part will be iteratively optimized
for ``loops`` times, by default 50.
segment_size : Union[int, float]
Data will be split into segments of this size and data in each segment
will be used together to train the model, by default 1.0.
eval_interval : int
The model will be evaluated every ``eval_interval`` loops during training,
by default 1.
save_interval : Optional[int]
The model will be saved every ``eval_interval`` loops during training, by
default None.
save_dir : Optional[str]
Directory to save the model, by default None.
"""
data_samples = self.data_preprocess("train", train_data)
if label_data is not None:
@@ -147,6 +280,14 @@ class SimpleBridge(BaseBridge):
)
def _valid(self, data_samples: ListData) -> None:
"""
Internal method for validating the model with given data samples.
Parameters
----------
data_samples : ListData
Data samples to be used for validation.
"""
self.predict(data_samples)
self.idx_to_pseudo_label(data_samples)
@@ -165,6 +306,14 @@ class SimpleBridge(BaseBridge):
self,
val_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""
Validate the model with the given validation data.
Parameters
----------
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Validation data to be used for model evaluation.
"""
val_data_samples = self.data_preprocess(val_data)
self._valid(val_data_samples)
@@ -172,5 +321,13 @@ class SimpleBridge(BaseBridge):
self,
test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""
Test the model with the given test data.
Parameters
----------
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Test data to be used for model evaluation.
"""
test_data_samples = self.data_preprocess("test", test_data)
self._valid(test_data_samples)