|
|
@@ -39,7 +39,6 @@ class Tester(object): |
|
|
|
|
|
|
|
for req_key in required_args: |
|
|
|
if req_key not in kwargs: |
|
|
|
logger.error("Tester lacks argument {}".format(req_key)) |
|
|
|
raise ValueError("Tester lacks argument {}".format(req_key)) |
|
|
|
|
|
|
|
for key in default_args: |
|
|
@@ -49,7 +48,6 @@ class Tester(object): |
|
|
|
else: |
|
|
|
msg = "Argument %s type mismatch: expected %s while get %s" % ( |
|
|
|
key, type(default_args[key]), type(kwargs[key])) |
|
|
|
logger.error(msg) |
|
|
|
raise ValueError(msg) |
|
|
|
else: |
|
|
|
# Tester doesn't care about extra arguments |
|
|
@@ -85,8 +83,7 @@ class Tester(object): |
|
|
|
for k, v in batch_y.items(): |
|
|
|
truths[k].append(v) |
|
|
|
eval_results = self.evaluate(**output, **truths) |
|
|
|
# print("[tester] {}".format(self.print_eval_results(eval_results))) |
|
|
|
# logger.info("[tester] {}".format(self.print_eval_results(eval_results))) |
|
|
|
print("[tester] {}".format(self.print_eval_results(eval_results))) |
|
|
|
self.mode(network, is_test=False) |
|
|
|
self.metrics = eval_results |
|
|
|
return eval_results |
|
|
|