You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :share_example.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. #================================================================#
  12. from utils.plog import logger
  13. from models.wabl_models import DecisionTree, KNN
  14. import pickle as pk
  15. import numpy as np
  16. import time
  17. import framework
  18. import utils.plog as plog
  19. import torch.nn as nn
  20. import torch
  21. from models.lenet5 import LeNet5
  22. from models.basic_model import BasicModel
  23. from models.wabl_models import MyModel
  24. from multiprocessing import Pool
  25. import os
  26. from datasets.data_generator import generate_data_via_codes, code_generator
  27. from collections import defaultdict
  28. from abducer.abducer_base import AbducerBase
  29. from abducer.kb import add_KB, hwf_KB
  30. from datasets.mnist_add.get_mnist_add import get_mnist_add
  31. from datasets.hwf.get_hwf import get_hwf
  32. class Params:
  33. imgH = 45
  34. imgW = 45
  35. keep_ratio = True
  36. saveInterval = 10
  37. batchSize = 16
  38. workers = 16
  39. n_epoch = 10
  40. stop_loss = None
  41. def run_test():
  42. result_dir = 'results'
  43. recorder_file_path = f"{result_dir}/1116.pk"#
  44. # words = code_generator(code_len, code_num, letter_num)
  45. kb = add_KB()
  46. abducer = AbducerBase(kb)
  47. recorder = logger()
  48. recorder.set_savefile("test.log")
  49. train_X, train_Y, test_X, test_Y = get_mnist_add()
  50. # train_X, train_Y, test_X, test_Y = get_hwf()
  51. recorder = plog.ResultRecorder()
  52. cls = LeNet5()
  53. criterion = nn.CrossEntropyLoss(size_average=True)
  54. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  55. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  56. sign_list = list(range(10))
  57. base_model = BasicModel(cls, criterion, optimizer, device, Params(), sign_list, recorder=recorder)
  58. model = MyModel(base_model)
  59. res = framework.train(model, abducer, train_X, train_Y, logic_forward = kb.logic_forward, sample_num = 10000, verbose = 1)
  60. print(res)
  61. recorder.dump(open(recorder_file_path, "wb"))
  62. return True
  63. if __name__ == "__main__":
  64. os.system("mkdir results")
  65. run_test()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.