`Learn the Basics `_ ||
`Quick Start `_ ||
`Dataset & Data Structure `_ ||
`Learning Part `_ ||
`Reasoning Part `_ ||
**Evaluation Metrics** ||
`Bridge `_
Evaluation Metrics
==================
In this section, we will look at how to build evaluation metrics.
.. code:: python
from ablkit.data.evaluation import BaseMetric, SymbolAccuracy, ReasoningMetric
ABL Kit seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing.
To customize our own metrics, we need to inherit from ``BaseMetric`` and implement the ``process`` and ``compute_metrics`` methods.
- The ``process`` method accepts a batch of model prediction and saves the information to ``self.results`` property after processing this batch.
- The ``compute_metrics`` method uses all the information saved in ``self.results`` to calculate and return a dict that holds the evaluation results.
Besides, we can assign a ``str`` to the ``prefix`` argument of the ``__init__`` function. This string is automatically prefixed to the output metric names. For example, if we set ``prefix="mnist_add"``, the output metric name will be ``character_accuracy``.
We provide two basic metrics, namely ``SymbolAccuracy`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the final reasoning results, respectively. Using ``SymbolAccuracy`` as an example, the following code shows how to implement a custom metric.
.. code:: python
class SymbolAccuracy(BaseMetric):
def __init__(self, prefix: Optional[str] = None) -> None:
# prefix is used to distinguish different metrics
super().__init__(prefix)
def process(self, data_examples: Sequence[dict]) -> None:
# pred_pseudo_label and gt_pseudo_label are both of type List[List[Any]]
# and have the same length
pred_pseudo_label = data_examples.pred_pseudo_label
gt_pseudo_label = data_examples.gt_pseudo_label
for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
correct_num = 0
for pred_symbol, symbol in zip(pred_z, z):
if pred_symbol == symbol:
correct_num += 1
self.results.append(correct_num / len(z))
def compute_metrics(self, results: list) -> dict:
metrics = dict()
metrics["character_accuracy"] = sum(results) / len(results)
return metrics