diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 5c7be44b..db48011b 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -118,6 +118,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> elif backend == 'numpy': return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'torch': + # 这里 ele_dtype 传入为 None 的原因是防止出现 paddle tensor 转换为 torch tensor return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'paddle': return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 5432b17a..7e91ec42 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -132,6 +132,9 @@ class PaddleTensorPadder(Padder): try: if not isinstance(batch_field[0], paddle.Tensor): batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] + else: + if dtype is None: + dtype = batch_field[0].dtype except AttributeError: raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index d6d07dcd..b67aeff8 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -118,6 +118,8 @@ class TorchTensorPadder(Padder): batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] else: device = batch_field[0].device + if dtype is None: + dtype = batch_field[0].dtype except AttributeError: raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 48aee094..6aeaed6b 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -169,14 +169,14 @@ class Evaluator: raise e finally: self.finally_progress_bar() - - metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) + if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 + metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) + if self.verbose: + if self.progress_bar == 'rich': + f_rich_progress.print(metric_results) + else: + logger.info(metric_results) self.driver.set_model_mode(mode='train') - if self.verbose: - if self.progress_bar == 'rich': - f_rich_progress.print(metric_results) - else: - logger.info(metric_results) return metric_results