@@ -26,6 +26,7 @@ class BaseMetric(metaclass=ABCMeta): | |||
self, | |||
prefix: Optional[str] = None, | |||
) -> None: | |||
self.default_prefix = "" | |||
self.results: List[Any] = [] | |||
self.prefix = prefix or self.default_prefix | |||
@@ -204,7 +204,7 @@ class Reasoner: | |||
dim=dimension, | |||
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), | |||
) | |||
parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
parameter = Parameter(budget=200, intermediate_result=False, autoset=True) | |||
solution = Opt.min(objective, parameter) | |||
return solution | |||
@@ -1,5 +1,6 @@ | |||
import os | |||
from collections import defaultdict | |||
from typing import Any, List, Optional, Tuple, Union | |||
import torch | |||
@@ -41,7 +42,7 @@ class HedBridge(SimpleBridge): | |||
cls_autoencoder, | |||
loss_fn, | |||
optimizer, | |||
device, | |||
device=device, | |||
save_interval=1, | |||
save_dir=weights_dir, | |||
num_epochs=10, | |||
@@ -115,7 +116,7 @@ class HedBridge(SimpleBridge): | |||
) | |||
print_log(log_string, logger="current") | |||
if true_ratio > 0.95 and false_ratio < 0.1: | |||
if true_ratio > 0.9 and false_ratio < 0.05: | |||
return True | |||
return False | |||
@@ -215,7 +216,7 @@ class HedBridge(SimpleBridge): | |||
self.abduce_pseudo_label(sub_data_examples) | |||
filtered_sub_data_examples = self.filter_empty(sub_data_examples) | |||
self.pseudo_label_to_idx(filtered_sub_data_examples) | |||
loss = self.model.train(filtered_sub_data_examples) | |||
self.model.train(filtered_sub_data_examples) | |||
if self.check_training_impact(filtered_sub_data_examples, sub_data_examples): | |||
condition_num += 1 | |||
@@ -231,6 +232,7 @@ class HedBridge(SimpleBridge): | |||
seems_good = self.check_rule_quality(rules, val_data, equation_len) | |||
if seems_good: | |||
self.reasoner.kb.learned_rules.update({equation_len: rules}) | |||
self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth") | |||
break | |||
else: | |||
@@ -244,3 +246,19 @@ class HedBridge(SimpleBridge): | |||
self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth") | |||
condition_num = 0 | |||
print_log("Reload Model and retrain", logger="current") | |||
def test( | |||
self, | |||
test_data: Union[ | |||
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] | |||
], | |||
min_len=5, | |||
max_len=8, | |||
) -> None: | |||
for equation_len in range(min_len, max_len): | |||
test_data_examples = self.data_preprocess(test_data[1], equation_len) | |||
print_log(f"Test on true equations with length {equation_len}", logger="current") | |||
self._valid(test_data_examples) | |||
test_data_examples = self.data_preprocess(test_data[0], equation_len) | |||
print_log(f"Test on false equations with length {equation_len}", logger="current") | |||
self._valid(test_data_examples) |
@@ -0,0 +1,28 @@ | |||
from typing import Optional | |||
from abl.reasoning import KBBase | |||
from abl.data.structures import ListData | |||
from abl.data.evaluation.base_metric import BaseMetric | |||
class ConsistencyMetric(BaseMetric): | |||
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: | |||
super().__init__(prefix) | |||
self.kb = kb | |||
def process(self, data_examples: ListData) -> None: | |||
pred_pseudo_label = data_examples.pred_pseudo_label | |||
learned_rules = self.kb.learned_rules | |||
consistent_num = sum( | |||
[ | |||
self.kb.consist_rule(instance, learned_rules[len(instance)]) | |||
for instance in pred_pseudo_label | |||
] | |||
) | |||
self.results.append((consistent_num, len(pred_pseudo_label))) | |||
def compute_metrics(self) -> dict: | |||
results = self.results | |||
metrics = dict() | |||
metrics["consistency"] = sum(t[0] for t in results) / sum(t[1] for t in results) | |||
return metrics |
@@ -13,7 +13,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 1, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -26,7 +26,8 @@ | |||
"from examples.models.nn import SymbolNet\n", | |||
"from abl.learning import ABLModel, BasicNN\n", | |||
"from examples.hed.reasoning import HedKB, HedReasoner\n", | |||
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | |||
"from abl.data.evaluation import SymbolAccuracy\n", | |||
"from examples.hed.consistency_metric import ConsistencyMetric\n", | |||
"from abl.utils import ABLLogger, print_log\n", | |||
"from examples.hed.bridge import HedBridge" | |||
] | |||
@@ -47,7 +48,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -65,7 +66,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -119,7 +120,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -240,7 +241,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -254,7 +255,7 @@ | |||
" cls,\n", | |||
" loss_fn,\n", | |||
" optimizer,\n", | |||
" device,\n", | |||
" device=device,\n", | |||
" batch_size=32,\n", | |||
" num_epochs=1,\n", | |||
" stop_loss=None,\n", | |||
@@ -270,7 +271,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -298,7 +299,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -316,7 +317,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -340,11 +341,11 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"metric_list = [SymbolAccuracy(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]" | |||
"metric_list = [ConsistencyMetric(kb=kb)]" | |||
] | |||
}, | |||
{ | |||
@@ -359,7 +360,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -9,6 +9,7 @@ CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
class HedKB(PrologKB): | |||
def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): | |||
super().__init__(pseudo_label_list, pl_file) | |||
self.learned_rules = {} | |||
def consist_rule(self, exs, rules): | |||
rules = str(rules).replace("'", "") | |||