diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index cbc9429d..89662c36 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -19,6 +19,9 @@ from collections import defaultdict from .dataset import DataSet from .sampler import SequentialSampler +from .field import _get_ele_type_and_dim +from ._logger import logger + _python_is_exit = False @@ -31,6 +34,21 @@ def _set_python_is_exit(): atexit.register(_set_python_is_exit) +def may_to_tensor(data, as_numpy, fn): + if not as_numpy: + dtype, dim = _get_ele_type_and_dim(data) + try: + data, flag = _to_tensor(data, dtype) + except TypeError as e: + logger.error(f"Field {fn} cannot be converted to torch.tensor.") + raise e + return data + + +def convert_tensor(batch_dict, as_numpy): + for n, v in batch_dict.items(): + batch_dict[n] = may_to_tensor(v, as_numpy, n) + class DataSetGetter: """ 传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 @@ -80,6 +98,8 @@ class DataSetGetter: sin_x = pad(sin_x) sin_y = pad(sin_y) + convert_tensor(sin_x, self.as_numpy) + convert_tensor(sin_y, self.as_numpy) if not self.dataset.collector.is_empty(): bx, by = self.dataset._collect_batch(ins_list) diff --git a/fastNLP/core/collect_fn.py b/fastNLP/core/collect_fn.py index d80db154..0660ae33 100644 --- a/fastNLP/core/collect_fn.py +++ b/fastNLP/core/collect_fn.py @@ -95,7 +95,7 @@ class Collector: def copy_from(self, col): assert isinstance(col, Collector) new_col = Collector() - new_col.collect_fns = deepcopy(col) + new_col.collect_fns = deepcopy(col.collect_fns) return new_col