From 84024aaaa4a2a6be91fec1162250d5a03fe30bc7 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 2 Dec 2018 10:36:20 +0800 Subject: [PATCH] =?UTF-8?q?=5Fprepare=5Fmetric=E5=87=BD=E6=95=B0=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=A3=80=E6=9F=A5evaluate=E4=B8=8Eget=5Fmetric?= =?UTF-8?q?=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index e599ec7b..5296b0bf 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -193,6 +193,11 @@ def _prepare_metrics(metrics): if isinstance(metric, type): metric = metric() if isinstance(metric, MetricBase): + metric_name = metric.__class__.__name__ + if not callable(metric.evaluate): + raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.") + if not callable(metric.get_metric): + raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") _metrics.append(metric) else: raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")