Browse Source

修复_MetricsWrapper update传参的bug

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
d2439fe443
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      fastNLP/core/controllers/evaluator.py

+ 5
- 5
fastNLP/core/controllers/evaluator.py View File

@@ -364,16 +364,16 @@ class _MetricsWrapper:
else:
args.append(batch)
if not isinstance(outputs, dict):
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly"
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
f" return a dict from your model or use `output_mapping` to convert it into dict type.")
if isinstance(metric, Metric):
auto_param_call(metric.update, batch, *args)
auto_param_call(metric.update, outputs, *args)
elif _is_torchmetrics_metric(metric):
auto_param_call(metric.update, batch, *args)
auto_param_call(metric.update, outputs, *args)
elif _is_allennlp_metric(metric):
auto_param_call(metric.__call__, batch, *args)
auto_param_call(metric.__call__, outputs, *args)
elif _is_paddle_metric(metric):
res = auto_param_call(metric.compute, batch, *args)
res = auto_param_call(metric.compute, outputs, *args)
metric.update(res)

def reset(self):


Loading…
Cancel
Save