You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # MNIST Addition
  2. This example shows a simple implementation of [MNIST Addition](https://arxiv.org/abs/1805.10872) task, where pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base containing information on how to perform addition operations. The task is to recognize the digits of handwritten images and accurately determine their sum.
  3. ## Run
  4. ```bash
  5. pip install -r requirements.txt
  6. python main.py
  7. ```
  8. ## Usage
  9. ```bash
  10. usage: main.py [-h] [--no-cuda] [--epochs EPOCHS]
  11. [--label_smoothing LABEL_SMOOTHING] [--lr LR]
  12. [--alpha ALPHA] [--batch-size BATCH_SIZE]
  13. [--loops LOOPS] [--segment_size SEGMENT_SIZE]
  14. [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION]
  15. [--require-more-revision REQUIRE_MORE_REVISION]
  16. [--prolog | --ground]
  17. MNIST Addition example
  18. optional arguments:
  19. -h, --help show this help message and exit
  20. --no-cuda disables CUDA training
  21. --epochs EPOCHS number of epochs in each learning loop iteration
  22. (default : 1)
  23. --label_smoothing LABEL_SMOOTHING
  24. label smoothing in cross entropy loss (default : 0.2)
  25. --lr LR base model learning rate (default : 0.001)
  26. --alpha ALPHA alpha in RMSprop (default : 0.9)
  27. --batch-size BATCH_SIZE
  28. base model batch size (default : 32)
  29. --loops LOOPS number of loop iterations (default : 5)
  30. --segment_size SEGMENT_SIZE
  31. segment size (default : 1/3)
  32. --save_interval SAVE_INTERVAL
  33. save interval (default : 1)
  34. --max-revision MAX_REVISION
  35. maximum revision in reasoner (default : -1)
  36. --require-more-revision REQUIRE_MORE_REVISION
  37. require more revision in reasoner (default : 0)
  38. --prolog use PrologKB (default: False)
  39. --ground use GroundKB (default: False)
  40. ```
  41. ## Environment
  42. For all experiments, we used a single linux server. Details on the specifications are listed in the table below.
  43. <table class="tg" style="margin-left: auto; margin-right: auto;">
  44. <thead>
  45. <tr>
  46. <th>CPU</th>
  47. <th>GPU</th>
  48. <th>Memory</th>
  49. <th>OS</th>
  50. </tr>
  51. </thead>
  52. <tbody>
  53. <tr>
  54. <td>2 * Xeon Platinum 8358, 32 Cores, 2.6 GHz Base Frequency</td>
  55. <td>A100 80GB</td>
  56. <td>512GB</td>
  57. <td>Ubuntu 20.04</td>
  58. </tr>
  59. </tbody>
  60. </table>
  61. ## Performance
  62. We present the results of ABL as follows, which include the reasoning accuracy (the proportion of equations that are correctly summed), and the training time used to achieve this accuracy. These results are compared with the following methods:
  63. - [**NeurASP**](https://github.com/azreasoners/NeurASP): An extension of answer set programs by treating the neural network output as the probability distribution over atomic facts;
  64. - [**DeepProbLog**](https://github.com/ML-KULeuven/deepproblog): An extension of ProbLog by introducing neural predicates in Probabilistic Logic Programming;
  65. - [**LTN**](https://github.com/logictensornetworks/logictensornetworks): A neural-symbolic framework that uses differentiable first-order logic language to incorporate data and logic.
  66. - [**DeepStochLog**](https://github.com/ML-KULeuven/deepstochlog): A neural-symbolic framework based on stochastic logic program.
  67. <table class="tg" style="margin-left: auto; margin-right: auto;">
  68. <thead>
  69. <tr>
  70. <th>Method</th>
  71. <th>Accuracy</th>
  72. <th>Time to achieve the Acc. (s)</th>
  73. <th>Average Memory Usage (MB)</th>
  74. </tr>
  75. </thead>
  76. <tbody>
  77. <tr>
  78. <td>NeurASP</td>
  79. <td>96.2</td>
  80. <td>966</td>
  81. <td>3552</td>
  82. </tr>
  83. <tr>
  84. <td>DeepProbLog</td>
  85. <td>97.1</td>
  86. <td>2045</td>
  87. <td>3521</td>
  88. </tr>
  89. <tr>
  90. <td>LTN</td>
  91. <td>97.4</td>
  92. <td>251</td>
  93. <td>3860</td>
  94. </tr>
  95. <tr>
  96. <td>DeepStochLog</td>
  97. <td>97.5</td>
  98. <td>257</td>
  99. <td>3545</td>
  100. </tr>
  101. <tr>
  102. <td>ABL</td>
  103. <td><span style="font-weight:bold">98.1</span></td>
  104. <td><span style="font-weight:bold">47</span></td>
  105. <td><span style="font-weight:bold">2482</span></td>
  106. </tr>
  107. </tbody>
  108. </table>

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.