Browse Source

[bugfix] auto convert tensor type when batching

tags/v0.5.5
yunfan 5 years ago
parent
commit
2b41e4dd29
1 changed files with 18 additions and 9 deletions
  1. +18
    -9
      fastNLP/core/batch.py

+ 18
- 9
fastNLP/core/batch.py View File

@@ -69,13 +69,20 @@ class DataSetGetter:

def may_to_tensor(data):
dtype, dim = _get_ele_type_and_dim(data)
print(dtype, type(dtype))
# print(dtype, type(dtype), str(dtype))
if not self.as_numpy:
try:
data, flag = _to_tensor(data, dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
# if torch.is_tensor(data):
# str_dtype = str(dtype)
# if 'float' in str_dtype:
# data = data.float()
# elif 'int' in str_dtype:
# data = data.long()
# print(data.dtype)
return data

def pad(batch_dict):
@@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype):
if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor):
if issubclass(field_dtype, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32
elif issubclass(field_dtype, np.integer):
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制
else:
new_batch = torch.as_tensor(batch)
return new_batch, True
new_batch = torch.as_tensor(batch)
flag = True
else:
return batch, False
new_batch = batch
flag = False
if torch.is_tensor(new_batch):
if 'float' in new_batch.dtype.__repr__():
new_batch = new_batch.float()
elif 'int' in new_batch.dtype.__repr__():
new_batch = new_batch.long()
return new_batch, flag
except Exception as e:
raise e

Loading…
Cancel
Save