|
|
@@ -69,13 +69,20 @@ class DataSetGetter: |
|
|
|
|
|
|
|
def may_to_tensor(data): |
|
|
|
dtype, dim = _get_ele_type_and_dim(data) |
|
|
|
print(dtype, type(dtype)) |
|
|
|
# print(dtype, type(dtype), str(dtype)) |
|
|
|
if not self.as_numpy: |
|
|
|
try: |
|
|
|
data, flag = _to_tensor(data, dtype) |
|
|
|
except TypeError as e: |
|
|
|
logger.error(f"Field {n} cannot be converted to torch.tensor.") |
|
|
|
raise e |
|
|
|
# if torch.is_tensor(data): |
|
|
|
# str_dtype = str(dtype) |
|
|
|
# if 'float' in str_dtype: |
|
|
|
# data = data.float() |
|
|
|
# elif 'int' in str_dtype: |
|
|
|
# data = data.long() |
|
|
|
# print(data.dtype) |
|
|
|
return data |
|
|
|
|
|
|
|
def pad(batch_dict): |
|
|
@@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype): |
|
|
|
if field_dtype is not None and isinstance(field_dtype, type)\ |
|
|
|
and issubclass(field_dtype, Number) \ |
|
|
|
and not isinstance(batch, torch.Tensor): |
|
|
|
if issubclass(field_dtype, np.floating): |
|
|
|
new_batch = torch.as_tensor(batch).float() # 默认使用float32 |
|
|
|
elif issubclass(field_dtype, np.integer): |
|
|
|
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 |
|
|
|
else: |
|
|
|
new_batch = torch.as_tensor(batch) |
|
|
|
return new_batch, True |
|
|
|
new_batch = torch.as_tensor(batch) |
|
|
|
flag = True |
|
|
|
else: |
|
|
|
return batch, False |
|
|
|
new_batch = batch |
|
|
|
flag = False |
|
|
|
if torch.is_tensor(new_batch): |
|
|
|
if 'float' in new_batch.dtype.__repr__(): |
|
|
|
new_batch = new_batch.float() |
|
|
|
elif 'int' in new_batch.dtype.__repr__(): |
|
|
|
new_batch = new_batch.long() |
|
|
|
return new_batch, flag |
|
|
|
except Exception as e: |
|
|
|
raise e |