@@ -14,7 +14,7 @@ class BaseBridge(metaclass=ABCMeta): | |||
to construct a typical pipeline of Abductive Learning (corresponding to ``train``), | |||
which involves the following four methods: | |||
- predict: Predict class indices on the given data samples. | |||
- predict: Predict class indices on the given data examples. | |||
- idx_to_pseudo_label: Map indices into pseudo labels. | |||
- abduce_pseudo_label: Revise pseudo labels based on abdutive reasoning. | |||
- pseudo_label_to_idx: Map revised pseudo labels back into indices. | |||
@@ -42,30 +42,30 @@ class BaseBridge(metaclass=ABCMeta): | |||
self.reasoner = reasoner | |||
@abstractmethod | |||
def predict(self, data_samples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
def predict(self, data_examples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]: | |||
"""Placeholder for predicting class indices from input.""" | |||
@abstractmethod | |||
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
"""Placeholder for revising pseudo labels based on abdutive reasoning.""" | |||
@abstractmethod | |||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
"""Placeholder for mapping indices to pseudo labels.""" | |||
@abstractmethod | |||
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||
def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: | |||
"""Placeholder for mapping pseudo labels to indices.""" | |||
def filter_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
def filter_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
"""Default filter function for pseudo label.""" | |||
non_empty_idx = [ | |||
i | |||
for i in range(len(data_samples.abduced_pseudo_label)) | |||
if data_samples.abduced_pseudo_label[i] | |||
for i in range(len(data_examples.abduced_pseudo_label)) | |||
if data_examples.abduced_pseudo_label[i] | |||
] | |||
data_samples.update(data_samples[non_empty_idx]) | |||
return data_samples | |||
data_examples.update(data_examples[non_empty_idx]) | |||
return data_examples | |||
@abstractmethod | |||
def train( | |||
@@ -18,7 +18,7 @@ class SimpleBridge(BaseBridge): | |||
This class implements the typical pipeline of Abductive Learning, which involves | |||
the following five steps: | |||
- Predict class probabilities and indices for the given data samples. | |||
- Predict class probabilities and indices for the given data examples. | |||
- Map indices into pseudo labels. | |||
- Revise pseudo labels based on abdutive reasoning. | |||
- Map the revised pseudo labels to indices. | |||
@@ -44,69 +44,69 @@ class SimpleBridge(BaseBridge): | |||
super().__init__(model, reasoner) | |||
self.metric_list = metric_list | |||
def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]: | |||
""" | |||
Predict class indices and probabilities (if ``predict_proba`` is implemented in | |||
``self.model.base_model``) on the given data samples. | |||
``self.model.base_model``) on the given data examples. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
Data samples on which predictions are to be made. | |||
data_examples : ListData | |||
Data examples on which predictions are to be made. | |||
Returns | |||
------- | |||
Tuple[List[ndarray], List[ndarray]] | |||
A tuple containing lists of predicted indices and probabilities. | |||
""" | |||
self.model.predict(data_samples) | |||
return data_samples.pred_idx, data_samples.pred_prob | |||
self.model.predict(data_examples) | |||
return data_examples.pred_idx, data_examples.pred_prob | |||
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
""" | |||
Revise predicted pseudo labels of the given data samples using abduction. | |||
Revise predicted pseudo labels of the given data examples using abduction. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
Data samples containing predicted pseudo labels. | |||
data_examples : ListData | |||
Data examples containing predicted pseudo labels. | |||
Returns | |||
------- | |||
List[List[Any]] | |||
A list of abduced pseudo labels for the given data samples. | |||
A list of abduced pseudo labels for the given data examples. | |||
""" | |||
self.reasoner.batch_abduce(data_samples) | |||
return data_samples.abduced_pseudo_label | |||
self.reasoner.batch_abduce(data_examples) | |||
return data_examples.abduced_pseudo_label | |||
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: | |||
""" | |||
Map indices of data samples into pseudo labels. | |||
Map indices of data examples into pseudo labels. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
Data samples containing the indices. | |||
data_examples : ListData | |||
Data examples containing the indices. | |||
Returns | |||
------- | |||
List[List[Any]] | |||
A list of pseudo labels converted from indices. | |||
""" | |||
pred_idx = data_samples.pred_idx | |||
data_samples.pred_pseudo_label = [ | |||
pred_idx = data_examples.pred_idx | |||
data_examples.pred_pseudo_label = [ | |||
[self.reasoner.idx_to_label[_idx] for _idx in sub_list] for sub_list in pred_idx | |||
] | |||
return data_samples.pred_pseudo_label | |||
return data_examples.pred_pseudo_label | |||
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]: | |||
def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: | |||
""" | |||
Map pseudo labels of data samples into indices. | |||
Map pseudo labels of data examples into indices. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
Data samples containing pseudo labels. | |||
data_examples : ListData | |||
Data examples containing pseudo labels. | |||
Returns | |||
------- | |||
@@ -115,10 +115,10 @@ class SimpleBridge(BaseBridge): | |||
""" | |||
abduced_idx = [ | |||
[self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list] | |||
for sub_list in data_samples.abduced_pseudo_label | |||
for sub_list in data_examples.abduced_pseudo_label | |||
] | |||
data_samples.abduced_idx = abduced_idx | |||
return data_samples.abduced_idx | |||
data_examples.abduced_idx = abduced_idx | |||
return data_examples.abduced_idx | |||
def data_preprocess( | |||
self, | |||
@@ -141,49 +141,49 @@ class SimpleBridge(BaseBridge): | |||
The preprocessed ListData object. | |||
""" | |||
if isinstance(data, ListData): | |||
data_samples = data | |||
data_examples = data | |||
if not ( | |||
hasattr(data_samples, "X") | |||
and hasattr(data_samples, "gt_pseudo_label") | |||
and hasattr(data_samples, "Y") | |||
hasattr(data_examples, "X") | |||
and hasattr(data_examples, "gt_pseudo_label") | |||
and hasattr(data_examples, "Y") | |||
): | |||
raise ValueError( | |||
f"{prefix}data should have X, gt_pseudo_label and Y attribute but " | |||
f"only {data_samples.all_keys()} are provided." | |||
f"only {data_examples.all_keys()} are provided." | |||
) | |||
else: | |||
X, gt_pseudo_label, Y = data | |||
data_samples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y) | |||
return data_samples | |||
return data_examples | |||
def concat_data_samples( | |||
self, unlabel_data_samples: ListData, label_data_samples: Optional[ListData] | |||
def concat_data_examples( | |||
self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData] | |||
) -> ListData: | |||
""" | |||
Concatenate unlabeled and labeled data samples. ``abduced_pseudo_label`` of unlabeled data samples and ``gt_pseudo_label`` of labeled data samples will be used to train the model. | |||
Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model. | |||
Parameters | |||
---------- | |||
unlabel_data_samples : ListData | |||
Unlabeled data samples to concatenate. | |||
label_data_samples : Optional[ListData] | |||
Labeled data samples to concatenate, if available. | |||
unlabel_data_examples : ListData | |||
Unlabeled data examples to concatenate. | |||
label_data_examples : Optional[ListData] | |||
Labeled data examples to concatenate, if available. | |||
Returns | |||
------- | |||
ListData | |||
Concatenated data samples. | |||
Concatenated data examples. | |||
""" | |||
if label_data_samples is None: | |||
return unlabel_data_samples | |||
if label_data_examples is None: | |||
return unlabel_data_examples | |||
unlabel_data_samples.X = unlabel_data_samples.X + label_data_samples.X | |||
unlabel_data_samples.abduced_pseudo_label = ( | |||
unlabel_data_samples.abduced_pseudo_label + label_data_samples.gt_pseudo_label | |||
unlabel_data_examples.X = unlabel_data_examples.X + label_data_examples.X | |||
unlabel_data_examples.abduced_pseudo_label = ( | |||
unlabel_data_examples.abduced_pseudo_label + label_data_examples.gt_pseudo_label | |||
) | |||
unlabel_data_samples.Y = unlabel_data_samples.Y + label_data_samples.Y | |||
return unlabel_data_samples | |||
unlabel_data_examples.Y = unlabel_data_examples.Y + label_data_examples.Y | |||
return unlabel_data_examples | |||
def train( | |||
self, | |||
@@ -227,51 +227,51 @@ class SimpleBridge(BaseBridge): | |||
save_dir : Optional[str] | |||
Directory to save the model, by default None. | |||
""" | |||
data_samples = self.data_preprocess("train", train_data) | |||
data_examples = self.data_preprocess("train", train_data) | |||
if label_data is not None: | |||
label_data_samples = self.data_preprocess("label", label_data) | |||
label_data_examples = self.data_preprocess("label", label_data) | |||
else: | |||
label_data_samples = None | |||
label_data_examples = None | |||
if val_data is not None: | |||
val_data_samples = self.data_preprocess("val", val_data) | |||
val_data_examples = self.data_preprocess("val", val_data) | |||
else: | |||
val_data_samples = data_samples | |||
val_data_examples = data_examples | |||
if isinstance(segment_size, int): | |||
if segment_size <= 0: | |||
raise ValueError("segment_size should be positive.") | |||
elif isinstance(segment_size, float): | |||
if 0 < segment_size <= 1: | |||
segment_size = int(segment_size * len(data_samples)) | |||
segment_size = int(segment_size * len(data_examples)) | |||
else: | |||
raise ValueError("segment_size should be in (0, 1].") | |||
else: | |||
raise ValueError("segment_size should be int or float.") | |||
for loop in range(loops): | |||
for seg_idx in range((len(data_samples) - 1) // segment_size + 1): | |||
for seg_idx in range((len(data_examples) - 1) // segment_size + 1): | |||
print_log( | |||
f"loop(train) [{loop + 1}/{loops}] segment(train) " | |||
f"[{(seg_idx + 1)}/{(len(data_samples) - 1) // segment_size + 1}] ", | |||
f"[{(seg_idx + 1)}/{(len(data_examples) - 1) // segment_size + 1}] ", | |||
logger="current", | |||
) | |||
sub_data_samples = data_samples[ | |||
sub_data_examples = data_examples[ | |||
seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
] | |||
self.predict(sub_data_samples) | |||
self.idx_to_pseudo_label(sub_data_samples) | |||
self.abduce_pseudo_label(sub_data_samples) | |||
self.filter_pseudo_label(sub_data_samples) | |||
self.concat_data_samples(sub_data_samples, label_data_samples) | |||
self.pseudo_label_to_idx(sub_data_samples) | |||
self.model.train(sub_data_samples) | |||
self.predict(sub_data_examples) | |||
self.idx_to_pseudo_label(sub_data_examples) | |||
self.abduce_pseudo_label(sub_data_examples) | |||
self.filter_pseudo_label(sub_data_examples) | |||
self.concat_data_examples(sub_data_examples, label_data_examples) | |||
self.pseudo_label_to_idx(sub_data_examples) | |||
self.model.train(sub_data_examples) | |||
if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||
print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") | |||
self._valid(val_data_samples) | |||
self._valid(val_data_examples) | |||
if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||
print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||
@@ -279,20 +279,20 @@ class SimpleBridge(BaseBridge): | |||
save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") | |||
) | |||
def _valid(self, data_samples: ListData) -> None: | |||
def _valid(self, data_examples: ListData) -> None: | |||
""" | |||
Internal method for validating the model with given data samples. | |||
Internal method for validating the model with given data examples. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
Data samples to be used for validation. | |||
data_examples : ListData | |||
Data examples to be used for validation. | |||
""" | |||
self.predict(data_samples) | |||
self.idx_to_pseudo_label(data_samples) | |||
self.predict(data_examples) | |||
self.idx_to_pseudo_label(data_examples) | |||
for metric in self.metric_list: | |||
metric.process(data_samples) | |||
metric.process(data_examples) | |||
res = dict() | |||
for metric in self.metric_list: | |||
@@ -314,8 +314,8 @@ class SimpleBridge(BaseBridge): | |||
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] | |||
Validation data to be used for model evaluation. | |||
""" | |||
val_data_samples = self.data_preprocess(val_data) | |||
self._valid(val_data_samples) | |||
val_data_examples = self.data_preprocess(val_data) | |||
self._valid(val_data_examples) | |||
def test( | |||
self, | |||
@@ -329,5 +329,5 @@ class SimpleBridge(BaseBridge): | |||
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] | |||
Test data to be used for model evaluation. | |||
""" | |||
test_data_samples = self.data_preprocess("test", test_data) | |||
self._valid(test_data_samples) | |||
test_data_examples = self.data_preprocess("test", test_data) | |||
self._valid(test_data_examples) |
@@ -10,7 +10,7 @@ class BaseMetric(metaclass=ABCMeta): | |||
""" | |||
Base class for a metrics. | |||
The metrics first processes each batch of data_samples and appends the processed | |||
The metrics first processes each batch of data_examples and appends the processed | |||
results to the results list. Then, it computes the metrics of the entire dataset. | |||
Parameters | |||
@@ -30,16 +30,16 @@ class BaseMetric(metaclass=ABCMeta): | |||
self.prefix = prefix or self.default_prefix | |||
@abstractmethod | |||
def process(self, data_samples: ListData) -> None: | |||
def process(self, data_examples: ListData) -> None: | |||
""" | |||
Process one batch of data samples. The processed results should be stored | |||
Process one batch of data examples. The processed results should be stored | |||
in ``self.results``, which will be used to compute the metrics when all | |||
batches have been processed. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
A batch of data samples. | |||
data_examples : ListData | |||
A batch of data examples. | |||
""" | |||
@abstractmethod | |||
@@ -25,7 +25,7 @@ class ReasoningMetric(BaseMetric): | |||
Notes | |||
----- | |||
The `ReasoningMetric` expects data_samples to have the attributes `pred_pseudo_label`, | |||
The `ReasoningMetric` expects data_examples to have the attributes `pred_pseudo_label`, | |||
`Y`, and `X`, corresponding to the predicted pseduo labels, ground truth of reasoning | |||
results, and input data, respectively. | |||
""" | |||
@@ -34,24 +34,24 @@ class ReasoningMetric(BaseMetric): | |||
super().__init__(prefix) | |||
self.kb = kb | |||
def process(self, data_samples: ListData) -> None: | |||
def process(self, data_examples: ListData) -> None: | |||
""" | |||
Process a batch of data samples. | |||
Process a batch of data examples. | |||
This method takes in a batch of data samples, each containing predicted pseudo labels(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It | |||
evaluates the reasoning accuracy of each sample by comparing the logical reasoning | |||
This method takes in a batch of data examples, each containing predicted pseudo labels(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It | |||
evaluates the reasoning accuracy of each example by comparing the logical reasoning | |||
result (derived using the knowledge base) of the predicted pseudo labels against Y | |||
The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended | |||
to ``self.results``. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
A batch of data samples. | |||
data_examples : ListData | |||
A batch of data examples. | |||
""" | |||
pred_pseudo_label_list = data_samples.pred_pseudo_label | |||
y_list = data_samples.Y | |||
x_list = data_samples.X | |||
pred_pseudo_label_list = data_examples.pred_pseudo_label | |||
y_list = data_examples.Y | |||
x_list = data_examples.X | |||
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): | |||
if self.kb._check_equal( | |||
self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y | |||
@@ -63,7 +63,7 @@ class ReasoningMetric(BaseMetric): | |||
def compute_metrics(self) -> dict: | |||
""" | |||
Compute the reasoning accuracy metrics from ``self.results``. It calculates the | |||
percentage of correctly reasoned samples over all samples. | |||
percentage of correctly reasoned examples over all examples. | |||
Returns | |||
------- | |||
@@ -23,19 +23,19 @@ class SymbolMetric(BaseMetric): | |||
def __init__(self, prefix: Optional[str] = None) -> None: | |||
super().__init__(prefix) | |||
def process(self, data_samples: ListData) -> None: | |||
def process(self, data_examples: ListData) -> None: | |||
""" | |||
Processes a batch of data samples. | |||
Processes a batch of data examples. | |||
This method takes in a batch of data samples, each containing a list of predicted | |||
This method takes in a batch of data examples, each containing a list of predicted | |||
pseudo labels (pred_pseudo_label) and their ground truth (gt_pseudo_label). It | |||
calculates the accuracy by comparing the two lists. Then, a tuple of correct symbol | |||
count and total symbol count is appended to `self.results`. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
A batch of data samples, each containing: | |||
data_examples : ListData | |||
A batch of data examples, each containing: | |||
- `pred_pseudo_label`: List of predicted pseudo labels. | |||
- `gt_pseudo_label`: List of ground truth pseudo labels. | |||
@@ -44,8 +44,8 @@ class SymbolMetric(BaseMetric): | |||
ValueError | |||
If the lengths of predicted and ground truth symbol lists are not equal. | |||
""" | |||
pred_pseudo_label_list = data_samples.flatten("pred_pseudo_label") | |||
gt_pseudo_label_list = data_samples.flatten("gt_pseudo_label") | |||
pred_pseudo_label_list = data_examples.flatten("pred_pseudo_label") | |||
gt_pseudo_label_list = data_examples.flatten("gt_pseudo_label") | |||
if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list): | |||
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal") | |||
@@ -24,13 +24,13 @@ class ABLModel: | |||
self.base_model = base_model | |||
def predict(self, data_samples: ListData) -> Dict: | |||
def predict(self, data_examples: ListData) -> Dict: | |||
""" | |||
Predict the labels and probabilities for the given data. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
data_examples : ListData | |||
A batch of data to predict on. | |||
Returns | |||
@@ -39,28 +39,28 @@ class ABLModel: | |||
A dictionary containing the predicted labels and probabilities. | |||
""" | |||
model = self.base_model | |||
data_X = data_samples.flatten("X") | |||
data_X = data_examples.flatten("X") | |||
if hasattr(model, "predict_proba"): | |||
prob = model.predict_proba(X=data_X) | |||
label = prob.argmax(axis=1) | |||
prob = reform_list(prob, data_samples.X) | |||
prob = reform_list(prob, data_examples.X) | |||
else: | |||
prob = None | |||
label = model.predict(X=data_X) | |||
label = reform_list(label, data_samples.X) | |||
label = reform_list(label, data_examples.X) | |||
data_samples.pred_idx = label | |||
data_samples.pred_prob = prob | |||
data_examples.pred_idx = label | |||
data_examples.pred_prob = prob | |||
return {"label": label, "prob": prob} | |||
def train(self, data_samples: ListData) -> float: | |||
def train(self, data_examples: ListData) -> float: | |||
""" | |||
Train the model on the given data. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
data_examples : ListData | |||
A batch of data to train on, which typically contains the data, `X`, and the | |||
corresponding labels, `abduced_idx`. | |||
@@ -69,17 +69,17 @@ class ABLModel: | |||
float | |||
The loss value of the trained model. | |||
""" | |||
data_X = data_samples.flatten("X") | |||
data_y = data_samples.flatten("abduced_idx") | |||
data_X = data_examples.flatten("X") | |||
data_y = data_examples.flatten("abduced_idx") | |||
return self.base_model.fit(X=data_X, y=data_y) | |||
def valid(self, data_samples: ListData) -> float: | |||
def valid(self, data_examples: ListData) -> float: | |||
""" | |||
Validate the model on the given data. | |||
Parameters | |||
---------- | |||
data_samples : ListData | |||
data_examples : ListData | |||
A batch of data to train on, which typically contains the data, `X`, | |||
and the corresponding labels, `abduced_idx`. | |||
@@ -88,8 +88,8 @@ class ABLModel: | |||
float | |||
The accuracy the trained model. | |||
""" | |||
data_X = data_samples.flatten("X") | |||
data_y = data_samples.flatten("abduced_idx") | |||
data_X = data_examples.flatten("X") | |||
data_y = data_examples.flatten("abduced_idx") | |||
score = self.base_model.score(X=data_X, y=data_y) | |||
return score | |||
@@ -27,7 +27,7 @@ class KBBase(ABC): | |||
list so that each aligns with its corresponding index in the base model: the first with | |||
the 0th index, the second with the 1st, and so forth. | |||
max_err : float, optional | |||
The upper tolerance limit when comparing the similarity between a pseudo label sample's | |||
The upper tolerance limit when comparing the similarity between a pseudo label example's | |||
reasoning result and the ground truth. This is only applicable when the reasoning | |||
result is of a numerical type. This is particularly relevant for regression problems where | |||
exact matches might not be feasible. Defaults to 1e-10. | |||
@@ -82,15 +82,15 @@ class KBBase(ABC): | |||
@abstractmethod | |||
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any: | |||
""" | |||
How to perform (deductive) logical reasoning, i.e. matching each pseudo label sample to | |||
How to perform (deductive) logical reasoning, i.e. matching each pseudo label example to | |||
their reasoning result. Users are required to provide this. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample. | |||
Pseudo label example. | |||
x : Optional[List[Any]] | |||
The corresponding input sample. If deductive logical reasoning does not require any | |||
The corresponding input example. If deductive logical reasoning does not require any | |||
information from the input, the overridden function provided by the user can omit | |||
this parameter. | |||
@@ -114,13 +114,13 @@ class KBBase(ABC): | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample (to be revised by abductive reasoning). | |||
Pseudo label example (to be revised by abductive reasoning). | |||
y : Any | |||
Ground truth of the reasoning result for the sample. | |||
Ground truth of the reasoning result for the example. | |||
x : List[Any] | |||
The corresponding input sample. | |||
The corresponding input example. | |||
max_revision_num : int | |||
The upper limit on the number of revised labels for each sample. | |||
The upper limit on the number of revised labels for each example. | |||
require_more_revision : int | |||
Specifies additional number of revisions permitted beyond the minimum required. | |||
@@ -128,7 +128,7 @@ class KBBase(ABC): | |||
------- | |||
Tuple[List[List[Any]], List[Any]] | |||
A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
pseudo label samples that are compatible with the knowledge base. The second element is | |||
pseudo label examples that are compatible with the knowledge base. The second element is | |||
a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
logic_forward function. | |||
""" | |||
@@ -136,7 +136,7 @@ class KBBase(ABC): | |||
def _check_equal(self, reasoning_result: Any, y: Any) -> bool: | |||
""" | |||
Check whether the reasoning result of a pseduo label sample is equal to the ground truth | |||
Check whether the reasoning result of a pseduo label example is equal to the ground truth | |||
(or, within the maximum error allowed for numerical results). | |||
Returns | |||
@@ -160,24 +160,24 @@ class KBBase(ABC): | |||
revision_idx: List[int], | |||
) -> List[List[Any]]: | |||
""" | |||
Revise the pseudo label sample at specified index positions. | |||
Revise the pseudo label example at specified index positions. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample (to be revised). | |||
Pseudo label example (to be revised). | |||
y : Any | |||
Ground truth of the reasoning result for the sample. | |||
Ground truth of the reasoning result for the example. | |||
x : List[Any] | |||
The corresponding input sample. | |||
The corresponding input example. | |||
revision_idx : List[int] | |||
A list specifying indices of where revisions should be made to the pseudo label sample. | |||
A list specifying indices of where revisions should be made to the pseudo label example. | |||
Returns | |||
------- | |||
Tuple[List[List[Any]], List[Any]] | |||
A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
pseudo label samples that are compatible with the knowledge base. The second element is | |||
pseudo label examples that are compatible with the knowledge base. The second element is | |||
a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
logic_forward function. | |||
""" | |||
@@ -200,7 +200,7 @@ class KBBase(ABC): | |||
x: List[Any], | |||
) -> List[List[Any]]: | |||
""" | |||
For a specified number of labels in a pseudo label sample to revise, iterate through | |||
For a specified number of labels in a pseudo label example to revise, iterate through | |||
all possible indices to find any candidates that are compatible with the knowledge base. | |||
""" | |||
new_candidates, new_reasoning_results = [], [] | |||
@@ -221,29 +221,29 @@ class KBBase(ABC): | |||
) -> List[List[Any]]: | |||
""" | |||
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and | |||
continuously increase the number of labels in a pseudo label sample to revise, until | |||
continuously increase the number of labels in a pseudo label example to revise, until | |||
candidates that are compatible with the knowledge base are found. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample (to be revised). | |||
Pseudo label example (to be revised). | |||
y : Any | |||
Ground truth of the reasoning result for the sample. | |||
Ground truth of the reasoning result for the example. | |||
x : List[Any] | |||
The corresponding input sample. | |||
The corresponding input example. | |||
max_revision_num : int | |||
The upper limit on the number of revisions. | |||
require_more_revision : int | |||
If larger than 0, then after having found any candidates compatible with the | |||
knowledge base, continue to increase the number of labels in a pseudo label sample to | |||
knowledge base, continue to increase the number of labels in a pseudo label example to | |||
revise to get more possible compatible candidates. | |||
Returns | |||
------- | |||
Tuple[List[List[Any]], List[Any]] | |||
A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
pseudo label samples that are compatible with the knowledge base. The second element is | |||
pseudo label examples that are compatible with the knowledge base. The second element is | |||
a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
logic_forward function. | |||
""" | |||
@@ -286,7 +286,7 @@ class GroundKB(KBBase): | |||
pseudo_label_list : list | |||
Refer to class `KBBase`. | |||
GKB_len_list : list | |||
List of possible lengths for a pseudo label sample. | |||
List of possible lengths for a pseudo label example. | |||
max_err : float, optional | |||
Refer to class `KBBase`. | |||
@@ -359,13 +359,13 @@ class GroundKB(KBBase): | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample (to be revised by abductive reasoning). | |||
Pseudo label example (to be revised by abductive reasoning). | |||
y : Any | |||
Ground truth of the reasoning result for the sample. | |||
Ground truth of the reasoning result for the example. | |||
x : List[Any] | |||
The corresponding input sample (unused in GroundKB). | |||
The corresponding input example (unused in GroundKB). | |||
max_revision_num : int | |||
The upper limit on the number of revised labels for each sample. | |||
The upper limit on the number of revised labels for each example. | |||
require_more_revision : int | |||
Specifies additional number of revisions permitted beyond the minimum required. | |||
@@ -373,7 +373,7 @@ class GroundKB(KBBase): | |||
------- | |||
Tuple[List[List[Any]], List[Any]] | |||
A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
pseudo label samples that are compatible with the knowledge base. The second element is | |||
pseudo label examples that are compatible with the knowledge base. The second element is | |||
a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
logic_forward function. | |||
""" | |||
@@ -477,7 +477,7 @@ class PrologKB(KBBase): | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample. | |||
Pseudo label example. | |||
""" | |||
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"] | |||
if result == "true": | |||
@@ -519,13 +519,13 @@ class PrologKB(KBBase): | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample (to be revised by abductive reasoning). | |||
Pseudo label example (to be revised by abductive reasoning). | |||
y : Any | |||
Ground truth of the reasoning result for the sample. | |||
Ground truth of the reasoning result for the example. | |||
x : List[Any] | |||
The corresponding input sample. | |||
The corresponding input example. | |||
revision_idx : List[int] | |||
A list specifying indices of where revisions should be made to the pseudo label sample. | |||
A list specifying indices of where revisions should be made to the pseudo label example. | |||
Returns | |||
------- | |||
@@ -546,26 +546,26 @@ class PrologKB(KBBase): | |||
revision_idx: List[int], | |||
) -> List[List[Any]]: | |||
""" | |||
Revise the pseudo label sample at specified index positions by querying Prolog. | |||
Revise the pseudo label example at specified index positions by querying Prolog. | |||
Parameters | |||
---------- | |||
pseudo_label : List[Any] | |||
Pseudo label sample (to be revised). | |||
Pseudo label example (to be revised). | |||
y : Any | |||
Ground truth of the reasoning result for the sample. | |||
Ground truth of the reasoning result for the example. | |||
x : List[Any] | |||
The corresponding input sample. | |||
The corresponding input example. | |||
revision_idx : List[int] | |||
A list specifying indices of where revisions should be made to the pseudo label sample. | |||
A list specifying indices of where revisions should be made to the pseudo label example. | |||
Returns | |||
------- | |||
Tuple[List[List[Any]], List[Any]] | |||
A list of candidates, i.e. revised pseudo label samples that are compatible with the | |||
A list of candidates, i.e. revised pseudo label examples that are compatible with the | |||
knowledge base. | |||
A tuple of two element. The first element is a list of candidate revisions, i.e. revised | |||
pseudo label samples that are compatible with the knowledge base. The second element is | |||
pseudo label examples that are compatible with the knowledge base. The second element is | |||
a list of reasoning results corresponding to each candidate, i.e., the outcome of the | |||
logic_forward function. | |||
""" | |||
@@ -24,11 +24,11 @@ class Reasoner: | |||
abduced label. It can be either a string representing a predefined distance | |||
function or a callable function. The available predefined distance functions: | |||
'hamming' | 'confidence'. 'hamming': directly calculates the Hamming | |||
distance between the predicted pseudo label in the data sample and each | |||
distance between the predicted pseudo label in the data example and each | |||
candidate, 'confidence': calculates the distance between the prediction | |||
and each candidate based on confidence derived from the predicted probability | |||
in the data sample. The callable function should have the signature | |||
dist_func(data_sample, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element | |||
in the data example. The callable function should have the signature | |||
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element | |||
in this cost list should be a numerical value representing the cost for each | |||
candidate, and the list should have the same length as candidates. | |||
Defaults to 'confidence'. | |||
@@ -36,7 +36,7 @@ class Reasoner: | |||
A mapping from index in the base model to label. If not provided, a default | |||
order-based index to label mapping is created. Defaults to None. | |||
max_revision : Union[int, float], optional | |||
The upper limit on the number of revisions for each data sample when | |||
The upper limit on the number of revisions for each data example when | |||
performing abductive reasoning. If float, denotes the fraction of the total | |||
length that can be revised. A value of -1 implies no restriction on the | |||
number of revisions. Defaults to -1. | |||
@@ -100,7 +100,7 @@ class Reasoner: | |||
def _get_one_candidate( | |||
self, | |||
data_sample: ListData, | |||
data_example: ListData, | |||
candidates: List[List[Any]], | |||
reasoning_results: List[Any], | |||
) -> List[Any]: | |||
@@ -111,8 +111,8 @@ class Reasoner: | |||
Parameters | |||
---------- | |||
data_sample : ListData | |||
Data sample. | |||
data_example : ListData | |||
Data example. | |||
candidates : List[List[Any]] | |||
Multiple compatible candidates. | |||
reasoning_results : List[Any] | |||
@@ -128,23 +128,23 @@ class Reasoner: | |||
elif len(candidates) == 1: | |||
return candidates[0] | |||
else: | |||
cost_array = self._get_cost_list(data_sample, candidates, reasoning_results) | |||
cost_array = self._get_cost_list(data_example, candidates, reasoning_results) | |||
candidate = candidates[np.argmin(cost_array)] | |||
return candidate | |||
def _get_cost_list( | |||
self, | |||
data_sample: ListData, | |||
data_example: ListData, | |||
candidates: List[List[Any]], | |||
reasoning_results: List[Any], | |||
) -> Union[List[Union[int, float]], np.ndarray]: | |||
""" | |||
Get the list of costs between each candidate and the given data sample. | |||
Get the list of costs between each candidate and the given data example. | |||
Parameters | |||
---------- | |||
data_sample : ListData | |||
Data sample. | |||
data_example : ListData | |||
Data example. | |||
candidates : List[List[Any]] | |||
Multiple compatible candidates. | |||
reasoning_results : List[Any] | |||
@@ -156,13 +156,13 @@ class Reasoner: | |||
The list of costs. | |||
""" | |||
if self.dist_func == "hamming": | |||
return hamming_dist(data_sample.pred_pseudo_label, candidates) | |||
return hamming_dist(data_example.pred_pseudo_label, candidates) | |||
elif self.dist_func == "confidence": | |||
candidates = [[self.label_to_idx[x] for x in c] for c in candidates] | |||
return confidence_dist(data_sample.pred_prob, candidates) | |||
return confidence_dist(data_example.pred_prob, candidates) | |||
else: | |||
candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] | |||
cost_list = self.dist_func(data_sample, candidates, candidate_idxs, reasoning_results) | |||
cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) | |||
if len(cost_list) != len(candidates): | |||
raise ValueError( | |||
f"The length of the array returned by dist_func must be equal to the number of candidates. " | |||
@@ -173,7 +173,7 @@ class Reasoner: | |||
def _zoopt_get_solution( | |||
self, | |||
symbol_num: int, | |||
data_sample: ListData, | |||
data_example: ListData, | |||
max_revision_num: int, | |||
) -> List[bool]: | |||
""" | |||
@@ -184,8 +184,8 @@ class Reasoner: | |||
---------- | |||
symbol_num : int | |||
Number of total symbols. | |||
data_sample : ListData | |||
Data sample. | |||
data_example : ListData | |||
Data example. | |||
max_revision_num : int | |||
Specifies the maximum number of revisions allowed. | |||
@@ -196,7 +196,7 @@ class Reasoner: | |||
""" | |||
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | |||
objective = Objective( | |||
lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol), | |||
lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol), | |||
dim=dimension, | |||
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
) | |||
@@ -207,7 +207,7 @@ class Reasoner: | |||
def zoopt_revision_score( | |||
self, | |||
symbol_num: int, | |||
data_sample: ListData, | |||
data_example: ListData, | |||
sol: List[bool], | |||
) -> int: | |||
""" | |||
@@ -218,8 +218,8 @@ class Reasoner: | |||
---------- | |||
symbol_num : int | |||
Number of total symbols. | |||
data_sample : ListData | |||
Data sample. | |||
data_example : ListData | |||
Data example. | |||
sol: List[bool] | |||
The solution for ZOOpt library. | |||
@@ -230,10 +230,10 @@ class Reasoner: | |||
""" | |||
revision_idx = np.where(sol.get_x() != 0)[0] | |||
candidates, reasoning_results = self.kb.revise_at_idx( | |||
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx | |||
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx | |||
) | |||
if len(candidates) > 0: | |||
return np.min(self._get_cost_list(data_sample, candidates, reasoning_results)) | |||
return np.min(self._get_cost_list(data_example, candidates, reasoning_results)) | |||
else: | |||
return symbol_num | |||
@@ -267,53 +267,53 @@ class Reasoner: | |||
) | |||
return max_revision | |||
def abduce(self, data_sample: ListData) -> List[Any]: | |||
def abduce(self, data_example: ListData) -> List[Any]: | |||
""" | |||
Perform abductive reasoning on the given data sample. | |||
Perform abductive reasoning on the given data example. | |||
Parameters | |||
---------- | |||
data_sample : ListData | |||
Data sample. | |||
data_example : ListData | |||
Data example. | |||
Returns | |||
------- | |||
List[Any] | |||
A revised pseudo label sample through abductive reasoning, which is compatible | |||
A revised pseudo label example through abductive reasoning, which is compatible | |||
with the knowledge base. | |||
""" | |||
symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
symbol_num = data_example.elements_num("pred_pseudo_label") | |||
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||
if self.use_zoopt: | |||
solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num) | |||
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) | |||
revision_idx = np.where(solution != 0)[0] | |||
candidates, reasoning_results = self.kb.revise_at_idx( | |||
pseudo_label=data_sample.pred_pseudo_label, | |||
y=data_sample.Y, | |||
x=data_sample.X, | |||
pseudo_label=data_example.pred_pseudo_label, | |||
y=data_example.Y, | |||
x=data_example.X, | |||
revision_idx=revision_idx | |||
) | |||
else: | |||
candidates, reasoning_results = self.kb.abduce_candidates( | |||
pseudo_label=data_sample.pred_pseudo_label, | |||
y=data_sample.Y, | |||
x=data_sample.X, | |||
pseudo_label=data_example.pred_pseudo_label, | |||
y=data_example.Y, | |||
x=data_example.X, | |||
max_revision_num=max_revision_num, | |||
require_more_revision=self.require_more_revision | |||
) | |||
candidate = self._get_one_candidate(data_sample, candidates, reasoning_results) | |||
candidate = self._get_one_candidate(data_example, candidates, reasoning_results) | |||
return candidate | |||
def batch_abduce(self, data_samples: ListData) -> List[List[Any]]: | |||
def batch_abduce(self, data_examples: ListData) -> List[List[Any]]: | |||
""" | |||
Perform abductive reasoning on the given prediction data samples. | |||
Perform abductive reasoning on the given prediction data examples. | |||
For detailed information, refer to `abduce`. | |||
""" | |||
abduced_pseudo_label = [self.abduce(data_sample) for data_sample in data_samples] | |||
data_samples.abduced_pseudo_label = abduced_pseudo_label | |||
abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples] | |||
data_examples.abduced_pseudo_label = abduced_pseudo_label | |||
return abduced_pseudo_label | |||
def __call__(self, data_samples: ListData) -> List[List[Any]]: | |||
return self.batch_abduce(data_samples) | |||
def __call__(self, data_examples: ListData) -> List[List[Any]]: | |||
return self.batch_abduce(data_examples) |
@@ -25,21 +25,21 @@ class BaseDataElement: | |||
``LabelData`` inheriting from ``BaseDataElement`` to represent different | |||
types of ground truth labels or predictions. | |||
Another common data element is sample data. A sample data consists of input | |||
Another common data element is data example. A data example consists of input | |||
data (such as an image) and its annotations and predictions. In general, | |||
an image can have multiple types of annotations and/or predictions at the | |||
same time (for example, both pixel-level semantic segmentation annotations | |||
and instance-level detection bboxes annotations). All labels and | |||
predictions of a training sample are often passed between Dataset, Model, | |||
predictions of a training example are often passed between Dataset, Model, | |||
Visualizer, and Evaluator components. In order to simplify the interface | |||
between components, we can treat them as a large data element and | |||
encapsulate them. Such data elements are generally called XXDataSample in | |||
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |||
allows `BaseDataElement` as its attribute. Such a class generally | |||
encapsulates all the data of a sample in the algorithm library, and its | |||
encapsulates all the data of a example in the algorithm library, and its | |||
attributes generally are various types of data elements. For example, | |||
MMDetection is assigned by the BaseDataElement to encapsulate all the data | |||
elements of the sample labeling and prediction of a sample in the | |||
elements of the example labeling and prediction of a example in the | |||
algorithm library. | |||
The attributes in ``BaseDataElement`` are divided into two parts, | |||
@@ -150,9 +150,9 @@ class BaseDataElement: | |||
>>> metainfo = dict(img_shape=(800, 1196, 3)) | |||
>>> gt_instances = BaseDataElement( | |||
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |||
>>> sample = BaseDataElement(metainfo=metainfo, | |||
>>> example = BaseDataElement(metainfo=metainfo, | |||
... gt_instances=gt_instances) | |||
>>> print(sample) | |||
>>> print(example) | |||
<BaseDataElement( | |||
META INFORMATION | |||
img_shape: (800, 1196, 3) | |||
@@ -196,15 +196,15 @@ class BaseDataElement: | |||
... @pred_instances.deleter | |||
... def pred_instances(self): | |||
... del self._pred_instances | |||
>>> det_sample = DetDataSample() | |||
>>> det_example = DetDataSample() | |||
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |||
>>> det_sample.proposals = proposals | |||
>>> assert 'proposals' in det_sample | |||
>>> assert det_sample.proposals == proposals | |||
>>> del det_sample.proposals | |||
>>> assert 'proposals' not in det_sample | |||
>>> det_example.proposals = proposals | |||
>>> assert 'proposals' in det_example | |||
>>> assert det_example.proposals == proposals | |||
>>> del det_example.proposals | |||
>>> assert 'proposals' not in det_example | |||
>>> with self.assertRaises(AssertionError): | |||
... det_sample.proposals = torch.rand((5, 4)) | |||
... det_example.proposals = torch.rand((5, 4)) | |||
""" | |||
def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |||
@@ -114,24 +114,24 @@ image, respectively. As shown below: | |||
.. code:: ipython3 | |||
pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)]) | |||
print(f"Shape of pred_idx for a batch of 32 samples: {pred_idx.shape}") | |||
print(f"Shape of pred_idx for a batch of 32 examples: {pred_idx.shape}") | |||
pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)]) | |||
print(f"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}") | |||
print(f"Shape of pred_prob for a batch of 32 examples: {pred_prob.shape}") | |||
Out: | |||
.. code:: none | |||
:class: code-out | |||
Shape of pred_idx for a batch of 32 samples: (32,) | |||
Shape of pred_prob for a batch of 32 samples: (32, 10) | |||
Shape of pred_idx for a batch of 32 examples: (32,) | |||
Shape of pred_prob for a batch of 32 examples: (32, 10) | |||
However, the base model built above deals with instance-level data | |||
(i.e., a single image), and can not directly deal with sample-level | |||
(i.e., a single image), and can not directly deal with example-level | |||
data (i.e., a pair of images). Therefore, we wrap the base model | |||
into ``ABLModel``, which enables the learning part to train, test, | |||
and predict on sample-level data. | |||
and predict on example-level data. | |||
.. code:: ipython3 | |||
@@ -142,10 +142,10 @@ TODO: 示例展示ablmodel和base model的predict的不同 | |||
.. code:: ipython3 | |||
# from abl.structures import ListData | |||
# data_samples = ListData() | |||
# data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)] | |||
# data_examples = ListData() | |||
# data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)] | |||
# model.predict(data_samples) | |||
# model.predict(data_examples) | |||
Building the Reasoning Part | |||
--------------------------- | |||
@@ -174,16 +174,16 @@ performing (deductive) reasoning: # TODO: ABDUCTIVE REASONING | |||
.. code:: ipython3 | |||
pseudo_label_sample = [1, 2] | |||
reasoning_result = kb.logic_forward(pseudo_label_sample) | |||
print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.") | |||
pseudo_label_example = [1, 2] | |||
reasoning_result = kb.logic_forward(pseudo_label_example) | |||
print(f"Reasoning result of pseudo label example {pseudo_label_example} is {reasoning_result}.") | |||
Out: | |||
.. code:: none | |||
:class: code-out | |||
Reasoning result of pseudo label sample [1, 2] is 3. | |||
Reasoning result of pseudo label example [1, 2] is 3. | |||
.. note:: | |||
@@ -211,7 +211,7 @@ candidate that has highest consistency. | |||
During creating reasoner, the definition of “consistency” can be | |||
customized within the ``dist_func`` parameter. In the code above, we | |||
employ a consistency measurement based on confidence, which calculates | |||
the consistency between the data sample and candidates based on the | |||
the consistency between the data example and candidates based on the | |||
confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we | |||
provide options for utilizing other forms of consistency measurement. | |||
@@ -56,13 +56,13 @@ Use ABL-Package Step by Step | |||
In a typical Abductive Learning process, as illustrated below, | |||
data inputs are first predicted by a machine learning model, and the outcomes are a pseudo label | |||
sample (which consists of multiple pseudo labels). | |||
example (which consists of multiple pseudo labels). | |||
These labels then pass through a knowledge base :math:`\mathcal{KB}` | |||
to obtain the reasoning result by deductive reasoning. During training, | |||
alongside the aforementioned forward flow (i.e., prediction --> deduction reasoning), | |||
there also exists a reverse flow, which starts from the reasoning result and | |||
involves abductive reasoning to generate possible pseudo label samples. | |||
Subsequently, these samples are processed to minimize inconsistencies with machine learning, | |||
involves abductive reasoning to generate possible pseudo label examples. | |||
Subsequently, these examples are processed to minimize inconsistencies with machine learning, | |||
which in turn revise the outcomes of the machine learning model, and then | |||
fed back into the machine learning model for further training. | |||
To implement this process, the following five steps are necessary: | |||
@@ -81,7 +81,7 @@ To implement this process, the following five steps are necessary: | |||
3. Build the reasoning part | |||
Define a knowledge base by building a subclass of ``KBBase``, specifying how to | |||
map pseudo label samples to reasoning results. | |||
map pseudo label examples to reasoning results. | |||
Also, create a ``Reasoner`` for minimizing of inconsistencies | |||
between the knowledge base and the learning part. | |||
@@ -27,15 +27,15 @@ In this section, we will look at how to bridge learning and reasoning parts to t | |||
+---------------------------------------+----------------------------------------------------+ | |||
| Method Signature | Description | | |||
+=======================================+====================================================+ | |||
| ``predict(data_samples)`` | Predicts class probabilities and indices | | |||
| | for the given data samples. | | |||
| ``predict(data_examples)`` | Predicts class probabilities and indices | | |||
| | for the given data examples. | | |||
+---------------------------------------+----------------------------------------------------+ | |||
| ``abduce_pseudo_label(data_samples)`` | Abduces pseudo labels for the given data samples. | | |||
| ``abduce_pseudo_label(data_examples)`` | Abduces pseudo labels for the given data examples. | | |||
+---------------------------------------+----------------------------------------------------+ | |||
| ``idx_to_pseudo_label(data_samples)`` | Converts indices to pseudo labels using | | |||
| ``idx_to_pseudo_label(data_examples)`` | Converts indices to pseudo labels using | | |||
| | the provided or default mapping. | | |||
+---------------------------------------+----------------------------------------------------+ | |||
| ``pseudo_label_to_idx(data_samples)`` | Converts pseudo labels to indices | | |||
| ``pseudo_label_to_idx(data_examples)`` | Converts pseudo labels to indices | | |||
| | using the provided or default remapping. | | |||
+---------------------------------------+----------------------------------------------------+ | |||
| ``train(train_data)`` | Train the model. | | |||
@@ -43,11 +43,11 @@ In this section, we will look at how to bridge learning and reasoning parts to t | |||
| ``test(test_data)`` | Test the model. | | |||
+---------------------------------------+----------------------------------------------------+ | |||
where ``train_data`` and ``test_data`` are both in the form of ``(X, gt_pseudo_label, Y)``. They will be used to construct ``ListData`` instances which are referred to as ``data_samples`` in the ``train`` and ``test`` methods respectively. More details can be found in `preparing datasets <Datasets.html>`_. | |||
where ``train_data`` and ``test_data`` are both in the form of ``(X, gt_pseudo_label, Y)``. They will be used to construct ``ListData`` instances which are referred to as ``data_examples`` in the ``train`` and ``test`` methods respectively. More details can be found in `preparing datasets <Datasets.html>`_. | |||
``SimpleBridge`` inherits from ``BaseBridge`` and provides a basic implementation. Besides the ``model`` and ``reasoner``, ``SimpleBridge`` has an extra initialization arguments, ``metric_list``, which will be used to evaluate model performance. Its training process involves several Abductive Learning loops and each loop consists of the following five steps: | |||
1. Predict class probabilities and indices for the given data samples. | |||
1. Predict class probabilities and indices for the given data examples. | |||
2. Transform indices into pseudo labels. | |||
3. Revise pseudo labels based on abdutive reasoning. | |||
4. Transform the revised pseudo labels to indices. | |||
@@ -71,21 +71,21 @@ The fundamental part of the ``train`` method is as follows: | |||
will be used together to train the model. | |||
""" | |||
if isinstance(train_data, ListData): | |||
data_samples = train_data | |||
data_examples = train_data | |||
else: | |||
data_samples = self.data_preprocess(*train_data) | |||
data_examples = self.data_preprocess(*train_data) | |||
if isinstance(segment_size, float): | |||
segment_size = int(segment_size * len(data_samples)) | |||
segment_size = int(segment_size * len(data_examples)) | |||
for loop in range(loops): | |||
for seg_idx in range((len(data_samples) - 1) // segment_size + 1): | |||
sub_data_samples = data_samples[ | |||
for seg_idx in range((len(data_examples) - 1) // segment_size + 1): | |||
sub_data_examples = data_examples[ | |||
seg_idx * segment_size : (seg_idx + 1) * segment_size | |||
] | |||
self.predict(sub_data_samples) # 1 | |||
self.idx_to_pseudo_label(sub_data_samples) # 2 | |||
self.abduce_pseudo_label(sub_data_samples) # 3 | |||
self.pseudo_label_to_idx(sub_data_samples) # 4 | |||
loss = self.model.train(sub_data_samples) # 5, self.model is an ABLModel object | |||
self.predict(sub_data_examples) # 1 | |||
self.idx_to_pseudo_label(sub_data_examples) # 2 | |||
self.abduce_pseudo_label(sub_data_examples) # 3 | |||
self.pseudo_label_to_idx(sub_data_examples) # 4 | |||
loss = self.model.train(sub_data_examples) # 5, self.model is an ABLModel object | |||
@@ -24,11 +24,11 @@ Dataset | |||
ABL-Package assumes user data to be structured as a tuple, comprising the following three components: | |||
- ``X``: List[List[Any]] | |||
A list of sublists representing the input data. We refer to each sublist in ``X`` as an sample and each sample may contain several instances. | |||
A list of sublists representing the input data. We refer to each sublist in ``X`` as an example and each example may contain several instances. | |||
- ``gt_pseudo_label``: List[List[Any]], optional | |||
A list of sublists with each sublist representing a ground truth pseudo label sample. Each sample consists of ground truth pseudo labels for each **instance** within a sample of ``X``. | |||
A list of sublists with each sublist representing a ground truth pseudo label example. Each example consists of ground truth pseudo labels for each **instance** within a example of ``X``. | |||
- ``Y``: List[Any] | |||
A list representing the ground truth reasoning result for each **sample** in ``X``. | |||
A list representing the ground truth reasoning result for each **example** in ``X``. | |||
.. warning:: | |||
Each sublist in ``gt_pseudo_label`` should have the same length as the sublist in ``X``. ``gt_pseudo_label`` is only used to evaluate the performance of the learning part but not to train the model. If the pseudo label of the instances in the datasets are unlabeled, ``gt_pseudo_label`` can be ``None``. | |||
@@ -34,11 +34,11 @@ We provide two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, w | |||
# prefix is used to distinguish different metrics | |||
super().__init__(prefix) | |||
def process(self, data_samples: Sequence[dict]) -> None: | |||
def process(self, data_examples: Sequence[dict]) -> None: | |||
# pred_pseudo_label and gt_pseudo_label are both of type List[List[Any]] | |||
# and have the same length | |||
pred_pseudo_label = data_samples.pred_pseudo_label | |||
gt_pseudo_label = data_samples.gt_pseudo_label | |||
pred_pseudo_label = data_examples.pred_pseudo_label | |||
gt_pseudo_label = data_examples.gt_pseudo_label | |||
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): | |||
correct_num = 0 | |||
@@ -15,7 +15,7 @@ In this section, we will look at how to build the learning part. | |||
In ABL-Package, building the learning part involves two steps: | |||
1. Build a base machine learning model used to make predictions on instance-level data, typically referred to as ``base_model``. | |||
2. Instantiate an ``ABLModel`` with the ``base_model``, which enables the learning part to train, test, and predict on sample-level data. | |||
2. Instantiate an ``ABLModel`` with the ``base_model``, which enables the learning part to train, test, and predict on example-level data. | |||
.. code:: python | |||
@@ -77,7 +77,7 @@ Besides the necessary methods required to instantiate an ``ABLModel``, i.e., ``f | |||
Instantiating an ABLModel | |||
------------------------- | |||
Typically, ``base_model`` is trained to make predictions on instance-level data, and can not directly utilize sample-level data to train and predict, which is not suitable for most neural-symbolic tasks. ABL-Package provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all ``base_model``, which enables the learning part to train, test, and predict on sample-level data. | |||
Typically, ``base_model`` is trained to make predictions on instance-level data, and can not directly utilize example-level data to train and predict, which is not suitable for most neural-symbolic tasks. ABL-Package provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all ``base_model``, which enables the learning part to train, test, and predict on example-level data. | |||
Generally, we can simply instantiate an ``ABLModel`` by: | |||
@@ -56,7 +56,7 @@ To facilitate uniform processing, ABL-Package provides the ``BasicNN`` class to | |||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
base_model = BasicNN(cls, loss_fn, optimizer, device) | |||
However, Base model built above are trained to make predictions on instance-level data (e.g., a single image), which is not suitable enough for our task. Therefore, we then wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on sample-level data, (e.g., images that comprise the equation). | |||
However, Base model built above are trained to make predictions on instance-level data (e.g., a single image), which is not suitable enough for our task. Therefore, we then wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise the equation). | |||
.. code:: python | |||
@@ -71,7 +71,7 @@ Building the Reasoning Part | |||
To build the reasoning part, we first define a knowledge base by | |||
creating a subclass of ``KBBase``, which specifies how to map a pseudo | |||
label sample to its reasoning result. In the subclass, we initialize the | |||
label example to its reasoning result. In the subclass, we initialize the | |||
``pseudo_label_list`` parameter and override the ``logic_forward`` | |||
function specifying how to perform (deductive) reasoning. | |||
@@ -15,7 +15,7 @@ leverage domain knowledge and perform deductive or abductive reasoning. | |||
In ABL-Package, building the reasoning part involves two steps: | |||
1. Build a knowledge base by creating a subclass of ``KBBase``, which | |||
specifies how to map pseudo label samples to reasoning results. | |||
specifies how to map pseudo label examples to reasoning results. | |||
2. Create a reasoner by instantiating the class ``Reasoner`` | |||
to minimize inconsistencies between the knowledge base and pseudo | |||
labels predicted by the learning part. | |||
@@ -43,7 +43,7 @@ and override the ``logic_forward`` function: | |||
- ``pseudo_label_list`` is the list of possible pseudo labels (also, | |||
the output of the machine learning model). | |||
- ``logic_forward`` defines how to perform (deductive) reasoning, | |||
i.e. matching each pseudo label sample (often consisting of multiple | |||
i.e. matching each pseudo label example (often consisting of multiple | |||
pseudo labels) to its reasoning result. | |||
After that, other operations, including how to perform abductive | |||
@@ -54,7 +54,7 @@ MNIST Addition example | |||
As an example, the ``pseudo_label_list`` passed in MNIST Addition is all the | |||
possible digits, namely, ``[0,1,2,...,9]``, and the ``logic_forward`` | |||
should be: “Add the two labels in the pseudo label sample to get the result.”. Therefore, the | |||
should be: “Add the two labels in the pseudo label example to get the result.”. Therefore, the | |||
construction of the KB (``add_kb``) for MNIST Addition would be: | |||
.. code:: python | |||
@@ -72,15 +72,15 @@ and (deductive) reasoning in ``add_kb`` would be: | |||
.. code:: python | |||
pseudo_label_sample = [1, 2] | |||
reasoning_result = add_kb.logic_forward(pseudo_label_sample) | |||
print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.") | |||
pseudo_label_example = [1, 2] | |||
reasoning_result = add_kb.logic_forward(pseudo_label_example) | |||
print(f"Reasoning result of pseudo label example {pseudo_label_example} is {reasoning_result}.") | |||
Out: | |||
.. code:: none | |||
:class: code-out | |||
Reasoning result of pseudo label sample [1, 2] is 3 | |||
Reasoning result of pseudo label example [1, 2] is 3 | |||
.. _other-par: | |||
@@ -91,13 +91,13 @@ We can also pass the following parameters in the ``__init__`` function when buil | |||
knowledge base: | |||
- ``max_err`` (float, optional), specifying the upper tolerance limit | |||
when comparing the similarity between a pseudo label sample's reasoning result | |||
when comparing the similarity between a pseudo label example's reasoning result | |||
and the ground truth during abductive reasoning. This is only | |||
applicable when the reasoning result is of a numerical type. This is | |||
particularly relevant for regression problems where exact matches | |||
might not be feasible. Defaults to 1e-10. See :ref:`an example <kb-abd-2>`. | |||
- ``use_cache`` (bool, optional), indicating whether to use cache to store | |||
previous candidates (pseudo label samples generated from abductive reasoning) | |||
previous candidates (pseudo label examples generated from abductive reasoning) | |||
to speed up subsequent abductive reasoning operations. Defaults to True. | |||
For more information of abductive reasoning, please refer to :ref:`this <kb-abd>`. | |||
- ``cache_size`` (int, optional), specifying the maximum cache | |||
@@ -173,7 +173,7 @@ override the ``logic_forward`` function, and are allowed to pass other | |||
:ref:`optional parameters <other-par>`. Additionally, we are required pass the | |||
``GKB_len_list`` parameter in the ``__init__`` function. | |||
- ``GKB_len_list`` is the list of possible lengths for a pseudo label sample. | |||
- ``GKB_len_list`` is the list of possible lengths for a pseudo label example. | |||
After that, other operations, including auto-construction of GKB, and | |||
how to perform abductive reasoning, will be **automatically** set up. | |||
@@ -206,29 +206,29 @@ Performing abductive reasoning in the knowledge base | |||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
As mentioned in :ref:`What is Abductive Reasoning? <abd>`, abductive reasoning | |||
enables the inference of candidates (which are pseudo label samples) as potential | |||
enables the inference of candidates (which are pseudo label examples) as potential | |||
explanations for the reasoning result. Also, in Abductive Learning where | |||
an observation (a pseudo label sample predicted by the learning part) is | |||
an observation (a pseudo label example predicted by the learning part) is | |||
available, we aim to let the candidate do not largely revise the | |||
previously identified pseudo label sample. | |||
previously identified pseudo label example. | |||
``KBBase`` (also, ``GroundKB`` and ``PrologKB``) implement the method | |||
``abduce_candidates(pseudo_label, y, max_revision_num, require_more_revision)`` | |||
for performing abductive reasoning, where the parameters are: | |||
- ``pseudo_label``, the pseudo label sample to be revised by abductive | |||
- ``pseudo_label``, the pseudo label example to be revised by abductive | |||
reasoning, usually generated by the learning part. | |||
- ``y``, the ground truth of the reasoning result for the sample. The | |||
- ``y``, the ground truth of the reasoning result for the example. The | |||
returned candidates should be compatible with it. | |||
- ``max_revision_num``, an int value specifying the upper limit on the | |||
number of revised labels for each sample. | |||
number of revised labels for each example. | |||
- ``require_more_revision``, an int value specifying additional number | |||
of revisions permitted beyond the minimum required. (e.g., If we set | |||
it to 0, even if ``max_revision_num`` is set to a high value, the | |||
method will only output candidates with the minimum possible | |||
revisions.) | |||
And it return a list of candidates (i.e., revised pseudo label samples) that | |||
And it return a list of candidates (i.e., revised pseudo label examples) that | |||
are all compatible with ``y``. | |||
MNIST Addition example (cont.) | |||
@@ -292,7 +292,7 @@ When instantiating, besides the required knowledge base, we may also | |||
specify: | |||
- ``max_revision`` (int or float, optional), specifies the upper limit | |||
on the number of revisions for each sample when performing | |||
on the number of revisions for each example when performing | |||
:ref:`abductive reasoning in the knowledge base <kb-abd>`. If float, denotes the | |||
fraction of the total length that can be revised. A value of -1 | |||
implies no restriction on the number of revisions. Defaults to -1. | |||
@@ -308,18 +308,18 @@ specify: | |||
candidate returned from knowledge base. Valid options include | |||
“confidence” (default) and “hamming”. For “confidence”, it calculates | |||
the distance between the prediction and candidate based on confidence | |||
derived from the predicted probability in the data sample. For | |||
derived from the predicted probability in the data example. For | |||
“hamming”, it directly calculates the Hamming distance between the | |||
predicted pseudo label in the data sample and candidate. | |||
predicted pseudo label in the data example and candidate. | |||
The main method implemented by ``Reasoner`` is | |||
``abduce(data_sample)``, which obtains the most consistent candidate | |||
``abduce(data_example)``, which obtains the most consistent candidate | |||
based on the distance function defined in ``dist_func``. | |||
MNIST Addition example (cont.) | |||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
As an example, consider these data samples for MNIST Addition: | |||
As an example, consider these data examples for MNIST Addition: | |||
.. code:: python | |||
@@ -331,37 +331,37 @@ As an example, consider these data samples for MNIST Addition: | |||
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
sample1 = ListData() | |||
sample1.pred_pseudo_label = [1, 1] | |||
sample1.pred_prob = prob1 | |||
sample1.Y = 8 | |||
example1 = ListData() | |||
example1.pred_pseudo_label = [1, 1] | |||
example1.pred_prob = prob1 | |||
example1.Y = 8 | |||
sample2 = ListData() | |||
sample2.pred_pseudo_label = [1, 1] | |||
sample2.pred_prob = prob2 | |||
sample2.Y = 8 | |||
example2 = ListData() | |||
example2.pred_pseudo_label = [1, 1] | |||
example2.pred_prob = prob2 | |||
example2.Y = 8 | |||
The compatible candidates after abductive reasoning for both samples | |||
The compatible candidates after abductive reasoning for both examples | |||
would be ``[[1,7], [7,1]]``. However, when the reasoner call ``abduce`` | |||
to select only one candidate based on the ``confidence`` distance function, | |||
the output would differ for each sample: | |||
the output would differ for each example: | |||
.. code:: python | |||
reasoner_add = Reasoner(kb_add, dist_func="confidence") | |||
candidate1 = reasoner_add.abduce(sample1) | |||
candidate2 = reasoner_add.abduce(sample2) | |||
print(f"The outputs for sample1 and sample2 are {candidate1} and {candidate2}, respectively.") | |||
candidate1 = reasoner_add.abduce(example1) | |||
candidate2 = reasoner_add.abduce(example2) | |||
print(f"The outputs for example1 and example2 are {candidate1} and {candidate2}, respectively.") | |||
Out: | |||
.. code:: none | |||
:class: code-out | |||
The outputs for sample1 and sample2 are [1,7] and [7,1], respectively. | |||
The outputs for example1 and example2 are [1,7] and [7,1], respectively. | |||
Specifically, as mentioned before, ``confidence`` calculates the distance between the data | |||
sample and candidates based on the confidence derived from the predicted probability. | |||
Take ``sample1`` as an example, the ``pred_prob`` in it indicates a higher | |||
example and candidates based on the confidence derived from the predicted probability. | |||
Take ``example1`` as an example, the ``pred_prob`` in it indicates a higher | |||
confidence that the first label should be "1" rather than "7". Therefore, among the | |||
candidates [1,7] and [7,1], it would be closer to [1,7] (as its first label is "1"). | |||
@@ -0,0 +1,173 @@ | |||
import os | |||
import itertools | |||
import random | |||
import numpy as np | |||
from PIL import Image | |||
import pickle | |||
def get_sign_path_list(data_dir, sign_names): | |||
sign_num = len(sign_names) | |||
index_dict = dict(zip(sign_names, list(range(sign_num)))) | |||
ret = [[] for _ in range(sign_num)] | |||
for path in os.listdir(data_dir): | |||
if (path in sign_names): | |||
index = index_dict[path] | |||
sign_path = os.path.join(data_dir, path) | |||
for p in os.listdir(sign_path): | |||
ret[index].append(os.path.join(sign_path, p)) | |||
return ret | |||
def split_pool_by_rate(pools, rate, seed = None): | |||
if seed is not None: | |||
random.seed(seed) | |||
ret1 = [] | |||
ret2 = [] | |||
for pool in pools: | |||
random.shuffle(pool) | |||
num = int(len(pool) * rate) | |||
ret1.append(pool[:num]) | |||
ret2.append(pool[num:]) | |||
return ret1, ret2 | |||
def int_to_system_form(num, system_num): | |||
if num == 0: | |||
return "0" | |||
ret = "" | |||
while (num > 0): | |||
ret += str(num % system_num) | |||
num //= system_num | |||
return ret[::-1] | |||
def generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type): | |||
expr_len = left_opt_len + right_opt_len | |||
num_list = "".join([str(i) for i in range(system_num)]) | |||
ret = [] | |||
if generate_type == "all": | |||
candidates = itertools.product(num_list, repeat = expr_len) | |||
else: | |||
candidates = [''.join(random.sample(['0', '1'] * expr_len, expr_len))] | |||
random.shuffle(candidates) | |||
for nums in candidates: | |||
left_num = "".join(nums[:left_opt_len]) | |||
right_num = "".join(nums[left_opt_len:]) | |||
left_value = int(left_num, system_num) | |||
right_value = int(right_num, system_num) | |||
result_value = left_value + right_value | |||
if (label == 'negative'): | |||
result_value += random.randint(-result_value, result_value) | |||
if (left_value + right_value == result_value): | |||
continue | |||
result_num = int_to_system_form(result_value, system_num) | |||
#leading zeros | |||
if (res_opt_len != len(result_num)): | |||
continue | |||
if ((left_opt_len > 1 and left_num[0] == '0') or (right_opt_len > 1 and right_num[0] == '0')): | |||
continue | |||
#add leading zeros | |||
if (res_opt_len < len(result_num)): | |||
continue | |||
while (len(result_num) < res_opt_len): | |||
result_num = '0' + result_num | |||
#continue | |||
ret.append(left_num + '+' + right_num + '=' + result_num) # current only consider '+' and '=' | |||
#print(ret[-1]) | |||
return ret | |||
def generator_equation_by_len(equation_len, system_num = 2, label = 0, require_num = 1): | |||
generate_type = "one" | |||
ret = [] | |||
equation_sign_num = 2 # '+' and '=' | |||
while len(ret) < require_num: | |||
left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num) | |||
right_opt_len = random.randint(1, equation_len - left_opt_len - equation_sign_num) | |||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||
return ret | |||
def generator_equations_by_len(equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all"): | |||
ret = [] | |||
equation_sign_num = 2 # '+' and '=' | |||
for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1): | |||
for right_opt_len in range(1, equation_len - left_opt_len - (1 + equation_sign_num) + 1): | |||
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
for i in range(repeat_times): #generate more equations | |||
if random.random() > keep ** (equation_len): | |||
continue | |||
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type)) | |||
return ret | |||
def generator_equations_by_max_len(max_equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all", num_per_len = None): | |||
ret = [] | |||
equation_sign_num = 2 # '+' and '=' | |||
for equation_len in range(3 + equation_sign_num, max_equation_len + 1): | |||
if (num_per_len is None): | |||
ret.extend(generator_equations_by_len(equation_len, system_num, label, repeat_times, keep, generate_type)) | |||
else: | |||
ret.extend(generator_equation_by_len(equation_len, system_num, label, require_num = num_per_len)) | |||
return ret | |||
def generator_equation_images(image_pools, equations, signs, shape, seed, is_color): | |||
if (seed is not None): | |||
random.seed(seed) | |||
ret = [] | |||
sign_num = len(signs) | |||
sign_index_dict = dict(zip(signs, list(range(sign_num)))) | |||
for equation in equations: | |||
data = [] | |||
for sign in equation: | |||
index = sign_index_dict[sign] | |||
pick = random.randint(0, len(image_pools[index]) - 1) | |||
if is_color: | |||
image = Image.open(image_pools[index][pick]).convert('RGB').resize(shape) | |||
else: | |||
image = Image.open(image_pools[index][pick]).convert('I').resize(shape) | |||
image_array = np.array(image) | |||
image_array = (image_array-127)*(1./128) | |||
data.append(image_array) | |||
ret.append(np.array(data)) | |||
return ret | |||
def get_equation_std_data(data_dir, sign_dir_lists, sign_output_lists, shape = (28, 28), train_max_equation_len = 10, test_max_equation_len = 10, system_num = 2, tmp_file_prev = | |||
None, seed = None, train_num_per_len = 10, test_num_per_len = 10, is_color = False): | |||
tmp_file = "" | |||
if (tmp_file_prev is not None): | |||
tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % (tmp_file_prev, train_max_equation_len, test_max_equation_len, system_num) | |||
if (os.path.exists(tmp_file)): | |||
return pickle.load(open(tmp_file, "rb")) | |||
image_pools = get_sign_path_list(data_dir, sign_dir_lists) | |||
train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed) | |||
ret = {} | |||
for label in ["positive", "negative"]: | |||
print("Generating equations.") | |||
train_equations = generator_equations_by_max_len(train_max_equation_len, system_num, label, num_per_len = train_num_per_len) | |||
test_equations = generator_equations_by_max_len(test_max_equation_len, system_num, label, num_per_len = test_num_per_len) | |||
print(train_equations) | |||
print(test_equations) | |||
print("Generated equations.") | |||
print("Generating equation image data.") | |||
ret["train:%s" % (label)] = generator_equation_images(train_pool, train_equations, sign_output_lists, shape, seed, is_color) | |||
ret["test:%s" % (label)] = generator_equation_images(test_pool, test_equations, sign_output_lists, shape, seed, is_color) | |||
print("Generated equation image data.") | |||
if (tmp_file_prev is not None): | |||
pickle.dump(ret, open(tmp_file, "wb")) | |||
return ret | |||
if __name__ == "__main__": | |||
data_dirs = ["./dataset/hed/mnist_images", "./dataset/hed/random_images"] #, "../dataset/cifar10_images"] | |||
tmp_file_prevs = ["mnist_equation_data", "random_equation_data"] #, "cifar10_equation_data"] | |||
for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs): | |||
data = get_equation_std_data(data_dir = data_dir,\ | |||
sign_dir_lists = ['0', '1', '10', '11'],\ | |||
sign_output_lists = ['0', '1', '+', '='],\ | |||
shape = (28, 28),\ | |||
train_max_equation_len = 26, \ | |||
test_max_equation_len = 26, \ | |||
system_num = 2, \ | |||
tmp_file_prev = tmp_file_prev, \ | |||
train_num_per_len = 300, \ | |||
test_num_per_len = 300, \ | |||
is_color = False) |
@@ -16,7 +16,7 @@ eval_eq(Ex, Feature):- | |||
%%%%%%%%%%%%%% | |||
%% Abduction | |||
%%%%%%%%%%%%%% | |||
% Make abduction when given examples that have been interpreted as pseudo-labels | |||
% Make abduction when given samples that have been interpreted as pseudo-labels | |||
abduce(Exs, Delta_C) :- | |||
abduce(Exs, [], Delta_C). | |||
abduce([], Delta_C, Delta_C). | |||
@@ -45,13 +45,13 @@ consistent_inst_feature(Exs, Delta_C):- | |||
%% (Experimental) Parallel abduction | |||
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
abduce_consistent_exs_concurrent(Exs) :- | |||
% Split the current data batch into grounding examples and variable examples (which need to be revised) | |||
% Split the current data batch into grounding samples and variable samples (which need to be revised) | |||
split_exs(Exs, Ground_Exs, Var_Exs), | |||
% Find the simplest Delta_C for grounding examples. | |||
% Find the simplest Delta_C for grounding samples. | |||
abduce(Ground_Exs, Ground_Delta_C), !, | |||
% Extend Ground Delta_C into all possible variations | |||
extend_op_rule(Ground_Delta_C, Possible_Deltas), | |||
% Concurrently abduce the variable examples | |||
% Concurrently abduce the variable samples | |||
maplist(append([abduce2, Var_Exs, Ground_Exs]), [[Possible_Deltas]], Call_List), | |||
maplist(=.., Goals, Call_List), | |||
% writeln(Goals), | |||
@@ -76,7 +76,7 @@ extend_op_rule(Rules, Ext) :- | |||
% abduction without learning new Delta_C (Because they have been extended!) | |||
abduce2([], _, _). | |||
abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
% abduce by finding ground examples | |||
% abduce by finding ground samples | |||
member(E, Ground_Exs), | |||
abduce2(Exs, Ground_Exs, Delta_C). | |||
abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
@@ -62,15 +62,15 @@ class HEDBridge(SimpleBridge): | |||
self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth")) | |||
def select_mapping_and_abduce(self, data_samples: ListData): | |||
def select_mapping_and_abduce(self, data_examples: ListData): | |||
candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) | |||
mapping_score = [] | |||
abduced_pseudo_label_list = [] | |||
for _mapping in candidate_mappings: | |||
self.reasoner.idx_to_label = _mapping | |||
self.reasoner.label_to_idx = dict(zip(_mapping.values(), _mapping.keys())) | |||
self.idx_to_pseudo_label(data_samples) | |||
abduced_pseudo_label = self.reasoner.abduce(data_samples) | |||
self.idx_to_pseudo_label(data_examples) | |||
abduced_pseudo_label = self.reasoner.abduce(data_examples) | |||
mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([])) | |||
abduced_pseudo_label_list.append(abduced_pseudo_label) | |||
@@ -80,18 +80,18 @@ class HEDBridge(SimpleBridge): | |||
self.reasoner.label_to_idx = dict( | |||
zip(self.reasoner.idx_to_label.values(), self.reasoner.idx_to_label.keys()) | |||
) | |||
self.idx_to_pseudo_label(data_samples) | |||
data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx] | |||
self.idx_to_pseudo_label(data_examples) | |||
data_examples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx] | |||
return data_samples.abduced_pseudo_label | |||
return data_examples.abduced_pseudo_label | |||
def abduce_pseudo_label(self, data_samples: ListData): | |||
self.reasoner.abduce(data_samples) | |||
return data_samples.abduced_pseudo_label | |||
def abduce_pseudo_label(self, data_examples: ListData): | |||
self.reasoner.abduce(data_examples) | |||
return data_examples.abduced_pseudo_label | |||
def check_training_impact(self, filtered_data_samples, data_samples): | |||
character_accuracy = self.model.valid(filtered_data_samples) | |||
revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) | |||
def check_training_impact(self, filtered_data_examples, data_examples): | |||
character_accuracy = self.model.valid(filtered_data_examples) | |||
revisible_ratio = len(filtered_data_examples.X) / len(data_examples.X) | |||
log_string = ( | |||
f"Revisible ratio is {revisible_ratio:.3f}, Character " | |||
f"accuracy is {character_accuracy:.3f}" | |||
@@ -119,23 +119,23 @@ class HEDBridge(SimpleBridge): | |||
return True | |||
return False | |||
def calc_consistent_ratio(self, data_samples, rule): | |||
self.predict(data_samples) | |||
pred_pseudo_label = self.idx_to_pseudo_label(data_samples) | |||
def calc_consistent_ratio(self, data_examples, rule): | |||
self.predict(data_examples) | |||
pred_pseudo_label = self.idx_to_pseudo_label(data_examples) | |||
consistent_num = sum( | |||
[self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label] | |||
) | |||
return consistent_num / len(data_samples.X) | |||
return consistent_num / len(data_examples.X) | |||
def get_rules_from_data(self, data_samples, samples_per_rule, samples_num): | |||
def get_rules_from_data(self, data_examples, examples_per_rule, examples_num): | |||
rules = [] | |||
sampler = InfiniteSampler(len(data_samples), batch_size=samples_per_rule) | |||
sampler = InfiniteSampler(len(data_examples), batch_size=examples_per_rule) | |||
for _ in range(samples_num): | |||
for _ in range(examples_num): | |||
for select_idx in sampler: | |||
sub_data_samples = data_samples[select_idx] | |||
self.predict(sub_data_samples) | |||
pred_pseudo_label = self.idx_to_pseudo_label(sub_data_samples) | |||
sub_data_examples = data_examples[select_idx] | |||
self.predict(sub_data_examples) | |||
pred_pseudo_label = self.idx_to_pseudo_label(sub_data_examples) | |||
consistent_instance = [] | |||
for instance in pred_pseudo_label: | |||
if self.reasoner.kb.logic_forward([instance]): | |||
@@ -157,13 +157,13 @@ class HEDBridge(SimpleBridge): | |||
return rules | |||
@staticmethod | |||
def filter_empty(data_samples: ListData): | |||
def filter_empty(data_examples: ListData): | |||
consistent_dix = [ | |||
i | |||
for i in range(len(data_samples.abduced_pseudo_label)) | |||
if len(data_samples.abduced_pseudo_label[i]) > 0 | |||
for i in range(len(data_examples.abduced_pseudo_label)) | |||
if len(data_examples.abduced_pseudo_label[i]) > 0 | |||
] | |||
return data_samples[consistent_dix] | |||
return data_examples[consistent_dix] | |||
@staticmethod | |||
def select_rules(rule_dict): | |||
@@ -184,12 +184,12 @@ class HEDBridge(SimpleBridge): | |||
return list(rule_dict) | |||
def data_preprocess(self, data, equation_len) -> ListData: | |||
data_samples = ListData() | |||
data_samples.X = data[equation_len] + data[equation_len + 1] | |||
data_samples.gt_pseudo_label = None | |||
data_samples.Y = [None] * len(data_samples.X) | |||
data_examples = ListData() | |||
data_examples.X = data[equation_len] + data[equation_len + 1] | |||
data_examples.gt_pseudo_label = None | |||
data_examples.Y = [None] * len(data_examples.X) | |||
return data_samples | |||
return data_examples | |||
def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8): | |||
for equation_len in range(min_len, max_len): | |||
@@ -199,25 +199,25 @@ class HEDBridge(SimpleBridge): | |||
) | |||
condition_num = 0 | |||
data_samples = self.data_preprocess(train_data[1], equation_len) | |||
sampler = InfiniteSampler(len(data_samples), batch_size=segment_size) | |||
data_examples = self.data_preprocess(train_data[1], equation_len) | |||
sampler = InfiniteSampler(len(data_examples), batch_size=segment_size) | |||
for seg_idx, select_idx in enumerate(sampler): | |||
print_log( | |||
f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}]", | |||
logger="current", | |||
) | |||
sub_data_samples = data_samples[select_idx] | |||
self.predict(sub_data_samples) | |||
sub_data_examples = data_examples[select_idx] | |||
self.predict(sub_data_examples) | |||
if equation_len == min_len: | |||
self.select_mapping_and_abduce(sub_data_samples) | |||
self.select_mapping_and_abduce(sub_data_examples) | |||
else: | |||
self.idx_to_pseudo_label(sub_data_samples) | |||
self.abduce_pseudo_label(sub_data_samples) | |||
filtered_sub_data_samples = self.filter_empty(sub_data_samples) | |||
self.pseudo_label_to_idx(filtered_sub_data_samples) | |||
loss = self.model.train(filtered_sub_data_samples) | |||
self.idx_to_pseudo_label(sub_data_examples) | |||
self.abduce_pseudo_label(sub_data_examples) | |||
filtered_sub_data_examples = self.filter_empty(sub_data_examples) | |||
self.pseudo_label_to_idx(filtered_sub_data_examples) | |||
loss = self.model.train(filtered_sub_data_examples) | |||
if self.check_training_impact(filtered_sub_data_samples, sub_data_samples): | |||
if self.check_training_impact(filtered_sub_data_examples, sub_data_examples): | |||
condition_num += 1 | |||
else: | |||
condition_num = 0 | |||
@@ -225,7 +225,7 @@ class HEDBridge(SimpleBridge): | |||
if condition_num >= 5: | |||
print_log("Now checking if we can go to next course", logger="current") | |||
rules = self.get_rules_from_data( | |||
data_samples, samples_per_rule=3, samples_num=50 | |||
data_examples, examples_per_rule=3, examples_num=50 | |||
) | |||
print_log("Learned rules from data: " + str(rules), logger="current") | |||
@@ -2,7 +2,7 @@ | |||
"cells": [ | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 2, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -24,9 +24,17 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 3, | |||
"metadata": {}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"12/18 09:01:12 - abl - INFO - Abductive Learning on the HED example.\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"# Build logger\n", | |||
"print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n", | |||
@@ -46,7 +54,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 4, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -69,20 +77,20 @@ | |||
"\n", | |||
"\n", | |||
"class HedReasoner(Reasoner):\n", | |||
" def revise_at_idx(self, data_sample):\n", | |||
" revision_idx = np.where(np.array(data_sample.flatten(\"revision_flag\")) != 0)[0]\n", | |||
" def revise_at_idx(self, data_example):\n", | |||
" revision_idx = np.where(np.array(data_example.flatten(\"revision_flag\")) != 0)[0]\n", | |||
" candidate = self.kb.revise_at_idx(\n", | |||
" data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx\n", | |||
" data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx\n", | |||
" )\n", | |||
" return candidate\n", | |||
"\n", | |||
" def zoopt_revision_score(self, symbol_num, data_sample, sol):\n", | |||
" def zoopt_revision_score(self, symbol_num, data_example, sol):\n", | |||
" revision_flag = reform_list(\n", | |||
" list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label\n", | |||
" list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label\n", | |||
" )\n", | |||
" data_sample.revision_flag = revision_flag\n", | |||
" data_example.revision_flag = revision_flag\n", | |||
"\n", | |||
" lefted_idxs = [i for i in range(len(data_sample.pred_idx))]\n", | |||
" lefted_idxs = [i for i in range(len(data_example.pred_idx))]\n", | |||
" candidate_size = []\n", | |||
" max_consistent_idxs = []\n", | |||
" while lefted_idxs:\n", | |||
@@ -90,10 +98,10 @@ | |||
" idxs.append(lefted_idxs.pop(0))\n", | |||
" max_candidate_idxs = []\n", | |||
" found = False\n", | |||
" for idx in range(-1, len(data_sample.pred_idx)):\n", | |||
" for idx in range(-1, len(data_example.pred_idx)):\n", | |||
" if (not idx in idxs) and (idx >= 0):\n", | |||
" idxs.append(idx)\n", | |||
" candidates, _ = self.revise_at_idx(data_sample[idxs])\n", | |||
" candidates, _ = self.revise_at_idx(data_example[idxs])\n", | |||
" if len(candidates) == 0:\n", | |||
" if len(idxs) > 1:\n", | |||
" idxs.pop()\n", | |||
@@ -115,10 +123,10 @@ | |||
" score -= math.exp(-i) * candidate_size[i]\n", | |||
" return score, max_consistent_idxs\n", | |||
" \n", | |||
" def _zoopt_get_solution(self, symbol_num, data_sample, max_revision_num):\n", | |||
" def _zoopt_get_solution(self, symbol_num, data_example, max_revision_num):\n", | |||
" dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n", | |||
" objective = Objective(\n", | |||
" lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol)[0],\n", | |||
" lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol)[0],\n", | |||
" dim=dimension,\n", | |||
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n", | |||
" )\n", | |||
@@ -126,20 +134,20 @@ | |||
" solution = Opt.min(objective, parameter)\n", | |||
" return solution\n", | |||
"\n", | |||
" def abduce(self, data_sample):\n", | |||
" symbol_num = data_sample.elements_num(\"pred_pseudo_label\")\n", | |||
" def abduce(self, data_example):\n", | |||
" symbol_num = data_example.elements_num(\"pred_pseudo_label\")\n", | |||
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n", | |||
"\n", | |||
" solution = self._zoopt_get_solution(symbol_num, data_sample, max_revision_num)\n", | |||
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_sample, solution)\n", | |||
" solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)\n", | |||
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution)\n", | |||
"\n", | |||
" abduced_pseudo_label = [[] for _ in range(len(data_sample))]\n", | |||
" abduced_pseudo_label = [[] for _ in range(len(data_example))]\n", | |||
"\n", | |||
" if len(max_candidate_idxs) > 0:\n", | |||
" candidates, _ = self.revise_at_idx(data_sample[max_candidate_idxs])\n", | |||
" candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])\n", | |||
" for i, idx in enumerate(max_candidate_idxs):\n", | |||
" abduced_pseudo_label[idx] = candidates[0][i]\n", | |||
" data_sample.abduced_pseudo_label = abduced_pseudo_label\n", | |||
" data_example.abduced_pseudo_label = abduced_pseudo_label\n", | |||
" return abduced_pseudo_label\n", | |||
"\n", | |||
" def abduce_rules(self, pred_res):\n", | |||
@@ -160,7 +168,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 5, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -173,7 +181,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 6, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -194,7 +202,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 7, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -214,7 +222,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 8, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -232,7 +240,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 9, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -249,7 +257,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 12, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -268,13 +276,176 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 13, | |||
"metadata": {}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"12/18 09:04:27 - abl - INFO - Pretrain Start\n", | |||
"12/18 09:04:31 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_1.pth\n", | |||
"12/18 09:04:33 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_2.pth\n", | |||
"12/18 09:04:34 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_3.pth\n", | |||
"12/18 09:04:36 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_4.pth\n", | |||
"12/18 09:04:37 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_5.pth\n", | |||
"12/18 09:04:38 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_6.pth\n", | |||
"12/18 09:04:40 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_7.pth\n", | |||
"12/18 09:04:41 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_8.pth\n", | |||
"12/18 09:04:43 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_9.pth\n", | |||
"12/18 09:04:44 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_10.pth\n", | |||
"12/18 09:04:44 - abl - INFO - model loss: 0.78453\n", | |||
"12/18 09:04:44 - abl - INFO - min loss is <abl.learning.basic_nn.BasicNN object at 0x7f6c4f9393d0>\n", | |||
"12/18 09:04:44 - abl - INFO - Loads checkpoint by local backend from path: ./weights/pretrain_weights.pth\n", | |||
"12/18 09:04:44 - abl - INFO - ============== equation_len: 5-6 ================\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-1.0, 9.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-1.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-2.0, 8.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-1.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-2.0, 8.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-1.0, 9.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-2.0, 6.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-1.0, 9.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [0.0, 10.0]\n", | |||
"[zoopt] x: array([0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-1.0, 8.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-2.0, 8.0]\n", | |||
"[zoopt] x: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||
" 0., 0.])\n", | |||
"[zoopt] value: [-2.0, 8.0]\n", | |||
"12/18 09:05:16 - abl - INFO - Checkpoints will be saved to results/20231218_09_01_12/weights/model_checkpoint_epoch_1.pth\n", | |||
"12/18 09:05:16 - abl - INFO - model loss: 0.59495\n" | |||
] | |||
}, | |||
{ | |||
"ename": "TypeError", | |||
"evalue": "unsupported format string passed to BasicNN.__format__", | |||
"output_type": "error", | |||
"traceback": [ | |||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |||
"Input \u001b[0;32mIn [13]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m bridge\u001b[38;5;241m.\u001b[39mpretrain(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./weights\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mbridge\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_data\u001b[49m\u001b[43m)\u001b[49m\n", | |||
"File \u001b[0;32m~/ABL-Package/examples/hed/hed_bridge.py:217\u001b[0m, in \u001b[0;36mHEDBridge.train\u001b[0;34m(self, train_data, val_data, segment_size, min_len, max_len)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mabduce_pseudo_label(sub_data_examples)\n\u001b[1;32m 216\u001b[0m filtered_sub_data_examples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilter_empty(sub_data_examples)\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpseudo_label_to_idx(filtered_sub_data_examples)\n\u001b[1;32m 218\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mtrain(filtered_sub_data_examples)\n\u001b[1;32m 220\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_training_impact(filtered_sub_data_examples, sub_data_examples):\n", | |||
"\u001b[0;31mTypeError\u001b[0m: unsupported format string passed to BasicNN.__format__" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"bridge.pretrain(\"./weights\")\n", | |||
"bridge.train(train_data, val_data)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
@@ -293,7 +464,7 @@ | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.8.18" | |||
"version": "3.8.13" | |||
}, | |||
"orig_nbformat": 4, | |||
"vscode": { | |||
@@ -5,14 +5,14 @@ import torch.utils.data.sampler as sampler | |||
class InfiniteSampler(sampler.Sampler): | |||
def __init__(self, num_samples, batch_size=1): | |||
self.num_samples = num_samples | |||
def __init__(self, num_examples, batch_size=1): | |||
self.num_examples = num_examples | |||
self.batch_size = batch_size | |||
def __iter__(self): | |||
while True: | |||
order = np.random.permutation(self.num_samples) | |||
for i in range(self.num_samples): | |||
order = np.random.permutation(self.num_examples) | |||
for i in range(self.num_examples): | |||
yield order[i : i + self.batch_size] | |||
i += self.batch_size | |||
@@ -13,8 +13,8 @@ img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( | |||
def download_and_unzip(url, zip_file_name): | |||
try: | |||
gdown.download(url, zip_file_name) | |||
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref: | |||
zip_ref.extractall() | |||
with zipfile.pseudo_labelipFile(zip_file_name, 'r') as zip_ref: | |||
zip_ref.extractall(CURRENT_DIR) | |||
os.remove(zip_file_name) | |||
except Exception as e: | |||
if os.path.exists(zip_file_name): | |||
@@ -23,11 +23,11 @@ def download_and_unzip(url, zip_file_name): | |||
def get_dataset(train=True, get_pseudo_label=False): | |||
data_dir = CURRENT_DIR + '/data' | |||
url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download' | |||
if not os.path.exists(data_dir): | |||
print("Dataset not exist, downloading it...") | |||
download_and_unzip(url, 'HWF.zip') | |||
url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download' | |||
download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip")) | |||
print("Download and extraction complete.") | |||
if train: | |||
@@ -36,7 +36,7 @@ def get_dataset(train=True, get_pseudo_label=False): | |||
file = os.path.join(data_dir, "expr_test.json") | |||
X = [] | |||
Z = [] if get_pseudo_label else None | |||
pseudo_label = [] if get_pseudo_label else None | |||
Y = [] | |||
img_dir = os.path.join(CURRENT_DIR, "data/Handwritten_Math_Symbols/") | |||
with open(file) as f: | |||
@@ -50,12 +50,13 @@ def get_dataset(train=True, get_pseudo_label=False): | |||
img = img_transform(img) | |||
imgs.append(img) | |||
if get_pseudo_label: | |||
imgs_pseudo_label.append(img_path.split("/")[0]) | |||
label_mappings = {"times": "*", "div": "/"} | |||
label = img_path.split("/")[0] | |||
label = label_mappings.get(label, label) | |||
imgs_pseudo_label.append(label) | |||
X.append(imgs) | |||
if get_pseudo_label: | |||
Z.append(imgs_pseudo_label) | |||
pseudo_label.append(imgs_pseudo_label) | |||
Y.append(data[idx]["res"]) | |||
return X, Z, Y | |||
get_dataset() | |||
return X, pseudo_label, Y |
@@ -20,7 +20,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"execution_count": 1, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -36,7 +36,18 @@ | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"\n", | |||
"from examples.models.nn import SymbolNet\n", | |||
"from examples.hwf.datasets.get_dataset import get_dataset" | |||
"from examples.hwf.datasets import get_dataset" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Get training and testing data\n", | |||
"train_data = get_dataset(train=True, get_pseudo_label=True)\n", | |||
"test_data = get_dataset(train=False, get_pseudo_label=True)" | |||
] | |||
}, | |||
{ | |||
@@ -75,21 +86,18 @@ | |||
" for i in range(len(formula)):\n", | |||
" if i % 2 == 0 and formula[i] not in [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]:\n", | |||
" return False\n", | |||
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"times\", \"div\"]:\n", | |||
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"*\", \"/\"]:\n", | |||
" return False\n", | |||
" return True\n", | |||
"\n", | |||
" def logic_forward(self, formula):\n", | |||
" if not self._valid_candidate(formula):\n", | |||
" return np.inf\n", | |||
" mapping = {str(i): str(i) for i in range(1, 10)}\n", | |||
" mapping.update({\"+\": \"+\", \"-\": \"-\", \"times\": \"*\", \"div\": \"/\"})\n", | |||
" formula = [mapping[f] for f in formula]\n", | |||
" return np.info\n", | |||
" return eval(\"\".join(formula))\n", | |||
"\n", | |||
"\n", | |||
"kb = HWF_KB(\n", | |||
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"times\", \"div\"],\n", | |||
" pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"],\n", | |||
" max_err=1e-10,\n", | |||
" use_cache=False,\n", | |||
")\n", | |||
@@ -175,17 +183,6 @@ | |||
"### Dataset" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# Get training and testing data\n", | |||
"train_data = get_dataset(train=True, get_pseudo_label=True)\n", | |||
"test_data = get_dataset(train=False, get_pseudo_label=True)" | |||
] | |||
}, | |||
{ | |||
"attachments": {}, | |||
"cell_type": "markdown", | |||
@@ -33,21 +33,18 @@ class HWF_KB(KBBase): | |||
for i in range(len(formula)): | |||
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
return False | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: | |||
return False | |||
return True | |||
def logic_forward(self, formula): | |||
if not self._valid_candidate(formula): | |||
return np.inf | |||
mapping = {str(i): str(i) for i in range(1, 10)} | |||
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
formula = [mapping[f] for f in formula] | |||
return np.info | |||
return eval("".join(formula)) | |||
kb = HWF_KB( | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "times", "div"], | |||
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], | |||
max_err=1e-10, | |||
use_cache=False, | |||
) | |||
@@ -18,13 +18,13 @@ def get_dataset(train=True, get_pseudo_label=False): | |||
file = os.path.join(CURRENT_DIR, "test_data.txt") | |||
X = [] | |||
Z = [] if get_pseudo_label else None | |||
pseudo_label = [] if get_pseudo_label else None | |||
Y = [] | |||
with open(file) as f: | |||
for line in f: | |||
x1, x2, y = map(int, line.strip().split(" ")) | |||
X.append([img_dataset[x1][0], img_dataset[x2][0]]) | |||
if get_pseudo_label: | |||
Z.append([img_dataset[x1][1], img_dataset[x2][1]]) | |||
pseudo_label.append([img_dataset[x1][1], img_dataset[x2][1]]) | |||
Y.append(y) | |||
return X, Z, Y | |||
return X, pseudo_label, Y |
@@ -59,7 +59,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"execution_count": 15, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -122,7 +122,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"execution_count": 16, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -150,7 +150,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": 17, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -164,21 +164,21 @@ | |||
], | |||
"source": [ | |||
"pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])\n", | |||
"print(f\"Shape of pred_idx for a batch of 32 samples: {pred_idx.shape}\")\n", | |||
"print(f\"Shape of pred_idx for a batch of 32 examples: {pred_idx.shape}\")\n", | |||
"pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])\n", | |||
"print(f\"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}\")" | |||
"print(f\"Shape of pred_prob for a batch of 32 examples: {pred_prob.shape}\")" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"However, base model built above are trained to make predictions on instance-level data, i.e., a single image, and can not directly utilize sample-level data, i.e., a pair of images. Therefore, we then wrap the base model into `ABLModel` which enables the learning part to train, test, and predict on sample-level data." | |||
"However, base model built above are trained to make predictions on instance-level data, i.e., a single image, and can not directly utilize example-level data, i.e., a pair of images. Therefore, we then wrap the base model into `ABLModel` which enables the learning part to train, test, and predict on example-level data." | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": 18, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -194,15 +194,15 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": 19, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# from abl.structures import ListData\n", | |||
"# data_samples = ListData()\n", | |||
"# data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]\n", | |||
"# data_examples = ListData()\n", | |||
"# data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)]\n", | |||
"\n", | |||
"# model.predict(data_samples)" | |||
"# model.predict(data_examples)" | |||
] | |||
}, | |||
{ | |||
@@ -221,7 +221,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": 20, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -245,7 +245,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": 21, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -257,9 +257,9 @@ | |||
} | |||
], | |||
"source": [ | |||
"pseudo_label_sample = [1, 2]\n", | |||
"reasoning_result = kb.logic_forward(pseudo_label_sample)\n", | |||
"print(f\"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.\")" | |||
"pseudo_label_example = [1, 2]\n", | |||
"reasoning_result = kb.logic_forward(pseudo_label_example)\n", | |||
"print(f\"Reasoning result of pseudo label example {pseudo_label_example} is {reasoning_result}.\")" | |||
] | |||
}, | |||
{ | |||
@@ -278,7 +278,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": 22, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -289,7 +289,7 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data sample and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n", | |||
"Note: During creating reasoner, the definition of \"consistency\" can be customized within the `dist_func` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data example and candidates based on the confidence derived from the predicted probability. In `main.py`, we provide options for utilizing other forms of consistency measurement.\n", | |||
"\n", | |||
"Also, during process of inconsistency minimization, one can leverage [ZOOpt library](https://github.com/polixir/ZOOpt) for acceleration. Options for this are also available in `main.py`. Those interested are encouraged to explore these features." | |||
] | |||
@@ -311,7 +311,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"execution_count": 23, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -330,7 +330,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 12, | |||
"execution_count": 24, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -346,9 +346,68 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 25, | |||
"metadata": {}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"12/19 14:41:46 - abl - INFO - Abductive Learning on the MNIST Addition example.\n", | |||
"12/19 14:41:46 - abl - INFO - loop(train) [1/5] segment(train) [1/3] \n", | |||
"12/19 14:41:51 - abl - INFO - model loss: 1.81279\n", | |||
"12/19 14:41:51 - abl - INFO - loop(train) [1/5] segment(train) [2/3] \n", | |||
"12/19 14:41:56 - abl - INFO - model loss: 1.40474\n", | |||
"12/19 14:41:56 - abl - INFO - loop(train) [1/5] segment(train) [3/3] \n", | |||
"12/19 14:42:01 - abl - INFO - model loss: 1.17817\n", | |||
"12/19 14:42:01 - abl - INFO - Evaluation start: loop(val) [1]\n", | |||
"12/19 14:42:02 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.496 mnist_add/reasoning_accuracy: 0.336 \n", | |||
"12/19 14:42:02 - abl - INFO - Saving model: loop(save) [1]\n", | |||
"12/19 14:42:02 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_1.pth\n", | |||
"12/19 14:42:02 - abl - INFO - loop(train) [2/5] segment(train) [1/3] \n", | |||
"12/19 14:42:07 - abl - INFO - model loss: 0.85932\n", | |||
"12/19 14:42:07 - abl - INFO - loop(train) [2/5] segment(train) [2/3] \n", | |||
"12/19 14:42:11 - abl - INFO - model loss: 0.62120\n", | |||
"12/19 14:42:11 - abl - INFO - loop(train) [2/5] segment(train) [3/3] \n", | |||
"12/19 14:42:16 - abl - INFO - model loss: 0.35382\n", | |||
"12/19 14:42:16 - abl - INFO - Evaluation start: loop(val) [2]\n", | |||
"12/19 14:42:17 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.980 mnist_add/reasoning_accuracy: 0.961 \n", | |||
"12/19 14:42:17 - abl - INFO - Saving model: loop(save) [2]\n", | |||
"12/19 14:42:17 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_2.pth\n", | |||
"12/19 14:42:17 - abl - INFO - loop(train) [3/5] segment(train) [1/3] \n", | |||
"12/19 14:42:21 - abl - INFO - model loss: 0.08302\n", | |||
"12/19 14:42:21 - abl - INFO - loop(train) [3/5] segment(train) [2/3] \n", | |||
"12/19 14:42:25 - abl - INFO - model loss: 0.05917\n", | |||
"12/19 14:42:25 - abl - INFO - loop(train) [3/5] segment(train) [3/3] \n", | |||
"12/19 14:42:30 - abl - INFO - model loss: 0.05425\n", | |||
"12/19 14:42:30 - abl - INFO - Evaluation start: loop(val) [3]\n", | |||
"12/19 14:42:31 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.988 mnist_add/reasoning_accuracy: 0.976 \n", | |||
"12/19 14:42:31 - abl - INFO - Saving model: loop(save) [3]\n", | |||
"12/19 14:42:31 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_3.pth\n", | |||
"12/19 14:42:31 - abl - INFO - loop(train) [4/5] segment(train) [1/3] \n", | |||
"12/19 14:42:35 - abl - INFO - model loss: 0.04650\n", | |||
"12/19 14:42:35 - abl - INFO - loop(train) [4/5] segment(train) [2/3] \n", | |||
"12/19 14:42:39 - abl - INFO - model loss: 0.04175\n", | |||
"12/19 14:42:39 - abl - INFO - loop(train) [4/5] segment(train) [3/3] \n", | |||
"12/19 14:42:44 - abl - INFO - model loss: 0.04207\n", | |||
"12/19 14:42:44 - abl - INFO - Evaluation start: loop(val) [4]\n", | |||
"12/19 14:42:45 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.979 \n", | |||
"12/19 14:42:45 - abl - INFO - Saving model: loop(save) [4]\n", | |||
"12/19 14:42:45 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_4.pth\n", | |||
"12/19 14:42:45 - abl - INFO - loop(train) [5/5] segment(train) [1/3] \n", | |||
"12/19 14:42:49 - abl - INFO - model loss: 0.03484\n", | |||
"12/19 14:42:49 - abl - INFO - loop(train) [5/5] segment(train) [2/3] \n", | |||
"12/19 14:42:53 - abl - INFO - model loss: 0.03319\n", | |||
"12/19 14:42:53 - abl - INFO - loop(train) [5/5] segment(train) [3/3] \n", | |||
"12/19 14:42:58 - abl - INFO - model loss: 0.03510\n", | |||
"12/19 14:42:58 - abl - INFO - Evaluation start: loop(val) [5]\n", | |||
"12/19 14:42:59 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.993 mnist_add/reasoning_accuracy: 0.987 \n", | |||
"12/19 14:42:59 - abl - INFO - Saving model: loop(save) [5]\n", | |||
"12/19 14:42:59 - abl - INFO - Checkpoints will be saved to results/20231219_14_41_46/weights/model_checkpoint_loop_5.pth\n", | |||
"12/19 14:42:59 - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.988 mnist_add/reasoning_accuracy: 0.976 \n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"# Build logger\n", | |||
"print_log(\"Abductive Learning on the MNIST Addition example.\", logger=\"current\")\n", | |||
@@ -0,0 +1,189 @@ | |||
import os.path as osp | |||
import numpy as np | |||
from sklearn.ensemble import RandomForestClassifier | |||
from sklearn.metrics import accuracy_score | |||
from z3 import Solver, Int, If, Not, Implies, Sum, sat | |||
import openml | |||
from abl.learning import ABLModel | |||
from abl.reasoning import KBBase, Reasoner | |||
from abl.evaluation import ReasoningMetric, SymbolMetric | |||
from abl.bridge import SimpleBridge | |||
from abl.utils.utils import confidence_dist | |||
from abl.utils import ABLLogger, print_log | |||
# Build logger | |||
print_log("Abductive Learning on the Zoo example.", logger="current") | |||
# Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
log_dir = ABLLogger.get_current_instance().log_dir | |||
weights_dir = osp.join(log_dir, "weights") | |||
# Learning Part | |||
rf = RandomForestClassifier() | |||
model = ABLModel(rf) | |||
# %% [markdown] | |||
# ### Logic Part | |||
# %% | |||
class ZooKB(KBBase): | |||
def __init__(self): | |||
super().__init__(pseudo_label_list=list(range(7)), use_cache=False) | |||
# Use z3 solver | |||
self.solver = Solver() | |||
# Load information of Zoo dataset | |||
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False) | |||
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||
self.attribute_names = attribute_names | |||
self.target_names = y.cat.categories.tolist() | |||
# Define variables | |||
for name in self.attribute_names+self.target_names: | |||
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules | |||
# Define rules | |||
rules = [ | |||
Implies(milk == 1, mammal == 1), | |||
Implies(mammal == 1, milk == 1), | |||
Implies(mammal == 1, backbone == 1), | |||
Implies(mammal == 1, breathes == 1), | |||
Implies(feathers == 1, bird == 1), | |||
Implies(bird == 1, feathers == 1), | |||
Implies(bird == 1, eggs == 1), | |||
Implies(bird == 1, backbone == 1), | |||
Implies(bird == 1, breathes == 1), | |||
Implies(bird == 1, legs == 2), | |||
Implies(bird == 1, tail == 1), | |||
Implies(reptile == 1, backbone == 1), | |||
Implies(reptile == 1, breathes == 1), | |||
Implies(reptile == 1, tail == 1), | |||
Implies(fish == 1, aquatic == 1), | |||
Implies(fish == 1, toothed == 1), | |||
Implies(fish == 1, backbone == 1), | |||
Implies(fish == 1, Not(breathes == 1)), | |||
Implies(fish == 1, fins == 1), | |||
Implies(fish == 1, legs == 0), | |||
Implies(fish == 1, tail == 1), | |||
Implies(amphibian == 1, eggs == 1), | |||
Implies(amphibian == 1, aquatic == 1), | |||
Implies(amphibian == 1, backbone == 1), | |||
Implies(amphibian == 1, breathes == 1), | |||
Implies(amphibian == 1, legs == 4), | |||
Implies(insect == 1, eggs == 1), | |||
Implies(insect == 1, Not(backbone == 1)), | |||
Implies(insect == 1, legs == 6), | |||
Implies(invertebrate == 1, Not(backbone == 1)) | |||
] | |||
# Define weights and sum of violated weights | |||
self.weights = {rule: 1 for rule in rules} | |||
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights]) | |||
def logic_forward(self, pseudo_label, data_point): | |||
attribute_names, target_names = self.attribute_names, self.target_names | |||
solver = self.solver | |||
total_violation_weight = self.total_violation_weight | |||
pseudo_label, data_point = pseudo_label[0], data_point[0] | |||
self.solver.reset() | |||
for name, value in zip(attribute_names, data_point): | |||
solver.add(eval(f"{name} == {value}")) | |||
for cate, name in zip(self.pseudo_label_list,target_names): | |||
value = 1 if (cate == pseudo_label) else 0 | |||
solver.add(eval(f"{name} == {value}")) | |||
if solver.check() == sat: | |||
model = solver.model() | |||
total_weight = model.evaluate(total_violation_weight) | |||
return total_weight.as_long() | |||
else: | |||
# No solution found | |||
return 1e10 | |||
def consitency(data_example, candidates, candidate_idxs, reasoning_results): | |||
pred_prob = data_example.pred_prob | |||
model_scores = confidence_dist(pred_prob, candidate_idxs) | |||
rule_scores = np.array(reasoning_results) | |||
scores = model_scores + rule_scores | |||
return scores | |||
kb = ZooKB() | |||
reasoner = Reasoner(kb, dist_func=consitency) | |||
# %% [markdown] | |||
# ### Datasets and Evaluation Metrics | |||
# %% | |||
# Function to load and preprocess the dataset | |||
def load_and_preprocess_dataset(dataset_id): | |||
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False) | |||
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute) | |||
# Convert data types | |||
for col in X.select_dtypes(include='bool').columns: | |||
X[col] = X[col].astype(int) | |||
y = y.cat.codes.astype(int) | |||
X, y = X.to_numpy(), y.to_numpy() | |||
return X, y | |||
# Function to split data (one shot) | |||
def split_dataset(X, y, test_size = 0.3): | |||
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1) | |||
label_indices, unlabel_indices, test_indices = [], [], [] | |||
for class_label in np.unique(y): | |||
idxs = np.where(y == class_label)[0] | |||
np.random.shuffle(idxs) | |||
n_train_unlabel = int((1-test_size)*(len(idxs)-1)) | |||
label_indices.append(idxs[0]) | |||
unlabel_indices.extend(idxs[1:1+n_train_unlabel]) | |||
test_indices.extend(idxs[1+n_train_unlabel:]) | |||
X_label, y_label = X[label_indices], y[label_indices] | |||
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices] | |||
X_test, y_test = X[test_indices], y[test_indices] | |||
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test | |||
# Load and preprocess the Zoo dataset | |||
X, y = load_and_preprocess_dataset(dataset_id=62) | |||
# Split data into labeled/unlabeled/test data | |||
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | |||
# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results) | |||
# For tabular data in abl, each example contains a single instance (a row from the dataset). | |||
# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated. | |||
def transform_tab_data(X, y): | |||
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) | |||
label_data = transform_tab_data(X_label, y_label) | |||
test_data = transform_tab_data(X_test, y_test) | |||
train_data = transform_tab_data(X_unlabel, y_unlabel) | |||
# %% | |||
# Set up metrics | |||
metric_list = [SymbolMetric(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")] | |||
# %% [markdown] | |||
# ### Bridge Machine Learning and Logic Reasoning | |||
# %% | |||
bridge = SimpleBridge(model, reasoner, metric_list) | |||
# %% [markdown] | |||
# ### Train and Test | |||
# %% | |||
# Pre-train the machine learning model | |||
rf.fit(X_label, y_label) | |||
# %% | |||
# Test the initial model | |||
print("------- Test the initial model -----------") | |||
bridge.test(test_data) | |||
print("------- Use ABL to train the model -----------") | |||
# Use ABL to train the model | |||
bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir) | |||
print("------- Test the final model -----------") | |||
# Test the final model | |||
bridge.test(test_data) | |||
@@ -140,8 +140,8 @@ | |||
" # No solution found\n", | |||
" return 1e10\n", | |||
" \n", | |||
"def consitency(data_sample, candidates, candidate_idxs, reasoning_results):\n", | |||
" pred_prob = data_sample.pred_prob\n", | |||
"def consitency(data_example, candidates, candidate_idxs, reasoning_results):\n", | |||
" pred_prob = data_example.pred_prob\n", | |||
" model_scores = confidence_dist(pred_prob, candidate_idxs)\n", | |||
" rule_scores = np.array(reasoning_results)\n", | |||
" scores = model_scores + rule_scores\n", | |||
@@ -198,8 +198,8 @@ | |||
"X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)\n", | |||
"\n", | |||
"# Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)\n", | |||
"# For tabular data in abl, each sample contains a single instance (a row from the dataset).\n", | |||
"# For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated.\n", | |||
"# For tabular data in abl, each example contains a single instance (a row from the dataset).\n", | |||
"# For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated.\n", | |||
"def transform_tab_data(X, y):\n", | |||
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", | |||
"label_data = transform_tab_data(X_label, y_label)\n", | |||
@@ -61,15 +61,15 @@ def base_model_instance(): | |||
# Fixture for ListData instance | |||
@pytest.fixture | |||
def list_data_instance(): | |||
data_samples = ListData() | |||
data_samples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)] | |||
data_samples.Y = [1, 2, 3] | |||
data_samples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] | |||
return data_samples | |||
data_examples = ListData() | |||
data_examples.X = [list(torch.randn(2, 1, 28, 28)) for _ in range(3)] | |||
data_examples.Y = [1, 2, 3] | |||
data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] | |||
return data_examples | |||
@pytest.fixture | |||
def data_samples_add(): | |||
def data_examples_add(): | |||
# favor 1 in first one | |||
prob1 = [ | |||
[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||
@@ -81,27 +81,27 @@ def data_samples_add(): | |||
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
] | |||
data_samples_add = ListData() | |||
data_samples_add.X = None | |||
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | |||
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] | |||
data_samples_add.Y = [8, 8, 17, 10] | |||
return data_samples_add | |||
data_examples_add = ListData() | |||
data_examples_add.X = None | |||
data_examples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | |||
data_examples_add.pred_prob = [prob1, prob2, prob1, prob2] | |||
data_examples_add.Y = [8, 8, 17, 10] | |||
return data_examples_add | |||
@pytest.fixture | |||
def data_samples_hwf(): | |||
data_samples_hwf = ListData() | |||
data_samples_hwf.X = None | |||
data_samples_hwf.pred_pseudo_label = [ | |||
def data_examples_hwf(): | |||
data_examples_hwf = ListData() | |||
data_examples_hwf.X = None | |||
data_examples_hwf.pred_pseudo_label = [ | |||
["5", "+", "2"], | |||
["5", "+", "9"], | |||
["5", "+", "9"], | |||
["5", "-", "8", "8", "8"], | |||
] | |||
data_samples_hwf.pred_prob = [None, None, None, None] | |||
data_samples_hwf.Y = [3, 64, 65, 3.17] | |||
return data_samples_hwf | |||
data_examples_hwf.pred_prob = [None, None, None, None] | |||
data_examples_hwf.Y = [3, 64, 65, 3.17] | |||
return data_examples_hwf | |||
class AddKB(KBBase): | |||
@@ -199,7 +199,7 @@ def kb_add_ground(): | |||
@pytest.fixture | |||
def kb_add_prolog(): | |||
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl") | |||
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl") | |||
return kb | |||
@pytest.fixture | |||
@@ -53,7 +53,7 @@ class TestGroundKB(object): | |||
class TestPrologKB(object): | |||
def test_init_pl1(self, kb_add_prolog): | |||
assert kb_add_prolog.pseudo_label_list == list(range(10)) | |||
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl" | |||
assert kb_add_prolog.pl_file == "examples/mnist_add/add.pl" | |||
def test_init_pl2(self, kb_hed): | |||
assert kb_hed.pseudo_label_list == [1, 0, "+", "="] | |||
@@ -101,7 +101,7 @@ class TestReaonser(object): | |||
excinfo.value | |||
) | |||
def random_dist(self, data_sample, candidates, candidate_idxs, reasoning_results): | |||
def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results): | |||
cost_list = [np.random.rand() for _ in candidates] | |||
return cost_list | |||
@@ -113,11 +113,11 @@ class TestReaonser(object): | |||
cost_list = np.array([np.random.rand() for _ in candidates]) | |||
return cost_list | |||
def invalid_dist2(self, data_sample, candidates, candidate_idxs, reasoning_results): | |||
def invalid_dist2(self, data_example, candidates, candidate_idxs, reasoning_results): | |||
cost_list = np.array([np.random.rand() for _ in candidates]) | |||
return np.append(cost_list, np.random.rand()) | |||
def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add): | |||
def test_invalid_user_defined_dist_func(self, kb_add, data_examples_add): | |||
with pytest.raises(ValueError) as excinfo: | |||
Reasoner(kb_add, self.invalid_dist1) | |||
assert 'User-defined dist_func must have exactly four parameters' in str( | |||
@@ -125,98 +125,98 @@ class TestReaonser(object): | |||
) | |||
with pytest.raises(ValueError) as excinfo: | |||
reasoner = Reasoner(kb_add, self.invalid_dist2) | |||
reasoner.batch_abduce(data_samples_add) | |||
reasoner.batch_abduce(data_examples_add) | |||
assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str( | |||
excinfo.value | |||
) | |||
class TestBatchAbduce(object): | |||
def test_batch_abduce_add(self, kb_add, data_samples_add): | |||
def test_batch_abduce_add(self, kb_add, data_examples_add): | |||
reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0) | |||
reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1) | |||
reasoner3 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=0) | |||
reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1) | |||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||
assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner3.batch_abduce(data_examples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[1, 9], | |||
] | |||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||
assert reasoner4.batch_abduce(data_examples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[7, 3], | |||
] | |||
def test_batch_abduce_ground(self, kb_add_ground, data_samples_add): | |||
def test_batch_abduce_ground(self, kb_add_ground, data_examples_add): | |||
reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0) | |||
reasoner2 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=1) | |||
reasoner3 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=0) | |||
reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1) | |||
assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | |||
assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)] | |||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||
assert reasoner1.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)] | |||
assert reasoner2.batch_abduce(data_examples_add) == [(1, 7), (7, 1), [], (1, 9)] | |||
assert reasoner3.batch_abduce(data_examples_add) == [ | |||
(1, 7), | |||
(7, 1), | |||
(8, 9), | |||
(1, 9), | |||
] | |||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||
assert reasoner4.batch_abduce(data_examples_add) == [ | |||
(1, 7), | |||
(7, 1), | |||
(8, 9), | |||
(7, 3), | |||
] | |||
def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add): | |||
def test_batch_abduce_prolog(self, kb_add_prolog, data_examples_add): | |||
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0) | |||
reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1) | |||
reasoner3 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=0) | |||
reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1) | |||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner3.batch_abduce(data_samples_add) == [ | |||
assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner3.batch_abduce(data_examples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[1, 9], | |||
] | |||
assert reasoner4.batch_abduce(data_samples_add) == [ | |||
assert reasoner4.batch_abduce(data_examples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[7, 3], | |||
] | |||
def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add): | |||
def test_batch_abduce_zoopt(self, kb_add_prolog, data_examples_add): | |||
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1) | |||
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2) | |||
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_samples_add) == [ | |||
assert reasoner1.batch_abduce(data_examples_add) == [[1, 7], [7, 1], [], [1, 9]] | |||
assert reasoner2.batch_abduce(data_examples_add) == [ | |||
[1, 7], | |||
[7, 1], | |||
[8, 9], | |||
[7, 3], | |||
] | |||
def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf): | |||
def test_batch_abduce_hwf1(self, kb_hwf1, data_examples_hwf): | |||
reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0) | |||
reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0) | |||
reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0) | |||
res = reasoner1.batch_abduce(data_samples_hwf) | |||
res = reasoner1.batch_abduce(data_examples_hwf) | |||
assert res == [ | |||
["1", "+", "2"], | |||
["8", "times", "8"], | |||
[], | |||
["4", "-", "6", "div", "8"], | |||
] | |||
res = reasoner2.batch_abduce(data_samples_hwf) | |||
res = reasoner2.batch_abduce(data_examples_hwf) | |||
assert res == [["1", "+", "2"], [], [], []] | |||
res = reasoner3.batch_abduce(data_samples_hwf) | |||
res = reasoner3.batch_abduce(data_examples_hwf) | |||
assert res == [ | |||
["1", "+", "2"], | |||
["8", "times", "8"], | |||
@@ -224,25 +224,25 @@ class TestBatchAbduce(object): | |||
["4", "-", "6", "div", "8"], | |||
] | |||
def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf): | |||
def test_batch_abduce_hwf2(self, kb_hwf2, data_examples_hwf): | |||
reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0) | |||
reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0) | |||
reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0) | |||
res = reasoner1.batch_abduce(data_samples_hwf) | |||
res = reasoner1.batch_abduce(data_examples_hwf) | |||
assert res == [ | |||
["1", "+", "2"], | |||
["7", "times", "9"], | |||
["8", "times", "8"], | |||
["5", "-", "8", "div", "8"], | |||
] | |||
res = reasoner2.batch_abduce(data_samples_hwf) | |||
res = reasoner2.batch_abduce(data_examples_hwf) | |||
assert res == [ | |||
["1", "+", "2"], | |||
["7", "times", "9"], | |||
[], | |||
["5", "-", "8", "div", "8"], | |||
] | |||
res = reasoner3.batch_abduce(data_samples_hwf) | |||
res = reasoner3.batch_abduce(data_examples_hwf) | |||
assert res == [ | |||
["1", "+", "2"], | |||
["7", "times", "9"], | |||