Browse Source

[ENH] implement test for hed

pull/1/head
Gao Enhao 1 year ago
parent
commit
bf55d3e06a
6 changed files with 66 additions and 17 deletions
  1. +1
    -0
      abl/data/evaluation/base_metric.py
  2. +1
    -1
      abl/reasoning/reasoner.py
  3. +21
    -3
      examples/hed/bridge.py
  4. +28
    -0
      examples/hed/consistency_metric.py
  5. +14
    -13
      examples/hed/hed.ipynb
  6. +1
    -0
      examples/hed/reasoning/reasoning.py

+ 1
- 0
abl/data/evaluation/base_metric.py View File

@@ -26,6 +26,7 @@ class BaseMetric(metaclass=ABCMeta):
self,
prefix: Optional[str] = None,
) -> None:
self.default_prefix = ""
self.results: List[Any] = []
self.prefix = prefix or self.default_prefix



+ 1
- 1
abl/reasoning/reasoner.py View File

@@ -204,7 +204,7 @@ class Reasoner:
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
parameter = Parameter(budget=200, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter)
return solution



+ 21
- 3
examples/hed/bridge.py View File

@@ -1,5 +1,6 @@
import os
from collections import defaultdict
from typing import Any, List, Optional, Tuple, Union

import torch

@@ -41,7 +42,7 @@ class HedBridge(SimpleBridge):
cls_autoencoder,
loss_fn,
optimizer,
device,
device=device,
save_interval=1,
save_dir=weights_dir,
num_epochs=10,
@@ -115,7 +116,7 @@ class HedBridge(SimpleBridge):
)
print_log(log_string, logger="current")

if true_ratio > 0.95 and false_ratio < 0.1:
if true_ratio > 0.9 and false_ratio < 0.05:
return True
return False

@@ -215,7 +216,7 @@ class HedBridge(SimpleBridge):
self.abduce_pseudo_label(sub_data_examples)
filtered_sub_data_examples = self.filter_empty(sub_data_examples)
self.pseudo_label_to_idx(filtered_sub_data_examples)
loss = self.model.train(filtered_sub_data_examples)
self.model.train(filtered_sub_data_examples)

if self.check_training_impact(filtered_sub_data_examples, sub_data_examples):
condition_num += 1
@@ -231,6 +232,7 @@ class HedBridge(SimpleBridge):

seems_good = self.check_rule_quality(rules, val_data, equation_len)
if seems_good:
self.reasoner.kb.learned_rules.update({equation_len: rules})
self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth")
break
else:
@@ -244,3 +246,19 @@ class HedBridge(SimpleBridge):
self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth")
condition_num = 0
print_log("Reload Model and retrain", logger="current")

def test(
self,
test_data: Union[
ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
],
min_len=5,
max_len=8,
) -> None:
for equation_len in range(min_len, max_len):
test_data_examples = self.data_preprocess(test_data[1], equation_len)
print_log(f"Test on true equations with length {equation_len}", logger="current")
self._valid(test_data_examples)
test_data_examples = self.data_preprocess(test_data[0], equation_len)
print_log(f"Test on false equations with length {equation_len}", logger="current")
self._valid(test_data_examples)

+ 28
- 0
examples/hed/consistency_metric.py View File

@@ -0,0 +1,28 @@
from typing import Optional

from abl.reasoning import KBBase
from abl.data.structures import ListData
from abl.data.evaluation.base_metric import BaseMetric


class ConsistencyMetric(BaseMetric):
def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
self.kb = kb

def process(self, data_examples: ListData) -> None:
pred_pseudo_label = data_examples.pred_pseudo_label
learned_rules = self.kb.learned_rules
consistent_num = sum(
[
self.kb.consist_rule(instance, learned_rules[len(instance)])
for instance in pred_pseudo_label
]
)
self.results.append((consistent_num, len(pred_pseudo_label)))

def compute_metrics(self) -> dict:
results = self.results
metrics = dict()
metrics["consistency"] = sum(t[0] for t in results) / sum(t[1] for t in results)
return metrics

+ 14
- 13
examples/hed/hed.ipynb View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -26,7 +26,8 @@
"from examples.models.nn import SymbolNet\n",
"from abl.learning import ABLModel, BasicNN\n",
"from examples.hed.reasoning import HedKB, HedReasoner\n",
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n",
"from abl.data.evaluation import SymbolAccuracy\n",
"from examples.hed.consistency_metric import ConsistencyMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"from examples.hed.bridge import HedBridge"
]
@@ -47,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -65,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -119,7 +120,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -240,7 +241,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -254,7 +255,7 @@
" cls,\n",
" loss_fn,\n",
" optimizer,\n",
" device,\n",
" device=device,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
" stop_loss=None,\n",
@@ -270,7 +271,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -298,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -316,7 +317,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -340,11 +341,11 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"metric_list = [SymbolAccuracy(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]"
"metric_list = [ConsistencyMetric(kb=kb)]"
]
},
{
@@ -359,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [


+ 1
- 0
examples/hed/reasoning/reasoning.py View File

@@ -9,6 +9,7 @@ CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
class HedKB(PrologKB):
def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")):
super().__init__(pseudo_label_list, pl_file)
self.learned_rules = {}

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")


Loading…
Cancel
Save