|
@@ -364,16 +364,16 @@ class _MetricsWrapper: |
|
|
else: |
|
|
else: |
|
|
args.append(batch) |
|
|
args.append(batch) |
|
|
if not isinstance(outputs, dict): |
|
|
if not isinstance(outputs, dict): |
|
|
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" |
|
|
|
|
|
|
|
|
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" |
|
|
f" return a dict from your model or use `output_mapping` to convert it into dict type.") |
|
|
f" return a dict from your model or use `output_mapping` to convert it into dict type.") |
|
|
if isinstance(metric, Metric): |
|
|
if isinstance(metric, Metric): |
|
|
auto_param_call(metric.update, batch, *args) |
|
|
|
|
|
|
|
|
auto_param_call(metric.update, outputs, *args) |
|
|
elif _is_torchmetrics_metric(metric): |
|
|
elif _is_torchmetrics_metric(metric): |
|
|
auto_param_call(metric.update, batch, *args) |
|
|
|
|
|
|
|
|
auto_param_call(metric.update, outputs, *args) |
|
|
elif _is_allennlp_metric(metric): |
|
|
elif _is_allennlp_metric(metric): |
|
|
auto_param_call(metric.__call__, batch, *args) |
|
|
|
|
|
|
|
|
auto_param_call(metric.__call__, outputs, *args) |
|
|
elif _is_paddle_metric(metric): |
|
|
elif _is_paddle_metric(metric): |
|
|
res = auto_param_call(metric.compute, batch, *args) |
|
|
|
|
|
|
|
|
res = auto_param_call(metric.compute, outputs, *args) |
|
|
metric.update(res) |
|
|
metric.update(res) |
|
|
|
|
|
|
|
|
def reset(self): |
|
|
def reset(self): |
|
|