@@ -118,6 +118,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
elif backend == 'numpy': | elif backend == 'numpy': | ||||
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | ||||
elif backend == 'torch': | elif backend == 'torch': | ||||
# 这里 ele_dtype 传入为 None 的原因是防止出现 paddle tensor 转换为 torch tensor | |||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | ||||
@@ -132,6 +132,9 @@ class PaddleTensorPadder(Padder): | |||||
try: | try: | ||||
if not isinstance(batch_field[0], paddle.Tensor): | if not isinstance(batch_field[0], paddle.Tensor): | ||||
batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] | batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] | ||||
else: | |||||
if dtype is None: | |||||
dtype = batch_field[0].dtype | |||||
except AttributeError: | except AttributeError: | ||||
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | ||||
f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
@@ -118,6 +118,8 @@ class TorchTensorPadder(Padder): | |||||
batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] | batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] | ||||
else: | else: | ||||
device = batch_field[0].device | device = batch_field[0].device | ||||
if dtype is None: | |||||
dtype = batch_field[0].dtype | |||||
except AttributeError: | except AttributeError: | ||||
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | ||||
f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
@@ -169,14 +169,14 @@ class Evaluator: | |||||
raise e | raise e | ||||
finally: | finally: | ||||
self.finally_progress_bar() | self.finally_progress_bar() | ||||
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) | |||||
if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 | |||||
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) | |||||
if self.verbose: | |||||
if self.progress_bar == 'rich': | |||||
f_rich_progress.print(metric_results) | |||||
else: | |||||
logger.info(metric_results) | |||||
self.driver.set_model_mode(mode='train') | self.driver.set_model_mode(mode='train') | ||||
if self.verbose: | |||||
if self.progress_bar == 'rich': | |||||
f_rich_progress.print(metric_results) | |||||
else: | |||||
logger.info(metric_results) | |||||
return metric_results | return metric_results | ||||