Browse Source

[FIX] change sample to example

pull/1/head
troyyyyy 1 year ago
parent
commit
a21d1ddc79
31 changed files with 1093 additions and 506 deletions
  1. +10
    -10
      abl/bridge/base_bridge.py
  2. +78
    -78
      abl/bridge/simple_bridge.py
  3. +5
    -5
      abl/evaluation/base_metric.py
  4. +11
    -11
      abl/evaluation/reasoning_metric.py
  5. +7
    -7
      abl/evaluation/symbol_metric.py
  6. +15
    -15
      abl/learning/abl_model.py
  7. +41
    -41
      abl/reasoning/kb.py
  8. +44
    -44
      abl/reasoning/reasoner.py
  9. +13
    -13
      abl/structures/base_data_element.py
  10. +14
    -14
      docs/Examples/MNISTAdd.rst
  11. +4
    -4
      docs/Intro/Basics.rst
  12. +17
    -17
      docs/Intro/Bridge.rst
  13. +3
    -3
      docs/Intro/Datasets.rst
  14. +3
    -3
      docs/Intro/Evaluation.rst
  15. +2
    -2
      docs/Intro/Learning.rst
  16. +2
    -2
      docs/Intro/Quick-Start.rst
  17. +38
    -38
      docs/Intro/Reasoning.rst
  18. +173
    -0
      examples/hed/datasets/equation_generator.py
  19. +5
    -5
      examples/hed/datasets/learn_add.pl
  20. +43
    -43
      examples/hed/hed_bridge.py
  21. +202
    -31
      examples/hed/hed_example.ipynb
  22. +4
    -4
      examples/hed/utils.py
  23. +12
    -11
      examples/hwf/datasets/get_dataset.py
  24. +16
    -19
      examples/hwf/hwf.ipynb
  25. +3
    -6
      examples/hwf/main.py
  26. +3
    -3
      examples/mnist_add/datasets/get_dataset.py
  27. +81
    -22
      examples/mnist_add/mnist_add.ipynb
  28. +189
    -0
      examples/zoo/main.py
  29. +4
    -4
      examples/zoo/zoo_example.ipynb
  30. +20
    -20
      tests/conftest.py
  31. +31
    -31
      tests/test_reasoning.py

+ 10
- 10
abl/bridge/base_bridge.py View File

@@ -14,7 +14,7 @@ class BaseBridge(metaclass=ABCMeta):
to construct a typical pipeline of Abductive Learning (corresponding to ``train``),
which involves the following four methods:

- predict: Predict class indices on the given data samples.
- predict: Predict class indices on the given data examples.
- idx_to_pseudo_label: Map indices into pseudo labels.
- abduce_pseudo_label: Revise pseudo labels based on abdutive reasoning.
- pseudo_label_to_idx: Map revised pseudo labels back into indices.
@@ -42,30 +42,30 @@ class BaseBridge(metaclass=ABCMeta):
self.reasoner = reasoner

@abstractmethod
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
def predict(self, data_examples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
"""Placeholder for predicting class indices from input."""

@abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""Placeholder for revising pseudo labels based on abdutive reasoning."""

@abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""Placeholder for mapping indices to pseudo labels."""

@abstractmethod
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
"""Placeholder for mapping pseudo labels to indices."""

def filter_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
def filter_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""Default filter function for pseudo label."""
non_empty_idx = [
i
for i in range(len(data_samples.abduced_pseudo_label))
if data_samples.abduced_pseudo_label[i]
for i in range(len(data_examples.abduced_pseudo_label))
if data_examples.abduced_pseudo_label[i]
]
data_samples.update(data_samples[non_empty_idx])
return data_samples
data_examples.update(data_examples[non_empty_idx])
return data_examples

@abstractmethod
def train(


+ 78
- 78
abl/bridge/simple_bridge.py View File

@@ -18,7 +18,7 @@ class SimpleBridge(BaseBridge):
This class implements the typical pipeline of Abductive Learning, which involves
the following five steps:

- Predict class probabilities and indices for the given data samples.
- Predict class probabilities and indices for the given data examples.
- Map indices into pseudo labels.
- Revise pseudo labels based on abdutive reasoning.
- Map the revised pseudo labels to indices.
@@ -44,69 +44,69 @@ class SimpleBridge(BaseBridge):
super().__init__(model, reasoner)
self.metric_list = metric_list

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
"""
Predict class indices and probabilities (if ``predict_proba`` is implemented in
``self.model.base_model``) on the given data samples.
``self.model.base_model``) on the given data examples.

Parameters
----------
data_samples : ListData
Data samples on which predictions are to be made.
data_examples : ListData
Data examples on which predictions are to be made.

Returns
-------
Tuple[List[ndarray], List[ndarray]]
A tuple containing lists of predicted indices and probabilities.
"""
self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.pred_prob
self.model.predict(data_examples)
return data_examples.pred_idx, data_examples.pred_prob

def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""
Revise predicted pseudo labels of the given data samples using abduction.
Revise predicted pseudo labels of the given data examples using abduction.

Parameters
----------
data_samples : ListData
Data samples containing predicted pseudo labels.
data_examples : ListData
Data examples containing predicted pseudo labels.

Returns
-------
List[List[Any]]
A list of abduced pseudo labels for the given data samples.
A list of abduced pseudo labels for the given data examples.
"""
self.reasoner.batch_abduce(data_samples)
return data_samples.abduced_pseudo_label
self.reasoner.batch_abduce(data_examples)
return data_examples.abduced_pseudo_label

def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""
Map indices of data samples into pseudo labels.
Map indices of data examples into pseudo labels.

Parameters
----------
data_samples : ListData
Data samples containing the indices.
data_examples : ListData
Data examples containing the indices.

Returns
-------
List[List[Any]]
A list of pseudo labels converted from indices.
"""
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
pred_idx = data_examples.pred_idx
data_examples.pred_pseudo_label = [
[self.reasoner.idx_to_label[_idx] for _idx in sub_list] for sub_list in pred_idx
]
return data_samples.pred_pseudo_label
return data_examples.pred_pseudo_label

def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
"""
Map pseudo labels of data samples into indices.
Map pseudo labels of data examples into indices.

Parameters
----------
data_samples : ListData
Data samples containing pseudo labels.
data_examples : ListData
Data examples containing pseudo labels.

Returns
-------
@@ -115,10 +115,10 @@ class SimpleBridge(BaseBridge):
"""
abduced_idx = [
[self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
for sub_list in data_examples.abduced_pseudo_label
]
data_samples.abduced_idx = abduced_idx
return data_samples.abduced_idx
data_examples.abduced_idx = abduced_idx
return data_examples.abduced_idx

def data_preprocess(
self,
@@ -141,49 +141,49 @@ class SimpleBridge(BaseBridge):
The preprocessed ListData object.
"""
if isinstance(data, ListData):
data_samples = data
data_examples = data
if not (
hasattr(data_samples, "X")
and hasattr(data_samples, "gt_pseudo_label")
and hasattr(data_samples, "Y")
hasattr(data_examples, "X")
and hasattr(data_examples, "gt_pseudo_label")
and hasattr(data_examples, "Y")
):
raise ValueError(
f"{prefix}data should have X, gt_pseudo_label and Y attribute but "
f"only {data_samples.all_keys()} are provided."
f"only {data_examples.all_keys()} are provided."
)
else:
X, gt_pseudo_label, Y = data
data_samples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y)
data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y)

return data_samples
return data_examples

def concat_data_samples(
self, unlabel_data_samples: ListData, label_data_samples: Optional[ListData]
def concat_data_examples(
self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData]
) -> ListData:
"""
Concatenate unlabeled and labeled data samples. ``abduced_pseudo_label`` of unlabeled data samples and ``gt_pseudo_label`` of labeled data samples will be used to train the model.
Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model.

Parameters
----------
unlabel_data_samples : ListData
Unlabeled data samples to concatenate.
label_data_samples : Optional[ListData]
Labeled data samples to concatenate, if available.
unlabel_data_examples : ListData
Unlabeled data examples to concatenate.
label_data_examples : Optional[ListData]
Labeled data examples to concatenate, if available.

Returns
-------
ListData
Concatenated data samples.
Concatenated data examples.
"""
if label_data_samples is None:
return unlabel_data_samples
if label_data_examples is None:
return unlabel_data_examples

unlabel_data_samples.X = unlabel_data_samples.X + label_data_samples.X
unlabel_data_samples.abduced_pseudo_label = (
unlabel_data_samples.abduced_pseudo_label + label_data_samples.gt_pseudo_label
unlabel_data_examples.X = unlabel_data_examples.X + label_data_examples.X
unlabel_data_examples.abduced_pseudo_label = (
unlabel_data_examples.abduced_pseudo_label + label_data_examples.gt_pseudo_label
)
unlabel_data_samples.Y = unlabel_data_samples.Y + label_data_samples.Y
return unlabel_data_samples
unlabel_data_examples.Y = unlabel_data_examples.Y + label_data_examples.Y
return unlabel_data_examples

def train(
self,
@@ -227,51 +227,51 @@ class SimpleBridge(BaseBridge):
save_dir : Optional[str]
Directory to save the model, by default None.
"""
data_samples = self.data_preprocess("train", train_data)
data_examples = self.data_preprocess("train", train_data)

if label_data is not None:
label_data_samples = self.data_preprocess("label", label_data)
label_data_examples = self.data_preprocess("label", label_data)
else:
label_data_samples = None
label_data_examples = None

if val_data is not None:
val_data_samples = self.data_preprocess("val", val_data)
val_data_examples = self.data_preprocess("val", val_data)
else:
val_data_samples = data_samples
val_data_examples = data_examples

if isinstance(segment_size, int):
if segment_size <= 0:
raise ValueError("segment_size should be positive.")
elif isinstance(segment_size, float):
if 0 < segment_size <= 1:
segment_size = int(segment_size * len(data_samples))
segment_size = int(segment_size * len(data_examples))
else:
raise ValueError("segment_size should be in (0, 1].")
else:
raise ValueError("segment_size should be int or float.")

for loop in range(loops):
for seg_idx in range((len(data_samples) - 1) // segment_size + 1):
for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
print_log(
f"loop(train) [{loop + 1}/{loops}] segment(train) "
f"[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] ",
f"[{(seg_idx + 1)}/{(len(data_examples) - 1) // segment_size + 1}] ",
logger="current",
)

sub_data_samples = data_samples[
sub_data_examples = data_examples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
self.predict(sub_data_samples)
self.idx_to_pseudo_label(sub_data_samples)
self.abduce_pseudo_label(sub_data_samples)
self.filter_pseudo_label(sub_data_samples)
self.concat_data_samples(sub_data_samples, label_data_samples)
self.pseudo_label_to_idx(sub_data_samples)
self.model.train(sub_data_samples)
self.predict(sub_data_examples)
self.idx_to_pseudo_label(sub_data_examples)
self.abduce_pseudo_label(sub_data_examples)
self.filter_pseudo_label(sub_data_examples)
self.concat_data_examples(sub_data_examples, label_data_examples)
self.pseudo_label_to_idx(sub_data_examples)
self.model.train(sub_data_examples)

if (loop + 1) % eval_interval == 0 or loop == loops - 1:
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current")
self._valid(val_data_samples)
self._valid(val_data_examples)

if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
@@ -279,20 +279,20 @@ class SimpleBridge(BaseBridge):
save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
)

def _valid(self, data_samples: ListData) -> None:
def _valid(self, data_examples: ListData) -> None:
"""
Internal method for validating the model with given data samples.
Internal method for validating the model with given data examples.

Parameters
----------
data_samples : ListData
Data samples to be used for validation.
data_examples : ListData
Data examples to be used for validation.
"""
self.predict(data_samples)
self.idx_to_pseudo_label(data_samples)
self.predict(data_examples)
self.idx_to_pseudo_label(data_examples)

for metric in self.metric_list:
metric.process(data_samples)
metric.process(data_examples)

res = dict()
for metric in self.metric_list:
@@ -314,8 +314,8 @@ class SimpleBridge(BaseBridge):
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Validation data to be used for model evaluation.
"""
val_data_samples = self.data_preprocess(val_data)
self._valid(val_data_samples)
val_data_examples = self.data_preprocess(val_data)
self._valid(val_data_examples)

def test(
self,
@@ -329,5 +329,5 @@ class SimpleBridge(BaseBridge):
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Test data to be used for model evaluation.
"""
test_data_samples = self.data_preprocess("test", test_data)
self._valid(test_data_samples)
test_data_examples = self.data_preprocess("test", test_data)
self._valid(test_data_examples)

+ 5
- 5
abl/evaluation/base_metric.py View File

@@ -10,7 +10,7 @@ class BaseMetric(metaclass=ABCMeta):
"""
Base class for a metrics.

The metrics first processes each batch of data_samples and appends the processed
The metrics first processes each batch of data_examples and appends the processed
results to the results list. Then, it computes the metrics of the entire dataset.

Parameters
@@ -30,16 +30,16 @@ class BaseMetric(metaclass=ABCMeta):
self.prefix = prefix or self.default_prefix

@abstractmethod
def process(self, data_samples: ListData) -> None:
def process(self, data_examples: ListData) -> None:
"""
Process one batch of data samples. The processed results should be stored
Process one batch of data examples. The processed results should be stored
in ``self.results``, which will be used to compute the metrics when all
batches have been processed.

Parameters
----------
data_samples : ListData
A batch of data samples.
data_examples : ListData
A batch of data examples.
"""

@abstractmethod


+ 11
- 11
abl/evaluation/reasoning_metric.py View File

@@ -25,7 +25,7 @@ class ReasoningMetric(BaseMetric):

Notes
-----
The `ReasoningMetric` expects data_samples to have the attributes `pred_pseudo_label`,
The `ReasoningMetric` expects data_examples to have the attributes `pred_pseudo_label`,
`Y`, and `X`, corresponding to the predicted pseduo labels, ground truth of reasoning
results, and input data, respectively.
"""
@@ -34,24 +34,24 @@ class ReasoningMetric(BaseMetric):
super().__init__(prefix)
self.kb = kb

def process(self, data_samples: ListData) -> None:
def process(self, data_examples: ListData) -> None:
"""
Process a batch of data samples.
Process a batch of data examples.

This method takes in a batch of data samples, each containing predicted pseudo labels(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It
evaluates the reasoning accuracy of each sample by comparing the logical reasoning
This method takes in a batch of data examples, each containing predicted pseudo labels(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It
evaluates the reasoning accuracy of each example by comparing the logical reasoning
result (derived using the knowledge base) of the predicted pseudo labels against Y
The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended
to ``self.results``.

Parameters
----------
data_samples : ListData
A batch of data samples.
data_examples : ListData
A batch of data examples.
"""
pred_pseudo_label_list = data_samples.pred_pseudo_label
y_list = data_samples.Y
x_list = data_samples.X
pred_pseudo_label_list = data_examples.pred_pseudo_label
y_list = data_examples.Y
x_list = data_examples.X
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
if self.kb._check_equal(
self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y
@@ -63,7 +63,7 @@ class ReasoningMetric(BaseMetric):
def compute_metrics(self) -> dict:
"""
Compute the reasoning accuracy metrics from ``self.results``. It calculates the
percentage of correctly reasoned samples over all samples.
percentage of correctly reasoned examples over all examples.

Returns
-------


+ 7
- 7
abl/evaluation/symbol_metric.py View File

@@ -23,19 +23,19 @@ class SymbolMetric(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
super().__init__(prefix)

def process(self, data_samples: ListData) -> None:
def process(self, data_examples: ListData) -> None:
"""
Processes a batch of data samples.
Processes a batch of data examples.

This method takes in a batch of data samples, each containing a list of predicted
This method takes in a batch of data examples, each containing a list of predicted
pseudo labels (pred_pseudo_label) and their ground truth (gt_pseudo_label). It
calculates the accuracy by comparing the two lists. Then, a tuple of correct symbol
count and total symbol count is appended to `self.results`.

Parameters
----------
data_samples : ListData
A batch of data samples, each containing:
data_examples : ListData
A batch of data examples, each containing:
- `pred_pseudo_label`: List of predicted pseudo labels.
- `gt_pseudo_label`: List of ground truth pseudo labels.

@@ -44,8 +44,8 @@ class SymbolMetric(BaseMetric):
ValueError
If the lengths of predicted and ground truth symbol lists are not equal.
"""
pred_pseudo_label_list = data_samples.flatten("pred_pseudo_label")
gt_pseudo_label_list = data_samples.flatten("gt_pseudo_label")
pred_pseudo_label_list = data_examples.flatten("pred_pseudo_label")
gt_pseudo_label_list = data_examples.flatten("gt_pseudo_label")

if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")


+ 15
- 15
abl/learning/abl_model.py View File

@@ -24,13 +24,13 @@ class ABLModel:

self.base_model = base_model

def predict(self, data_samples: ListData) -> Dict:
def predict(self, data_examples: ListData) -> Dict:
"""
Predict the labels and probabilities for the given data.

Parameters
----------
data_samples : ListData
data_examples : ListData
A batch of data to predict on.

Returns
@@ -39,28 +39,28 @@ class ABLModel:
A dictionary containing the predicted labels and probabilities.
"""
model = self.base_model
data_X = data_samples.flatten("X")
data_X = data_examples.flatten("X")
if hasattr(model, "predict_proba"):
prob = model.predict_proba(X=data_X)
label = prob.argmax(axis=1)
prob = reform_list(prob, data_samples.X)
prob = reform_list(prob, data_examples.X)
else:
prob = None
label = model.predict(X=data_X)
label = reform_list(label, data_samples.X)
label = reform_list(label, data_examples.X)

data_samples.pred_idx = label
data_samples.pred_prob = prob
data_examples.pred_idx = label
data_examples.pred_prob = prob

return {"label": label, "prob": prob}

def train(self, data_samples: ListData) -> float:
def train(self, data_examples: ListData) -> float:
"""
Train the model on the given data.

Parameters
----------
data_samples : ListData
data_examples : ListData
A batch of data to train on, which typically contains the data, `X`, and the
corresponding labels, `abduced_idx`.

@@ -69,17 +69,17 @@ class ABLModel:
float
The loss value of the trained model.
"""
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("abduced_idx")
data_X = data_examples.flatten("X")
data_y = data_examples.flatten("abduced_idx")
return self.base_model.fit(X=data_X, y=data_y)

def valid(self, data_samples: ListData) -> float:
def valid(self, data_examples: ListData) -> float:
"""
Validate the model on the given data.

Parameters
----------
data_samples : ListData
data_examples : ListData
A batch of data to train on, which typically contains the data, `X`,
and the corresponding labels, `abduced_idx`.

@@ -88,8 +88,8 @@ class ABLModel:
float
The accuracy the trained model.
"""
data_X = data_samples.flatten("X")
data_y = data_samples.flatten("abduced_idx")
data_X = data_examples.flatten("X")
data_y = data_examples.flatten("abduced_idx")
score = self.base_model.score(X=data_X, y=data_y)
return score



+ 41
- 41
abl/reasoning/kb.py View File

@@ -27,7 +27,7 @@ class KBBase(ABC):
list so that each aligns with its corresponding index in the base model: the first with
the 0th index, the second with the 1st, and so forth.
max_err : float, optional
The upper tolerance limit when comparing the similarity between a pseudo label sample's
The upper tolerance limit when comparing the similarity between a pseudo label example's
reasoning result and the ground truth. This is only applicable when the reasoning
result is of a numerical type. This is particularly relevant for regression problems where
exact matches might not be feasible. Defaults to 1e-10.
@@ -82,15 +82,15 @@ class KBBase(ABC):
@abstractmethod
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any:
"""
How to perform (deductive) logical reasoning, i.e. matching each pseudo label sample to
How to perform (deductive) logical reasoning, i.e. matching each pseudo label example to
their reasoning result. Users are required to provide this.

Parameters
----------
pseudo_label : List[Any]
Pseudo label sample.
Pseudo label example.
x : Optional[List[Any]]
The corresponding input sample. If deductive logical reasoning does not require any
The corresponding input example. If deductive logical reasoning does not require any
information from the input, the overridden function provided by the user can omit
this parameter.
@@ -114,13 +114,13 @@ class KBBase(ABC):
Parameters
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised by abductive reasoning).
Pseudo label example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the sample.
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input sample.
The corresponding input example.
max_revision_num : int
The upper limit on the number of revised labels for each sample.
The upper limit on the number of revised labels for each example.
require_more_revision : int
Specifies additional number of revisions permitted beyond the minimum required.

@@ -128,7 +128,7 @@ class KBBase(ABC):
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
pseudo label examples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
@@ -136,7 +136,7 @@ class KBBase(ABC):

def _check_equal(self, reasoning_result: Any, y: Any) -> bool:
"""
Check whether the reasoning result of a pseduo label sample is equal to the ground truth
Check whether the reasoning result of a pseduo label example is equal to the ground truth
(or, within the maximum error allowed for numerical results).

Returns
@@ -160,24 +160,24 @@ class KBBase(ABC):
revision_idx: List[int],
) -> List[List[Any]]:
"""
Revise the pseudo label sample at specified index positions.
Revise the pseudo label example at specified index positions.

Parameters
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised).
Pseudo label example (to be revised).
y : Any
Ground truth of the reasoning result for the sample.
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input sample.
The corresponding input example.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo label sample.
A list specifying indices of where revisions should be made to the pseudo label example.

Returns
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
pseudo label examples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
@@ -200,7 +200,7 @@ class KBBase(ABC):
x: List[Any],
) -> List[List[Any]]:
"""
For a specified number of labels in a pseudo label sample to revise, iterate through
For a specified number of labels in a pseudo label example to revise, iterate through
all possible indices to find any candidates that are compatible with the knowledge base.
"""
new_candidates, new_reasoning_results = [], []
@@ -221,29 +221,29 @@ class KBBase(ABC):
) -> List[List[Any]]:
"""
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of labels in a pseudo label sample to revise, until
continuously increase the number of labels in a pseudo label example to revise, until
candidates that are compatible with the knowledge base are found.

Parameters
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised).
Pseudo label example (to be revised).
y : Any
Ground truth of the reasoning result for the sample.
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input sample.
The corresponding input example.
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int
If larger than 0, then after having found any candidates compatible with the
knowledge base, continue to increase the number of labels in a pseudo label sample to
knowledge base, continue to increase the number of labels in a pseudo label example to
revise to get more possible compatible candidates.

Returns
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
pseudo label examples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
@@ -286,7 +286,7 @@ class GroundKB(KBBase):
pseudo_label_list : list
Refer to class `KBBase`.
GKB_len_list : list
List of possible lengths for a pseudo label sample.
List of possible lengths for a pseudo label example.
max_err : float, optional
Refer to class `KBBase`.

@@ -359,13 +359,13 @@ class GroundKB(KBBase):
Parameters
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised by abductive reasoning).
Pseudo label example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the sample.
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input sample (unused in GroundKB).
The corresponding input example (unused in GroundKB).
max_revision_num : int
The upper limit on the number of revised labels for each sample.
The upper limit on the number of revised labels for each example.
require_more_revision : int
Specifies additional number of revisions permitted beyond the minimum required.

@@ -373,7 +373,7 @@ class GroundKB(KBBase):
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
pseudo label examples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""
@@ -477,7 +477,7 @@ class PrologKB(KBBase):
Parameters
----------
pseudo_label : List[Any]
Pseudo label sample.
Pseudo label example.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"]
if result == "true":
@@ -519,13 +519,13 @@ class PrologKB(KBBase):
Parameters
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised by abductive reasoning).
Pseudo label example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the sample.
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input sample.
The corresponding input example.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo label sample.
A list specifying indices of where revisions should be made to the pseudo label example.

Returns
-------
@@ -546,26 +546,26 @@ class PrologKB(KBBase):
revision_idx: List[int],
) -> List[List[Any]]:
"""
Revise the pseudo label sample at specified index positions by querying Prolog.
Revise the pseudo label example at specified index positions by querying Prolog.

Parameters
----------
pseudo_label : List[Any]
Pseudo label sample (to be revised).
Pseudo label example (to be revised).
y : Any
Ground truth of the reasoning result for the sample.
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input sample.
The corresponding input example.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo label sample.
A list specifying indices of where revisions should be made to the pseudo label example.

Returns
-------
Tuple[List[List[Any]], List[Any]]
A list of candidates, i.e. revised pseudo label samples that are compatible with the
A list of candidates, i.e. revised pseudo label examples that are compatible with the
knowledge base.
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo label samples that are compatible with the knowledge base. The second element is
pseudo label examples that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
"""


+ 44
- 44
abl/reasoning/reasoner.py View File

@@ -24,11 +24,11 @@ class Reasoner:
abduced label. It can be either a string representing a predefined distance
function or a callable function. The available predefined distance functions:
'hamming' | 'confidence'. 'hamming': directly calculates the Hamming
distance between the predicted pseudo label in the data sample and each
distance between the predicted pseudo label in the data example and each
candidate, 'confidence': calculates the distance between the prediction
and each candidate based on confidence derived from the predicted probability
in the data sample. The callable function should have the signature
dist_func(data_sample, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
in the data example. The callable function should have the signature
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
Defaults to 'confidence'.
@@ -36,7 +36,7 @@ class Reasoner:
A mapping from index in the base model to label. If not provided, a default
order-based index to label mapping is created. Defaults to None.
max_revision : Union[int, float], optional
The upper limit on the number of revisions for each data sample when
The upper limit on the number of revisions for each data example when
performing abductive reasoning. If float, denotes the fraction of the total
length that can be revised. A value of -1 implies no restriction on the
number of revisions. Defaults to -1.
@@ -100,7 +100,7 @@ class Reasoner:

def _get_one_candidate(
self,
data_sample: ListData,
data_example: ListData,
candidates: List[List[Any]],
reasoning_results: List[Any],
) -> List[Any]:
@@ -111,8 +111,8 @@ class Reasoner:

Parameters
----------
data_sample : ListData
Data sample.
data_example : ListData
Data example.
candidates : List[List[Any]]
Multiple compatible candidates.
reasoning_results : List[Any]
@@ -128,23 +128,23 @@ class Reasoner:
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self._get_cost_list(data_sample, candidates, reasoning_results)
cost_array = self._get_cost_list(data_example, candidates, reasoning_results)
candidate = candidates[np.argmin(cost_array)]
return candidate

def _get_cost_list(
self,
data_sample: ListData,
data_example: ListData,
candidates: List[List[Any]],
reasoning_results: List[Any],
) -> Union[List[Union[int, float]], np.ndarray]:
"""
Get the list of costs between each candidate and the given data sample.
Get the list of costs between each candidate and the given data example.

Parameters
----------
data_sample : ListData
Data sample.
data_example : ListData
Data example.
candidates : List[List[Any]]
Multiple compatible candidates.
reasoning_results : List[Any]
@@ -156,13 +156,13 @@ class Reasoner:
The list of costs.
"""
if self.dist_func == "hamming":
return hamming_dist(data_sample.pred_pseudo_label, candidates)
return hamming_dist(data_example.pred_pseudo_label, candidates)
elif self.dist_func == "confidence":
candidates = [[self.label_to_idx[x] for x in c] for c in candidates]
return confidence_dist(data_sample.pred_prob, candidates)
return confidence_dist(data_example.pred_prob, candidates)
else:
candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
cost_list = self.dist_func(data_sample, candidates, candidate_idxs, reasoning_results)
cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(
f"The length of the array returned by dist_func must be equal to the number of candidates. "
@@ -173,7 +173,7 @@ class Reasoner:
def _zoopt_get_solution(
self,
symbol_num: int,
data_sample: ListData,
data_example: ListData,
max_revision_num: int,
) -> List[bool]:
"""
@@ -184,8 +184,8 @@ class Reasoner:
----------
symbol_num : int
Number of total symbols.
data_sample : ListData
Data sample.
data_example : ListData
Data example.
max_revision_num : int
Specifies the maximum number of revisions allowed.

@@ -196,7 +196,7 @@ class Reasoner:
"""
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol),
lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
@@ -207,7 +207,7 @@ class Reasoner:
def zoopt_revision_score(
self,
symbol_num: int,
data_sample: ListData,
data_example: ListData,
sol: List[bool],
) -> int:
"""
@@ -218,8 +218,8 @@ class Reasoner:
----------
symbol_num : int
Number of total symbols.
data_sample : ListData
Data sample.
data_example : ListData
Data example.
sol: List[bool]
The solution for ZOOpt library.

@@ -230,10 +230,10 @@ class Reasoner:
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates, reasoning_results = self.kb.revise_at_idx(
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
)
if len(candidates) > 0:
return np.min(self._get_cost_list(data_sample, candidates, reasoning_results))
return np.min(self._get_cost_list(data_example, candidates, reasoning_results))
else:
return symbol_num

@@ -267,53 +267,53 @@ class Reasoner:
)
return max_revision

def abduce(self, data_sample: ListData) -> List[Any]:
def abduce(self, data_example: ListData) -> List[Any]:
"""
Perform abductive reasoning on the given data sample.
Perform abductive reasoning on the given data example.

Parameters
----------
data_sample : ListData
Data sample.
data_example : ListData
Data example.

Returns
-------
List[Any]
A revised pseudo label sample through abductive reasoning, which is compatible
A revised pseudo label example through abductive reasoning, which is compatible
with the knowledge base.
"""
symbol_num = data_sample.elements_num("pred_pseudo_label")
symbol_num = data_example.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)

if self.use_zoopt:
solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num)
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates, reasoning_results = self.kb.revise_at_idx(
pseudo_label=data_sample.pred_pseudo_label,
y=data_sample.Y,
x=data_sample.X,
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
x=data_example.X,
revision_idx=revision_idx
)
else:
candidates, reasoning_results = self.kb.abduce_candidates(
pseudo_label=data_sample.pred_pseudo_label,
y=data_sample.Y,
x=data_sample.X,
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,
x=data_example.X,
max_revision_num=max_revision_num,
require_more_revision=self.require_more_revision
)

candidate = self._get_one_candidate(data_sample, candidates, reasoning_results)
candidate = self._get_one_candidate(data_example, candidates, reasoning_results)
return candidate

def batch_abduce(self, data_samples: ListData) -> List[List[Any]]:
def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
"""
Perform abductive reasoning on the given prediction data samples.
Perform abductive reasoning on the given prediction data examples.
For detailed information, refer to `abduce`.
"""
abduced_pseudo_label = [self.abduce(data_sample) for data_sample in data_samples]
data_samples.abduced_pseudo_label = abduced_pseudo_label
abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples]
data_examples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

def __call__(self, data_samples: ListData) -> List[List[Any]]:
return self.batch_abduce(data_samples)
def __call__(self, data_examples: ListData) -> List[List[Any]]:
return self.batch_abduce(data_examples)

+ 13
- 13
abl/structures/base_data_element.py View File

@@ -25,21 +25,21 @@ class BaseDataElement:
``LabelData`` inheriting from ``BaseDataElement`` to represent different
types of ground truth labels or predictions.

Another common data element is sample data. A sample data consists of input
Another common data element is data example. A data example consists of input
data (such as an image) and its annotations and predictions. In general,
an image can have multiple types of annotations and/or predictions at the
same time (for example, both pixel-level semantic segmentation annotations
and instance-level detection bboxes annotations). All labels and
predictions of a training sample are often passed between Dataset, Model,
predictions of a training example are often passed between Dataset, Model,
Visualizer, and Evaluator components. In order to simplify the interface
between components, we can treat them as a large data element and
encapsulate them. Such data elements are generally called XXDataSample in
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
allows `BaseDataElement` as its attribute. Such a class generally
encapsulates all the data of a sample in the algorithm library, and its
encapsulates all the data of a example in the algorithm library, and its
attributes generally are various types of data elements. For example,
MMDetection is assigned by the BaseDataElement to encapsulate all the data
elements of the sample labeling and prediction of a sample in the
elements of the example labeling and prediction of a example in the
algorithm library.

The attributes in ``BaseDataElement`` are divided into two parts,
@@ -150,9 +150,9 @@ class BaseDataElement:
>>> metainfo = dict(img_shape=(800, 1196, 3))
>>> gt_instances = BaseDataElement(
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3]))
>>> sample = BaseDataElement(metainfo=metainfo,
>>> example = BaseDataElement(metainfo=metainfo,
... gt_instances=gt_instances)
>>> print(sample)
>>> print(example)
<BaseDataElement(
META INFORMATION
img_shape: (800, 1196, 3)
@@ -196,15 +196,15 @@ class BaseDataElement:
... @pred_instances.deleter
... def pred_instances(self):
... del self._pred_instances
>>> det_sample = DetDataSample()
>>> det_example = DetDataSample()
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4)))
>>> det_sample.proposals = proposals
>>> assert 'proposals' in det_sample
>>> assert det_sample.proposals == proposals
>>> del det_sample.proposals
>>> assert 'proposals' not in det_sample
>>> det_example.proposals = proposals
>>> assert 'proposals' in det_example
>>> assert det_example.proposals == proposals
>>> del det_example.proposals
>>> assert 'proposals' not in det_example
>>> with self.assertRaises(AssertionError):
... det_sample.proposals = torch.rand((5, 4))
... det_example.proposals = torch.rand((5, 4))
"""

def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None:


+ 14
- 14
docs/Examples/MNISTAdd.rst View File

@@ -114,24 +114,24 @@ image, respectively. As shown below:
.. code:: ipython3

pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
print(f"Shape of pred_idx for a batch of 32 samples: {pred_idx.shape}")
print(f"Shape of pred_idx for a batch of 32 examples: {pred_idx.shape}")
pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
print(f"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}")
print(f"Shape of pred_prob for a batch of 32 examples: {pred_prob.shape}")


Out:
.. code:: none
:class: code-out

Shape of pred_idx for a batch of 32 samples: (32,)
Shape of pred_prob for a batch of 32 samples: (32, 10)
Shape of pred_idx for a batch of 32 examples: (32,)
Shape of pred_prob for a batch of 32 examples: (32, 10)

However, the base model built above deals with instance-level data
(i.e., a single image), and can not directly deal with sample-level
(i.e., a single image), and can not directly deal with example-level
data (i.e., a pair of images). Therefore, we wrap the base model
into ``ABLModel``, which enables the learning part to train, test,
and predict on sample-level data.
and predict on example-level data.

.. code:: ipython3

@@ -142,10 +142,10 @@ TODO: 示例展示ablmodel和base model的predict的不同
.. code:: ipython3

# from abl.structures import ListData
# data_samples = ListData()
# data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
# data_examples = ListData()
# data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
# model.predict(data_samples)
# model.predict(data_examples)

Building the Reasoning Part
---------------------------
@@ -174,16 +174,16 @@ performing (deductive) reasoning: # TODO: ABDUCTIVE REASONING

.. code:: ipython3

pseudo_label_sample = [1, 2]
reasoning_result = kb.logic_forward(pseudo_label_sample)
print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.")
pseudo_label_example = [1, 2]
reasoning_result = kb.logic_forward(pseudo_label_example)
print(f"Reasoning result of pseudo label example {pseudo_label_example} is {reasoning_result}.")


Out:
.. code:: none
:class: code-out

Reasoning result of pseudo label sample [1, 2] is 3.
Reasoning result of pseudo label example [1, 2] is 3.

.. note::
@@ -211,7 +211,7 @@ candidate that has highest consistency.
During creating reasoner, the definition of “consistency” can be
customized within the ``dist_func`` parameter. In the code above, we
employ a consistency measurement based on confidence, which calculates
the consistency between the data sample and candidates based on the
the consistency between the data example and candidates based on the
confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
provide options for utilizing other forms of consistency measurement.



+ 4
- 4
docs/Intro/Basics.rst View File

@@ -56,13 +56,13 @@ Use ABL-Package Step by Step

In a typical Abductive Learning process, as illustrated below,
data inputs are first predicted by a machine learning model, and the outcomes are a pseudo label
sample (which consists of multiple pseudo labels).
example (which consists of multiple pseudo labels).
These labels then pass through a knowledge base :math:`\mathcal{KB}`
to obtain the reasoning result by deductive reasoning. During training,
alongside the aforementioned forward flow (i.e., prediction --> deduction reasoning),
there also exists a reverse flow, which starts from the reasoning result and
involves abductive reasoning to generate possible pseudo label samples.
Subsequently, these samples are processed to minimize inconsistencies with machine learning,
involves abductive reasoning to generate possible pseudo label examples.
Subsequently, these examples are processed to minimize inconsistencies with machine learning,
which in turn revise the outcomes of the machine learning model, and then
fed back into the machine learning model for further training.
To implement this process, the following five steps are necessary:
@@ -81,7 +81,7 @@ To implement this process, the following five steps are necessary:
3. Build the reasoning part

Define a knowledge base by building a subclass of ``KBBase``, specifying how to
map pseudo label samples to reasoning results.
map pseudo label examples to reasoning results.
Also, create a ``Reasoner`` for minimizing of inconsistencies
between the knowledge base and the learning part.



+ 17
- 17
docs/Intro/Bridge.rst View File

@@ -27,15 +27,15 @@ In this section, we will look at how to bridge learning and reasoning parts to t
+---------------------------------------+----------------------------------------------------+
| Method Signature | Description |
+=======================================+====================================================+
| ``predict(data_samples)`` | Predicts class probabilities and indices |
| | for the given data samples. |
| ``predict(data_examples)`` | Predicts class probabilities and indices |
| | for the given data examples. |
+---------------------------------------+----------------------------------------------------+
| ``abduce_pseudo_label(data_samples)`` | Abduces pseudo labels for the given data samples. |
| ``abduce_pseudo_label(data_examples)`` | Abduces pseudo labels for the given data examples. |
+---------------------------------------+----------------------------------------------------+
| ``idx_to_pseudo_label(data_samples)`` | Converts indices to pseudo labels using |
| ``idx_to_pseudo_label(data_examples)`` | Converts indices to pseudo labels using |
| | the provided or default mapping. |
+---------------------------------------+----------------------------------------------------+
| ``pseudo_label_to_idx(data_samples)`` | Converts pseudo labels to indices |
| ``pseudo_label_to_idx(data_examples)`` | Converts pseudo labels to indices |
| | using the provided or default remapping. |
+---------------------------------------+----------------------------------------------------+
| ``train(train_data)`` | Train the model. |
@@ -43,11 +43,11 @@ In this section, we will look at how to bridge learning and reasoning parts to t
| ``test(test_data)`` | Test the model. |
+---------------------------------------+----------------------------------------------------+

where ``train_data`` and ``test_data`` are both in the form of ``(X, gt_pseudo_label, Y)``. They will be used to construct ``ListData`` instances which are referred to as ``data_samples`` in the ``train`` and ``test`` methods respectively. More details can be found in `preparing datasets <Datasets.html>`_.
where ``train_data`` and ``test_data`` are both in the form of ``(X, gt_pseudo_label, Y)``. They will be used to construct ``ListData`` instances which are referred to as ``data_examples`` in the ``train`` and ``test`` methods respectively. More details can be found in `preparing datasets <Datasets.html>`_.

``SimpleBridge`` inherits from ``BaseBridge`` and provides a basic implementation. Besides the ``model`` and ``reasoner``, ``SimpleBridge`` has an extra initialization arguments, ``metric_list``, which will be used to evaluate model performance. Its training process involves several Abductive Learning loops and each loop consists of the following five steps:

1. Predict class probabilities and indices for the given data samples.
1. Predict class probabilities and indices for the given data examples.
2. Transform indices into pseudo labels.
3. Revise pseudo labels based on abdutive reasoning.
4. Transform the revised pseudo labels to indices.
@@ -71,21 +71,21 @@ The fundamental part of the ``train`` method is as follows:
will be used together to train the model.
"""
if isinstance(train_data, ListData):
data_samples = train_data
data_examples = train_data
else:
data_samples = self.data_preprocess(*train_data)
data_examples = self.data_preprocess(*train_data)
if isinstance(segment_size, float):
segment_size = int(segment_size * len(data_samples))
segment_size = int(segment_size * len(data_examples))

for loop in range(loops):
for seg_idx in range((len(data_samples) - 1) // segment_size + 1):
sub_data_samples = data_samples[
for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
sub_data_examples = data_examples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
self.predict(sub_data_samples) # 1
self.idx_to_pseudo_label(sub_data_samples) # 2
self.abduce_pseudo_label(sub_data_samples) # 3
self.pseudo_label_to_idx(sub_data_samples) # 4
loss = self.model.train(sub_data_samples) # 5, self.model is an ABLModel object
self.predict(sub_data_examples) # 1
self.idx_to_pseudo_label(sub_data_examples) # 2
self.abduce_pseudo_label(sub_data_examples) # 3
self.pseudo_label_to_idx(sub_data_examples) # 4
loss = self.model.train(sub_data_examples) # 5, self.model is an ABLModel object


+ 3
- 3
docs/Intro/Datasets.rst View File

@@ -24,11 +24,11 @@ Dataset
ABL-Package assumes user data to be structured as a tuple, comprising the following three components:

- ``X``: List[List[Any]]
A list of sublists representing the input data. We refer to each sublist in ``X`` as an sample and each sample may contain several instances.
A list of sublists representing the input data. We refer to each sublist in ``X`` as an example and each example may contain several instances.
- ``gt_pseudo_label``: List[List[Any]], optional
A list of sublists with each sublist representing a ground truth pseudo label sample. Each sample consists of ground truth pseudo labels for each **instance** within a sample of ``X``.
A list of sublists with each sublist representing a ground truth pseudo label example. Each example consists of ground truth pseudo labels for each **instance** within a example of ``X``.
- ``Y``: List[Any]
A list representing the ground truth reasoning result for each **sample** in ``X``.
A list representing the ground truth reasoning result for each **example** in ``X``.

.. warning::
Each sublist in ``gt_pseudo_label`` should have the same length as the sublist in ``X``. ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model. If the pseudo label of the instances in the datasets are unlabeled, ``gt_pseudo_label`` can be ``None``.


+ 3
- 3
docs/Intro/Evaluation.rst View File

@@ -34,11 +34,11 @@ We provide two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, w
# prefix is used to distinguish different metrics
super().__init__(prefix)

def process(self, data_samples: Sequence[dict]) -> None:
def process(self, data_examples: Sequence[dict]) -> None:
# pred_pseudo_label and gt_pseudo_label are both of type List[List[Any]]
# and have the same length
pred_pseudo_label = data_samples.pred_pseudo_label
gt_pseudo_label = data_samples.gt_pseudo_label
pred_pseudo_label = data_examples.pred_pseudo_label
gt_pseudo_label = data_examples.gt_pseudo_label
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
correct_num = 0


+ 2
- 2
docs/Intro/Learning.rst View File

@@ -15,7 +15,7 @@ In this section, we will look at how to build the learning part.
In ABL-Package, building the learning part involves two steps:

1. Build a base machine learning model used to make predictions on instance-level data, typically referred to as ``base_model``.
2. Instantiate an ``ABLModel`` with the ``base_model``, which enables the learning part to train, test, and predict on sample-level data.
2. Instantiate an ``ABLModel`` with the ``base_model``, which enables the learning part to train, test, and predict on example-level data.

.. code:: python

@@ -77,7 +77,7 @@ Besides the necessary methods required to instantiate an ``ABLModel``, i.e., ``f
Instantiating an ABLModel
-------------------------

Typically, ``base_model`` is trained to make predictions on instance-level data, and can not directly utilize sample-level data to train and predict, which is not suitable for most neural-symbolic tasks. ABL-Package provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all ``base_model``, which enables the learning part to train, test, and predict on sample-level data.
Typically, ``base_model`` is trained to make predictions on instance-level data, and can not directly utilize example-level data to train and predict, which is not suitable for most neural-symbolic tasks. ABL-Package provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all ``base_model``, which enables the learning part to train, test, and predict on example-level data.

Generally, we can simply instantiate an ``ABLModel`` by:



+ 2
- 2
docs/Intro/Quick-Start.rst View File

@@ -56,7 +56,7 @@ To facilitate uniform processing, ABL-Package provides the ``BasicNN`` class to
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = BasicNN(cls, loss_fn, optimizer, device)

However, Base model built above are trained to make predictions on instance-level data (e.g., a single image), which is not suitable enough for our task. Therefore, we then wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on sample-level data, (e.g., images that comprise the equation).
However, Base model built above are trained to make predictions on instance-level data (e.g., a single image), which is not suitable enough for our task. Therefore, we then wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise the equation).

.. code:: python

@@ -71,7 +71,7 @@ Building the Reasoning Part

To build the reasoning part, we first define a knowledge base by
creating a subclass of ``KBBase``, which specifies how to map a pseudo
label sample to its reasoning result. In the subclass, we initialize the
label example to its reasoning result. In the subclass, we initialize the
``pseudo_label_list`` parameter and override the ``logic_forward``
function specifying how to perform (deductive) reasoning.



+ 38
- 38
docs/Intro/Reasoning.rst View File

@@ -15,7 +15,7 @@ leverage domain knowledge and perform deductive or abductive reasoning.
In ABL-Package, building the reasoning part involves two steps:

1. Build a knowledge base by creating a subclass of ``KBBase``, which
specifies how to map pseudo label samples to reasoning results.
specifies how to map pseudo label examples to reasoning results.
2. Create a reasoner by instantiating the class ``Reasoner``
to minimize inconsistencies between the knowledge base and pseudo
labels predicted by the learning part.
@@ -43,7 +43,7 @@ and override the ``logic_forward`` function:
- ``pseudo_label_list`` is the list of possible pseudo labels (also,
the output of the machine learning model).
- ``logic_forward`` defines how to perform (deductive) reasoning,
i.e. matching each pseudo label sample (often consisting of multiple
i.e. matching each pseudo label example (often consisting of multiple
pseudo labels) to its reasoning result.

After that, other operations, including how to perform abductive
@@ -54,7 +54,7 @@ MNIST Addition example

As an example, the ``pseudo_label_list`` passed in MNIST Addition is all the
possible digits, namely, ``[0,1,2,...,9]``, and the ``logic_forward``
should be: “Add the two labels in the pseudo label sample to get the result.”. Therefore, the
should be: “Add the two labels in the pseudo label example to get the result.”. Therefore, the
construction of the KB (``add_kb``) for MNIST Addition would be:

.. code:: python
@@ -72,15 +72,15 @@ and (deductive) reasoning in ``add_kb`` would be:

.. code:: python

pseudo_label_sample = [1, 2]
reasoning_result = add_kb.logic_forward(pseudo_label_sample)
print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.")
pseudo_label_example = [1, 2]
reasoning_result = add_kb.logic_forward(pseudo_label_example)
print(f"Reasoning result of pseudo label example {pseudo_label_example} is {reasoning_result}.")

Out:
.. code:: none
:class: code-out

Reasoning result of pseudo label sample [1, 2] is 3
Reasoning result of pseudo label example [1, 2] is 3

.. _other-par:

@@ -91,13 +91,13 @@ We can also pass the following parameters in the ``__init__`` function when buil
knowledge base:

- ``max_err`` (float, optional), specifying the upper tolerance limit
when comparing the similarity between a pseudo label sample's reasoning result
when comparing the similarity between a pseudo label example's reasoning result
and the ground truth during abductive reasoning. This is only
applicable when the reasoning result is of a numerical type. This is
particularly relevant for regression problems where exact matches
might not be feasible. Defaults to 1e-10. See :ref:`an example <kb-abd-2>`.
- ``use_cache`` (bool, optional), indicating whether to use cache to store
previous candidates (pseudo label samples generated from abductive reasoning)
previous candidates (pseudo label examples generated from abductive reasoning)
to speed up subsequent abductive reasoning operations. Defaults to True.
For more information of abductive reasoning, please refer to :ref:`this <kb-abd>`.
- ``cache_size`` (int, optional), specifying the maximum cache
@@ -173,7 +173,7 @@ override the ``logic_forward`` function, and are allowed to pass other
:ref:`optional parameters <other-par>`. Additionally, we are required pass the
``GKB_len_list`` parameter in the ``__init__`` function.

- ``GKB_len_list`` is the list of possible lengths for a pseudo label sample.
- ``GKB_len_list`` is the list of possible lengths for a pseudo label example.

After that, other operations, including auto-construction of GKB, and
how to perform abductive reasoning, will be **automatically** set up.
@@ -206,29 +206,29 @@ Performing abductive reasoning in the knowledge base
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

As mentioned in :ref:`What is Abductive Reasoning? <abd>`, abductive reasoning
enables the inference of candidates (which are pseudo label samples) as potential
enables the inference of candidates (which are pseudo label examples) as potential
explanations for the reasoning result. Also, in Abductive Learning where
an observation (a pseudo label sample predicted by the learning part) is
an observation (a pseudo label example predicted by the learning part) is
available, we aim to let the candidate do not largely revise the
previously identified pseudo label sample.
previously identified pseudo label example.

``KBBase`` (also, ``GroundKB`` and ``PrologKB``) implement the method
``abduce_candidates(pseudo_label, y, max_revision_num, require_more_revision)``
for performing abductive reasoning, where the parameters are:

- ``pseudo_label``, the pseudo label sample to be revised by abductive
- ``pseudo_label``, the pseudo label example to be revised by abductive
reasoning, usually generated by the learning part.
- ``y``, the ground truth of the reasoning result for the sample. The
- ``y``, the ground truth of the reasoning result for the example. The
returned candidates should be compatible with it.
- ``max_revision_num``, an int value specifying the upper limit on the
number of revised labels for each sample.
number of revised labels for each example.
- ``require_more_revision``, an int value specifying additional number
of revisions permitted beyond the minimum required. (e.g., If we set
it to 0, even if ``max_revision_num`` is set to a high value, the
method will only output candidates with the minimum possible
revisions.)

And it return a list of candidates (i.e., revised pseudo label samples) that
And it return a list of candidates (i.e., revised pseudo label examples) that
are all compatible with ``y``.

MNIST Addition example (cont.)
@@ -292,7 +292,7 @@ When instantiating, besides the required knowledge base, we may also
specify:

- ``max_revision`` (int or float, optional), specifies the upper limit
on the number of revisions for each sample when performing
on the number of revisions for each example when performing
:ref:`abductive reasoning in the knowledge base <kb-abd>`. If float, denotes the
fraction of the total length that can be revised. A value of -1
implies no restriction on the number of revisions. Defaults to -1.
@@ -308,18 +308,18 @@ specify:
candidate returned from knowledge base. Valid options include
“confidence” (default) and “hamming”. For “confidence”, it calculates
the distance between the prediction and candidate based on confidence
derived from the predicted probability in the data sample. For
derived from the predicted probability in the data example. For
“hamming”, it directly calculates the Hamming distance between the
predicted pseudo label in the data sample and candidate.
predicted pseudo label in the data example and candidate.

The main method implemented by ``Reasoner`` is
``abduce(data_sample)``, which obtains the most consistent candidate
``abduce(data_example)``, which obtains the most consistent candidate
based on the distance function defined in ``dist_func``.

MNIST Addition example (cont.)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

As an example, consider these data samples for MNIST Addition:
As an example, consider these data examples for MNIST Addition:

.. code:: python

@@ -331,37 +331,37 @@ As an example, consider these data samples for MNIST Addition:
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]

sample1 = ListData()
sample1.pred_pseudo_label = [1, 1]
sample1.pred_prob = prob1
sample1.Y = 8
example1 = ListData()
example1.pred_pseudo_label = [1, 1]
example1.pred_prob = prob1
example1.Y = 8

sample2 = ListData()
sample2.pred_pseudo_label = [1, 1]
sample2.pred_prob = prob2
sample2.Y = 8
example2 = ListData()
example2.pred_pseudo_label = [1, 1]
example2.pred_prob = prob2
example2.Y = 8

The compatible candidates after abductive reasoning for both samples
The compatible candidates after abductive reasoning for both examples
would be ``[[1,7], [7,1]]``. However, when the reasoner call ``abduce``
to select only one candidate based on the ``confidence`` distance function,
the output would differ for each sample:
the output would differ for each example:

.. code:: python

reasoner_add = Reasoner(kb_add, dist_func="confidence")
candidate1 = reasoner_add.abduce(sample1)
candidate2 = reasoner_add.abduce(sample2)
print(f"The outputs for sample1 and sample2 are {candidate1} and {candidate2}, respectively.")
candidate1 = reasoner_add.abduce(example1)
candidate2 = reasoner_add.abduce(example2)
print(f"The outputs for example1 and example2 are {candidate1} and {candidate2}, respectively.")

Out:
.. code:: none
:class: code-out

The outputs for sample1 and sample2 are [1,7] and [7,1], respectively.
The outputs for example1 and example2 are [1,7] and [7,1], respectively.

Specifically, as mentioned before, ``confidence`` calculates the distance between the data
sample and candidates based on the confidence derived from the predicted probability.
Take ``sample1`` as an example, the ``pred_prob`` in it indicates a higher
example and candidates based on the confidence derived from the predicted probability.
Take ``example1`` as an example, the ``pred_prob`` in it indicates a higher
confidence that the first label should be "1" rather than "7". Therefore, among the
candidates [1,7] and [7,1], it would be closer to [1,7] (as its first label is "1").


+ 173
- 0
examples/hed/datasets/equation_generator.py View File

@@ -0,0 +1,173 @@
import os
import itertools
import random
import numpy as np
from PIL import Image
import pickle

def get_sign_path_list(data_dir, sign_names):
sign_num = len(sign_names)
index_dict = dict(zip(sign_names, list(range(sign_num))))
ret = [[] for _ in range(sign_num)]
for path in os.listdir(data_dir):
if (path in sign_names):
index = index_dict[path]
sign_path = os.path.join(data_dir, path)
for p in os.listdir(sign_path):
ret[index].append(os.path.join(sign_path, p))
return ret

def split_pool_by_rate(pools, rate, seed = None):
if seed is not None:
random.seed(seed)
ret1 = []
ret2 = []
for pool in pools:
random.shuffle(pool)
num = int(len(pool) * rate)
ret1.append(pool[:num])
ret2.append(pool[num:])
return ret1, ret2

def int_to_system_form(num, system_num):
if num == 0:
return "0"
ret = ""
while (num > 0):
ret += str(num % system_num)
num //= system_num
return ret[::-1]

def generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type):
expr_len = left_opt_len + right_opt_len
num_list = "".join([str(i) for i in range(system_num)])
ret = []
if generate_type == "all":
candidates = itertools.product(num_list, repeat = expr_len)
else:
candidates = [''.join(random.sample(['0', '1'] * expr_len, expr_len))]
random.shuffle(candidates)
for nums in candidates:
left_num = "".join(nums[:left_opt_len])
right_num = "".join(nums[left_opt_len:])
left_value = int(left_num, system_num)
right_value = int(right_num, system_num)
result_value = left_value + right_value
if (label == 'negative'):
result_value += random.randint(-result_value, result_value)
if (left_value + right_value == result_value):
continue
result_num = int_to_system_form(result_value, system_num)
#leading zeros
if (res_opt_len != len(result_num)):
continue
if ((left_opt_len > 1 and left_num[0] == '0') or (right_opt_len > 1 and right_num[0] == '0')):
continue

#add leading zeros
if (res_opt_len < len(result_num)):
continue
while (len(result_num) < res_opt_len):
result_num = '0' + result_num
#continue
ret.append(left_num + '+' + right_num + '=' + result_num) # current only consider '+' and '='
#print(ret[-1])
return ret

def generator_equation_by_len(equation_len, system_num = 2, label = 0, require_num = 1):
generate_type = "one"
ret = []
equation_sign_num = 2 # '+' and '='
while len(ret) < require_num:
left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num)
right_opt_len = random.randint(1, equation_len - left_opt_len - equation_sign_num)
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type))
return ret

def generator_equations_by_len(equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all"):
ret = []
equation_sign_num = 2 # '+' and '='
for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1):
for right_opt_len in range(1, equation_len - left_opt_len - (1 + equation_sign_num) + 1):
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num
for i in range(repeat_times): #generate more equations
if random.random() > keep ** (equation_len):
continue
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type))
return ret

def generator_equations_by_max_len(max_equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all", num_per_len = None):
ret = []
equation_sign_num = 2 # '+' and '='
for equation_len in range(3 + equation_sign_num, max_equation_len + 1):
if (num_per_len is None):
ret.extend(generator_equations_by_len(equation_len, system_num, label, repeat_times, keep, generate_type))
else:
ret.extend(generator_equation_by_len(equation_len, system_num, label, require_num = num_per_len))
return ret

def generator_equation_images(image_pools, equations, signs, shape, seed, is_color):
if (seed is not None):
random.seed(seed)
ret = []
sign_num = len(signs)
sign_index_dict = dict(zip(signs, list(range(sign_num))))
for equation in equations:
data = []
for sign in equation:
index = sign_index_dict[sign]
pick = random.randint(0, len(image_pools[index]) - 1)
if is_color:
image = Image.open(image_pools[index][pick]).convert('RGB').resize(shape)
else:
image = Image.open(image_pools[index][pick]).convert('I').resize(shape)
image_array = np.array(image)
image_array = (image_array-127)*(1./128)
data.append(image_array)
ret.append(np.array(data))
return ret

def get_equation_std_data(data_dir, sign_dir_lists, sign_output_lists, shape = (28, 28), train_max_equation_len = 10, test_max_equation_len = 10, system_num = 2, tmp_file_prev =
None, seed = None, train_num_per_len = 10, test_num_per_len = 10, is_color = False):
tmp_file = ""
if (tmp_file_prev is not None):
tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % (tmp_file_prev, train_max_equation_len, test_max_equation_len, system_num)
if (os.path.exists(tmp_file)):
return pickle.load(open(tmp_file, "rb"))

image_pools = get_sign_path_list(data_dir, sign_dir_lists)
train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed)

ret = {}
for label in ["positive", "negative"]:
print("Generating equations.")
train_equations = generator_equations_by_max_len(train_max_equation_len, system_num, label, num_per_len = train_num_per_len)
test_equations = generator_equations_by_max_len(test_max_equation_len, system_num, label, num_per_len = test_num_per_len)
print(train_equations)
print(test_equations)
print("Generated equations.")
print("Generating equation image data.")
ret["train:%s" % (label)] = generator_equation_images(train_pool, train_equations, sign_output_lists, shape, seed, is_color)
ret["test:%s" % (label)] = generator_equation_images(test_pool, test_equations, sign_output_lists, shape, seed, is_color)
print("Generated equation image data.")

if (tmp_file_prev is not None):
pickle.dump(ret, open(tmp_file, "wb"))
return ret

if __name__ == "__main__":
data_dirs = ["./dataset/hed/mnist_images", "./dataset/hed/random_images"] #, "../dataset/cifar10_images"]
tmp_file_prevs = ["mnist_equation_data", "random_equation_data"] #, "cifar10_equation_data"]
for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs):
data = get_equation_std_data(data_dir = data_dir,\
sign_dir_lists = ['0', '1', '10', '11'],\
sign_output_lists = ['0', '1', '+', '='],\
shape = (28, 28),\
train_max_equation_len = 26, \
test_max_equation_len = 26, \
system_num = 2, \
tmp_file_prev = tmp_file_prev, \
train_num_per_len = 300, \
test_num_per_len = 300, \
is_color = False)

+ 5
- 5
examples/hed/datasets/learn_add.pl View File

@@ -16,7 +16,7 @@ eval_eq(Ex, Feature):-
%%%%%%%%%%%%%%
%% Abduction
%%%%%%%%%%%%%%
% Make abduction when given examples that have been interpreted as pseudo-labels
% Make abduction when given samples that have been interpreted as pseudo-labels
abduce(Exs, Delta_C) :-
abduce(Exs, [], Delta_C).
abduce([], Delta_C, Delta_C).
@@ -45,13 +45,13 @@ consistent_inst_feature(Exs, Delta_C):-
%% (Experimental) Parallel abduction
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
abduce_consistent_exs_concurrent(Exs) :-
% Split the current data batch into grounding examples and variable examples (which need to be revised)
% Split the current data batch into grounding samples and variable samples (which need to be revised)
split_exs(Exs, Ground_Exs, Var_Exs),
% Find the simplest Delta_C for grounding examples.
% Find the simplest Delta_C for grounding samples.
abduce(Ground_Exs, Ground_Delta_C), !,
% Extend Ground Delta_C into all possible variations
extend_op_rule(Ground_Delta_C, Possible_Deltas),
% Concurrently abduce the variable examples
% Concurrently abduce the variable samples
maplist(append([abduce2, Var_Exs, Ground_Exs]), [[Possible_Deltas]], Call_List),
maplist(=.., Goals, Call_List),
% writeln(Goals),
@@ -76,7 +76,7 @@ extend_op_rule(Rules, Ext) :-
% abduction without learning new Delta_C (Because they have been extended!)
abduce2([], _, _).
abduce2([E|Exs], Ground_Exs, Delta_C) :-
% abduce by finding ground examples
% abduce by finding ground samples
member(E, Ground_Exs),
abduce2(Exs, Ground_Exs, Delta_C).
abduce2([E|Exs], Ground_Exs, Delta_C) :-


+ 43
- 43
examples/hed/hed_bridge.py View File

@@ -62,15 +62,15 @@ class HEDBridge(SimpleBridge):

self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth"))

def select_mapping_and_abduce(self, data_samples: ListData):
def select_mapping_and_abduce(self, data_examples: ListData):
candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1])
mapping_score = []
abduced_pseudo_label_list = []
for _mapping in candidate_mappings:
self.reasoner.idx_to_label = _mapping
self.reasoner.label_to_idx = dict(zip(_mapping.values(), _mapping.keys()))
self.idx_to_pseudo_label(data_samples)
abduced_pseudo_label = self.reasoner.abduce(data_samples)
self.idx_to_pseudo_label(data_examples)
abduced_pseudo_label = self.reasoner.abduce(data_examples)
mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([]))
abduced_pseudo_label_list.append(abduced_pseudo_label)

@@ -80,18 +80,18 @@ class HEDBridge(SimpleBridge):
self.reasoner.label_to_idx = dict(
zip(self.reasoner.idx_to_label.values(), self.reasoner.idx_to_label.keys())
)
self.idx_to_pseudo_label(data_samples)
data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx]
self.idx_to_pseudo_label(data_examples)
data_examples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx]

return data_samples.abduced_pseudo_label
return data_examples.abduced_pseudo_label

def abduce_pseudo_label(self, data_samples: ListData):
self.reasoner.abduce(data_samples)
return data_samples.abduced_pseudo_label
def abduce_pseudo_label(self, data_examples: ListData):
self.reasoner.abduce(data_examples)
return data_examples.abduced_pseudo_label

def check_training_impact(self, filtered_data_samples, data_samples):
character_accuracy = self.model.valid(filtered_data_samples)
revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X)
def check_training_impact(self, filtered_data_examples, data_examples):
character_accuracy = self.model.valid(filtered_data_examples)
revisible_ratio = len(filtered_data_examples.X) / len(data_examples.X)
log_string = (
f"Revisible ratio is {revisible_ratio:.3f}, Character "
f"accuracy is {character_accuracy:.3f}"
@@ -119,23 +119,23 @@ class HEDBridge(SimpleBridge):
return True
return False

def calc_consistent_ratio(self, data_samples, rule):
self.predict(data_samples)
pred_pseudo_label = self.idx_to_pseudo_label(data_samples)
def calc_consistent_ratio(self, data_examples, rule):
self.predict(data_examples)
pred_pseudo_label = self.idx_to_pseudo_label(data_examples)
consistent_num = sum(
[self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label]
)
return consistent_num / len(data_samples.X)
return consistent_num / len(data_examples.X)

def get_rules_from_data(self, data_samples, samples_per_rule, samples_num):
def get_rules_from_data(self, data_examples, examples_per_rule, examples_num):
rules = []
sampler = InfiniteSampler(len(data_samples), batch_size=samples_per_rule)
sampler = InfiniteSampler(len(data_examples), batch_size=examples_per_rule)

for _ in range(samples_num):
for _ in range(examples_num):
for select_idx in sampler:
sub_data_samples = data_samples[select_idx]
self.predict(sub_data_samples)
pred_pseudo_label = self.idx_to_pseudo_label(sub_data_samples)
sub_data_examples = data_examples[select_idx]
self.predict(sub_data_examples)
pred_pseudo_label = self.idx_to_pseudo_label(sub_data_examples)
consistent_instance = []
for instance in pred_pseudo_label:
if self.reasoner.kb.logic_forward([instance]):
@@ -157,13 +157,13 @@ class HEDBridge(SimpleBridge):
return rules

@staticmethod
def filter_empty(data_samples: ListData):
def filter_empty(data_examples: ListData):
consistent_dix = [
i
for i in range(len(data_samples.abduced_pseudo_label))
if len(data_samples.abduced_pseudo_label[i]) > 0
for i in range(len(data_examples.abduced_pseudo_label))
if len(data_examples.abduced_pseudo_label[i]) > 0
]
return data_samples[consistent_dix]
return data_examples[consistent_dix]

@staticmethod
def select_rules(rule_dict):
@@ -184,12 +184,12 @@ class HEDBridge(SimpleBridge):
return list(rule_dict)

def data_preprocess(self, data, equation_len) -> ListData:
data_samples = ListData()
data_samples.X = data[equation_len] + data[equation_len + 1]
data_samples.gt_pseudo_label = None
data_samples.Y = [None] * len(data_samples.X)
data_examples = ListData()
data_examples.X = data[equation_len] + data[equation_len + 1]
data_examples.gt_pseudo_label = None
data_examples.Y = [None] * len(data_examples.X)

return data_samples
return data_examples

def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8):
for equation_len in range(min_len, max_len):
@@ -199,25 +199,25 @@ class HEDBridge(SimpleBridge):
)

condition_num = 0
data_samples = self.data_preprocess(train_data[1], equation_len)
sampler = InfiniteSampler(len(data_samples), batch_size=segment_size)
data_examples = self.data_preprocess(train_data[1], equation_len)
sampler = InfiniteSampler(len(data_examples), batch_size=segment_size)
for seg_idx, select_idx in enumerate(sampler):
print_log(
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}]",
logger="current",
)
sub_data_samples = data_samples[select_idx]
self.predict(sub_data_samples)
sub_data_examples = data_examples[select_idx]
self.predict(sub_data_examples)
if equation_len == min_len:
self.select_mapping_and_abduce(sub_data_samples)
self.select_mapping_and_abduce(sub_data_examples)
else:
self.idx_to_pseudo_label(sub_data_samples)
self.abduce_pseudo_label(sub_data_samples)
filtered_sub_data_samples = self.filter_empty(sub_data_samples)
self.pseudo_label_to_idx(filtered_sub_data_samples)
loss = self.model.train(filtered_sub_data_samples)
self.idx_to_pseudo_label(sub_data_examples)
self.abduce_pseudo_label(sub_data_examples)
filtered_sub_data_examples = self.filter_empty(sub_data_examples)
self.pseudo_label_to_idx(filtered_sub_data_examples)
loss = self.model.train(filtered_sub_data_examples)

if self.check_training_impact(filtered_sub_data_samples, sub_data_samples):
if self.check_training_impact(filtered_sub_data_examples, sub_data_examples):
condition_num += 1
else:
condition_num = 0
@@ -225,7 +225,7 @@ class HEDBridge(SimpleBridge):
if condition_num >= 5:
print_log("Now checking if we can go to next course", logger="current")
rules = self.get_rules_from_data(
data_samples, samples_per_rule=3, samples_num=50
data_examples, examples_per_rule=3, examples_num=50
)
print_log("Learned rules from data: " + str(rules), logger="current")



+ 202
- 31
examples/hed/hed_example.ipynb View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -24,9 +24,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12/18 09:01:12 - abl - INFO - Abductive Learning on the HED example.\n"
]
}
],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n",
@@ -46,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -69,20 +77,20 @@
"\n",
"\n",
"class HedReasoner(Reasoner):\n",
" def revise_at_idx(self, data_sample):\n",
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n",
" def revise_at_idx(self, data_example):\n",
" revision_idx = np.where(np.array(data_example.flatten(\"revision_flag\")) != 0)[0]\n",
" candidate = self.kb.revise_at_idx(\n",
" data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx\n",
" data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx\n",
" )\n",
" return candidate\n",
"\n",
" def zoopt_revision_score(self, symbol_num, data_sample, sol):\n",
" def zoopt_revision_score(self, symbol_num, data_example, sol):\n",
" revision_flag = reform_list(\n",
" list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label\n",
" list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label\n",
" )\n",
" data_sample.revision_flag = revision_flag\n",
" data_example.revision_flag = revision_flag\n",
"\n",
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n",
" lefted_idxs = [i for i in range(len(data_example.pred_idx))]\n",
" candidate_size = []\n",
" max_consistent_idxs = []\n",
" while lefted_idxs:\n",
@@ -90,10 +98,10 @@
" idxs.append(lefted_idxs.pop(0))\n",
" max_candidate_idxs = []\n",
" found = False\n",
" for idx in range(-1, len(data_sample.pred_idx)):\n",
" for idx in range(-1, len(data_example.pred_idx)):\n",
" if (not idx in idxs) and (idx >= 0):\n",
" idxs.append(idx)\n",
" candidates, _ = self.revise_at_idx(data_sample[idxs])\n",
" candidates, _ = self.revise_at_idx(data_example[idxs])\n",
" if len(candidates) == 0:\n",
" if len(idxs) > 1:\n",
" idxs.pop()\n",
@@ -115,10 +123,10 @@
" score -= math.exp(-i) * candidate_size[i]\n",
" return score, max_consistent_idxs\n",
" \n",
" def _zoopt_get_solution(self, symbol_num, data_sample, max_revision_num):\n",
" def _zoopt_get_solution(self, symbol_num, data_example, max_revision_num):\n",
" dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n",
" objective = Objective(\n",
" lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol)[0],\n",
" lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol)[0],\n",
" dim=dimension,\n",
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n",
" )\n",
@@ -126,20 +134,20 @@
" solution = Opt.min(objective, parameter)\n",
" return solution\n",
"\n",
" def abduce(self, data_sample):\n",
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n",
" def abduce(self, data_example):\n",
" symbol_num = data_example.elements_num(\"pred_pseudo_label\")\n",
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n",
"\n",
" solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num)\n",
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_sample, solution)\n",
" solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)\n",
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution)\n",
"\n",
" abduced_pseudo_label = [[] for _ in range(len(data_sample))]\n",
" abduced_pseudo_label = [[] for _ in range(len(data_example))]\n",
"\n",
" if len(max_candidate_idxs) > 0:\n",
" candidates, _ = self.revise_at_idx(data_sample[max_candidate_idxs])\n",
" candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])\n",
" for i, idx in enumerate(max_candidate_idxs):\n",
" abduced_pseudo_label[idx] = candidates[0][i]\n",
" data_sample.abduced_pseudo_label = abduced_pseudo_label\n",
" data_example.abduced_pseudo_label = abduced_pseudo_label\n",
" return abduced_pseudo_label\n",
"\n",
" def abduce_rules(self, pred_res):\n",
@@ -160,7 +168,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -173,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -194,7 +202,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -214,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -232,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -249,7 +257,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -268,13 +276,176 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12/18 09:04:27 - abl - INFO - Pretrain Start\n",
"12/18 09:04:31 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_1.pth\n",
"12/18 09:04:33 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_2.pth\n",
"12/18 09:04:34 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_3.pth\n",
"12/18 09:04:36 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_4.pth\n",
"12/18 09:04:37 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_5.pth\n",
"12/18 09:04:38 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_6.pth\n",
"12/18 09:04:40 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_7.pth\n",
"12/18 09:04:41 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_8.pth\n",
"12/18 09:04:43 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_9.pth\n",
"12/18 09:04:44 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_10.pth\n",
"12/18 09:04:44 - abl - INFO - model loss: 0.78453\n",
"12/18 09:04:44 - abl - INFO - min loss is <abl.learning.basic_nn.BasicNN object at 0x7f6c4f9393d0>\n",
"12/18 09:04:44 - abl - INFO - Loads checkpoint by local backend from path: ./weights/pretrain_weights.pth\n",
"12/18 09:04:44 - abl - INFO - ============== equation_len: 5-6 ================\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 9.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 9.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 6.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 9.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [0.0, 10.0]\n",
"[zoopt] x: array([0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-1.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.])\n",
"[zoopt] value: [-2.0, 8.0]\n",
"12/18 09:05:16 - abl - INFO - Checkpoints will be saved to results/20231218_09_01_12/weights/model_checkpoint_epoch_1.pth\n",
"12/18 09:05:16 - abl - INFO - model loss: 0.59495\n"
]
},
{
"ename": "TypeError",
"evalue": "unsupported format string passed to BasicNN.__format__",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Input \u001b[0;32mIn [13]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m bridge\u001b[38;5;241m.\u001b[39mpretrain(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./weights\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mbridge\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_data\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/ABL-Package/examples/hed/hed_bridge.py:217\u001b[0m, in \u001b[0;36mHEDBridge.train\u001b[0;34m(self, train_data, val_data, segment_size, min_len, max_len)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mabduce_pseudo_label(sub_data_examples)\n\u001b[1;32m 216\u001b[0m filtered_sub_data_examples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_empty(sub_data_examples)\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpseudo_label_to_idx(filtered_sub_data_examples)\n\u001b[1;32m 218\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mtrain(filtered_sub_data_examples)\n\u001b[1;32m 220\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_training_impact(filtered_sub_data_examples, sub_data_examples):\n",
"\u001b[0;31mTypeError\u001b[0m: unsupported format string passed to BasicNN.__format__"
]
}
],
"source": [
"bridge.pretrain(\"./weights\")\n",
"bridge.train(train_data, val_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -293,7 +464,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {


+ 4
- 4
examples/hed/utils.py View File

@@ -5,14 +5,14 @@ import torch.utils.data.sampler as sampler


class InfiniteSampler(sampler.Sampler):
def __init__(self, num_samples, batch_size=1):
self.num_samples = num_samples
def __init__(self, num_examples, batch_size=1):
self.num_examples = num_examples
self.batch_size = batch_size

def __iter__(self):
while True:
order = np.random.permutation(self.num_samples)
for i in range(self.num_samples):
order = np.random.permutation(self.num_examples)
for i in range(self.num_examples):
yield order[i : i + self.batch_size]
i += self.batch_size



+ 12
- 11
examples/hwf/datasets/get_dataset.py View File

@@ -13,8 +13,8 @@ img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
def download_and_unzip(url, zip_file_name):
try:
gdown.download(url, zip_file_name)
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
zip_ref.extractall()
with zipfile.pseudo_labelipFile(zip_file_name, 'r') as zip_ref:
zip_ref.extractall(CURRENT_DIR)
os.remove(zip_file_name)
except Exception as e:
if os.path.exists(zip_file_name):
@@ -23,11 +23,11 @@ def download_and_unzip(url, zip_file_name):

def get_dataset(train=True, get_pseudo_label=False):
data_dir = CURRENT_DIR + '/data'
url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download'

if not os.path.exists(data_dir):
print("Dataset not exist, downloading it...")
download_and_unzip(url, 'HWF.zip')
url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download'
download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip"))
print("Download and extraction complete.")
if train:
@@ -36,7 +36,7 @@ def get_dataset(train=True, get_pseudo_label=False):
file = os.path.join(data_dir, "expr_test.json")

X = []
Z = [] if get_pseudo_label else None
pseudo_label = [] if get_pseudo_label else None
Y = []
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/")
with open(file) as f:
@@ -50,12 +50,13 @@ def get_dataset(train=True, get_pseudo_label=False):
img = img_transform(img)
imgs.append(img)
if get_pseudo_label:
imgs_pseudo_label.append(img_path.split("/")[0])
label_mappings = {"times": "*", "div": "/"}
label = img_path.split("/")[0]
label = label_mappings.get(label, label)
imgs_pseudo_label.append(label)
X.append(imgs)
if get_pseudo_label:
Z.append(imgs_pseudo_label)
pseudo_label.append(imgs_pseudo_label)
Y.append(data[idx]["res"])

return X, Z, Y

get_dataset()
return X, pseudo_label, Y

+ 16
- 19
examples/hwf/hwf.ipynb View File

@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -36,7 +36,18 @@
"from abl.utils import ABLLogger, print_log\n",
"\n",
"from examples.models.nn import SymbolNet\n",
"from examples.hwf.datasets.get_dataset import get_dataset"
"from examples.hwf.datasets import get_dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_data = get_dataset(train=True, get_pseudo_label=True)\n",
"test_data = get_dataset(train=False, get_pseudo_label=True)"
]
},
{
@@ -75,21 +86,18 @@
" for i in range(len(formula)):\n",
" if i % 2 == 0 and formula[i] not in [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]:\n",
" return False\n",
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"times\", \"div\"]:\n",
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"*\", \"/\"]:\n",
" return False\n",
" return True\n",
"\n",
" def logic_forward(self, formula):\n",
" if not self._valid_candidate(formula):\n",
" return np.inf\n",
" mapping = {str(i): str(i) for i in range(1, 10)}\n",
" mapping.update({\"+\": \"+\", \"-\": \"-\", \"times\": \"*\", \"div\": \"/\"})\n",
" formula = [mapping[f] for f in formula]\n",
" return np.info\n",
" return eval(\"\".join(formula))\n",
"\n",
"\n",
"kb = HWF_KB(\n",
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"times\", \"div\"],\n",
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"],\n",
" max_err=1e-10,\n",
" use_cache=False,\n",
")\n",
@@ -175,17 +183,6 @@
"### Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_data = get_dataset(train=True, get_pseudo_label=True)\n",
"test_data = get_dataset(train=False, get_pseudo_label=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",


+ 3
- 6
examples/hwf/main.py View File

@@ -33,21 +33,18 @@ class HWF_KB(KBBase):
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return np.info
return eval("".join(formula))


kb = HWF_KB(
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"],
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
max_err=1e-10,
use_cache=False,
)


+ 3
- 3
examples/mnist_add/datasets/get_dataset.py View File

@@ -18,13 +18,13 @@ def get_dataset(train=True, get_pseudo_label=False):
file = os.path.join(CURRENT_DIR, "test_data.txt")

X = []
Z = [] if get_pseudo_label else None
pseudo_label = [] if get_pseudo_label else None
Y = []
with open(file) as f:
for line in f:
x1, x2, y = map(int, line.strip().split(" "))
X.append([img_dataset[x1][0], img_dataset[x2][0]])
if get_pseudo_label:
Z.append([img_dataset[x1][1], img_dataset[x2][1]])
pseudo_label.append([img_dataset[x1][1], img_dataset[x2][1]])
Y.append(y)
return X, Z, Y
return X, pseudo_label, Y

+ 81
- 22
examples/mnist_add/mnist_add.ipynb View File

@@ -59,7 +59,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -122,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -150,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 17,
"metadata": {},
"outputs": [
{
@@ -164,21 +164,21 @@
],
"source": [
"pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])\n",
"print(f\"Shape of pred_idx for a batch of 32 samples: {pred_idx.shape}\")\n",
"print(f\"Shape of pred_idx for a batch of 32 examples: {pred_idx.shape}\")\n",
"pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])\n",
"print(f\"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}\")"
"print(f\"Shape of pred_prob for a batch of 32 examples: {pred_prob.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, base model built above are trained to make predictions on instance-level data, i.e., a single image, and can not directly utilize sample-level data, i.e., a pair of images. Therefore, we then wrap the base model into `ABLModel` which enables the learning part to train, test, and predict on sample-level data."
"However, base model built above are trained to make predictions on instance-level data, i.e., a single image, and can not directly utilize example-level data, i.e., a pair of images. Therefore, we then wrap the base model into `ABLModel` which enables the learning part to train, test, and predict on example-level data."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -194,15 +194,15 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# from abl.structures import ListData\n",
"# data_samples = ListData()\n",
"# data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]\n",
"# data_examples = ListData()\n",
"# data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]\n",
"\n",
"# model.predict(data_samples)"
"# model.predict(data_examples)"
]
},
{
@@ -221,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -245,7 +245,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -257,9 +257,9 @@
}
],
"source": [
"pseudo_label_sample = [1, 2]\n",
"reasoning_result = kb.logic_forward(pseudo_label_sample)\n",
"print(f\"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.\")"
"pseudo_label_example = [1, 2]\n",
"reasoning_result = kb.logic_forward(pseudo_label_example)\n",
"print(f\"Reasoning result of pseudo label example {pseudo_label_example} is {reasoning_result}.\")"
]
},
{
@@ -278,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@@ -289,7 +289,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data sample and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n",
"Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data example and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n",
"\n",
"Also, during process of inconsistency minimization, one can leverage [ZOOpt library](https://github.com/polixir/ZOOpt) for acceleration. Options for this are also available in `main.py`. Those interested are encouraged to explore these features."
]
@@ -311,7 +311,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@@ -330,7 +330,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@@ -346,9 +346,68 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 25,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12/19 14:41:46 - abl - INFO - Abductive Learning on the MNIST Addition example.\n",
"12/19 14:41:46 - abl - INFO - loop(train) [1/5] segment(train) [1/3] \n",
"12/19 14:41:51 - abl - INFO - model loss: 1.81279\n",
"12/19 14:41:51 - abl - INFO - loop(train) [1/5] segment(train) [2/3] \n",
"12/19 14:41:56 - abl - INFO - model loss: 1.40474\n",
"12/19 14:41:56 - abl - INFO - loop(train) [1/5] segment(train) [3/3] \n",
"12/19 14:42:01 - abl - INFO - model loss: 1.17817\n",
"12/19 14:42:01 - abl - INFO - Evaluation start: loop(val) [1]\n",
"12/19 14:42:02 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.496 mnist_add/reasoning_accuracy: 0.336 \n",
"12/19 14:42:02 - abl - INFO - Saving model: loop(save) [1]\n",
"12/19 14:42:02 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_1.pth\n",
"12/19 14:42:02 - abl - INFO - loop(train) [2/5] segment(train) [1/3] \n",
"12/19 14:42:07 - abl - INFO - model loss: 0.85932\n",
"12/19 14:42:07 - abl - INFO - loop(train) [2/5] segment(train) [2/3] \n",
"12/19 14:42:11 - abl - INFO - model loss: 0.62120\n",
"12/19 14:42:11 - abl - INFO - loop(train) [2/5] segment(train) [3/3] \n",
"12/19 14:42:16 - abl - INFO - model loss: 0.35382\n",
"12/19 14:42:16 - abl - INFO - Evaluation start: loop(val) [2]\n",
"12/19 14:42:17 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.980 mnist_add/reasoning_accuracy: 0.961 \n",
"12/19 14:42:17 - abl - INFO - Saving model: loop(save) [2]\n",
"12/19 14:42:17 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_2.pth\n",
"12/19 14:42:17 - abl - INFO - loop(train) [3/5] segment(train) [1/3] \n",
"12/19 14:42:21 - abl - INFO - model loss: 0.08302\n",
"12/19 14:42:21 - abl - INFO - loop(train) [3/5] segment(train) [2/3] \n",
"12/19 14:42:25 - abl - INFO - model loss: 0.05917\n",
"12/19 14:42:25 - abl - INFO - loop(train) [3/5] segment(train) [3/3] \n",
"12/19 14:42:30 - abl - INFO - model loss: 0.05425\n",
"12/19 14:42:30 - abl - INFO - Evaluation start: loop(val) [3]\n",
"12/19 14:42:31 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.988 mnist_add/reasoning_accuracy: 0.976 \n",
"12/19 14:42:31 - abl - INFO - Saving model: loop(save) [3]\n",
"12/19 14:42:31 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_3.pth\n",
"12/19 14:42:31 - abl - INFO - loop(train) [4/5] segment(train) [1/3] \n",
"12/19 14:42:35 - abl - INFO - model loss: 0.04650\n",
"12/19 14:42:35 - abl - INFO - loop(train) [4/5] segment(train) [2/3] \n",
"12/19 14:42:39 - abl - INFO - model loss: 0.04175\n",
"12/19 14:42:39 - abl - INFO - loop(train) [4/5] segment(train) [3/3] \n",
"12/19 14:42:44 - abl - INFO - model loss: 0.04207\n",
"12/19 14:42:44 - abl - INFO - Evaluation start: loop(val) [4]\n",
"12/19 14:42:45 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.979 \n",
"12/19 14:42:45 - abl - INFO - Saving model: loop(save) [4]\n",
"12/19 14:42:45 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_4.pth\n",
"12/19 14:42:45 - abl - INFO - loop(train) [5/5] segment(train) [1/3] \n",
"12/19 14:42:49 - abl - INFO - model loss: 0.03484\n",
"12/19 14:42:49 - abl - INFO - loop(train) [5/5] segment(train) [2/3] \n",
"12/19 14:42:53 - abl - INFO - model loss: 0.03319\n",
"12/19 14:42:53 - abl - INFO - loop(train) [5/5] segment(train) [3/3] \n",
"12/19 14:42:58 - abl - INFO - model loss: 0.03510\n",
"12/19 14:42:58 - abl - INFO - Evaluation start: loop(val) [5]\n",
"12/19 14:42:59 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.993 mnist_add/reasoning_accuracy: 0.987 \n",
"12/19 14:42:59 - abl - INFO - Saving model: loop(save) [5]\n",
"12/19 14:42:59 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_5.pth\n",
"12/19 14:42:59 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.988 mnist_add/reasoning_accuracy: 0.976 \n"
]
}
],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the MNIST Addition example.\", logger=\"current\")\n",


+ 189
- 0
examples/zoo/main.py View File

@@ -0,0 +1,189 @@
import os.path as osp

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from z3 import Solver, Int, If, Not, Implies, Sum, sat
import openml

from abl.learning import ABLModel
from abl.reasoning import KBBase, Reasoner
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.bridge import SimpleBridge
from abl.utils.utils import confidence_dist
from abl.utils import ABLLogger, print_log

# Build logger
print_log("Abductive Learning on the Zoo example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

# Learning Part
rf = RandomForestClassifier()
model = ABLModel(rf)

# %% [markdown]
# ### Logic Part

# %%
class ZooKB(KBBase):
def __init__(self):
super().__init__(pseudo_label_list=list(range(7)), use_cache=False)
# Use z3 solver
self.solver = Solver()

# Load information of Zoo dataset
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
self.attribute_names = attribute_names
self.target_names = y.cat.categories.tolist()
# Define variables
for name in self.attribute_names+self.target_names:
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules
# Define rules
rules = [
Implies(milk == 1, mammal == 1),
Implies(mammal == 1, milk == 1),
Implies(mammal == 1, backbone == 1),
Implies(mammal == 1, breathes == 1),
Implies(feathers == 1, bird == 1),
Implies(bird == 1, feathers == 1),
Implies(bird == 1, eggs == 1),
Implies(bird == 1, backbone == 1),
Implies(bird == 1, breathes == 1),
Implies(bird == 1, legs == 2),
Implies(bird == 1, tail == 1),
Implies(reptile == 1, backbone == 1),
Implies(reptile == 1, breathes == 1),
Implies(reptile == 1, tail == 1),
Implies(fish == 1, aquatic == 1),
Implies(fish == 1, toothed == 1),
Implies(fish == 1, backbone == 1),
Implies(fish == 1, Not(breathes == 1)),
Implies(fish == 1, fins == 1),
Implies(fish == 1, legs == 0),
Implies(fish == 1, tail == 1),
Implies(amphibian == 1, eggs == 1),
Implies(amphibian == 1, aquatic == 1),
Implies(amphibian == 1, backbone == 1),
Implies(amphibian == 1, breathes == 1),
Implies(amphibian == 1, legs == 4),
Implies(insect == 1, eggs == 1),
Implies(insect == 1, Not(backbone == 1)),
Implies(insect == 1, legs == 6),
Implies(invertebrate == 1, Not(backbone == 1))
]
# Define weights and sum of violated weights
self.weights = {rule: 1 for rule in rules}
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])
def logic_forward(self, pseudo_label, data_point):
attribute_names, target_names = self.attribute_names, self.target_names
solver = self.solver
total_violation_weight = self.total_violation_weight
pseudo_label, data_point = pseudo_label[0], data_point[0]
self.solver.reset()
for name, value in zip(attribute_names, data_point):
solver.add(eval(f"{name} == {value}"))
for cate, name in zip(self.pseudo_label_list,target_names):
value = 1 if (cate == pseudo_label) else 0
solver.add(eval(f"{name} == {value}"))
if solver.check() == sat:
model = solver.model()
total_weight = model.evaluate(total_violation_weight)
return total_weight.as_long()
else:
# No solution found
return 1e10
def consitency(data_example, candidates, candidate_idxs, reasoning_results):
pred_prob = data_example.pred_prob
model_scores = confidence_dist(pred_prob, candidate_idxs)
rule_scores = np.array(reasoning_results)
scores = model_scores + rule_scores
return scores

kb = ZooKB()
reasoner = Reasoner(kb, dist_func=consitency)

# %% [markdown]
# ### Datasets and Evaluation Metrics

# %%
# Function to load and preprocess the dataset
def load_and_preprocess_dataset(dataset_id):
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
# Convert data types
for col in X.select_dtypes(include='bool').columns:
X[col] = X[col].astype(int)
y = y.cat.codes.astype(int)
X, y = X.to_numpy(), y.to_numpy()
return X, y

# Function to split data (one shot)
def split_dataset(X, y, test_size = 0.3):
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)
label_indices, unlabel_indices, test_indices = [], [], []
for class_label in np.unique(y):
idxs = np.where(y == class_label)[0]
np.random.shuffle(idxs)
n_train_unlabel = int((1-test_size)*(len(idxs)-1))
label_indices.append(idxs[0])
unlabel_indices.extend(idxs[1:1+n_train_unlabel])
test_indices.extend(idxs[1+n_train_unlabel:])
X_label, y_label = X[label_indices], y[label_indices]
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]
X_test, y_test = X[test_indices], y[test_indices]
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test

# Load and preprocess the Zoo dataset
X, y = load_and_preprocess_dataset(dataset_id=62)

# Split data into labeled/unlabeled/test data
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)

# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)
# For tabular data in abl, each example contains a single instance (a row from the dataset).
# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated.
def transform_tab_data(X, y):
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))
label_data = transform_tab_data(X_label, y_label)
test_data = transform_tab_data(X_test, y_test)
train_data = transform_tab_data(X_unlabel, y_unlabel)

# %%
# Set up metrics
metric_list = [SymbolMetric(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]

# %% [markdown]
# ### Bridge Machine Learning and Logic Reasoning

# %%
bridge = SimpleBridge(model, reasoner, metric_list)

# %% [markdown]
# ### Train and Test

# %%
# Pre-train the machine learning model
rf.fit(X_label, y_label)

# %%
# Test the initial model
print("------- Test the initial model -----------")
bridge.test(test_data)
print("------- Use ABL to train the model -----------")
# Use ABL to train the model
bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)
print("------- Test the final model -----------")
# Test the final model
bridge.test(test_data)



+ 4
- 4
examples/zoo/zoo_example.ipynb View File

@@ -140,8 +140,8 @@
" # No solution found\n",
" return 1e10\n",
" \n",
"def consitency(data_sample, candidates, candidate_idxs, reasoning_results):\n",
" pred_prob = data_sample.pred_prob\n",
"def consitency(data_example, candidates, candidate_idxs, reasoning_results):\n",
" pred_prob = data_example.pred_prob\n",
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n",
" rule_scores = np.array(reasoning_results)\n",
" scores = model_scores + rule_scores\n",
@@ -198,8 +198,8 @@
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)\n",
"\n",
"# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)\n",
"# For tabular data in abl, each sample contains a single instance (a row from the dataset).\n",
"# For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated.\n",
"# For tabular data in abl, each example contains a single instance (a row from the dataset).\n",
"# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated.\n",
"def transform_tab_data(X, y):\n",
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
"label_data = transform_tab_data(X_label, y_label)\n",


+ 20
- 20
tests/conftest.py View File

@@ -61,15 +61,15 @@ def base_model_instance():
# Fixture for ListData instance
@pytest.fixture
def list_data_instance():
data_samples = ListData()
data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
data_samples.Y = [1, 2, 3]
data_samples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
return data_samples
data_examples = ListData()
data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]
data_examples.Y = [1, 2, 3]
data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
return data_examples


@pytest.fixture
def data_samples_add():
def data_examples_add():
# favor 1 in first one
prob1 = [
[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
@@ -81,27 +81,27 @@ def data_samples_add():
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]

data_samples_add = ListData()
data_samples_add.X = None
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_samples_add.Y = [8, 8, 17, 10]
return data_samples_add
data_examples_add = ListData()
data_examples_add.X = None
data_examples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_examples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_examples_add.Y = [8, 8, 17, 10]
return data_examples_add


@pytest.fixture
def data_samples_hwf():
data_samples_hwf = ListData()
data_samples_hwf.X = None
data_samples_hwf.pred_pseudo_label = [
def data_examples_hwf():
data_examples_hwf = ListData()
data_examples_hwf.X = None
data_examples_hwf.pred_pseudo_label = [
["5", "+", "2"],
["5", "+", "9"],
["5", "+", "9"],
["5", "-", "8", "8", "8"],
]
data_samples_hwf.pred_prob = [None, None, None, None]
data_samples_hwf.Y = [3, 64, 65, 3.17]
return data_samples_hwf
data_examples_hwf.pred_prob = [None, None, None, None]
data_examples_hwf.Y = [3, 64, 65, 3.17]
return data_examples_hwf


class AddKB(KBBase):
@@ -199,7 +199,7 @@ def kb_add_ground():

@pytest.fixture
def kb_add_prolog():
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl")
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl")
return kb

@pytest.fixture


+ 31
- 31
tests/test_reasoning.py View File

@@ -53,7 +53,7 @@ class TestGroundKB(object):
class TestPrologKB(object):
def test_init_pl1(self, kb_add_prolog):
assert kb_add_prolog.pseudo_label_list == list(range(10))
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl"
assert kb_add_prolog.pl_file == "examples/mnist_add/add.pl"

def test_init_pl2(self, kb_hed):
assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
@@ -101,7 +101,7 @@ class TestReaonser(object):
excinfo.value
)
def random_dist(self, data_sample, candidates, candidate_idxs, reasoning_results):
def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results):
cost_list = [np.random.rand() for _ in candidates]
return cost_list
@@ -113,11 +113,11 @@ class TestReaonser(object):
cost_list = np.array([np.random.rand() for _ in candidates])
return cost_list
def invalid_dist2(self, data_sample, candidates, candidate_idxs, reasoning_results):
def invalid_dist2(self, data_example, candidates, candidate_idxs, reasoning_results):
cost_list = np.array([np.random.rand() for _ in candidates])
return np.append(cost_list, np.random.rand())
def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add):
def test_invalid_user_defined_dist_func(self, kb_add, data_examples_add):
with pytest.raises(ValueError) as excinfo:
Reasoner(kb_add, self.invalid_dist1)
assert 'User-defined dist_func must have exactly four parameters' in str(
@@ -125,98 +125,98 @@ class TestReaonser(object):
)
with pytest.raises(ValueError) as excinfo:
reasoner = Reasoner(kb_add, self.invalid_dist2)
reasoner.batch_abduce(data_samples_add)
reasoner.batch_abduce(data_examples_add)
assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str(
excinfo.value
)


class TestBatchAbduce(object):
def test_batch_abduce_add(self, kb_add, data_samples_add):
def test_batch_abduce_add(self, kb_add, data_examples_add):
reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1)
reasoner3 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=0)
reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [
assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_examples_add) == [
[1, 7],
[7, 1],
[8, 9],
[1, 9],
]
assert reasoner4.batch_abduce(data_samples_add) == [
assert reasoner4.batch_abduce(data_examples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]

def test_batch_abduce_ground(self, kb_add_ground, data_samples_add):
def test_batch_abduce_ground(self, kb_add_ground, data_examples_add):
reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=1)
reasoner3 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=0)
reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner3.batch_abduce(data_samples_add) == [
assert reasoner1.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner2.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner3.batch_abduce(data_examples_add) == [
(1, 7),
(7, 1),
(8, 9),
(1, 9),
]
assert reasoner4.batch_abduce(data_samples_add) == [
assert reasoner4.batch_abduce(data_examples_add) == [
(1, 7),
(7, 1),
(8, 9),
(7, 3),
]

def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add):
def test_batch_abduce_prolog(self, kb_add_prolog, data_examples_add):
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1)
reasoner3 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=0)
reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [
assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_examples_add) == [
[1, 7],
[7, 1],
[8, 9],
[1, 9],
]
assert reasoner4.batch_abduce(data_samples_add) == [
assert reasoner4.batch_abduce(data_examples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]

def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add):
def test_batch_abduce_zoopt(self, kb_add_prolog, data_examples_add):
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [
assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_examples_add) == [
[1, 7],
[7, 1],
[8, 9],
[7, 3],
]

def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf):
def test_batch_abduce_hwf1(self, kb_hwf1, data_examples_hwf):
reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf)
res = reasoner1.batch_abduce(data_examples_hwf)
assert res == [
["1", "+", "2"],
["8", "times", "8"],
[],
["4", "-", "6", "div", "8"],
]
res = reasoner2.batch_abduce(data_samples_hwf)
res = reasoner2.batch_abduce(data_examples_hwf)
assert res == [["1", "+", "2"], [], [], []]
res = reasoner3.batch_abduce(data_samples_hwf)
res = reasoner3.batch_abduce(data_examples_hwf)
assert res == [
["1", "+", "2"],
["8", "times", "8"],
@@ -224,25 +224,25 @@ class TestBatchAbduce(object):
["4", "-", "6", "div", "8"],
]

def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf):
def test_batch_abduce_hwf2(self, kb_hwf2, data_examples_hwf):
reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf)
res = reasoner1.batch_abduce(data_examples_hwf)
assert res == [
["1", "+", "2"],
["7", "times", "9"],
["8", "times", "8"],
["5", "-", "8", "div", "8"],
]
res = reasoner2.batch_abduce(data_samples_hwf)
res = reasoner2.batch_abduce(data_examples_hwf)
assert res == [
["1", "+", "2"],
["7", "times", "9"],
[],
["5", "-", "8", "div", "8"],
]
res = reasoner3.batch_abduce(data_samples_hwf)
res = reasoner3.batch_abduce(data_examples_hwf)
assert res == [
["1", "+", "2"],
["7", "times", "9"],


Loading…
Cancel
Save