You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

share_example.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :share_example.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. #================================================================#
  12. from utils.plog import logger
  13. from models.wabl_models import DecisionTree, KNN
  14. import pickle as pk
  15. import numpy as np
  16. import time
  17. import framework
  18. from multiprocessing import Pool
  19. import os
  20. from datasets.data_generator import generate_data_via_codes, code_generator
  21. from collections import defaultdict
  22. from abducer.abducer_base import AbducerBase
  23. from abducer.kb import ClsKB, RegKB
  24. def run_test(params):
  25. code_len, times, code_num, share, model_type, need_prob, letter_num = params
  26. if share:
  27. result_dir = "share_result"
  28. else:
  29. result_dir = "non_share_result"
  30. recoder_file_path = f"{result_dir}/random_{times}_{code_len}_{code_num}_{model_type}_{need_prob}.pk"#
  31. words = code_generator(code_len, code_num, letter_num)
  32. kb = ClsKB(words)
  33. abducer = AbducerBase(kb)
  34. label_lists = [[] for _ in range(code_len)]
  35. for widx, word in enumerate(words):
  36. for cidx, c in enumerate(word):
  37. label_lists[cidx].append(c)
  38. if share:
  39. label_lists = [sum(label_lists, [])]
  40. recoder = logger()
  41. recoder.set_savefile("test.log")
  42. for idx, err in enumerate(range(0, 41)):
  43. print("Start expriment", idx)
  44. start = time.process_time()
  45. err = err / 40.
  46. if 1 - err < (1. / letter_num):
  47. break
  48. if model_type == "KNN":
  49. model = KNN(code_len, label_lists = label_lists, share=share)
  50. elif model_type == "DT":
  51. model = DecisionTree(code_len, label_lists = label_lists, share=share)
  52. pre_X, pre_Y = generate_data_via_codes(words, err, letter_num)
  53. X, Y = generate_data_via_codes(words, 0, letter_num)
  54. str_words = ["".join(str(c) for c in word) for word in words]
  55. recoder.print(str_words)
  56. model.train(pre_X, pre_Y)
  57. abl_epoch = 30
  58. res = framework.train(model, abducer, X, Y, sample_num = 10000, verbose = 1)
  59. print("Initial data accuracy:", 1 - err)
  60. print("Abd word accuracy: ", res["accuracy_word"] * 1.0 / res["total_word"])
  61. print("Abd char accuracy: ", res["accuracy_abd_char"] * 1.0 / res["total_abd_char"])
  62. print("Ori char accuracy: ", res["accuracy_ori_char"] * 1.0 / res["total_ori_char"])
  63. print("End expriment", idx)
  64. print()
  65. recoder.dump(open(recoder_file_path, "wb"))
  66. return True
  67. if __name__ == "__main__":
  68. os.system("mkdir share_result")
  69. os.system("mkdir non_share_result")
  70. for times in range(5):
  71. for code_num in [32, 64, 128]:
  72. params = [11, times, code_num, True, "KNN", True, 2]
  73. run_test(params)
  74. params = [11, times, code_num, True, "KNN", False, 2]
  75. run_test(params)
  76. #params = [11, 0, 32, True, "DT", True, 2]
  77. #run_test(params)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.