Browse Source

convert numpy to tensor

tags/v0.5.5
yh_cc 5 years ago
parent
commit
bb88bb4f54
2 changed files with 21 additions and 1 deletions
  1. +20
    -0
      fastNLP/core/batch.py
  2. +1
    -1
      fastNLP/core/collect_fn.py

+ 20
- 0
fastNLP/core/batch.py View File

@@ -19,6 +19,9 @@ from collections import defaultdict

from .dataset import DataSet
from .sampler import SequentialSampler
from .field import _get_ele_type_and_dim
from ._logger import logger


_python_is_exit = False

@@ -31,6 +34,21 @@ def _set_python_is_exit():
atexit.register(_set_python_is_exit)


def may_to_tensor(data, as_numpy, fn):
if not as_numpy:
dtype, dim = _get_ele_type_and_dim(data)
try:
data, flag = _to_tensor(data, dtype)
except TypeError as e:
logger.error(f"Field {fn} cannot be converted to torch.tensor.")
raise e
return data


def convert_tensor(batch_dict, as_numpy):
for n, v in batch_dict.items():
batch_dict[n] = may_to_tensor(v, as_numpy, n)

class DataSetGetter:
"""
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。
@@ -80,6 +98,8 @@ class DataSetGetter:

sin_x = pad(sin_x)
sin_y = pad(sin_y)
convert_tensor(sin_x, self.as_numpy)
convert_tensor(sin_y, self.as_numpy)

if not self.dataset.collector.is_empty():
bx, by = self.dataset._collect_batch(ins_list)


+ 1
- 1
fastNLP/core/collect_fn.py View File

@@ -95,7 +95,7 @@ class Collector:
def copy_from(self, col):
assert isinstance(col, Collector)
new_col = Collector()
new_col.collect_fns = deepcopy(col)
new_col.collect_fns = deepcopy(col.collect_fns)
return new_col




Loading…
Cancel
Save