|
|
@@ -0,0 +1,85 @@ |
|
|
|
# coding: utf-8 |
|
|
|
#================================================================# |
|
|
|
# Copyright (C) 2021 Freecss All rights reserved. |
|
|
|
# |
|
|
|
# File Name :share_example.py |
|
|
|
# Author :freecss |
|
|
|
# Email :karlfreecss@gmail.com |
|
|
|
# Created Date :2021/06/07 |
|
|
|
# Description : |
|
|
|
# |
|
|
|
#================================================================# |
|
|
|
|
|
|
|
from utils.plog import logger |
|
|
|
from models.wabl_models import DecisionTree, KNN |
|
|
|
import pickle as pk |
|
|
|
import numpy as np |
|
|
|
import time |
|
|
|
import framework |
|
|
|
import utils.plog as plog |
|
|
|
import torch.nn as nn |
|
|
|
import torch |
|
|
|
|
|
|
|
from models.lenet5 import LeNet5 |
|
|
|
from models.basic_model import BasicModel |
|
|
|
from models.wabl_models import MyModel |
|
|
|
|
|
|
|
from multiprocessing import Pool |
|
|
|
import os |
|
|
|
from datasets.data_generator import generate_data_via_codes, code_generator |
|
|
|
from collections import defaultdict |
|
|
|
from abducer.abducer_base import AbducerBase |
|
|
|
from abducer.kb import add_KB, hwf_KB |
|
|
|
from datasets.mnist_add.get_mnist_add import get_mnist_add |
|
|
|
from datasets.hwf.get_hwf import get_hwf |
|
|
|
|
|
|
|
class Params: |
|
|
|
imgH = 45 |
|
|
|
imgW = 45 |
|
|
|
keep_ratio = True |
|
|
|
saveInterval = 10 |
|
|
|
batchSize = 16 |
|
|
|
workers = 16 |
|
|
|
n_epoch = 10 |
|
|
|
stop_loss = None |
|
|
|
|
|
|
|
def run_test(): |
|
|
|
|
|
|
|
result_dir = 'results' |
|
|
|
|
|
|
|
recorder_file_path = f"{result_dir}/1116.pk"# |
|
|
|
|
|
|
|
# words = code_generator(code_len, code_num, letter_num) |
|
|
|
kb = add_KB() |
|
|
|
abducer = AbducerBase(kb) |
|
|
|
|
|
|
|
recorder = logger() |
|
|
|
recorder.set_savefile("test.log") |
|
|
|
|
|
|
|
|
|
|
|
train_X, train_Y, test_X, test_Y = get_mnist_add() |
|
|
|
# train_X, train_Y, test_X, test_Y = get_hwf() |
|
|
|
|
|
|
|
|
|
|
|
recorder = plog.ResultRecorder() |
|
|
|
cls = LeNet5() |
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(size_average=True) |
|
|
|
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
sign_list = list(range(10)) |
|
|
|
base_model = BasicModel(cls, criterion, optimizer, device, Params(), sign_list, recorder=recorder) |
|
|
|
model = MyModel(base_model) |
|
|
|
|
|
|
|
res = framework.train(model, abducer, train_X, train_Y, logic_forward = kb.logic_forward, sample_num = 10000, verbose = 1) |
|
|
|
print(res) |
|
|
|
|
|
|
|
|
|
|
|
recorder.dump(open(recorder_file_path, "wb")) |
|
|
|
return True |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
os.system("mkdir results") |
|
|
|
|
|
|
|
run_test() |
|
|
|
|