|
|
@@ -1,248 +0,0 @@ |
|
|
|
import bisect |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
from collections import defaultdict |
|
|
|
from functools import lru_cache |
|
|
|
from itertools import combinations, product |
|
|
|
from multiprocessing import Pool |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import pyswip |
|
|
|
|
|
|
|
from ..utils.utils import (check_equal, flatten, hamming_dist, |
|
|
|
hashable_to_list, reform_idx, to_hashable) |
|
|
|
|
|
|
|
|
|
|
|
class KBBase(ABC): |
|
|
|
def __init__(self, pseudo_label_list, max_err=0, use_cache=True): |
|
|
|
# TODO:添加一下类型检查,比如 |
|
|
|
# if not isinstance(X, (np.ndarray, spmatrix)): |
|
|
|
# raise TypeError("X should be numpy array or sparse matrix") |
|
|
|
|
|
|
|
self.pseudo_label_list = pseudo_label_list |
|
|
|
self.max_err = max_err |
|
|
|
self.use_cache = use_cache |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
pass |
|
|
|
|
|
|
|
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0): |
|
|
|
if not self.use_cache: |
|
|
|
return self._abduce_by_search( |
|
|
|
pred_res, y, max_revision_num, require_more_revision |
|
|
|
) |
|
|
|
else: |
|
|
|
return self._abduce_by_search_cache( |
|
|
|
to_hashable(pred_res), |
|
|
|
to_hashable(y), |
|
|
|
max_revision_num, |
|
|
|
require_more_revision, |
|
|
|
) |
|
|
|
|
|
|
|
def revise_by_idx(self, pred_res, y, revision_idx): |
|
|
|
candidates = [] |
|
|
|
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx)) |
|
|
|
for c in abduce_c: |
|
|
|
candidate = pred_res.copy() |
|
|
|
for i, idx in enumerate(revision_idx): |
|
|
|
candidate[idx] = c[i] |
|
|
|
if check_equal(self.logic_forward(candidate), y, self.max_err): |
|
|
|
candidates.append(candidate) |
|
|
|
return candidates |
|
|
|
|
|
|
|
def _revision(self, revision_num, pred_res, y): |
|
|
|
new_candidates = [] |
|
|
|
revision_idx_list = combinations(range(len(pred_res)), revision_num) |
|
|
|
|
|
|
|
for revision_idx in revision_idx_list: |
|
|
|
candidates = self.revise_by_idx(pred_res, y, revision_idx) |
|
|
|
new_candidates.extend(candidates) |
|
|
|
return new_candidates |
|
|
|
|
|
|
|
def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision): |
|
|
|
candidates = [] |
|
|
|
for revision_num in range(len(pred_res) + 1): |
|
|
|
if revision_num == 0 and check_equal( |
|
|
|
self.logic_forward(pred_res), y, self.max_err |
|
|
|
): |
|
|
|
candidates.append(pred_res) |
|
|
|
elif revision_num > 0: |
|
|
|
candidates.extend(self._revision(revision_num, pred_res, y)) |
|
|
|
if len(candidates) > 0: |
|
|
|
min_revision_num = revision_num |
|
|
|
break |
|
|
|
if revision_num >= max_revision_num: |
|
|
|
return [] |
|
|
|
|
|
|
|
for revision_num in range( |
|
|
|
min_revision_num + 1, min_revision_num + require_more_revision + 1 |
|
|
|
): |
|
|
|
if revision_num > max_revision_num: |
|
|
|
return candidates |
|
|
|
candidates.extend(self._revision(revision_num, pred_res, y)) |
|
|
|
return candidates |
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
|
def _abduce_by_search_cache( |
|
|
|
self, pred_res, y, max_revision_num, require_more_revision |
|
|
|
): |
|
|
|
pred_res = hashable_to_list(pred_res) |
|
|
|
y = hashable_to_list(y) |
|
|
|
return self._abduce_by_search( |
|
|
|
pred_res, y, max_revision_num, require_more_revision |
|
|
|
) |
|
|
|
|
|
|
|
def _dict_len(self, dic): |
|
|
|
if not self.GKB_flag: |
|
|
|
return 0 |
|
|
|
else: |
|
|
|
return sum(len(c) for c in dic.values()) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
if not self.GKB_flag: |
|
|
|
return 0 |
|
|
|
else: |
|
|
|
return sum(self._dict_len(v) for v in self.base.values()) |
|
|
|
|
|
|
|
|
|
|
|
class ground_KB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0): |
|
|
|
super().__init__(pseudo_label_list, max_err) |
|
|
|
|
|
|
|
self.GKB_len_list = GKB_len_list |
|
|
|
self.base = {} |
|
|
|
X, Y = self._get_GKB() |
|
|
|
for x, y in zip(X, Y): |
|
|
|
self.base.setdefault(len(x), defaultdict(list))[y].append(x) |
|
|
|
|
|
|
|
# For parallel version of _get_GKB |
|
|
|
def _get_XY_list(self, args): |
|
|
|
pre_x, post_x_it = args[0], args[1] |
|
|
|
XY_list = [] |
|
|
|
for post_x in post_x_it: |
|
|
|
x = (pre_x,) + post_x |
|
|
|
y = self.logic_forward(x) |
|
|
|
if y is not None: |
|
|
|
XY_list.append((x, y)) |
|
|
|
return XY_list |
|
|
|
|
|
|
|
# Parallel _get_GKB |
|
|
|
def _get_GKB(self): |
|
|
|
X, Y = [], [] |
|
|
|
for length in self.GKB_len_list: |
|
|
|
print("Generating GKB of length %d" % length) |
|
|
|
arg_list = [] |
|
|
|
for pre_x in self.pseudo_label_list: |
|
|
|
post_x_it = product(self.pseudo_label_list, repeat=length - 1) |
|
|
|
arg_list.append((pre_x, post_x_it)) |
|
|
|
with Pool(processes=len(arg_list)) as pool: |
|
|
|
ret_list = pool.map(self._get_XY_list, arg_list) |
|
|
|
for XY_list in ret_list: |
|
|
|
if len(XY_list) == 0: |
|
|
|
continue |
|
|
|
part_X, part_Y = zip(*XY_list) |
|
|
|
X.extend(part_X) |
|
|
|
Y.extend(part_Y) |
|
|
|
if Y and isinstance(Y[0], (int, float)): |
|
|
|
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1])) |
|
|
|
return X, Y |
|
|
|
|
|
|
|
def abduce_candidates(self, data_sample, max_revision_num, require_more_revision=0): |
|
|
|
return self._abduce_by_GKB( |
|
|
|
data_sample, max_revision_num, require_more_revision=require_more_revision |
|
|
|
) |
|
|
|
|
|
|
|
def _find_candidate_GKB(self, cache_key, data_sample): |
|
|
|
y = data_sample["Y"][0] |
|
|
|
if self.max_err == 0: |
|
|
|
return self.base[cache_key][y] |
|
|
|
else: |
|
|
|
potential_candidates = self.base[cache_key] |
|
|
|
key_list = list(potential_candidates.keys()) |
|
|
|
key_idx = bisect.bisect_left(key_list, y) |
|
|
|
|
|
|
|
all_candidates = [] |
|
|
|
for idx in range(key_idx - 1, -1, -1): |
|
|
|
k = key_list[idx] |
|
|
|
if abs(k - y) <= self.max_err: |
|
|
|
all_candidates.extend(potential_candidates[k]) |
|
|
|
else: |
|
|
|
break |
|
|
|
|
|
|
|
for idx in range(key_idx, len(key_list)): |
|
|
|
k = key_list[idx] |
|
|
|
if abs(k - y) <= self.max_err: |
|
|
|
all_candidates.extend(potential_candidates[k]) |
|
|
|
else: |
|
|
|
break |
|
|
|
return all_candidates |
|
|
|
|
|
|
|
def _abduce_by_GKB(self, data_sample, max_revision_num, require_more_revision=0): |
|
|
|
cache_key = len(data_sample["pred_pseudo_label"][0]) |
|
|
|
if self.base == {} or cache_key not in self.GKB_len_list: |
|
|
|
return [] |
|
|
|
|
|
|
|
all_candidates = self._find_candidate_GKB(cache_key, data_sample) |
|
|
|
if len(all_candidates) == 0: |
|
|
|
return [] |
|
|
|
|
|
|
|
cost_array = hamming_dist(data_sample["pred_pseudo_label"][0], all_candidates) |
|
|
|
min_revision_num = np.min(cost_array) |
|
|
|
revision_num = min(max_revision_num, min_revision_num + require_more_revision) |
|
|
|
idxs = np.where(cost_array <= revision_num)[0] |
|
|
|
candidates = [all_candidates[idx] for idx in idxs] |
|
|
|
return candidates |
|
|
|
|
|
|
|
|
|
|
|
class prolog_KB(KBBase): |
|
|
|
def __init__(self, pseudo_label_list, pl_file, max_err=0): |
|
|
|
super().__init__(pseudo_label_list, max_err) |
|
|
|
self.prolog = pyswip.Prolog() |
|
|
|
self.prolog.consult(pl_file) |
|
|
|
|
|
|
|
def logic_forward(self, pseudo_labels): |
|
|
|
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0][ |
|
|
|
"Res" |
|
|
|
] |
|
|
|
if result == "true": |
|
|
|
return True |
|
|
|
elif result == "false": |
|
|
|
return False |
|
|
|
return result |
|
|
|
|
|
|
|
def _revision_pred_res(self, pred_res, revision_idx): |
|
|
|
import re |
|
|
|
|
|
|
|
revision_pred_res = pred_res.copy() |
|
|
|
revision_pred_res = flatten(revision_pred_res) |
|
|
|
|
|
|
|
for idx in revision_idx: |
|
|
|
revision_pred_res[idx] = "P" + str(idx) |
|
|
|
revision_pred_res = reform_idx(revision_pred_res, pred_res) |
|
|
|
|
|
|
|
# TODO:不知道有没有更简洁的方法 |
|
|
|
regex = r"'P\d+'" |
|
|
|
return re.sub( |
|
|
|
regex, lambda x: x.group().replace("'", ""), str(revision_pred_res) |
|
|
|
) |
|
|
|
|
|
|
|
def get_query_string(self, pred_res, y, revision_idx): |
|
|
|
query_string = "logic_forward(" |
|
|
|
query_string += self._revision_pred_res(pred_res, revision_idx) |
|
|
|
key_is_none_flag = y is None or (type(y) == list and y[0] is None) |
|
|
|
query_string += ",%s)." % y if not key_is_none_flag else ")." |
|
|
|
return query_string |
|
|
|
|
|
|
|
def revise_by_idx(self, pred_res, y, revision_idx): |
|
|
|
candidates = [] |
|
|
|
query_string = self.get_query_string(pred_res, y, revision_idx) |
|
|
|
save_pred_res = pred_res |
|
|
|
pred_res = flatten(pred_res) |
|
|
|
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] |
|
|
|
for c in abduce_c: |
|
|
|
candidate = pred_res.copy() |
|
|
|
for i, idx in enumerate(revision_idx): |
|
|
|
candidate[idx] = c[i] |
|
|
|
candidate = reform_idx(candidate, save_pred_res) |
|
|
|
candidates.append(candidate) |
|
|
|
return candidates |