Browse Source

[MNT] resolve comments in evaluation folder

pull/1/head
Gao Enhao 1 year ago
parent
commit
900c71aab8
10 changed files with 27 additions and 25 deletions
  1. +2
    -2
      abl/evaluation/__init__.py
  2. +5
    -3
      abl/evaluation/reasoning_metric.py
  3. +4
    -4
      abl/evaluation/symbol_metric.py
  4. +1
    -1
      docs/API/abl.evaluation.rst
  5. +1
    -1
      docs/Intro/Basics.rst
  6. +1
    -1
      docs/Intro/Evaluation.rst
  7. +7
    -7
      docs/Intro/Quick-Start.rst
  8. +2
    -2
      examples/hed/hed_example.ipynb
  9. +2
    -2
      examples/hwf/hwf_example.ipynb
  10. +2
    -2
      examples/mnist_add/mnist_add_example.ipynb

+ 2
- 2
abl/evaluation/__init__.py View File

@@ -1,5 +1,5 @@
from .base_metric import BaseMetric
from .semantics_metric import SemanticsMetric
from .reasoning_metric import ReasoningMetric
from .symbol_metric import SymbolMetric

__all__ = ["BaseMetric", "SemanticsMetric", "SymbolMetric"]
__all__ = ["BaseMetric", "ReasoningMetric", "SymbolMetric"]

abl/evaluation/semantics_metric.py → abl/evaluation/reasoning_metric.py View File

@@ -5,7 +5,7 @@ from ..structures import ListData
from .base_metric import BaseMetric


class SemanticsMetric(BaseMetric):
class ReasoningMetric(BaseMetric):
def __init__(self, kb: KBBase = None, prefix: Optional[str] = None) -> None:
super().__init__(prefix)
self.kb = kb
@@ -15,7 +15,9 @@ class SemanticsMetric(BaseMetric):
y_list = data_samples.Y
x_list = data_samples.X
for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
if self.kb._check_equal(self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y):
if self.kb._check_equal(
self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()), y
):
self.results.append(1)
else:
self.results.append(0)
@@ -23,5 +25,5 @@ class SemanticsMetric(BaseMetric):
def compute_metrics(self) -> dict:
results = self.results
metrics = dict()
metrics["semantics_accuracy"] = sum(results) / len(results)
metrics["reasoning_accuracy"] = sum(results) / len(results)
return metrics

+ 4
- 4
abl/evaluation/symbol_metric.py View File

@@ -1,5 +1,7 @@
from typing import Optional

import numpy as np

from ..structures import ListData
from .base_metric import BaseMetric

@@ -15,10 +17,8 @@ class SymbolMetric(BaseMetric):
if not len(pred_pseudo_label_list) == len(gt_pseudo_label_list):
raise ValueError("lengthes of pred_pseudo_label and gt_pseudo_label should be equal")

correct_num = 0
for pred_pseudo_label, gt_pseudo_label in zip(pred_pseudo_label_list, gt_pseudo_label_list):
if pred_pseudo_label == gt_pseudo_label:
correct_num += 1
correct_num = np.sum(np.array(pred_pseudo_label_list) == np.array(gt_pseudo_label_list))

self.results.append((correct_num, len(pred_pseudo_label_list)))

def compute_metrics(self) -> dict:


+ 1
- 1
docs/API/abl.evaluation.rst View File

@@ -13,4 +13,4 @@ Metrics

BaseMetric
SymbolMetric
SemanticsMetric
ReasoningMetric

+ 1
- 1
docs/Intro/Basics.rst View File

@@ -26,7 +26,7 @@ It first features class ``ListData`` (inherited from base class
``BaseDataElement``), which defines the data structures used in
Abductive Learning, and comprises common data operations like insertion,
deletion, retrieval, slicing, etc. Additionally, a series of Evaluation
Metrics, including class ``SymbolMetric`` and ``SemanticsMetric`` (both
Metrics, including class ``SymbolMetric`` and ``ReasoningMetric`` (both
specialized metrics derived from base class ``BaseMetric``), outline
methods for evaluating model quality from a data perspective.



+ 1
- 1
docs/Intro/Evaluation.rst View File

@@ -18,7 +18,7 @@ To customize our own metrics, we need to inherit from ``BaseMetric`` and impleme
- The ``compute_metrics`` method uses all the information saved in ``self.results`` to calculate and return a dict that holds the evaluation results.

Besides, we can assign a ``str`` to the ``prefix`` argument of the ``__init__`` method. This string is automatically prefixed to the output metric names. For example, if we set ``prefix="mnist_add"``, the output metric name will be ``character_accuracy``.
We provide two basic metrics, namely ``SymbolMetric`` and ``SemanticsMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively. Using ``SymbolMetric`` as an example, the following code shows how to implement a custom metrics.
We provide two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively. Using ``SymbolMetric`` as an example, the following code shows how to implement a custom metrics.

.. code:: python



+ 7
- 7
docs/Intro/Quick-Start.rst View File

@@ -162,13 +162,13 @@ Read more about `building the reasoning part <Reasoning.html>`_.
Building Evaluation Metrics
---------------------------

ABL-Package provides two basic metrics, namely ``SymbolMetric`` and ``SemanticsMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively.
ABL-Package provides two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively.

.. code:: python

from abl.evaluation import SemanticsMetric, SymbolMetric
from abl.evaluation import ReasoningMetric, SymbolMetric

metric_list = [SymbolMetric(prefix="mnist_add"), SemanticsMetric(kb=kb, prefix="mnist_add")]
metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

Read more about `building evaluation metrics <Evaluation.html>`_

@@ -203,7 +203,7 @@ Training log would be similar to this:
abl - INFO - loop(train) [1/5] segment(train) [3/3]
abl - INFO - model loss: 1.33183
abl - INFO - Evaluation start: loop(val) [1]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.450 mnist_add/semantics_accuracy: 0.237
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.450 mnist_add/reasoning_accuracy: 0.237
abl - INFO - Saving model: loop(save) [1]
abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_1.pth
abl - INFO - loop(train) [2/5] segment(train) [1/3]
@@ -213,7 +213,7 @@ Training log would be similar to this:
abl - INFO - loop(train) [2/5] segment(train) [3/3]
abl - INFO - model loss: 0.11282
abl - INFO - Evaluation start: loop(val) [2]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.976 mnist_add/semantics_accuracy: 0.954
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.976 mnist_add/reasoning_accuracy: 0.954
abl - INFO - Saving model: loop(save) [2]
abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_2.pth
...
@@ -224,9 +224,9 @@ Training log would be similar to this:
abl - INFO - loop(train) [5/5] segment(train) [3/3]
abl - INFO - model loss: 0.03423
abl - INFO - Evaluation start: loop(val) [5]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.992 mnist_add/semantics_accuracy: 0.984
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.992 mnist_add/reasoning_accuracy: 0.984
abl - INFO - Saving model: loop(save) [5]
abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_5.pth
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/semantics_accuracy: 0.975
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975

Read more about `bridging machine learning and reasoning <Bridge.html>`_.

+ 2
- 2
examples/hed/hed_example.ipynb View File

@@ -12,7 +12,7 @@
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from abl.evaluation import SemanticsMetric, SymbolMetric\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import PrologKB, Reasoner\n",
"from abl.utils import ABLLogger, print_log, reform_list\n",
@@ -210,7 +210,7 @@
"outputs": [],
"source": [
"# Set up metrics\n",
"metric_list = [SymbolMetric(prefix=\"hed\"), SemanticsMetric(prefix=\"hed\")]"
"metric_list = [SymbolMetric(prefix=\"hed\"), ReasoningMetric(prefix=\"hed\")]"
]
},
{


+ 2
- 2
examples/hwf/hwf_example.ipynb View File

@@ -14,7 +14,7 @@
"from abl.reasoning import Reasoner, KBBase\n",
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric, SemanticsMetric\n",
"from abl.evaluation import SymbolMetric, ReasoningMetric\n",
"from abl.utils import ABLLogger, print_log\n",
"\n",
"from examples.models.nn import SymbolNet\n",
@@ -146,7 +146,7 @@
"outputs": [],
"source": [
"# Add metric\n",
"metric_list = [SymbolMetric(prefix=\"hwf\"), SemanticsMetric(kb=kb, prefix=\"hwf\")]"
"metric_list = [SymbolMetric(prefix=\"hwf\"), ReasoningMetric(kb=kb, prefix=\"hwf\")]"
]
},
{


+ 2
- 2
examples/mnist_add/mnist_add_example.ipynb View File

@@ -12,7 +12,7 @@
"import torch.nn as nn\n",
"\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SemanticsMetric, SymbolMetric\n",
"from abl.evaluation import ReasoningMetric, SymbolMetric\n",
"from abl.learning import ABLModel, BasicNN\n",
"from abl.reasoning import KBBase, Reasoner\n",
"from abl.utils import ABLLogger, print_log\n",
@@ -138,7 +138,7 @@
"outputs": [],
"source": [
"# Set up metrics\n",
"metric_list = [SymbolMetric(prefix=\"mnist_add\"), SemanticsMetric(kb=kb, prefix=\"mnist_add\")]"
"metric_list = [SymbolMetric(prefix=\"mnist_add\"), ReasoningMetric(kb=kb, prefix=\"mnist_add\")]"
]
},
{


Loading…
Cancel
Save