|
|
@@ -193,6 +193,11 @@ def _prepare_metrics(metrics): |
|
|
|
if isinstance(metric, type): |
|
|
|
metric = metric() |
|
|
|
if isinstance(metric, MetricBase): |
|
|
|
metric_name = metric.__class__.__name__ |
|
|
|
if not callable(metric.evaluate): |
|
|
|
raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.") |
|
|
|
if not callable(metric.get_metric): |
|
|
|
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") |
|
|
|
_metrics.append(metric) |
|
|
|
else: |
|
|
|
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") |
|
|
|