diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index e599ec7b..5296b0bf 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -193,6 +193,11 @@ def _prepare_metrics(metrics): if isinstance(metric, type): metric = metric() if isinstance(metric, MetricBase): + metric_name = metric.__class__.__name__ + if not callable(metric.evaluate): + raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.") + if not callable(metric.get_metric): + raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") _metrics.append(metric) else: raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")