From 80a7dbfbda8721ffe54fbfe96e4a6b858a9781df Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 7 May 2022 21:58:04 +0800 Subject: [PATCH] =?UTF-8?q?=E9=98=B2=E6=AD=A2=20dtype=20=E4=B8=BA=E7=A9=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collators/padders/get_padder.py | 1 + fastNLP/core/collators/padders/paddle_padder.py | 3 +++ fastNLP/core/collators/padders/torch_padder.py | 2 ++ fastNLP/core/controllers/evaluator.py | 14 +++++++------- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 5c7be44b..db48011b 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -118,6 +118,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> elif backend == 'numpy': return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'torch': + # 这里 ele_dtype 传入为 None 的原因是防止出现 paddle tensor 转换为 torch tensor return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'paddle': return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 5432b17a..7e91ec42 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -132,6 +132,9 @@ class PaddleTensorPadder(Padder): try: if not isinstance(batch_field[0], paddle.Tensor): batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] + else: + if dtype is None: + dtype = batch_field[0].dtype except AttributeError: raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index d6d07dcd..b67aeff8 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -118,6 +118,8 @@ class TorchTensorPadder(Padder): batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] else: device = batch_field[0].device + if dtype is None: + dtype = batch_field[0].dtype except AttributeError: raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 382a9405..47301955 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -169,14 +169,14 @@ class Evaluator: raise e finally: self.finally_progress_bar() - - metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) + if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。 + metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) + if self.verbose: + if self.progress_bar == 'rich': + f_rich_progress.print(metric_results) + else: + logger.info(metric_results) self.driver.set_model_mode(mode='train') - if self.verbose: - if self.progress_bar == 'rich': - f_rich_progress.print(metric_results) - else: - logger.info(metric_results) return metric_results