Browse Source

[ENH] add performance in HWF

pull/5/head
troyyyyy 1 year ago
parent
commit
261634910d
5 changed files with 205 additions and 74 deletions
  1. +2
    -2
      docs/Examples/HED.rst
  2. +106
    -66
      docs/Examples/HWF.rst
  3. +5
    -3
      docs/Examples/MNISTAdd.rst
  4. +2
    -1
      docs/Examples/Zoo.rst
  5. +90
    -2
      examples/hwf/hwf.ipynb

+ 2
- 2
docs/Examples/HED.rst View File

@@ -272,8 +272,8 @@ respectively.
# Set up metrics
metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")]

Bridge Learning and Reasoning
-----------------------------
Bridging Learning and Reasoning
-------------------------------

Now, the last step is to bridge the learning and reasoning part. We
proceed with this step by creating an instance of ``HedBridge``, which is


+ 106
- 66
docs/Examples/HWF.rst View File

@@ -162,8 +162,9 @@ Out:

We may see that, in the 1001st data example, the length of the
formula is 3, while in the 3001st data example, the length of the
formula is 5. In the HWF dataset, the length of the formula varies from
1 to 7.
formula is 5. In the HWF dataset, the lengths of the formulas are
1, 3, 5, and 7 (Specifically, 10% of the equations have a length of 1,
10% have a length of 3, 20% have a length of 5, and 60% have a length of 7).

Building the Learning Part
--------------------------
@@ -369,8 +370,8 @@ respectively.

metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]

Bridge Learning and Reasoning
-----------------------------
Bridging Learning and Reasoning
-------------------------------

Now, the last step is to bridge the learning and reasoning part. We
proceed with this step by creating an instance of ``SimpleBridge``.
@@ -389,10 +390,12 @@ methods of ``SimpleBridge``.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
bridge.train(train_data, train_data, loops=3, segment_size=1000, save_dir=weights_dir)
bridge.train(train_data, loops=3, segment_size=1000, save_dir=weights_dir)
bridge.test(test_data)

Out:
The log will appear similar to the following:

Log:
.. code:: none
:class: code-out

@@ -400,73 +403,110 @@ Out:
abl - INFO - loop(train) [1/3] segment(train) [1/10]
abl - INFO - model loss: 0.00024
abl - INFO - loop(train) [1/3] segment(train) [2/10]
abl - INFO - model loss: 0.00053
abl - INFO - model loss: 0.00011
abl - INFO - loop(train) [1/3] segment(train) [3/10]
abl - INFO - model loss: 0.00260
abl - INFO - model loss: 0.00332
abl - INFO - loop(train) [1/3] segment(train) [4/10]
abl - INFO - model loss: 0.00162
abl - INFO - model loss: 0.00218
abl - INFO - loop(train) [1/3] segment(train) [5/10]
abl - INFO - model loss: 0.00073
abl - INFO - model loss: 0.00162
abl - INFO - loop(train) [1/3] segment(train) [6/10]
abl - INFO - model loss: 0.00055
abl - INFO - model loss: 0.00140
abl - INFO - loop(train) [1/3] segment(train) [7/10]
abl - INFO - model loss: 0.00148
abl - INFO - model loss: 0.00736
abl - INFO - loop(train) [1/3] segment(train) [8/10]
abl - INFO - model loss: 0.00034
abl - INFO - model loss: 0.00532
abl - INFO - loop(train) [1/3] segment(train) [9/10]
abl - INFO - model loss: 0.00167
abl - INFO - model loss: 0.00504
abl - INFO - loop(train) [1/3] segment(train) [10/10]
abl - INFO - model loss: 0.00185
abl - INFO - Evaluation start: loop(val) [1]
abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 0.999
abl - INFO - Saving model: loop(save) [1]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth
abl - INFO - model loss: 0.00259
abl - INFO - Eval start: loop(val) [1]
abl - INFO - Evaluation ended, hwf/character_accuracy: 0.997 hwf/reasoning_accuracy: 0.985
abl - INFO - loop(train) [2/3] segment(train) [1/10]
abl - INFO - model loss: 0.00219
abl - INFO - loop(train) [2/3] segment(train) [2/10]
abl - INFO - model loss: 0.00069
abl - INFO - loop(train) [2/3] segment(train) [3/10]
abl - INFO - model loss: 0.00013
abl - INFO - loop(train) [2/3] segment(train) [4/10]
abl - INFO - model loss: 0.00013
abl - INFO - loop(train) [2/3] segment(train) [5/10]
abl - INFO - model loss: 0.00248
abl - INFO - loop(train) [2/3] segment(train) [6/10]
abl - INFO - model loss: 0.00010
abl - INFO - loop(train) [2/3] segment(train) [7/10]
abl - INFO - model loss: 0.00020
abl - INFO - loop(train) [2/3] segment(train) [8/10]
abl - INFO - model loss: 0.00076
abl - INFO - loop(train) [2/3] segment(train) [9/10]
abl - INFO - model loss: 0.00061
abl - INFO - loop(train) [2/3] segment(train) [10/10]
abl - INFO - model loss: 0.00117
abl - INFO - Evaluation start: loop(val) [2]
abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 1.000
abl - INFO - Saving model: loop(save) [2]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
abl - INFO - model loss: 0.00126
...
abl - INFO - Eval start: loop(val) [2]
abl - INFO - Evaluation ended, hwf/character_accuracy: 0.998 hwf/reasoning_accuracy: 0.989
abl - INFO - loop(train) [3/3] segment(train) [1/10]
abl - INFO - model loss: 0.00120
abl - INFO - loop(train) [3/3] segment(train) [2/10]
abl - INFO - model loss: 0.00114
abl - INFO - loop(train) [3/3] segment(train) [3/10]
abl - INFO - model loss: 0.00071
abl - INFO - loop(train) [3/3] segment(train) [4/10]
abl - INFO - model loss: 0.00027
abl - INFO - loop(train) [3/3] segment(train) [5/10]
abl - INFO - model loss: 0.00017
abl - INFO - loop(train) [3/3] segment(train) [6/10]
abl - INFO - model loss: 0.00018
abl - INFO - loop(train) [3/3] segment(train) [7/10]
abl - INFO - model loss: 0.00141
abl - INFO - loop(train) [3/3] segment(train) [8/10]
abl - INFO - model loss: 0.00099
abl - INFO - loop(train) [3/3] segment(train) [9/10]
abl - INFO - model loss: 0.00145
abl - INFO - loop(train) [3/3] segment(train) [10/10]
abl - INFO - model loss: 0.00215
abl - INFO - Evaluation start: loop(val) [3]
abl - INFO - Evaluation ended, hwf/character_accuracy: 1.000 hwf/reasoning_accuracy: 1.000
abl - INFO - Saving model: loop(save) [3]
abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
abl - INFO - Evaluation ended, hwf/character_accuracy: 0.996 hwf/reasoning_accuracy: 0.977
abl - INFO - model loss: 0.00030
...
abl - INFO - Eval start: loop(val) [3]
abl - INFO - Evaluation ended, hwf/character_accuracy: 0.999 hwf/reasoning_accuracy: 0.996
abl - INFO - Test start:
abl - INFO - Evaluation ended, hwf/character_accuracy: 0.997 hwf/reasoning_accuracy: 0.986

Performance
-----------

We present the results of ABL as follows, which include the reasoning accuracy (for different equation lengths in the HWF dataset), and the training time (to achieve the accuracy using all equation lengths). These results are compared with the following methods:

- `DeepProbLog <https://github.com/ML-KULeuven/deepproblog/tree/master>`_: An extension of ProbLog by introducing neural predicates in Probabilistic Logic Programming;

- `DeepStochLog <https://github.com/ML-KULeuven/deepstochlog/tree/main>`_: A neural-symbolic framework based on stochastic logic program;

- `NGS <https://github.com/liqing-ustc/NGS>`_: A neural-symbolic framework that uses a grammar model and a back-search algorithm to improve its computing process.

.. raw:: html

<style type="text/css">
.tg {border-collapse:collapse;border-spacing:0;margin-bottom:20px;}
.tg td, .tg th {border:1px solid #ddd;padding:10px 15px;text-align:center;}
.tg th {background-color:#f5f5f5;color:#333333;}
.tg tr:nth-child(even) {background-color:#f9f9f9;}
.tg tr:nth-child(odd) {background-color:#ffffff;}
</style>
<table class="tg" style="margin-left: auto; margin-right: auto;">
<thead>
<tr>
<th rowspan="2"></th>
<th colspan="5">Reasoning Accuracy<br><span style="font-weight: normal; font-size: smaller;">(for different equation lengths)</span></th>
<th rowspan="2">Training Time (s)<br><span style="font-weight: normal; font-size: smaller;">(to achieve the Acc. using all lengths)</span></th>
</tr>
<tr>
<th>1</th>
<th>3</th>
<th>5</th>
<th>7</th>
<th>All</th>
</tr>
</thead>
<tbody>
<tr>
<td>NGS</td>
<td>91.2</td>
<td>89.1</td>
<td>92.7</td>
<td>5.2</td>
<td>98.4</td>
<td>426.2</td>
</tr>
<tr>
<td>DeepProbLog</td>
<td>90.8</td>
<td>85.6</td>
<td>timeout*</td>
<td>timeout</td>
<td>timeout</td>
<td>timeout</td>
</tr>
<tr>
<td>DeepStochLog</td>
<td>92.8</td>
<td>87.5</td>
<td>92.1</td>
<td>timeout</td>
<td>timeout</td>
<td>timeout</td>
</tr>
<tr>
<td>ABL</td>
<td><span style="font-weight:bold">94.0</span></td>
<td><span style="font-weight:bold">89.7</span></td>
<td><span style="font-weight:bold">96.5</span></td>
<td><span style="font-weight:bold">97.2</span></td>
<td><span style="font-weight:bold">98.6</span></td>
<td><span style="font-weight:bold">77.3</span></td>
</tr>
</tbody>
</table>
<p style="font-size: 13px;">* timeout: need more than 1 hour to execute</p>

+ 5
- 3
docs/Examples/MNISTAdd.rst View File

@@ -320,8 +320,8 @@ respectively.

metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

Bridge Learning and Reasoning
-----------------------------
Bridging Learning and Reasoning
-------------------------------

Now, the last step is to bridge the learning and reasoning part. We
proceed with this step by creating an instance of ``SimpleBridge``.
@@ -343,7 +343,9 @@ methods of ``SimpleBridge``.
bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)
bridge.test(test_data)

Out:
The log will appear similar to the following:

Log:
.. code:: none
:class: code-out



+ 2
- 1
docs/Examples/Zoo.rst View File

@@ -227,8 +227,9 @@ methods of ``SimpleBridge``.
print_log("------- Test the final model -----------", logger="current")
bridge.test(test_data)

The log will appear similar to the following:

Out:
Log:
.. code:: none
:class: code-out



+ 90
- 2
examples/hwf/hwf.ipynb View File

@@ -140,7 +140,7 @@
"source": [
"Note: The symbols in the HWF dataset can be one of digits or operators '+', '-', '×', '÷'. \n",
"\n",
"Note: We may see that, in the 1001st data example, the length of the formula is 3, while in the 3001st data example, the length of the formula is 5. In the HWF dataset, the length of the formula varies from 1 to 7."
"Note: We may see that, in the 1001st data example, the length of the formula is 3, while in the 3001st data example, the length of the formula is 5. In the HWF dataset, the lengths of the formulas are 1, 3, 5, and 7 (Specifically, 10% of the equations have a length of 1, 10% have a length of 3, 20% have a length of 5, and 60% have a length of 7)."
]
},
{
@@ -419,9 +419,97 @@
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")\n",
"\n",
"bridge.train(train_data, train_data, loops=3, segment_size=1000, save_dir=weights_dir)\n",
"bridge.train(train_data, loops=3, segment_size=1000, save_dir=weights_dir)\n",
"bridge.test(test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Performance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We present the results of ABL as follows, which include the reasoning accuracy (for different equation lengths in the HWF dataset), and the training time (to achieve the accuracy using all equation lengths). These results are compared with the following methods:\n",
"\n",
"- [**DeepProbLog**](https://github.com/ML-KULeuven/deepproblog/tree/master): An extension of ProbLog by introducing neural predicates in Probabilistic Logic Programming;\n",
"\n",
"- [**DeepStochLog**](https://github.com/ML-KULeuven/deepstochlog/tree/main): A neural-symbolic framework based on stochastic logic program;\n",
"\n",
"- [**NGS**](https://github.com/liqing-ustc/NGS): A neural-symbolic framework that uses a grammar model and a back-search algorithm to improve its computing process."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<style type=\"text/css\">\n",
".tg {border-collapse:collapse;border-spacing:0;margin-bottom:20px;}\n",
".tg td, .tg th {border:1px solid #ddd;padding:10px 15px;text-align:center;}\n",
".tg th {background-color:#f5f5f5;color:#333333;}\n",
".tg tr:nth-child(even) {background-color:#f9f9f9;}\n",
".tg tr:nth-child(odd) {background-color:#ffffff;}\n",
"</style>\n",
"<table class=\"tg\" style=\"margin-left: auto; margin-right: auto;\">\n",
"<thead>\n",
" <tr>\n",
" <th rowspan=\"2\"></th>\n",
" <th colspan=\"5\">Reasoning Accuracy<br><span style=\"font-weight: normal; font-size: smaller;\">(for different equation lengths)</span></th>\n",
" <th rowspan=\"2\">Training Time (s)<br><span style=\"font-weight: normal; font-size: smaller;\">(to achieve the Acc. using all lengths)</span></th>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <th>3</th>\n",
" <th>5</th>\n",
" <th>7</th>\n",
" <th>All</th>\n",
" </tr>\n",
"</thead>\n",
"<tbody>\n",
" <tr>\n",
" <td>NGS</td>\n",
" <td>91.2</td>\n",
" <td>89.1</td>\n",
" <td>92.7</td>\n",
" <td>5.2</td>\n",
" <td>98.4</td>\n",
" <td>426.2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>DeepProbLog</td>\n",
" <td>90.8</td>\n",
" <td>85.6</td>\n",
" <td>timeout*</td>\n",
" <td>timeout</td>\n",
" <td>timeout</td>\n",
" <td>timeout</td>\n",
" </tr>\n",
" <tr>\n",
" <td>DeepStochLog</td>\n",
" <td>92.8</td>\n",
" <td>87.5</td>\n",
" <td>92.1</td>\n",
" <td>timeout</td>\n",
" <td>timeout</td>\n",
" <td>timeout</td>\n",
" </tr>\n",
" <tr>\n",
" <td>ABL</td>\n",
" <td><span style=\"font-weight:bold\">94.0</span></td>\n",
" <td><span style=\"font-weight:bold\">89.7</span></td>\n",
" <td><span style=\"font-weight:bold\">96.5</span></td>\n",
" <td><span style=\"font-weight:bold\">97.2</span></td>\n",
" <td><span style=\"font-weight:bold\">98.6</span></td>\n",
" <td><span style=\"font-weight:bold\">77.3</span></td>\n",
" </tr>\n",
"</tbody>\n",
"</table>\n",
"<p style=\"font-size: 13px;\">* timeout: need more than 1 hour to execute</p>"
]
}
],
"metadata": {


Loading…
Cancel
Save