diff --git a/abl/data/evaluation/base_metric.py b/abl/data/evaluation/base_metric.py index 3371190..37e36dd 100644 --- a/abl/data/evaluation/base_metric.py +++ b/abl/data/evaluation/base_metric.py @@ -26,6 +26,7 @@ class BaseMetric(metaclass=ABCMeta): self, prefix: Optional[str] = None, ) -> None: + self.default_prefix = "" self.results: List[Any] = [] self.prefix = prefix or self.default_prefix diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index ec031f3..a7f2c4d 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -204,7 +204,7 @@ class Reasoner: dim=dimension, constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), ) - parameter = Parameter(budget=100, intermediate_result=False, autoset=True) + parameter = Parameter(budget=200, intermediate_result=False, autoset=True) solution = Opt.min(objective, parameter) return solution diff --git a/examples/hed/bridge.py b/examples/hed/bridge.py index 2852ca0..2a27908 100644 --- a/examples/hed/bridge.py +++ b/examples/hed/bridge.py @@ -1,5 +1,6 @@ import os from collections import defaultdict +from typing import Any, List, Optional, Tuple, Union import torch @@ -41,7 +42,7 @@ class HedBridge(SimpleBridge): cls_autoencoder, loss_fn, optimizer, - device, + device=device, save_interval=1, save_dir=weights_dir, num_epochs=10, @@ -115,7 +116,7 @@ class HedBridge(SimpleBridge): ) print_log(log_string, logger="current") - if true_ratio > 0.95 and false_ratio < 0.1: + if true_ratio > 0.9 and false_ratio < 0.05: return True return False @@ -215,7 +216,7 @@ class HedBridge(SimpleBridge): self.abduce_pseudo_label(sub_data_examples) filtered_sub_data_examples = self.filter_empty(sub_data_examples) self.pseudo_label_to_idx(filtered_sub_data_examples) - loss = self.model.train(filtered_sub_data_examples) + self.model.train(filtered_sub_data_examples) if self.check_training_impact(filtered_sub_data_examples, sub_data_examples): condition_num += 1 @@ -231,6 +232,7 @@ class HedBridge(SimpleBridge): seems_good = self.check_rule_quality(rules, val_data, equation_len) if seems_good: + self.reasoner.kb.learned_rules.update({equation_len: rules}) self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth") break else: @@ -244,3 +246,19 @@ class HedBridge(SimpleBridge): self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth") condition_num = 0 print_log("Reload Model and retrain", logger="current") + + def test( + self, + test_data: Union[ + ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] + ], + min_len=5, + max_len=8, + ) -> None: + for equation_len in range(min_len, max_len): + test_data_examples = self.data_preprocess(test_data[1], equation_len) + print_log(f"Test on true equations with length {equation_len}", logger="current") + self._valid(test_data_examples) + test_data_examples = self.data_preprocess(test_data[0], equation_len) + print_log(f"Test on false equations with length {equation_len}", logger="current") + self._valid(test_data_examples) diff --git a/examples/hed/consistency_metric.py b/examples/hed/consistency_metric.py new file mode 100644 index 0000000..5c68eb6 --- /dev/null +++ b/examples/hed/consistency_metric.py @@ -0,0 +1,28 @@ +from typing import Optional + +from abl.reasoning import KBBase +from abl.data.structures import ListData +from abl.data.evaluation.base_metric import BaseMetric + + +class ConsistencyMetric(BaseMetric): + def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: + super().__init__(prefix) + self.kb = kb + + def process(self, data_examples: ListData) -> None: + pred_pseudo_label = data_examples.pred_pseudo_label + learned_rules = self.kb.learned_rules + consistent_num = sum( + [ + self.kb.consist_rule(instance, learned_rules[len(instance)]) + for instance in pred_pseudo_label + ] + ) + self.results.append((consistent_num, len(pred_pseudo_label))) + + def compute_metrics(self) -> dict: + results = self.results + metrics = dict() + metrics["consistency"] = sum(t[0] for t in results) / sum(t[1] for t in results) + return metrics diff --git a/examples/hed/hed.ipynb b/examples/hed/hed.ipynb index 0470634..cfbf0d6 100644 --- a/examples/hed/hed.ipynb +++ b/examples/hed/hed.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,8 @@ "from examples.models.nn import SymbolNet\n", "from abl.learning import ABLModel, BasicNN\n", "from examples.hed.reasoning import HedKB, HedReasoner\n", - "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", + "from abl.data.evaluation import SymbolAccuracy\n", + "from examples.hed.consistency_metric import ConsistencyMetric\n", "from abl.utils import ABLLogger, print_log\n", "from examples.hed.bridge import HedBridge" ] @@ -47,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -65,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -119,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -240,7 +241,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -254,7 +255,7 @@ " cls,\n", " loss_fn,\n", " optimizer,\n", - " device,\n", + " device=device,\n", " batch_size=32,\n", " num_epochs=1,\n", " stop_loss=None,\n", @@ -270,7 +271,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -298,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -316,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -340,11 +341,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "metric_list = [SymbolAccuracy(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]" + "metric_list = [ConsistencyMetric(kb=kb)]" ] }, { @@ -359,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/examples/hed/reasoning/reasoning.py b/examples/hed/reasoning/reasoning.py index 3d6013f..e969386 100644 --- a/examples/hed/reasoning/reasoning.py +++ b/examples/hed/reasoning/reasoning.py @@ -9,6 +9,7 @@ CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) class HedKB(PrologKB): def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): super().__init__(pseudo_label_list, pl_file) + self.learned_rules = {} def consist_rule(self, exs, rules): rules = str(rules).replace("'", "")