From 64c7ce54682ae404e7566ccd1a9409317ff49230 Mon Sep 17 00:00:00 2001 From: yhcc Date: Tue, 31 May 2022 22:29:58 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8Drich=E5=9C=A8jupyter?= =?UTF-8?q?=E7=9A=84=E6=97=B6=E5=80=99=E4=B8=8D=E6=89=93=E5=8D=B0=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98;2.=E4=BF=AE=E5=A4=8D=E8=8B=A5=E5=B9=B2?= =?UTF-8?q?=E5=85=B6=E4=BB=96bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/callbacks/more_evaluate_callback.py | 3 ++- fastNLP/core/collators/collator.py | 4 +-- fastNLP/core/collators/padders/exceptions.py | 9 ++++++- fastNLP/core/collators/padders/get_padder.py | 17 ++++++------- .../core/collators/padders/jittor_padder.py | 2 +- .../core/collators/padders/torch_padder.py | 2 +- .../core/dataloaders/jittor_dataloader/fdl.py | 2 +- .../core/dataloaders/paddle_dataloader/fdl.py | 2 +- .../core/dataloaders/prepare_dataloader.py | 20 ++++++--------- .../core/dataloaders/torch_dataloader/fdl.py | 12 ++++----- fastNLP/core/utils/rich_progress.py | 25 ++++++++++++++++--- fastNLP/models/torch/seq2seq_generator.py | 2 +- fastNLP/models/torch/seq2seq_model.py | 2 +- .../torch/generator/seq2seq_generator.py | 4 +-- tests/core/collators/test_collator.py | 7 ++++++ .../dataloaders/test_prepare_dataloader.py | 17 ++++++++++++- 16 files changed, 85 insertions(+), 45 deletions(-) diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index e11bacde..690146a2 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -81,7 +81,8 @@ class MoreEvaluateCallback(HasMonitorCallback): **kwargs): super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, must_have_monitor=False) - + if watch_monitor is not None and evaluate_every == -1: # 将evaluate_every 弄掉。 + evaluate_every = None if watch_monitor is None and evaluate_every is None: raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") if watch_monitor is not None and evaluate_every is not None: diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 7fc11ec8..dab5028c 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -176,8 +176,8 @@ class Collator: self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 try: for key, padder in self.padders.items(): - batch = unpack_batch.get(key) - pad_batch[key] = padder(batch) + batch = unpack_batch.get(key) + pad_batch[key] = padder(batch) except BaseException as e: try: logger.error(f"The following exception happens when try to pad the `{key}` field with padder:{padder}:") diff --git a/fastNLP/core/collators/padders/exceptions.py b/fastNLP/core/collators/padders/exceptions.py index 8b08683d..a2b97cbf 100644 --- a/fastNLP/core/collators/padders/exceptions.py +++ b/fastNLP/core/collators/padders/exceptions.py @@ -3,7 +3,8 @@ __all__ = [ 'EleDtypeUnsupportedError', 'EleDtypeDtypeConversionError', 'DtypeUnsupportedError', - "DtypeError" + "DtypeError", + "NoProperPadderError" ] @@ -22,6 +23,12 @@ class DtypeError(BaseException): self.msg = msg +class NoProperPadderError(BaseException): + def __init__(self, msg, *args): + super(NoProperPadderError, self).__init__(msg, *args) + self.msg = msg + + class EleDtypeUnsupportedError(DtypeError): """ 当 batch 中的 element 的类别本身无法 pad 的时候报错。 diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 66d2eee2..dfc228a3 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -49,8 +49,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> f"information please set logger's level to DEBUG." if must_pad: raise InconsistencyError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) # 再检查所有的元素 shape 是否一致? shape_lens = set([len(v[0]) for v in catalog.values()]) @@ -60,8 +59,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> f"information please set logger's level to DEBUG." if must_pad: raise InconsistencyError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) # 再检查所有的元素 type 是否一致 try: @@ -74,8 +72,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> f"information please set logger's level to DEBUG." if must_pad: raise InconsistencyError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) depth = depths.pop() shape_len = shape_lens.pop() @@ -131,8 +128,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> msg = "Does not support pad tensor under nested list. If you need this, please report." if must_pad: raise RuntimeError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) except DtypeError as e: msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ @@ -140,8 +136,9 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> if must_pad: logger.error(msg) raise type(e)(msg=msg) - logger.debug(msg) - return NullPadder() + + except NoProperPadderError as e: + logger.debug(f"{e.msg}") except BaseException as e: raise e diff --git a/fastNLP/core/collators/padders/jittor_padder.py b/fastNLP/core/collators/padders/jittor_padder.py index d85893f1..c9b36b89 100644 --- a/fastNLP/core/collators/padders/jittor_padder.py +++ b/fastNLP/core/collators/padders/jittor_padder.py @@ -188,7 +188,7 @@ def fill_tensor(batch_field, padded_batch, dtype): padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype)) elif padded_batch.ndim == 4: try: # 应该是图像,所以直接应该就 ok 了。 - padded_batch = np.array(batch_field) + padded_batch = jittor.Var(batch_field) except: for i, content_i in enumerate(batch_field): for j, content_ii in enumerate(content_i): diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index c208adca..911c7d8c 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -175,7 +175,7 @@ def fill_tensor(batch_field, padded_batch, dtype): padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) elif padded_batch.ndim == 4: try: # 应该是图像,所以直接应该就 ok 了。 - padded_batch = np.array(batch_field) + padded_batch = torch.tensor(batch_field) except: for i, content_i in enumerate(batch_field): for j, content_ii in enumerate(content_i): diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 96f6747b..9f1d5e6f 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -203,7 +203,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto", - non_train_batch_size: int = 16) \ + non_train_batch_size: int = None) \ -> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: """ ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``JittorDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.JittorDataLoader`。 diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 3f1b6acd..50bac34b 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -254,7 +254,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, worker_init_fn: Callable = None, persistent_workers=False, - non_train_batch_size: int = 16) \ + non_train_batch_size: int = None) \ -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: """ ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.PaddleDataLoader`。 diff --git a/fastNLP/core/dataloaders/prepare_dataloader.py b/fastNLP/core/dataloaders/prepare_dataloader.py index 358578fc..f717841d 100644 --- a/fastNLP/core/dataloaders/prepare_dataloader.py +++ b/fastNLP/core/dataloaders/prepare_dataloader.py @@ -6,7 +6,6 @@ from typing import Union, Callable import os import sys -from ..samplers import RandomBatchSampler, RandomSampler from .torch_dataloader import prepare_torch_dataloader from .paddle_dataloader import prepare_paddle_dataloader from .jittor_dataloader import prepare_jittor_dataloader @@ -16,7 +15,7 @@ from ..log import logger def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, - seed: int = 0, backend: str = 'auto'): + backend: str = 'auto'): """ 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 @@ -43,7 +42,6 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro * 为 ``None`` 时 使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 :param num_workers: 使用多少进程进行数据的 fetch 。 - :param seed: 使用的随机数种子。 :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 * 为 ``auto`` 时 @@ -61,18 +59,14 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro if backend == 'auto': backend = _get_backend() if backend == 'torch': - batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, seed=seed) - return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, - num_workers=num_workers, shuffle=False, sampler=None) + return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, + num_workers=num_workers, shuffle=shuffle, sampler=None, + batch_size=batch_size) elif backend == 'paddle': - batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, seed=seed) - return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, - num_workers=num_workers) + return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, + num_workers=num_workers, batch_size=batch_size, shuffle=shuffle) elif backend == 'jittor': - sampler = RandomSampler(dataset=dataset, shuffle=shuffle, seed=seed) - prepare_jittor_dataloader(ds_or_db=dataset, sampler=sampler, collate_fn=collate_fn, + prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) else: diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 4e208b3f..0ef98ae2 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -222,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, persistent_workers: bool = False, non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, - non_train_batch_size: int = 16) \ + non_train_batch_size: int = None) \ -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: """ ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.TorchDataLoader`。 @@ -254,13 +254,13 @@ def prepare_torch_dataloader(ds_or_db, :param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 :param shuffle: 是否打乱数据集, 默认为 ``False``。 :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , - 默认为None, 当其不为 None 时, shuffle 参数无效。 + 默认为None, 当其不为 None 时, shuffle 参数无效。 :param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , - 默认为None, 当其不为 None 时, shuffle 参数无效。 + 默认为None, 当其不为 None 时, shuffle 参数无效。 :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 - dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 + dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 - 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. * callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, @@ -273,7 +273,7 @@ def prepare_torch_dataloader(ds_or_db, :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; - 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 :param timeout: 子进程的输出队列获取数据的超时值 :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 :param multiprocessing_context: 多进程的上下文环境 diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index d8e9d45b..b28bb3f7 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -46,7 +46,6 @@ class DummyFRichProgress: class FRichProgress(Progress, metaclass=Singleton): def new_progess(self, *columns: Union[str, ProgressColumn], - console: Optional[Console] = None, # 这里将 auto_refresh 关掉是想要避免单独开启线程,同时也是为了避免pdb的时候会持续刷新 auto_refresh: bool = False, refresh_per_second: float = 10, @@ -81,7 +80,7 @@ class FRichProgress(Progress, metaclass=Singleton): self.expand = expand self.live = Live( - console=console or get_console(), + console=get_console(), auto_refresh=auto_refresh, refresh_per_second=refresh_per_second, transient=transient, @@ -92,6 +91,12 @@ class FRichProgress(Progress, metaclass=Singleton): self.get_time = get_time or self.console.get_time self.print = self.console.print self.log = self.console.log + self.auto_refresh = auto_refresh + self.transient = transient + self.redirect_stdout = redirect_stdout + self.redirect_stderr = redirect_stderr + self.refresh_per_second = refresh_per_second + self._need_renew_live = False return self @@ -125,7 +130,19 @@ class FRichProgress(Progress, metaclass=Singleton): from .tqdm_progress import f_tqdm_progress assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop." - if self.live._started is False: + # 如果需要替换,应该是由于destroy的时候给换掉了 + if self._need_renew_live: + self.live = Live( + console=get_console(), + auto_refresh=self.auto_refresh, + refresh_per_second=self.refresh_per_second, + transient=self.transient, + redirect_stdout=self.redirect_stdout, + redirect_stderr=self.redirect_stderr, + get_renderable=self.get_renderable, + ) + self._need_renew_live = False + if not self.live.is_started: self.start() post_desc = fields.pop('post_desc', '') return super().add_task(description=description, @@ -155,6 +172,8 @@ class FRichProgress(Progress, metaclass=Singleton): setattr(self.live.console, 'line', lambda *args,**kwargs:...) self.live.stop() setattr(self.live.console, 'line', old_line) + # 在 jupyter 的情况下需要替换一下,不然会出不打印的问题。 + self._need_renew_live = True if is_notebook() else False def start(self) -> None: super().start() diff --git a/fastNLP/models/torch/seq2seq_generator.py b/fastNLP/models/torch/seq2seq_generator.py index 9ee723e5..68b405ba 100755 --- a/fastNLP/models/torch/seq2seq_generator.py +++ b/fastNLP/models/torch/seq2seq_generator.py @@ -65,7 +65,7 @@ class SequenceGeneratorModel(nn.Module): if tgt_seq_len is not None: mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) - loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) + loss = F.cross_entropy(pred[:, :-1].transpose(1, 2), tgt_tokens[:, 1:]) return {'loss': loss} def evaluate_step(self, src_tokens, src_seq_len=None): diff --git a/fastNLP/models/torch/seq2seq_model.py b/fastNLP/models/torch/seq2seq_model.py index 057fb93b..1420375e 100755 --- a/fastNLP/models/torch/seq2seq_model.py +++ b/fastNLP/models/torch/seq2seq_model.py @@ -59,7 +59,7 @@ class Seq2SeqModel(nn.Module): if tgt_seq_len is not None: mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) - loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) + loss = F.cross_entropy(pred[:, :-1].transpose(1, 2), tgt_tokens[:, 1:]) return {'loss': loss} def prepare_state(self, src_tokens, src_seq_len=None): diff --git a/fastNLP/modules/torch/generator/seq2seq_generator.py b/fastNLP/modules/torch/generator/seq2seq_generator.py index cf9c5306..b54eea28 100755 --- a/fastNLP/modules/torch/generator/seq2seq_generator.py +++ b/fastNLP/modules/torch/generator/seq2seq_generator.py @@ -368,13 +368,13 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) - from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids / (num_beams + 1)).long() # (batch_size, 2*num_beams) else: scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) - from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids / vocab_size).long() # (batch_size, 2*num_beams) next_tokens = ids % vocab_size # (batch_size, 2*num_beams) # 接下来需要组装下一个batch的结果。 diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 09ec4af8..8443ef92 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -318,6 +318,13 @@ class TestCollator: with pytest.raises(KeyError): collator.set_pad((1, 2)) + @pytest.mark.torch + def test_torch_4d(self): + collator = Collator(backend='torch') + data = [{'x': [[[0,1], [2,3]]]}, {'x': [[[0,1]]]}] + output = collator(data) + assert output['x'].size() == (2, 1, 2, 2) + @pytest.mark.torch def test_torch_dl(): diff --git a/tests/core/dataloaders/test_prepare_dataloader.py b/tests/core/dataloaders/test_prepare_dataloader.py index 223b7880..be6d5feb 100644 --- a/tests/core/dataloaders/test_prepare_dataloader.py +++ b/tests/core/dataloaders/test_prepare_dataloader.py @@ -2,6 +2,7 @@ import pytest from fastNLP import prepare_dataloader from fastNLP import DataSet +from fastNLP.io import DataBundle @pytest.mark.torch @@ -10,4 +11,18 @@ def test_torch(): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) dl = prepare_dataloader(ds, batch_size=2, shuffle=True) for batch in dl: - assert isinstance(batch['x'], torch.Tensor) \ No newline at end of file + assert isinstance(batch['x'], torch.Tensor) + + +@pytest.mark.torch +def test_torch_data_bundle(): + import torch + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dl = DataBundle() + dl.set_dataset(dataset=ds, name='train') + dl.set_dataset(dataset=ds, name='test') + dls = prepare_dataloader(dl, batch_size=2, shuffle=True) + for dl in dls.values(): + for batch in dl: + assert isinstance(batch['x'], torch.Tensor) + assert batch['x'].size(0) == 2