Browse Source

[FIX] remove private method in HedReasoner

pull/1/head
troyyyyy 1 year ago
parent
commit
5df67e4b91
14 changed files with 1024 additions and 491 deletions
  1. +9
    -9
      abl/reasoning/reasoner.py
  2. +0
    -0
      examples/hed/README.md
  3. +1
    -1
      examples/hed/bridge.py
  4. +4
    -0
      examples/hed/datasets/__init__.py
  5. +0
    -173
      examples/hed/datasets/equation_generator.py
  6. +1
    -1
      examples/hed/datasets/get_dataset.py
  7. +912
    -0
      examples/hed/hed.ipynb
  8. +0
    -307
      examples/hed/hed_example.ipynb
  9. +0
    -0
      examples/hed/main.py
  10. +0
    -0
      examples/hed/reasoning/BK.pl
  11. +3
    -0
      examples/hed/reasoning/__init__.py
  12. +0
    -0
      examples/hed/reasoning/learn_add.pl
  13. +93
    -0
      examples/hed/reasoning/reasoning.py
  14. +1
    -0
      examples/hed/requirements.txt

+ 9
- 9
abl/reasoning/reasoner.py View File

@@ -2,7 +2,7 @@ import inspect
from typing import Callable, Any, List, Optional, Union

import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter
from zoopt import Dimension, Objective, Opt, Parameter, Solution

from ..reasoning import KBBase
from ..structures import ListData
@@ -175,9 +175,9 @@ class Reasoner:
symbol_num: int,
data_example: ListData,
max_revision_num: int,
) -> List[bool]:
) -> Solution:
"""
Get the optimal solution using ZOOpt library. The solution is a list of
Get the optimal solution using ZOOpt library. From the solution, we can get a list of
boolean values, where '1' (True) indicates the indices chosen to be revised.

Parameters
@@ -191,7 +191,7 @@ class Reasoner:

Returns
-------
List[bool]
Solution
The solution for ZOOpt library.
"""
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
@@ -201,14 +201,14 @@ class Reasoner:
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
solution = Opt.min(objective, parameter)
return solution

def zoopt_revision_score(
self,
symbol_num: int,
data_example: ListData,
sol: List[bool],
sol: Solution,
) -> int:
"""
Get the revision score for a solution. A lower score suggests that ZOOpt library
@@ -220,7 +220,7 @@ class Reasoner:
Number of total symbols.
data_example : ListData
Data example.
sol: List[bool]
sol: Solution
The solution for ZOOpt library.

Returns
@@ -237,7 +237,7 @@ class Reasoner:
else:
return symbol_num

def _constrain_revision_num(self, solution: List[bool], max_revision_num: int) -> int:
def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int:
"""
Constrain that the total number of revisions chosen by the solution does not exceed
maximum number of revisions allowed.
@@ -287,7 +287,7 @@ class Reasoner:

if self.use_zoopt:
solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
revision_idx = np.where(solution != 0)[0]
revision_idx = np.where(solution.get_x() != 0)[0]
candidates, reasoning_results = self.kb.revise_at_idx(
pseudo_label=data_example.pred_pseudo_label,
y=data_example.Y,


+ 0
- 0
examples/hed/README.md View File


examples/hed/hed_bridge.py → examples/hed/bridge.py View File

@@ -10,7 +10,7 @@ from abl.learning import ABLModel, BasicNN
from abl.reasoning import Reasoner
from abl.structures import ListData
from abl.utils import print_log
from examples.hed.datasets.get_hed import get_pretrain_data
from examples.hed.datasets.get_dataset import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings
from examples.models.nn import SymbolNetAutoencoder


+ 4
- 0
examples/hed/datasets/__init__.py View File

@@ -0,0 +1,4 @@
from .get_dataset import get_dataset, split_equation


__all__ = ["get_dataset", "split_equation"]

+ 0
- 173
examples/hed/datasets/equation_generator.py View File

@@ -1,173 +0,0 @@
import os
import itertools
import random
import numpy as np
from PIL import Image
import pickle

def get_sign_path_list(data_dir, sign_names):
sign_num = len(sign_names)
index_dict = dict(zip(sign_names, list(range(sign_num))))
ret = [[] for _ in range(sign_num)]
for path in os.listdir(data_dir):
if (path in sign_names):
index = index_dict[path]
sign_path = os.path.join(data_dir, path)
for p in os.listdir(sign_path):
ret[index].append(os.path.join(sign_path, p))
return ret

def split_pool_by_rate(pools, rate, seed = None):
if seed is not None:
random.seed(seed)
ret1 = []
ret2 = []
for pool in pools:
random.shuffle(pool)
num = int(len(pool) * rate)
ret1.append(pool[:num])
ret2.append(pool[num:])
return ret1, ret2

def int_to_system_form(num, system_num):
if num == 0:
return "0"
ret = ""
while (num > 0):
ret += str(num % system_num)
num //= system_num
return ret[::-1]

def generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type):
expr_len = left_opt_len + right_opt_len
num_list = "".join([str(i) for i in range(system_num)])
ret = []
if generate_type == "all":
candidates = itertools.product(num_list, repeat = expr_len)
else:
candidates = [''.join(random.sample(['0', '1'] * expr_len, expr_len))]
random.shuffle(candidates)
for nums in candidates:
left_num = "".join(nums[:left_opt_len])
right_num = "".join(nums[left_opt_len:])
left_value = int(left_num, system_num)
right_value = int(right_num, system_num)
result_value = left_value + right_value
if (label == 'negative'):
result_value += random.randint(-result_value, result_value)
if (left_value + right_value == result_value):
continue
result_num = int_to_system_form(result_value, system_num)
#leading zeros
if (res_opt_len != len(result_num)):
continue
if ((left_opt_len > 1 and left_num[0] == '0') or (right_opt_len > 1 and right_num[0] == '0')):
continue

#add leading zeros
if (res_opt_len < len(result_num)):
continue
while (len(result_num) < res_opt_len):
result_num = '0' + result_num
#continue
ret.append(left_num + '+' + right_num + '=' + result_num) # current only consider '+' and '='
#print(ret[-1])
return ret

def generator_equation_by_len(equation_len, system_num = 2, label = 0, require_num = 1):
generate_type = "one"
ret = []
equation_sign_num = 2 # '+' and '='
while len(ret) < require_num:
left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num)
right_opt_len = random.randint(1, equation_len - left_opt_len - equation_sign_num)
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type))
return ret

def generator_equations_by_len(equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all"):
ret = []
equation_sign_num = 2 # '+' and '='
for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1):
for right_opt_len in range(1, equation_len - left_opt_len - (1 + equation_sign_num) + 1):
res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num
for i in range(repeat_times): #generate more equations
if random.random() > keep ** (equation_len):
continue
ret.extend(generator_equations(left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type))
return ret

def generator_equations_by_max_len(max_equation_len, system_num = 2, label = 0, repeat_times = 1, keep = 1, generate_type = "all", num_per_len = None):
ret = []
equation_sign_num = 2 # '+' and '='
for equation_len in range(3 + equation_sign_num, max_equation_len + 1):
if (num_per_len is None):
ret.extend(generator_equations_by_len(equation_len, system_num, label, repeat_times, keep, generate_type))
else:
ret.extend(generator_equation_by_len(equation_len, system_num, label, require_num = num_per_len))
return ret

def generator_equation_images(image_pools, equations, signs, shape, seed, is_color):
if (seed is not None):
random.seed(seed)
ret = []
sign_num = len(signs)
sign_index_dict = dict(zip(signs, list(range(sign_num))))
for equation in equations:
data = []
for sign in equation:
index = sign_index_dict[sign]
pick = random.randint(0, len(image_pools[index]) - 1)
if is_color:
image = Image.open(image_pools[index][pick]).convert('RGB').resize(shape)
else:
image = Image.open(image_pools[index][pick]).convert('I').resize(shape)
image_array = np.array(image)
image_array = (image_array-127)*(1./128)
data.append(image_array)
ret.append(np.array(data))
return ret

def get_equation_std_data(data_dir, sign_dir_lists, sign_output_lists, shape = (28, 28), train_max_equation_len = 10, test_max_equation_len = 10, system_num = 2, tmp_file_prev =
None, seed = None, train_num_per_len = 10, test_num_per_len = 10, is_color = False):
tmp_file = ""
if (tmp_file_prev is not None):
tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % (tmp_file_prev, train_max_equation_len, test_max_equation_len, system_num)
if (os.path.exists(tmp_file)):
return pickle.load(open(tmp_file, "rb"))

image_pools = get_sign_path_list(data_dir, sign_dir_lists)
train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed)

ret = {}
for label in ["positive", "negative"]:
print("Generating equations.")
train_equations = generator_equations_by_max_len(train_max_equation_len, system_num, label, num_per_len = train_num_per_len)
test_equations = generator_equations_by_max_len(test_max_equation_len, system_num, label, num_per_len = test_num_per_len)
print(train_equations)
print(test_equations)
print("Generated equations.")
print("Generating equation image data.")
ret["train:%s" % (label)] = generator_equation_images(train_pool, train_equations, sign_output_lists, shape, seed, is_color)
ret["test:%s" % (label)] = generator_equation_images(test_pool, test_equations, sign_output_lists, shape, seed, is_color)
print("Generated equation image data.")

if (tmp_file_prev is not None):
pickle.dump(ret, open(tmp_file, "wb"))
return ret

if __name__ == "__main__":
data_dirs = ["./dataset/hed/mnist_images", "./dataset/hed/random_images"] #, "../dataset/cifar10_images"]
tmp_file_prevs = ["mnist_equation_data", "random_equation_data"] #, "cifar10_equation_data"]
for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs):
data = get_equation_std_data(data_dir = data_dir,\
sign_dir_lists = ['0', '1', '10', '11'],\
sign_output_lists = ['0', '1', '+', '='],\
shape = (28, 28),\
train_max_equation_len = 26, \
test_max_equation_len = 26, \
system_num = 2, \
tmp_file_prev = tmp_file_prev, \
train_num_per_len = 300, \
test_num_per_len = 300, \
is_color = False)

examples/hed/datasets/get_hed.py → examples/hed/datasets/get_dataset.py View File

@@ -81,7 +81,7 @@ def split_equation(equations_by_len, prop_train, prop_val):
return train_equations_by_len, val_equations_by_len


def get_hed(dataset="mnist", train=True):
def get_dataset(dataset="mnist", train=True):
if dataset == "mnist":
file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk")
elif dataset == "random":

+ 912
- 0
examples/hed/hed.ipynb
File diff suppressed because it is too large
View File


+ 0
- 307
examples/hed/hed_example.ipynb View File

@@ -1,307 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os.path as osp\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"from zoopt import Dimension, Objective, Opt, Parameter\n",
"\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import PrologKB, Reasoner\n",
"from abl.utils import ABLLogger, print_log, reform_list\n",
"from examples.hed.datasets.get_hed import get_hed, split_equation\n",
"from examples.hed.hed_bridge import HEDBridge\n",
"from examples.models.nn import SymbolNet"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build logger\n",
"print_log(\"Abductive Learning on the HED example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logic Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"class HedKB(PrologKB):\n",
" def __init__(self, pseudo_label_list, pl_file):\n",
" super().__init__(pseudo_label_list, pl_file)\n",
"\n",
" def consist_rule(self, exs, rules):\n",
" rules = str(rules).replace(\"'\", \"\")\n",
" return len(list(self.prolog.query(\"eval_inst_feature(%s, %s).\" % (exs, rules)))) != 0\n",
"\n",
" def abduce_rules(self, pred_res):\n",
" prolog_result = list(self.prolog.query(\"consistent_inst_feature(%s, X).\" % pred_res))\n",
" if len(prolog_result) == 0:\n",
" return None\n",
" prolog_rules = prolog_result[0][\"X\"]\n",
" rules = [rule.value for rule in prolog_rules]\n",
" return rules\n",
"\n",
"\n",
"class HedReasoner(Reasoner):\n",
" def revise_at_idx(self, data_example):\n",
" revision_idx = np.where(np.array(data_example.flatten(\"revision_flag\")) != 0)[0]\n",
" candidate = self.kb.revise_at_idx(\n",
" data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx\n",
" )\n",
" return candidate\n",
"\n",
" def zoopt_revision_score(self, symbol_num, data_example, sol):\n",
" revision_flag = reform_list(\n",
" list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label\n",
" )\n",
" data_example.revision_flag = revision_flag\n",
"\n",
" lefted_idxs = [i for i in range(len(data_example.pred_idx))]\n",
" candidate_size = []\n",
" max_consistent_idxs = []\n",
" while lefted_idxs:\n",
" idxs = []\n",
" idxs.append(lefted_idxs.pop(0))\n",
" max_candidate_idxs = []\n",
" found = False\n",
" for idx in range(-1, len(data_example.pred_idx)):\n",
" if (not idx in idxs) and (idx >= 0):\n",
" idxs.append(idx)\n",
" candidates, _ = self.revise_at_idx(data_example[idxs])\n",
" if len(candidates) == 0:\n",
" if len(idxs) > 1:\n",
" idxs.pop()\n",
" else:\n",
" if len(idxs) > len(max_candidate_idxs):\n",
" found = True\n",
" max_candidate_idxs = idxs.copy()\n",
" removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n",
" if found:\n",
" removed.insert(0, idxs[0])\n",
" candidate_size.append(len(removed))\n",
" max_consistent_idxs = max_candidate_idxs.copy()\n",
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n",
" candidate_size.sort()\n",
" score = 0\n",
" import math\n",
"\n",
" for i in range(0, len(candidate_size)):\n",
" score -= math.exp(-i) * candidate_size[i]\n",
" return score, max_consistent_idxs\n",
" \n",
" def _zoopt_get_solution(self, symbol_num, data_example, max_revision_num):\n",
" dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)\n",
" objective = Objective(\n",
" lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol)[0],\n",
" dim=dimension,\n",
" constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),\n",
" )\n",
" parameter = Parameter(budget=200, intermediate_result=False, autoset=True)\n",
" solution = Opt.min(objective, parameter)\n",
" return solution\n",
"\n",
" def abduce(self, data_example):\n",
" symbol_num = data_example.elements_num(\"pred_pseudo_label\")\n",
" max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)\n",
"\n",
" solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)\n",
" _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution)\n",
"\n",
" abduced_pseudo_label = [[] for _ in range(len(data_example))]\n",
"\n",
" if len(max_candidate_idxs) > 0:\n",
" candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])\n",
" for i, idx in enumerate(max_candidate_idxs):\n",
" abduced_pseudo_label[idx] = candidates[0][i]\n",
" data_example.abduced_pseudo_label = abduced_pseudo_label\n",
" return abduced_pseudo_label\n",
"\n",
" def abduce_rules(self, pred_res):\n",
" return self.kb.abduce_rules(pred_res)\n",
"\n",
"\n",
"kb = HedKB(pseudo_label_list=[1, 0, \"+\", \"=\"], pl_file=\"./datasets/learn_add.pl\")\n",
"reasoner = HedReasoner(kb, dist_func=\"hamming\", use_zoopt=True, max_revision=10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Machine Learning Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build necessary components for BasicNN\n",
"cls = SymbolNet(num_classes=4)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build BasicNN\n",
"# The function of BasicNN is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicNN(\n",
" cls,\n",
" loss_fn,\n",
" optimizer,\n",
" device,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
" save_interval=1,\n",
" stop_loss=None,\n",
" save_dir=weights_dir,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Build ABLModel\n",
"# The main function of the ABL model is to serialize data and\n",
"# provide a unified interface for different machine learning models\n",
"model = ABLModel(base_model)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metric"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set up metrics\n",
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(kb=kb, prefix=\"hed\")]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Bridge Machine Learning and Logic Reasoning"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bridge = HEDBridge(model, reasoner, metric_list)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"total_train_data = get_hed(train=True)\n",
"train_data, val_data = split_equation(total_train_data, 3, 1)\n",
"test_data = get_hed(train=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bridge.pretrain(\"./weights\")\n",
"bridge.train(train_data, val_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ABL",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 0
- 0
examples/hed/main.py View File


examples/hed/datasets/BK.pl → examples/hed/reasoning/BK.pl View File


+ 3
- 0
examples/hed/reasoning/__init__.py View File

@@ -0,0 +1,3 @@
from .reasoning import HedKB, HedReasoner

__all__ = ["HedKB", "HedReasoner"]

examples/hed/datasets/learn_add.pl → examples/hed/reasoning/learn_add.pl View File


+ 93
- 0
examples/hed/reasoning/reasoning.py View File

@@ -0,0 +1,93 @@
import os
import numpy as np
import math
from zoopt import Dimension, Objective, Opt, Parameter
from abl.reasoning import PrologKB, Reasoner
from abl.utils import reform_list

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

class HedKB(PrologKB):
def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

def abduce_rules(self, pred_res):
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
rules = [rule.value for rule in prolog_rules]
return rules


class HedReasoner(Reasoner):
def revise_at_idx(self, data_example):
revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0]
candidate = self.kb.revise_at_idx(
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
)
return candidate

def zoopt_revision_score(self, symbol_num, data_example, sol, get_score=True):
revision_flag = reform_list(
list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label
)
data_example.revision_flag = revision_flag

lefted_idxs = [i for i in range(len(data_example.pred_idx))]
candidate_size = []
max_consistent_idxs = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(data_example.pred_idx)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidates, _ = self.revise_at_idx(data_example[idxs])
if len(candidates) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
removed.insert(0, idxs[0])
candidate_size.append(len(removed))
max_consistent_idxs = max_candidate_idxs.copy()
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0

for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
if get_score:
return score
else:
return max_consistent_idxs

def abduce(self, data_example):
symbol_num = data_example.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)

solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution, get_score=False)

abduced_pseudo_label = [[] for _ in range(len(data_example))]

if len(max_candidate_idxs) > 0:
candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])
for i, idx in enumerate(max_candidate_idxs):
abduced_pseudo_label[idx] = candidates[0][i]
data_example.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

+ 1
- 0
examples/hed/requirements.txt View File

@@ -0,0 +1 @@
abl

Loading…
Cancel
Save