Browse Source

[MNT] add and modify docstring in bridge folder

pull/1/head
Gao Enhao 1 year ago
parent
commit
c8f537a727
2 changed files with 186 additions and 7 deletions
  1. +28
    -6
      abl/bridge/base_bridge.py
  2. +158
    -1
      abl/bridge/simple_bridge.py

+ 28
- 6
abl/bridge/base_bridge.py View File

@@ -7,6 +7,27 @@ from ..structures import ListData


class BaseBridge(metaclass=ABCMeta):
"""
A base class for bridging machine learning and reasoning parts.

This class provides necessary methods that need to be overridden in subclasses
to construct a typical pipeline of Abductive learning (corresponding to ``train``),
which involves the following four methods:

- predict: Predict class indices on the given data samples.
- idx_to_pseudo_label: Map indices into pseudo labels.
- abduce_pseudo_label: Revise pseudo labels based on abdutive reasoning.
- pseudo_label_to_idx: Map revised pseudo labels back into indices.

Parameters
----------
model : ABLModel
The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training.
reasoner : Reasoner
The reasoning part wrapped in ``Reasoner``, which is used for pseudo label revision.
"""

def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel):
raise TypeError(
@@ -22,24 +43,25 @@ class BaseBridge(metaclass=ABCMeta):

@abstractmethod
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predicting labels from input."""
"""Placeholder for predicting class indices from input."""

@abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abducing pseudo labels."""
"""Placeholder for revising pseudo labels based on abdutive reasoning."""

@abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for mapping indexes to pseudo labels."""
"""Placeholder for mapping indices to pseudo labels."""

@abstractmethod
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for mapping pseudo labels to indexes."""
"""Placeholder for mapping pseudo labels to indices."""

def filter_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
'''Default filter function for pseudo label.'''
"""Default filter function for pseudo label."""
non_empty_idx = [
i for i in range(len(data_samples.abduced_pseudo_label))
i
for i in range(len(data_samples.abduced_pseudo_label))
if data_samples.abduced_pseudo_label[i]
]
data_samples.update(data_samples[non_empty_idx])


+ 158
- 1
abl/bridge/simple_bridge.py View File

@@ -12,6 +12,29 @@ from .base_bridge import BaseBridge


class SimpleBridge(BaseBridge):
"""
A basic implementation for bridging machine learning and reasoning parts.

This class implements the typical pipeline of Abductive learning, which involves
the following five steps:

- Predict class probabilities and indices for the given data samples.
- Map indices into pseudo labels.
- Revise pseudo labels based on abdutive reasoning.
- Map the revised pseudo labels to indices.
- Train the model.

Parameters
----------
model : ABLModel
The machine learning model wrapped in ``ABLModel``, which is mainly used for
prediction and model training.
reasoner : Reasoner
The reasoning part wrapped in ``Reasoner``, which is used for pseudo label revision.
metric_list : List[BaseMetric]
A list of metrics used for evaluating the model's performance.
"""

def __init__(
self,
model: ABLModel,
@@ -22,14 +45,54 @@ class SimpleBridge(BaseBridge):
self.metric_list = metric_list

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
"""
Predict class indices and probabilities (if ``predict_proba`` is implemented in
``self.model.base_model``) on the given data samples.

Parameters
----------
data_samples : ListData
Data samples on which predictions are to be made.

Returns
-------
Tuple[List[ndarray], List[ndarray]]
A tuple containing lists of predicted indices and probabilities.
"""
self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.pred_prob

def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""
Revise predicted pseudo labels of the given data samples using abduction.

Parameters
----------
data_samples : ListData
Data samples containing predicted pseudo labels.

Returns
-------
List[List[Any]]
A list of abduced pseudo labels for the given data samples.
"""
self.reasoner.batch_abduce(data_samples)
return data_samples.abduced_pseudo_label

def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""
Map indices of data samples into pseudo labels.

Parameters
----------
data_samples : ListData
Data samples containing the indices.

Returns
-------
List[List[Any]]
A list of pseudo labels converted from indices.
"""
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
@@ -37,6 +100,19 @@ class SimpleBridge(BaseBridge):
return data_samples.pred_pseudo_label

def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
"""
Map pseudo labels of data samples into indices.

Parameters
----------
data_samples : ListData
Data samples containing pseudo labels.

Returns
-------
List[List[Any]]
A list of indices converted from pseudo labels.
"""
abduced_idx = [
[self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
@@ -49,6 +125,21 @@ class SimpleBridge(BaseBridge):
prefix: str,
data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> ListData:
"""
Transform data in the form of (X, gt_pseudo_label, Y) into ListData.

Parameters
----------
prefix : str
A prefix indicating the type of data processing (e.g., 'train', 'test').
data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Data to be preprocessed. Can be ListData or a tuple of lists.

Returns
-------
ListData
The preprocessed ListData object.
"""
if isinstance(data, ListData):
data_samples = data
if not (
@@ -69,6 +160,21 @@ class SimpleBridge(BaseBridge):
def concat_data_samples(
self, unlabel_data_samples: ListData, label_data_samples: Optional[ListData]
) -> ListData:
"""
Concatenate unlabeled and labeled data samples. ``abduced_pseudo_label`` of unlabeled data samples and ``gt_pseudo_label`` of labeled data samples will be used to train the model.

Parameters
----------
unlabel_data_samples : ListData
Unlabeled data samples to concatenate.
label_data_samples : Optional[ListData]
Labeled data samples to concatenate, if available.

Returns
-------
ListData
Concatenated data samples.
"""
if label_data_samples is None:
return unlabel_data_samples

@@ -89,11 +195,38 @@ class SimpleBridge(BaseBridge):
Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
] = None,
loops: int = 50,
segment_size: Union[int, float] = -1,
segment_size: Union[int, float] = 1.0,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
):
"""
A typical training pipeline of Abuductive Learning.

Parameters
----------
train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Training data.
label_data : Optional[Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]]
Data with ``gt_pseudo_label`` that can be used to train the model, by
default None.
val_data : Optional[Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]]
Validation data, by default None.
loops : int
Machine Learning part and Reasoning part will be iteratively optimized
for ``loops`` times, by default 50.
segment_size : Union[int, float]
Data will be split into segments of this size and data in each segment
will be used together to train the model, by default 1.0.
eval_interval : int
The model will be evaluated every ``eval_interval`` loops during training,
by default 1.
save_interval : Optional[int]
The model will be saved every ``eval_interval`` loops during training, by
default None.
save_dir : Optional[str]
Directory to save the model, by default None.
"""
data_samples = self.data_preprocess("train", train_data)

if label_data is not None:
@@ -147,6 +280,14 @@ class SimpleBridge(BaseBridge):
)

def _valid(self, data_samples: ListData) -> None:
"""
Internal method for validating the model with given data samples.

Parameters
----------
data_samples : ListData
Data samples to be used for validation.
"""
self.predict(data_samples)
self.idx_to_pseudo_label(data_samples)

@@ -165,6 +306,14 @@ class SimpleBridge(BaseBridge):
self,
val_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""
Validate the model with the given validation data.

Parameters
----------
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Validation data to be used for model evaluation.
"""
val_data_samples = self.data_preprocess(val_data)
self._valid(val_data_samples)

@@ -172,5 +321,13 @@ class SimpleBridge(BaseBridge):
self,
test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
) -> None:
"""
Test the model with the given test data.

Parameters
----------
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Test data to be used for model evaluation.
"""
test_data_samples = self.data_preprocess("test", test_data)
self._valid(test_data_samples)

Loading…
Cancel
Save