You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hed_knn_example.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :share_example.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. # ================================================================#
  12. import sys
  13. sys.path.append("../")
  14. from abl.utils.plog import logger, INFO
  15. from abl.utils.utils import reduce_dimension
  16. import torch.nn as nn
  17. import torch
  18. from abl.models.nn import LeNet5, SymbolNet
  19. from abl.models.basic_model import BasicModel, BasicDataset
  20. from abl.models.wabl_models import DecisionTree, WABLBasicModel
  21. from sklearn.neighbors import KNeighborsClassifier
  22. from abl.abducer.abducer_base import AbducerBase
  23. from abl.abducer.kb import add_KB, HWF_KB, prolog_KB
  24. from datasets.mnist_add.get_mnist_add import get_mnist_add
  25. from datasets.hwf.get_hwf import get_hwf
  26. from datasets.hed.get_hed import get_hed, split_equation
  27. from abl import framework_hed_knn
  28. def run_test():
  29. # kb = add_KB(True)
  30. # kb = HWF_KB(True)
  31. # abducer = AbducerBase(kb)
  32. kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
  33. abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)
  34. recorder = logger()
  35. total_train_data = get_hed(train=True)
  36. train_data, val_data = split_equation(total_train_data, 3, 1)
  37. test_data = get_hed(train=False)
  38. # ========================= KNN model ============================ #
  39. reduce_dimension(train_data)
  40. reduce_dimension(val_data)
  41. reduce_dimension(test_data)
  42. base_model = KNeighborsClassifier(n_neighbors=3)
  43. pretrain_data_X, pretrain_data_Y = framework_hed_knn.hed_pretrain(base_model)
  44. model = WABLBasicModel(base_model, kb.pseudo_label_list)
  45. model, mapping = framework_hed_knn.train_with_rule(
  46. model, abducer, train_data, val_data, (pretrain_data_X, pretrain_data_Y), select_num=10, min_len=5, max_len=8
  47. )
  48. framework_hed_knn.hed_test(
  49. model, abducer, mapping, train_data, test_data, min_len=5, max_len=8
  50. )
  51. # ============================ End =============================== #
  52. recorder.dump()
  53. return True
  54. if __name__ == "__main__":
  55. run_test()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.