|
|
@@ -19,6 +19,9 @@ from collections import defaultdict |
|
|
|
|
|
|
|
from .dataset import DataSet |
|
|
|
from .sampler import SequentialSampler |
|
|
|
from .field import _get_ele_type_and_dim |
|
|
|
from ._logger import logger |
|
|
|
|
|
|
|
|
|
|
|
_python_is_exit = False |
|
|
|
|
|
|
@@ -31,6 +34,21 @@ def _set_python_is_exit(): |
|
|
|
atexit.register(_set_python_is_exit) |
|
|
|
|
|
|
|
|
|
|
|
def may_to_tensor(data, as_numpy, fn): |
|
|
|
if not as_numpy: |
|
|
|
dtype, dim = _get_ele_type_and_dim(data) |
|
|
|
try: |
|
|
|
data, flag = _to_tensor(data, dtype) |
|
|
|
except TypeError as e: |
|
|
|
logger.error(f"Field {fn} cannot be converted to torch.tensor.") |
|
|
|
raise e |
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
def convert_tensor(batch_dict, as_numpy): |
|
|
|
for n, v in batch_dict.items(): |
|
|
|
batch_dict[n] = may_to_tensor(v, as_numpy, n) |
|
|
|
|
|
|
|
class DataSetGetter: |
|
|
|
""" |
|
|
|
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 |
|
|
@@ -80,6 +98,8 @@ class DataSetGetter: |
|
|
|
|
|
|
|
sin_x = pad(sin_x) |
|
|
|
sin_y = pad(sin_y) |
|
|
|
convert_tensor(sin_x, self.as_numpy) |
|
|
|
convert_tensor(sin_y, self.as_numpy) |
|
|
|
|
|
|
|
if not self.dataset.collector.is_empty(): |
|
|
|
bx, by = self.dataset._collect_batch(ins_list) |
|
|
|