`Learn the Basics `_ || `Quick Start `_ || `Dataset & Data Structure `_ || `Learning Part `_ || `Reasoning Part `_ || **Evaluation Metrics** || `Bridge `_ Evaluation Metrics ================== In this section, we will look at how to build evaluation metrics. .. code:: python from ablkit.data.evaluation import BaseMetric, SymbolAccuracy, ReasoningMetric ABL Kit seperates the evaluation process from model training and testing as an independent class, ``BaseMetric``. The training and testing processes are implemented in the ``BaseBridge`` class, so metrics are used by this class and its sub-classes. After building a ``bridge`` with a list of ``BaseMetric`` instances, these metrics will be used by the ``bridge.valid`` method to evaluate the model performance during training and testing. To customize our own metrics, we need to inherit from ``BaseMetric`` and implement the ``process`` and ``compute_metrics`` methods. - The ``process`` method accepts a batch of model prediction and saves the information to ``self.results`` property after processing this batch. - The ``compute_metrics`` method uses all the information saved in ``self.results`` to calculate and return a dict that holds the evaluation results. Besides, we can assign a ``str`` to the ``prefix`` argument of the ``__init__`` function. This string is automatically prefixed to the output metric names. For example, if we set ``prefix="mnist_add"``, the output metric name will be ``character_accuracy``. We provide two basic metrics, namely ``SymbolAccuracy`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the final reasoning results, respectively. Using ``SymbolAccuracy`` as an example, the following code shows how to implement a custom metric. .. code:: python class SymbolAccuracy(BaseMetric): def __init__(self, prefix: Optional[str] = None) -> None: # prefix is used to distinguish different metrics super().__init__(prefix) def process(self, data_examples: Sequence[dict]) -> None: # pred_pseudo_label and gt_pseudo_label are both of type List[List[Any]] # and have the same length pred_pseudo_label = data_examples.pred_pseudo_label gt_pseudo_label = data_examples.gt_pseudo_label for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label): correct_num = 0 for pred_symbol, symbol in zip(pred_z, z): if pred_symbol == symbol: correct_num += 1 self.results.append(correct_num / len(z)) def compute_metrics(self, results: list) -> dict: metrics = dict() metrics["character_accuracy"] = sum(results) / len(results) return metrics