From 2b41e4dd298350aae97ee55e6b41ed671d2957d5 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sun, 15 Mar 2020 15:19:42 +0800 Subject: [PATCH] [bugfix] auto convert tensor type when batching --- fastNLP/core/batch.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 7f0c858b..7090ea01 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -69,13 +69,20 @@ class DataSetGetter: def may_to_tensor(data): dtype, dim = _get_ele_type_and_dim(data) - print(dtype, type(dtype)) + # print(dtype, type(dtype), str(dtype)) if not self.as_numpy: try: data, flag = _to_tensor(data, dtype) except TypeError as e: logger.error(f"Field {n} cannot be converted to torch.tensor.") raise e + # if torch.is_tensor(data): + # str_dtype = str(dtype) + # if 'float' in str_dtype: + # data = data.float() + # elif 'int' in str_dtype: + # data = data.long() + # print(data.dtype) return data def pad(batch_dict): @@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype): if field_dtype is not None and isinstance(field_dtype, type)\ and issubclass(field_dtype, Number) \ and not isinstance(batch, torch.Tensor): - if issubclass(field_dtype, np.floating): - new_batch = torch.as_tensor(batch).float() # 默认使用float32 - elif issubclass(field_dtype, np.integer): - new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 - else: - new_batch = torch.as_tensor(batch) - return new_batch, True + new_batch = torch.as_tensor(batch) + flag = True else: - return batch, False + new_batch = batch + flag = False + if torch.is_tensor(new_batch): + if 'float' in new_batch.dtype.__repr__(): + new_batch = new_batch.float() + elif 'int' in new_batch.dtype.__repr__(): + new_batch = new_batch.long() + return new_batch, flag except Exception as e: raise e