@@ -0,0 +1,24 @@ | |||
# Grounded abductive learning | |||
This is the code repository of grounded abductive learning. | |||
## Environment dependency | |||
... | |||
## Example | |||
share_example.py and nonshare_exaple.py are examples of grounded abductive learning. | |||
```bash | |||
python share_example.py | |||
``` | |||
## Authors | |||
- [Le-Wen Cai](http://www.lamda.nju.edu.cn/cailw/) (Nanjing University) | |||
- [Wang-Zhou Dai](http://daiwz.net) (Imperial College London) | |||
- [Yu-Xuan Huang](http://www.lamda.nju.edu.cn/huangyx/) (Nanjing University) | |||
## NOTICE | |||
They can only be used for academic purpose. For other purposes, please contact with LAMDA Group(www.lamda.nju.edu.cn). | |||
@@ -0,0 +1,104 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :abducer_base.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/06/03 | |||
# Description : | |||
# | |||
#================================================================# | |||
import abc | |||
from abducer.kb import ClsKB, RegKB | |||
#from kb import ClsKB, RegKB | |||
import numpy as np | |||
def hamming_dist(A, B): | |||
B = np.array(B) | |||
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
return np.sum(A != B, axis = 1) | |||
def confidence_dist(A, B): | |||
B = np.array(B) | |||
#print(A) | |||
A = np.clip(A, 1e-9, 1) | |||
A = np.expand_dims(A, axis=0) | |||
A = A.repeat(axis=0, repeats=(len(B))) | |||
rows = np.array(range(len(B))) | |||
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) | |||
cols = np.array(range(len(B[0]))) | |||
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) | |||
return 1 - np.prod(A[rows, cols, B], axis = 1) | |||
class AbducerBase(abc.ABC): | |||
def __init__(self, kb, dist_func = "hamming", pred_res_parse = None): | |||
self.kb = kb | |||
if dist_func == "hamming": | |||
dist_func = hamming_dist | |||
elif dist_func == "confidence": | |||
dist_func = confidence_dist | |||
self.dist_func = dist_func | |||
if pred_res_parse is None: | |||
pred_res_parse = lambda x : x["cls"] | |||
self.pred_res_parse = pred_res_parse | |||
def abduce(self, data, max_address_num, require_more_address, length = -1): | |||
pred_res, ans = data | |||
if length == -1: | |||
length = len(pred_res) | |||
candidates = self.kb.get_candidates(ans, length) | |||
pred_res = np.array(pred_res) | |||
cost_list = self.dist_func(pred_res, candidates) | |||
address_num = np.min(cost_list) | |||
threshold = min(address_num + require_more_address, max_address_num) | |||
idxs = np.where(cost_list <= address_num+require_more_address)[0] | |||
#return [candidates[idx] for idx in idxs], address_num | |||
if len(idxs) > 1: | |||
return None | |||
return [candidates[idx] for idx in idxs][0] | |||
def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0): | |||
return [ | |||
self.abduce((y, c), max_address_num, require_more_address)\ | |||
for y, c in zip(self.pred_res_parse(Y), C) | |||
] | |||
def __call__(self, Y, C, max_address_num = 3, require_more_address = 0): | |||
return batch_abduce(Y, C, max_address_num, require_more_address) | |||
if __name__ == "__main__": | |||
#["1+1", "0+1", "1+0", "2+0"] | |||
X = [[1,3,1], [0,3,1], [1,2,0], [3,2,0]] | |||
Y = [2, 1, 1, 2] | |||
kb = RegKB(X, Y) | |||
abd = AbducerBase(kb) | |||
res = abd.abduce(([0,2,0], None), 1, 0) | |||
print(res) | |||
res = abd.abduce(([0, 2, 0], 0.99), 1, 0) | |||
print(res) | |||
A = np.array([[0.5, 0.25, 0.25, 0], [0.3, 0.3, 0.3, 0.1], [0.1, 0.2, 0.3, 0.4]]) | |||
B = [[1, 2, 3], [0, 1, 3]] | |||
res = confidence_dist(A, B) | |||
print(res) | |||
A = np.array([[0.5, 0.25, 0.25, 0], [0.3, 1.0, 0.3, 0.1], [0.1, 0.2, 0.3, 1.0]]) | |||
B = [[0, 1, 3]] | |||
res = confidence_dist(A, B) | |||
print(res) | |||
kb_str = ['10010001011', '00010001100', '00111101011', '11101000011', '11110011001', '11111010001', '10001010010', '11100100001', '10001001100', '11011010001', '00110000100', '11000000111', '01110111111', '11000101100', '10101011010', '00000110110', '11111110010', '11100101100', '10111001111', '10000101100', '01001011101', '01001110000', '01110001110', '01010010001', '10000100010', '01001011011', '11111111100', '01011101101', '00101110101', '11101001101', '10010110000', '10000000011'] | |||
X = [[int(c) for c in s] for s in kb_str] | |||
kb = RegKB(X, len(X) * [None]) | |||
abd = AbducerBase(kb) | |||
res = abd.abduce(((1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1), None), 1, 0) | |||
print(res) |
@@ -0,0 +1,137 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2021 LAMDA All rights reserved. | |||
# | |||
# File Name :kb.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/06/03 | |||
# Description : | |||
# | |||
#================================================================# | |||
import abc | |||
import bisect | |||
import copy | |||
import numpy as np | |||
from collections import defaultdict | |||
class KBBase(abc.ABC): | |||
def __init__(self, X = None, Y = None): | |||
pass | |||
def get_candidates(self, key = None, length = None): | |||
pass | |||
def get_all_candidates(self): | |||
pass | |||
def _length(self, length): | |||
if length is None: | |||
length = list(self.base.keys()) | |||
if type(length) is int: | |||
length = [length] | |||
return length | |||
def __len__(self): | |||
pass | |||
class ClsKB(KBBase): | |||
def __init__(self, X, Y = None): | |||
super().__init__() | |||
self.base = {} | |||
if X is None: | |||
return | |||
if Y is None: | |||
Y = [None] * len(X) | |||
for x, y in zip(X, Y): | |||
self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
def get_candidates(self, key, length = None): | |||
if key is None: | |||
return self.get_all_candidates() | |||
length = self._length(length) | |||
return sum([self.base[l][key] for l in length], []) | |||
def get_all_candidates(self): | |||
return sum([sum(v.values(), []) for v in self.base.values()], []) | |||
def _dict_len(self, dic): | |||
return sum(len(c) for c in dic.values()) | |||
def __len__(self): | |||
return sum(self._dict_len(v) for v in self.base.values()) | |||
class RegKB(KBBase): | |||
def __init__(self, X, Y = None): | |||
super().__init__() | |||
tmp_dict = {} | |||
for x, y in zip(X, Y): | |||
tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
self.base = {} | |||
for l in tmp_dict.keys(): | |||
data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) | |||
X = [x for y, x in data] | |||
Y = [y for y, x in data] | |||
self.base[l] = (X, Y) | |||
def get_candidates(self, key, length = None): | |||
if key is None: | |||
return self.get_all_candidates() | |||
length = self._length(length) | |||
min_err = 999999 | |||
candidates = [] | |||
for l in length: | |||
X, Y = self.base[l] | |||
idx = bisect.bisect_left(Y, key) | |||
begin = max(0, idx - 1) | |||
end = min(idx + 2, len(X)) | |||
for idx in range(begin, end): | |||
err = abs(Y[idx] - key) | |||
if abs(err - min_err) < 1e-9: | |||
candidates.extend(X[idx]) | |||
elif err < min_err: | |||
candidates = copy.deepcopy(X[idx]) | |||
min_err = err | |||
return candidates | |||
def get_all_candidates(self): | |||
return sum([sum(D[0], []) for D in self.base.values()], []) | |||
def __len__(self): | |||
return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) | |||
if __name__ == "__main__": | |||
X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
Y = [2, 1, 1, 2, 2] | |||
kb = ClsKB(X, Y) | |||
print(len(kb)) | |||
res = kb.get_candidates(2, 5) | |||
print(res) | |||
res = kb.get_candidates(2, 3) | |||
print(res) | |||
res = kb.get_candidates(None) | |||
print(res) | |||
X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
Y = [2, 1, 1, 2, 1.5, 1.5] | |||
kb = RegKB(X, Y) | |||
print(len(kb)) | |||
res = kb.get_candidates(1.6) | |||
print(res) | |||
res = kb.get_candidates(1.6, length = 9) | |||
print(res) | |||
res = kb.get_candidates(None) | |||
print(res) | |||
@@ -0,0 +1,186 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2020 Freecss All rights reserved. | |||
# | |||
# File Name :data_generator.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2020/04/02 | |||
# Description : | |||
# | |||
#================================================================# | |||
from itertools import product | |||
import math | |||
import numpy as np | |||
import random | |||
import pickle as pk | |||
import random | |||
from multiprocessing import Pool | |||
import copy | |||
#def hamming_code_generator(data_len, p_len): | |||
# ret = [] | |||
# for data in product((0, 1), repeat=data_len): | |||
# p_idxs = [2 ** i for i in range(p_len)] | |||
# total_len = data_len + p_len | |||
# data_idx = 0 | |||
# hamming_code = [] | |||
# for idx in range(total_len): | |||
# if idx + 1 in p_idxs: | |||
# hamming_code.append(0) | |||
# else: | |||
# hamming_code.append(data[data_idx]) | |||
# data_idx += 1 | |||
# | |||
# for idx in range(total_len): | |||
# if idx + 1 in p_idxs: | |||
# for i in range(total_len): | |||
# if (i + 1) & (idx + 1) != 0: | |||
# hamming_code[idx] ^= hamming_code[i] | |||
# #hamming_code = "".join([str(x) for x in hamming_code]) | |||
# ret.append(hamming_code) | |||
# return ret | |||
def code_generator(code_len, code_num, letter_num = 2): | |||
codes = list(product(list(range(letter_num)), repeat = code_len)) | |||
random.shuffle(codes) | |||
return codes[:code_num] | |||
def hamming_distance_static(codes): | |||
min_dist = len(codes) | |||
avg_dist = 0. | |||
avg_min_dist = 0. | |||
relation_num = 0. | |||
for code1 in codes: | |||
tmp_min_dist = len(codes) | |||
for code2 in codes: | |||
if code1 == code2: | |||
continue | |||
dist = 0 | |||
relation_num += 1 | |||
for c1, c2 in zip(code1, code2): | |||
if c1 != c2: | |||
dist += 1 | |||
avg_dist += dist | |||
if tmp_min_dist > dist: | |||
tmp_min_dist = dist | |||
avg_min_dist += tmp_min_dist | |||
if min_dist > tmp_min_dist: | |||
min_dist = tmp_min_dist | |||
return avg_dist / relation_num, avg_min_dist / len(codes) | |||
def generate_cosin_data(codes, err, repeat, letter_num): | |||
Y = np.random.random(100000) * letter_num * 3 - 3 | |||
X = np.random.random(100000) * 20 - 10 | |||
data_X = np.concatenate((X.reshape(-1, 1), Y.reshape(-1, 1)), axis = 1) | |||
samples = {} | |||
all_sign = list(set(sum([[c for c in code] for code in codes], []))) | |||
for d, sign in enumerate(all_sign): | |||
labels = np.logical_and(Y < np.cos(X) + 2 * d, Y > np.cos(X) + 2 * d - 2) | |||
samples[sign] = data_X[labels] | |||
data = [] | |||
labels = [] | |||
count = 0 | |||
for _ in range(repeat): | |||
if (count > 100000): | |||
break | |||
for code in codes: | |||
tmp = [] | |||
count += 1 | |||
for d in code: | |||
if random.random() < err: | |||
candidates = copy.deepcopy(all_sign) | |||
candidates.remove(d) | |||
d = candidates[random.randint(0, letter_num - 2)] | |||
idx = random.randint(0, len(samples[d]) - 1) | |||
tmp.append(samples[d][idx]) | |||
data.append(tmp) | |||
labels.append(code) | |||
data = np.array(data) | |||
labels = np.array(labels) | |||
return data, labels | |||
#codes = """110011001 | |||
#100011001 | |||
#101101101 | |||
#011111001 | |||
#100100001 | |||
#111111101 | |||
#101110001 | |||
#111100101 | |||
#101000101 | |||
#001001101 | |||
#111110101 | |||
#100101001 | |||
#010010101 | |||
#110100101 | |||
#001111101 | |||
#111111001""" | |||
#codes = codes.split() | |||
def generate_data_via_codes(codes, err, letter_num): | |||
#codes = code_generator(code_len, code_num) | |||
data, labels = generate_cosin_data(codes, err, 100000, letter_num) | |||
return data, labels | |||
def generate_data(params): | |||
code_len = params["code_len"] | |||
times = params["times"] | |||
p = params["p"] | |||
code_num = params["code_num"] | |||
err = p / 20. | |||
codes = code_generator(code_len, code_num) | |||
data, labels = generate_cosin_data(codes, err) | |||
data_name = "code_%d_%d" % (code_len, code_num) | |||
pk.dump((codes, data, labels), open("generated_data/%d_%s_%.2f.pk" % (times, data_name, err), "wb")) | |||
return True | |||
def generate_multi_data(): | |||
pool = Pool(64) | |||
params_list = [] | |||
#for code_len in [7, 9, 11, 13, 15]: | |||
for code_len in [7, 11, 15]: | |||
for times in range(20): | |||
for p in range(0, 11): | |||
for code_num_power in range(1, code_len): | |||
code_num = 2 ** code_num_power | |||
params_list.append({"code_len" : code_len, "times" : times, "p" : p, "code_num" : code_num}) | |||
return list(pool.map(generate_data, params_list)) | |||
def read_lexicon(file_path): | |||
ret = [] | |||
with open(file_path) as fin: | |||
ret = [s.strip() for s in fin] | |||
all_sign = list(set(sum([[c for c in s] for s in ret], []))) | |||
#ret = ["".join(str(all_sign.index(t)) for t in tmp) for tmp in ret] | |||
return ret, len(all_sign) | |||
import os | |||
if __name__ == "__main__": | |||
for root, dirs, files in os.walk("lexicons"): | |||
if root != "lexicons": | |||
continue | |||
for file_name in files: | |||
file_path = os.path.join(root, file_name) | |||
codes, letter_num = read_lexicon(file_path) | |||
data, labels = generate_data_via_codes(codes, 0, letter_num) | |||
save_path = os.path.join("dataset", file_name.split(".")[0] + ".pk") | |||
pk.dump((data, labels, codes), open(save_path, "wb")) | |||
#res = read_lexicon("add2.txt") | |||
#print(res) | |||
exit(0) | |||
generate_multi_data() | |||
exit() |
@@ -0,0 +1,155 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :framework.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/06/07 | |||
# Description : | |||
# | |||
#================================================================# | |||
import pickle as pk | |||
import numpy as np | |||
from utils.plog import INFO, DEBUG, clocker | |||
@clocker | |||
def block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx): | |||
part_num = (len(X_bak) // sample_num) | |||
if part_num == 0: | |||
part_num = 1 | |||
seg_idx = epoch_idx % part_num | |||
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak)) | |||
X = X_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] | |||
Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] | |||
C = C_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)] | |||
return X, Y, C | |||
def get_taglist(self, Y): | |||
tmp = [[str(x) for x in label] for label in Y] | |||
tmp = sorted(list(set(tmp))) | |||
return tmp | |||
@clocker | |||
def result_statistics(pseudo_Y, Y, abduced_Y): | |||
abd_err_num = 0 | |||
abd_char_num = 0 | |||
abd_char_acc = 0 | |||
abd_failed = 0 | |||
word_err_num = 0 | |||
ori_char_num = 0 | |||
ori_char_acc = 0 | |||
for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)): | |||
pseudo_y = pseudo_y | |||
if sum(abduced_y != y) != 0: | |||
abd_err_num += 1 | |||
if abduced_y is not None: | |||
abd_char_num += len(y) | |||
abd_char_acc += sum(abduced_y == y) | |||
else: | |||
abd_failed += 1 | |||
ori_char_num += len(pseudo_y) | |||
ori_char_acc += sum(pseudo_y == y) | |||
if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0: | |||
INFO(pseudo_y, y, abduced_y) | |||
pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb")) | |||
if sum(pseudo_y != y) != 0: | |||
word_err_num += 1 | |||
INFO("") | |||
INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y)) | |||
INFO("Abd char level accuracy:", abd_char_acc / abd_char_num) | |||
INFO("Ori char level accuracy:", ori_char_acc / ori_char_num) | |||
INFO("") | |||
result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num, | |||
"total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc, | |||
"total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc, | |||
"total_abd_failed": abd_failed} | |||
return result | |||
@clocker | |||
def filter_data(X, abduced_Y): | |||
finetune_Y = [] | |||
finetune_X = [] | |||
for abduced_x, abduced_y in zip(X, abduced_Y): | |||
if abduced_y is not None: | |||
finetune_X.append(abduced_x) | |||
finetune_Y.append(abduced_y) | |||
return finetune_X, finetune_Y | |||
@clocker | |||
def is_all_sublabel_exist(labels, std_label_list): | |||
if not labels: | |||
return False | |||
labels = np.array(labels).T | |||
for idx, (std_label, label) in enumerate(zip(std_label_list, labels)): | |||
std_num = len(set(std_label)) | |||
sublabel_num = len(set(label)) | |||
if std_num != sublabel_num: | |||
INFO(f"sublabel {idx} should have {std_num} class, but data only have {sublabel_num} class", screen=True) | |||
return False | |||
return True | |||
def pretrain(model, X, Y): | |||
pass | |||
def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True): | |||
# Set default parameters | |||
if sample_num == -1: | |||
sample_num = len(X) | |||
if verbose < 1: | |||
verbose = epochs | |||
if C is None: | |||
C = [None] * len(X) | |||
# Set function running time recorder | |||
valid_func = clocker(model.valid) | |||
predict_func = clocker(model.predict) | |||
train_func = clocker(model.train) | |||
abduce_func = clocker(abducer.batch_abduce) | |||
X_bak = X | |||
Y_bak = Y | |||
C_bak = C | |||
# Abductive learning train process | |||
res = {} | |||
for epoch_idx in range(epochs): | |||
X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx) | |||
preds_res = predict_func(X) | |||
abduced_Y = abduce_func(preds_res, C) | |||
finetune_X, finetune_Y = filter_data(X, abduced_Y) | |||
score, score_list = valid_func(X, Y) | |||
if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1): | |||
res = result_statistics(preds_res["cls"], Y, abduced_Y) | |||
INFO(res) | |||
if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)): | |||
INFO("There is some sub label missing", len(finetune_Y)) | |||
break | |||
if len(finetune_X) > 0: | |||
train_func(finetune_X, finetune_Y)#, n_epoch = 10) | |||
else: | |||
INFO("lack of data, all abduced failed", len(finetune_X)) | |||
return res | |||
#return ret | |||
if __name__ == "__main__": | |||
pass |
@@ -0,0 +1,362 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2020 Freecss All rights reserved. | |||
# | |||
# File Name :basic_model.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2020/11/21 | |||
# Description : | |||
# | |||
#================================================================# | |||
import sys | |||
sys.path.append("..") | |||
import torch | |||
from torch.autograd import Variable | |||
from torch.utils.data import Dataset | |||
import torchvision | |||
import utils.utils as mutils | |||
import os | |||
from multiprocessing import Pool | |||
import random | |||
import torch | |||
from torch.utils.data import Dataset | |||
from torch.utils.data import sampler | |||
import torchvision.transforms as transforms | |||
import six | |||
import sys | |||
from PIL import Image | |||
import numpy as np | |||
import collections | |||
class resizeNormalize(object): | |||
def __init__(self, size, interpolation=Image.BILINEAR): | |||
self.size = size | |||
self.interpolation = interpolation | |||
self.toTensor = transforms.ToTensor() | |||
self.transform = transforms.Compose([ | |||
#transforms.ToPILImage(), | |||
#transforms.RandomHorizontalFlip(), | |||
#transforms.RandomVerticalFlip(), | |||
#transforms.RandomRotation(30), | |||
#transforms.RandomAffine(30), | |||
transforms.ToTensor(), | |||
]) | |||
def __call__(self, img): | |||
#img = img.resize(self.size, self.interpolation) | |||
#img = self.toTensor(img) | |||
img = self.transform(img) | |||
img.sub_(0.5).div_(0.5) | |||
return img | |||
class XYDataset(Dataset): | |||
def __init__(self, X, Y, transform=None, target_transform=None): | |||
self.X = X | |||
self.Y = Y | |||
self.n_sample = len(X) | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
def __len__(self): | |||
return len(self.X) | |||
def __getitem__(self, index): | |||
assert index < len(self), 'index range error' | |||
img = self.X[index] | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
label = self.Y[index] | |||
if self.target_transform is not None: | |||
label = self.target_transform(label) | |||
return (img, label, index) | |||
class alignCollate(object): | |||
def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): | |||
self.imgH = imgH | |||
self.imgW = imgW | |||
self.keep_ratio = keep_ratio | |||
self.min_ratio = min_ratio | |||
def __call__(self, batch): | |||
images, labels, img_keys = zip(*batch) | |||
imgH = self.imgH | |||
imgW = self.imgW | |||
if self.keep_ratio: | |||
ratios = [] | |||
for image in images: | |||
w, h = image.shape[:2] | |||
ratios.append(w / float(h)) | |||
ratios.sort() | |||
max_ratio = ratios[-1] | |||
imgW = int(np.floor(max_ratio * imgH)) | |||
imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW | |||
transform = resizeNormalize((imgW, imgH)) | |||
images = [transform(image) for image in images] | |||
images = torch.cat([t.unsqueeze(0) for t in images], 0) | |||
labels = torch.LongTensor(labels) | |||
return images, labels, img_keys | |||
class FakeRecorder(): | |||
def __init__(self): | |||
pass | |||
def print(self, *x): | |||
pass | |||
from torch.nn import init | |||
from torch import nn | |||
def weigth_init(m): | |||
if isinstance(m, nn.Conv2d): | |||
init.xavier_uniform_(m.weight.data) | |||
init.constant_(m.bias.data,0.1) | |||
elif isinstance(m, nn.BatchNorm2d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
elif isinstance(m, nn.Linear): | |||
m.weight.data.normal_(0,0.01) | |||
m.bias.data.zero_() | |||
class BasicModel(): | |||
def __init__(self, | |||
model, | |||
criterion, | |||
optimizer, | |||
converter, | |||
device, | |||
params, | |||
sign_list, | |||
recorder = None): | |||
self.model = model.to(device) | |||
self.model.apply(weigth_init) | |||
self.criterion = criterion | |||
self.optimizer = optimizer | |||
self.converter = converter | |||
self.device = device | |||
sign_list = sorted(list(set(sign_list))) | |||
self.mapping = dict(zip(sign_list, list(range(len(sign_list))))) | |||
self.remapping = dict(zip(list(range(len(sign_list))), sign_list)) | |||
if recorder is None: | |||
recorder = FakeRecorder() | |||
self.recorder = recorder | |||
self.save_interval = params.saveInterval | |||
self.params = params | |||
pass | |||
def _fit(self, data_loader, n_epoch, stop_loss): | |||
recorder = self.recorder | |||
recorder.print("model fitting") | |||
min_loss = 999999999 | |||
for epoch in range(n_epoch): | |||
loss_value = self.train_epoch(data_loader) | |||
recorder.print(f"{epoch}/{n_epoch} model training loss is {loss_value}") | |||
if loss_value < min_loss: | |||
min_loss = loss_value | |||
if loss_value < stop_loss: | |||
break | |||
recorder.print("Model fitted, minimal loss is ", min_loss) | |||
return loss_value | |||
def str2ints(self, Y): | |||
return [self.mapping[y] for y in Y] | |||
def fit(self, data_loader = None, | |||
X = None, | |||
y = None, | |||
n_epoch = 100, | |||
stop_loss = 0.001): | |||
if data_loader is None: | |||
params = self.params | |||
Y = self.str2ints(y) | |||
train_dataset = XYDataset(X, Y) | |||
sampler = None | |||
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \ | |||
shuffle=True, sampler=sampler, num_workers=int(params.workers), \ | |||
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) | |||
return self._fit(data_loader, n_epoch, stop_loss) | |||
def train_epoch(self, data_loader): | |||
loss_avg = mutils.averager() | |||
for i, data in enumerate(data_loader): | |||
X = data[0] | |||
Y = data[1] | |||
cost = self.train_batch(X, Y) | |||
loss_avg.add(cost) | |||
loss_value = float(loss_avg.val()) | |||
loss_avg.reset() | |||
return loss_value | |||
def train_batch(self, X, Y): | |||
#cpu_images, cpu_texts, _ = data | |||
model = self.model | |||
criterion = self.criterion | |||
optimizer = self.optimizer | |||
converter = self.converter | |||
device = self.device | |||
# set training mode | |||
for p in model.parameters(): | |||
p.requires_grad = True | |||
model.train() | |||
# init training status | |||
torch.autograd.set_detect_anomaly(True) | |||
optimizer.zero_grad() | |||
# model predict | |||
X = X.to(device) | |||
Y = Y.to(device) | |||
pred_Y = model(X) | |||
# calculate loss | |||
loss = criterion(pred_Y, Y) | |||
# back propagation and optimize | |||
loss.backward() | |||
optimizer.step() | |||
return loss | |||
def _predict(self, data_loader): | |||
model = self.model | |||
criterion = self.criterion | |||
converter = self.converter | |||
params = self.params | |||
device = self.device | |||
for p in model.parameters(): | |||
p.requires_grad = False | |||
model.eval() | |||
n_correct = 0 | |||
results = [] | |||
for i, data in enumerate(data_loader): | |||
X = data[0].to(device) | |||
pred_Y = model(X) | |||
results.append(pred_Y) | |||
return torch.cat(results, axis=0) | |||
def predict(self, data_loader = None, X = None, print_prefix = ""): | |||
params = self.params | |||
if data_loader is None: | |||
Y = [0] * len(X) | |||
val_dataset = XYDataset(X, Y) | |||
sampler = None | |||
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ | |||
shuffle=False, sampler=sampler, num_workers=int(params.workers), \ | |||
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) | |||
recorder = self.recorder | |||
recorder.print('Start Predict ', print_prefix) | |||
Y = self._predict(data_loader).argmax(axis=1) | |||
return [self.remapping[int(y)] for y in Y] | |||
def predict_proba(self, data_loader = None, X = None, print_prefix = ""): | |||
params = self.params | |||
if data_loader is None: | |||
Y = [0] * len(X) | |||
val_dataset = XYDataset(X, Y) | |||
sampler = None | |||
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ | |||
shuffle=False, sampler=sampler, num_workers=int(params.workers), \ | |||
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) | |||
recorder = self.recorder | |||
recorder.print('Start Predict ', print_prefix) | |||
return torch.softmax(self._predict(data_loader), axis=1) | |||
def _val(self, data_loader, print_prefix): | |||
model = self.model | |||
criterion = self.criterion | |||
recorder = self.recorder | |||
converter = self.converter | |||
params = self.params | |||
device = self.device | |||
recorder.print('Start val ', print_prefix) | |||
for p in model.parameters(): | |||
p.requires_grad = False | |||
model.eval() | |||
n_correct = 0 | |||
pred_num = 0 | |||
loss_avg = mutils.averager() | |||
for i, data in enumerate(data_loader): | |||
X = data[0].to(device) | |||
Y = data[1].to(device) | |||
pred_Y = model(X) | |||
correct_num = sum(Y == pred_Y.argmax(axis=1)) | |||
loss = criterion(pred_Y, Y) | |||
loss_avg.add(loss) | |||
n_correct += correct_num | |||
pred_num += len(X) | |||
accuracy = float(n_correct) / float(pred_num) | |||
recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, loss_avg.val(), accuracy)) | |||
return accuracy | |||
def val(self, data_loader = None, X = None, y = None, print_prefix = ""): | |||
params = self.params | |||
if data_loader is None: | |||
y = self.str2ints(y) | |||
val_dataset = XYDataset(X, y) | |||
sampler = None | |||
data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=params.batchSize, \ | |||
shuffle=True, sampler=sampler, num_workers=int(params.workers), \ | |||
collate_fn=alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) | |||
return self._val(data_loader, print_prefix) | |||
def score(self, data_loader = None, X = None, y = None, print_prefix = ""): | |||
return self.val(data_loader, X, y, print_prefix) | |||
def save(self, save_dir): | |||
recorder = self.recorder | |||
if not os.path.exists(save_dir): | |||
os.mkdir(save_dir) | |||
recorder.print("Saving model and opter") | |||
save_path = os.path.join(save_dir, "net.pth") | |||
torch.save(self.model.state_dict(), save_path) | |||
save_path = os.path.join(save_dir, "opt.pth") | |||
torch.save(self.optimizer.state_dict(), save_path) | |||
def load(self, load_dir): | |||
recorder = self.recorder | |||
recorder.print("Loading model and opter") | |||
load_path = os.path.join(load_dir, "net.pth") | |||
self.model.load_state_dict(torch.load(load_path)) | |||
load_path = os.path.join(load_dir, "opt.pth") | |||
self.optimizer.load_state_dict(torch.load(load_path)) | |||
if __name__ == "__main__": | |||
pass | |||
@@ -0,0 +1,96 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :lenet5.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/03/03 | |||
# Description : | |||
# | |||
#================================================================# | |||
import sys | |||
sys.path.append("..") | |||
import torchvision | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from torch.autograd import Variable | |||
import torchvision.transforms as transforms | |||
from models.basic_model import BasicModel | |||
import utils.plog as plog | |||
class LeNet5(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = nn.Conv2d(1, 6, 3, padding=1) | |||
self.conv2 = nn.Conv2d(6, 16, 3) | |||
self.conv3 = nn.Conv2d(16, 16, 3) | |||
self.fc1 = nn.Linear(256, 120) | |||
self.fc2 = nn.Linear(120, 84) | |||
self.fc3 = nn.Linear(84, 13) | |||
def forward(self, x): | |||
'''前向传播函数''' | |||
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) | |||
x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2)) | |||
x = F.relu(self.conv3(x)) | |||
x = x.view(-1, self.num_flat_features(x)) | |||
#print(x.size()) | |||
x = F.relu(self.fc1(x)) | |||
x = F.relu(self.fc2(x)) | |||
x = self.fc3(x) | |||
return x | |||
def num_flat_features(self, x): | |||
#x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_size | |||
size = x.size()[1:] #x.size返回的是一个元组,size表示截取元组中第二个开始的数字 | |||
num_features = 1 | |||
for s in size: | |||
num_features *= s | |||
return num_features | |||
class Params: | |||
imgH = 28 | |||
imgW = 28 | |||
keep_ratio = True | |||
saveInterval = 10 | |||
batchSize = 16 | |||
num_workers = 16 | |||
def get_data(): #数据预处理 | |||
transform = transforms.Compose([transforms.ToTensor(), | |||
transforms.Normalize((0.5), (0.5))]) | |||
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |||
#训练集 | |||
train_set = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True) | |||
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024, shuffle=True, num_workers = 16) | |||
#测试集 | |||
test_set = torchvision.datasets.MNIST(root='data/', train=False, transform=transform, download=True) | |||
test_loader = torch.utils.data.DataLoader(test_set, batch_size = 1024, shuffle = False, num_workers = 16) | |||
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') | |||
return train_loader, test_loader, classes | |||
if __name__ == "__main__": | |||
recorder = plog.ResultRecorder() | |||
cls = LeNet5() | |||
criterion = nn.CrossEntropyLoss(size_average=True) | |||
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) | |||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
model = BasicModel(cls, criterion, optimizer, None, device, Params(), recorder) | |||
train_loader, test_loader, classes = get_data() | |||
#model.val(test_loader, print_prefix = "before training") | |||
model.fit(train_loader, n_epoch = 100) | |||
model.val(test_loader, print_prefix = "after trained") | |||
res = model.predict(test_loader, print_prefix = "predict") | |||
print(res.argmax(axis=1)[:10]) | |||
@@ -0,0 +1,189 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2020 Freecss All rights reserved. | |||
# | |||
# File Name :models.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2020/04/02 | |||
# Description : | |||
# | |||
#================================================================# | |||
from itertools import chain | |||
from sklearn.tree import DecisionTreeClassifier | |||
from sklearn.model_selection import cross_val_score | |||
from sklearn.svm import LinearSVC | |||
from sklearn.pipeline import make_pipeline | |||
from sklearn.preprocessing import StandardScaler | |||
from sklearn.svm import SVC | |||
from sklearn.gaussian_process import GaussianProcessClassifier | |||
from sklearn.gaussian_process.kernels import RBF | |||
import pickle as pk | |||
import random | |||
from sklearn.neighbors import KNeighborsClassifier | |||
import numpy as np | |||
def get_part_data(X, i): | |||
return list(map(lambda x : x[i], X)) | |||
def merge_data(X): | |||
ret_mark = list(map(lambda x : len(x), X)) | |||
ret_X = list(chain(*X)) | |||
return ret_X, ret_mark | |||
def reshape_data(Y, marks): | |||
begin_mark = 0 | |||
ret_Y = [] | |||
for mark in marks: | |||
end_mark = begin_mark + mark | |||
ret_Y.append(Y[begin_mark:end_mark]) | |||
begin_mark = end_mark | |||
return ret_Y | |||
class WABLBasicModel: | |||
""" | |||
label_lists 的目标在于为各个符号设置编号,无论方法是给出字典形式的概率还是给出list形式的,都可以通过这种方式解决. | |||
后续可能会考虑更加完善的措施,降低这部分的复杂度 | |||
当模型共享的时候,label_lists 之间的元素也是共享的 | |||
""" | |||
def __init__(self): | |||
pass | |||
def predict(self, X): | |||
if self.share: | |||
data_X, marks = merge_data(X) | |||
prob = self.cls_list[0].predict_proba(X = data_X) | |||
cls = np.array(prob).argmax(axis = 1) | |||
prob = reshape_data(prob, marks) | |||
cls = reshape_data(cls, marks) | |||
else: | |||
cls_result = [] | |||
prob_result = [] | |||
for i in range(self.code_len): | |||
data_X = get_part_data(X, i) | |||
tmp_prob = self.cls_list[i].predict_proba(X = data_X) | |||
cls_result.append(np.array(tmp_prob).argmax(axis = 1)) | |||
prob_result.append(tmp_prob) | |||
cls = list(zip(*cls_result)) | |||
prob = list(zip(*prob_result)) | |||
return {"cls" : cls, "prob" : prob} | |||
def valid(self, X, Y): | |||
if self.share: | |||
data_X, _ = merge_data(X) | |||
data_Y, _ = merge_data(Y) | |||
score = self.cls_list[0].score(X = data_X, y = data_Y) | |||
return score, [score] | |||
else: | |||
score_list = [] | |||
for i in range(self.code_len): | |||
data_X = get_part_data(X, i) | |||
data_Y = get_part_data(Y, i) | |||
score_list.append(self.cls_list[i].score(data_X, data_Y)) | |||
return sum(score_list) / len(score_list), score_list | |||
def train(self, X, Y): | |||
#self.label_lists = [] | |||
if self.share: | |||
data_X, _ = merge_data(X) | |||
data_Y, _ = merge_data(Y) | |||
self.cls_list[0].fit(X = data_X, y = data_Y) | |||
else: | |||
for i in range(self.code_len): | |||
data_X = get_part_data(X, i) | |||
data_Y = get_part_data(Y, i) | |||
self.cls_list[i].fit(data_X, data_Y) | |||
def _set_label_lists(self, label_lists): | |||
label_lists = [sorted(list(set(label_list))) for label_list in label_lists] | |||
self.label_lists = label_lists | |||
class DecisionTree(WABLBasicModel): | |||
def __init__(self, code_len, label_lists, share = False): | |||
self.code_len = code_len | |||
self._set_label_lists(label_lists) | |||
self.cls_list = [] | |||
self.share = share | |||
if share: | |||
# 本质上是同一个分类器 | |||
self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3)) | |||
self.cls_list = self.cls_list * self.code_len | |||
else: | |||
for _ in range(code_len): | |||
self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3)) | |||
class KNN(WABLBasicModel): | |||
def __init__(self, code_len, label_lists, share = False, k = 3): | |||
self.code_len = code_len | |||
self._set_label_lists(label_lists) | |||
self.cls_list = [] | |||
self.share = share | |||
if share: | |||
# 本质上是同一个分类器 | |||
self.cls_list.append(KNeighborsClassifier(n_neighbors = k)) | |||
self.cls_list = self.cls_list * self.code_len | |||
else: | |||
for _ in range(code_len): | |||
self.cls_list.append(KNeighborsClassifier(n_neighbors = k)) | |||
class CNN(WABLBasicModel): | |||
def __init__(self, base_model, code_len, label_lists, share = True): | |||
assert share == True, "Not implemented" | |||
label_lists = [sorted(list(set(label_list))) for label_list in label_lists] | |||
self.label_lists = label_lists | |||
self.code_len = code_len | |||
self.cls_list = [] | |||
self.share = share | |||
if share: | |||
self.cls_list.append(base_model) | |||
def train(self, X, Y, n_epoch = 100): | |||
#self.label_lists = [] | |||
if self.share: | |||
# 因为是同一个分类器,所以只需要把数据放在一起,然后训练其中任意一个即可 | |||
data_X, _ = merge_data(X) | |||
data_Y, _ = merge_data(Y) | |||
self.cls_list[0].fit(X = data_X, y = data_Y, n_epoch = n_epoch) | |||
#self.label_lists = [sorted(list(set(data_Y)))] * self.code_len | |||
else: | |||
for i in range(self.code_len): | |||
data_X = get_part_data(X, i) | |||
data_Y = get_part_data(Y, i) | |||
self.cls_list[i].fit(data_X, data_Y) | |||
#self.label_lists.append(sorted(list(set(data_Y)))) | |||
if __name__ == "__main__": | |||
#data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk" | |||
data_path = "datasets/generated_data/0_code_7_2_0.00.pk" | |||
codes, data, labels = pk.load(open(data_path, "rb")) | |||
cls = KNN(7, False, k = 3) | |||
cls.train(data, labels) | |||
print(cls.valid(data, labels)) | |||
for res in cls.predict_proba(data): | |||
print(res) | |||
break | |||
for res in cls.predict(data): | |||
print(res) | |||
break | |||
print("Trained") | |||
@@ -0,0 +1,97 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :nonshare_example.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/06/07 | |||
# Description : | |||
# | |||
#================================================================# | |||
from utils.plog import logger | |||
from models.wabl_models import DecisionTree, KNN | |||
import pickle as pk | |||
import numpy as np | |||
import time | |||
import framework | |||
from multiprocessing import Pool | |||
import os | |||
from datasets.data_generator import generate_data_via_codes, code_generator | |||
from collections import defaultdict | |||
from abducer.abducer_base import AbducerBase | |||
from abducer.kb import ClsKB, RegKB | |||
def run_test(params): | |||
code_len, times, code_num, share, model_type, need_prob, letter_num = params | |||
if share: | |||
result_dir = "share_result" | |||
else: | |||
result_dir = "non_share_result" | |||
recoder_file_path = f"{result_dir}/random_{times}_{code_len}_{code_num}_{model_type}_{need_prob}.pk" | |||
words = code_generator(code_len, code_num, letter_num) | |||
kb = ClsKB(words) | |||
abducer = AbducerBase(kb, dist_func = "confidence", pred_res_parse = lambda x : x["prob"]) | |||
label_lists = [[] for _ in range(code_len)] | |||
for widx, word in enumerate(words): | |||
for cidx, c in enumerate(word): | |||
label_lists[cidx].append(c) | |||
if share: | |||
label_lists = [sum(label_lists, [])] | |||
recoder = logger() | |||
recoder.set_savefile("test.log") | |||
for idx, err in enumerate(range(15, 41)): | |||
start = time.process_time() | |||
err = err / 40. | |||
if 1 - err < (1. / letter_num): | |||
break | |||
print("Start expriment", idx) | |||
if model_type == "KNN": | |||
model = KNN(code_len, label_lists = label_lists, share=share) | |||
elif model_type == "DT": | |||
model = DecisionTree(code_len, label_lists = label_lists, share=share) | |||
pre_X, pre_Y = generate_data_via_codes(words, err, letter_num) | |||
X, Y = generate_data_via_codes(words, 0, letter_num) | |||
str_words = ["".join(str(c) for c in word) for word in words] | |||
recoder.print(str_words) | |||
model.train(pre_X, pre_Y) | |||
abl_epoch = 30 | |||
res = framework.train(model, abducer, X, Y, sample_num = 10000, verbose = 1) | |||
print("Initial data accuracy:", 1 - err) | |||
print("Abd word accuracy: ", res["accuracy_word"] * 1.0 / res["total_word"]) | |||
print("Abd char accuracy: ", res["accuracy_abd_char"] * 1.0 / res["total_abd_char"]) | |||
print("Ori char accuracy: ", res["accuracy_ori_char"] * 1.0 / res["total_ori_char"]) | |||
print("End expriment", idx) | |||
print() | |||
recoder.dump(open(recoder_file_path, "wb")) | |||
return True | |||
if __name__ == "__main__": | |||
os.system("mkdir share_result") | |||
os.system("mkdir non_share_result") | |||
for times in range(5): | |||
for code_num in [32, 64, 128]: | |||
params = [11, times, code_num, False, "KNN", True, 2] | |||
run_test(params) | |||
params = [11, times, code_num, False, "KNN", False, 2] | |||
run_test(params) | |||
#params = [11, 0, 32, False, "DT", False, 2] | |||
#run_test(params) | |||
@@ -0,0 +1,96 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2021 Freecss All rights reserved. | |||
# | |||
# File Name :share_example.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2021/06/07 | |||
# Description : | |||
# | |||
#================================================================# | |||
from utils.plog import logger | |||
from models.wabl_models import DecisionTree, KNN | |||
import pickle as pk | |||
import numpy as np | |||
import time | |||
import framework | |||
from multiprocessing import Pool | |||
import os | |||
from datasets.data_generator import generate_data_via_codes, code_generator | |||
from collections import defaultdict | |||
from abducer.abducer_base import AbducerBase | |||
from abducer.kb import ClsKB, RegKB | |||
def run_test(params): | |||
code_len, times, code_num, share, model_type, need_prob, letter_num = params | |||
if share: | |||
result_dir = "share_result" | |||
else: | |||
result_dir = "non_share_result" | |||
recoder_file_path = f"{result_dir}/random_{times}_{code_len}_{code_num}_{model_type}_{need_prob}.pk"# | |||
words = code_generator(code_len, code_num, letter_num) | |||
kb = ClsKB(words) | |||
abducer = AbducerBase(kb) | |||
label_lists = [[] for _ in range(code_len)] | |||
for widx, word in enumerate(words): | |||
for cidx, c in enumerate(word): | |||
label_lists[cidx].append(c) | |||
if share: | |||
label_lists = [sum(label_lists, [])] | |||
recoder = logger() | |||
recoder.set_savefile("test.log") | |||
for idx, err in enumerate(range(0, 41)): | |||
print("Start expriment", idx) | |||
start = time.process_time() | |||
err = err / 40. | |||
if 1 - err < (1. / letter_num): | |||
break | |||
if model_type == "KNN": | |||
model = KNN(code_len, label_lists = label_lists, share=share) | |||
elif model_type == "DT": | |||
model = DecisionTree(code_len, label_lists = label_lists, share=share) | |||
pre_X, pre_Y = generate_data_via_codes(words, err, letter_num) | |||
X, Y = generate_data_via_codes(words, 0, letter_num) | |||
str_words = ["".join(str(c) for c in word) for word in words] | |||
recoder.print(str_words) | |||
model.train(pre_X, pre_Y) | |||
abl_epoch = 30 | |||
res = framework.train(model, abducer, X, Y, sample_num = 10000, verbose = 1) | |||
print("Initial data accuracy:", 1 - err) | |||
print("Abd word accuracy: ", res["accuracy_word"] * 1.0 / res["total_word"]) | |||
print("Abd char accuracy: ", res["accuracy_abd_char"] * 1.0 / res["total_abd_char"]) | |||
print("Ori char accuracy: ", res["accuracy_ori_char"] * 1.0 / res["total_ori_char"]) | |||
print("End expriment", idx) | |||
print() | |||
recoder.dump(open(recoder_file_path, "wb")) | |||
return True | |||
if __name__ == "__main__": | |||
os.system("mkdir share_result") | |||
os.system("mkdir non_share_result") | |||
for times in range(5): | |||
for code_num in [32, 64, 128]: | |||
params = [11, times, code_num, True, "KNN", True, 2] | |||
run_test(params) | |||
params = [11, times, code_num, True, "KNN", False, 2] | |||
run_test(params) | |||
#params = [11, 0, 32, True, "DT", True, 2] | |||
#run_test(params) | |||
@@ -0,0 +1,152 @@ | |||
# coding: utf-8 | |||
#================================================================# | |||
# Copyright (C) 2020 Freecss All rights reserved. | |||
# | |||
# File Name :plog.py | |||
# Author :freecss | |||
# Email :karlfreecss@gmail.com | |||
# Created Date :2020/10/23 | |||
# Description : | |||
# | |||
#================================================================# | |||
import time | |||
import logging | |||
import pickle as pk | |||
import os | |||
import functools | |||
log_name = "default_log.txt" | |||
logging.basicConfig(level=logging.INFO, | |||
filename=log_name, | |||
filemode='a', | |||
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') | |||
global recorder | |||
recorder = None | |||
def mkdir(dirpath): | |||
if not os.path.exists(dirpath): | |||
os.makedirs(dirpath) | |||
class ResultRecorder: | |||
def __init__(self, pk_dir = None, pk_filepath = None): | |||
self.set_savefile(pk_dir, pk_filepath) | |||
self.result = {} | |||
logging.info("===========================================================") | |||
logging.info("============= Result Recorder Version: 0.02 ===============") | |||
logging.info("===========================================================\n") | |||
pass | |||
def set_savefile(self, pk_dir = None, pk_filepath = None): | |||
if pk_dir is None: | |||
pk_dir = "result" | |||
mkdir(pk_dir) | |||
if pk_filepath is None: | |||
local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) | |||
pk_filepath = os.path.join(pk_dir, local_time + ".pk") | |||
self.save_file = open(pk_filepath, "wb") | |||
logger = logging.getLogger() | |||
logger.handlers[0].stream.close() | |||
logger.removeHandler(logger.handlers[0]) | |||
filename = os.path.join(pk_dir, local_time + ".txt") | |||
file_handler = logging.FileHandler(filename) | |||
file_handler.setLevel(logging.DEBUG) | |||
formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') | |||
file_handler.setFormatter(formatter) | |||
logger.addHandler(file_handler) | |||
def print(self, *argv, screen = False): | |||
info = "" | |||
for data in argv: | |||
info += str(data) | |||
if screen: | |||
print(info) | |||
logging.info(info) | |||
def print_result(self, *argv): | |||
for data in argv: | |||
info = "#Result# %s" % str(data) | |||
#print(info) | |||
logging.info(info) | |||
def store(self, *argv): | |||
for data in argv: | |||
if data.find(":") < 0: | |||
continue | |||
label, data = data.split(":") | |||
self.store_kv(label, data) | |||
def write_result(self, *argv): | |||
self.print_result(*argv) | |||
self.store(*argv) | |||
def store_kv(self, label, data): | |||
self.result.setdefault(label, []) | |||
self.result[label].append(data) | |||
def write_kv(self, label, data): | |||
self.print_result({label : data}) | |||
#self.print_result(label + ":" + str(data)) | |||
self.store_kv(label, data) | |||
def dump(self, save_file = None): | |||
if save_file is None: | |||
save_file = self.save_file | |||
pk.dump(self.result, save_file) | |||
def clock(self, func): | |||
@functools.wraps(func) | |||
def clocked(*args, **kwargs): | |||
t0 = time.perf_counter() | |||
result = func(*args, **kwargs) | |||
elapsed = time.perf_counter() - t0 | |||
name = func.__name__ | |||
# arg_str = ','.join(repr(arg) for arg in args) | |||
# context = f"{name}: ({arg_str})=>({result}), cost {elapsed}s" | |||
context = f"{name}: ()=>(), cost {elapsed}s" | |||
self.write_kv("func:", context) | |||
return result | |||
return clocked | |||
def __del__(self): | |||
self.dump() | |||
def clocker(*argv): | |||
global recorder | |||
if recorder is None: | |||
recorder = ResultRecorder() | |||
return recorder.clock(*argv) | |||
def INFO(*argv, screen = False): | |||
global recorder | |||
if recorder is None: | |||
recorder = ResultRecorder() | |||
return recorder.print(*argv, screen = screen) | |||
def DEBUG(*argv, screen = False): | |||
global recorder | |||
if recorder is None: | |||
recorder = ResultRecorder() | |||
return recorder.print(*argv, screen = screen) | |||
def logger(): | |||
global recorder | |||
if recorder is None: | |||
recorder = ResultRecorder() | |||
return recorder | |||
if __name__ == "__main__": | |||
recorder = ResultRecorder() | |||
recorder.write_kv("test", 1) | |||
recorder.set_savefile(pk_dir = "haha") | |||
recorder.write_kv("test", 1) | |||