From 5d564a58cf96cf85e33e2b0c51547ffd11cd8929 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Tue, 3 May 2022 22:20:13 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9jittor=20fdl?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/jittor_dataloader/fdl.py | 1 + tests/core/dataloaders/jittor_dataloader/test_fdl.py | 1 + 2 files changed, 2 insertions(+) diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 3e9cf17a..787fcb69 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -91,6 +91,7 @@ class JittorDataLoader: self.dataset.dataset.set_attrs(batch_size=1) # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 # self._collate_fn = _collate_fn + self.cur_batch_indices = None def __iter__(self): # TODO 第一次迭代后不能设置collate_fn,设置是无效的 diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py index 92b49c09..b3124397 100644 --- a/tests/core/dataloaders/jittor_dataloader/test_fdl.py +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -42,6 +42,7 @@ class TestJittor: jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) # jtl.set_pad_val('x', 'y') # jtl.set_input('x') + print(str(jittor.Var([0]))) for batch in jtl: print(batch) print(jtl.get_batch_indices()) From d61bfe8e477bcbeaba5057e50786f153b468e2f3 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Tue, 3 May 2022 23:09:44 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9jittor=20fdl?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/__init__.py | 4 +++- fastNLP/core/dataloaders/jittor_dataloader/fdl.py | 2 +- fastNLP/core/dataloaders/utils.py | 5 +++++ tests/core/dataloaders/jittor_dataloader/test_fdl.py | 1 - 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index 40dd7b1c..e9dc51b4 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -5,10 +5,12 @@ __all__ = [ 'JittorDataLoader', 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', - 'prepare_torch_dataloader' + 'prepare_torch_dataloader', + 'indice_collate_wrapper' ] from .mix_dataloader import MixDataLoader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader +from .utils import indice_collate_wrapper diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 507073a4..2345a9b9 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -12,7 +12,7 @@ if _NEED_IMPORT_JITTOR: from jittor.dataset import Dataset else: from fastNLP.core.dataset import DataSet as Dataset -from fastNLP.core.utils.jittor_utils import jittor_collate_wraps + from fastNLP.core.collators import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index a71dc50c..2305cebe 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -1,3 +1,8 @@ +__all__ = [ + "indice_collate_wrapper" +] + + def indice_collate_wrapper(func): """ 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py index b3124397..92b49c09 100644 --- a/tests/core/dataloaders/jittor_dataloader/test_fdl.py +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -42,7 +42,6 @@ class TestJittor: jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) # jtl.set_pad_val('x', 'y') # jtl.set_input('x') - print(str(jittor.Var([0]))) for batch in jtl: print(batch) print(jtl.get_batch_indices()) From 97a7532a71241709688839e394ebc62e171b121a Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Thu, 5 May 2022 13:28:15 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9dl,=20dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/__init__.py | 4 +- .../core/dataloaders/torch_dataloader/fdl.py | 6 +-- fastNLP/core/dataset/dataset.py | 43 ++++++++++++++++--- tests/core/dataset/test_dataset.py | 6 ++- 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index e9dc51b4..40dd7b1c 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -5,12 +5,10 @@ __all__ = [ 'JittorDataLoader', 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', - 'prepare_torch_dataloader', - 'indice_collate_wrapper' + 'prepare_torch_dataloader' ] from .mix_dataloader import MixDataLoader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader -from .utils import indice_collate_wrapper diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index ff2d1e65..1616fb85 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -8,7 +8,7 @@ from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, Li from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper -from fastNLP.io.data_bundle import DataBundle +# from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler @@ -164,7 +164,7 @@ class TorchDataLoader(DataLoader): return self.cur_batch_indices -def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], +def prepare_torch_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, @@ -197,7 +197,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler :param non_train_batch_size: """ - + from fastNLP.io.data_bundle import DataBundle if isinstance(ds_or_db, DataSet): dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 3b9f027e..fa330854 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -15,16 +15,16 @@ import numpy as np from threading import Thread try: - import multiprocess as mp - from multiprocess import RLock + import multiprocessing as mp except: pass from .field import FieldArray from .instance import Instance -from fastNLP.core.utils.utils import pretty_table_printer, deprecated +from fastNLP.core.utils.utils import pretty_table_printer from fastNLP.core.collators import Collator from fastNLP.core.utils.rich_progress import f_rich_progress +from fastNLP.core.log import logger class ApplyResultException(Exception): @@ -67,7 +67,7 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s except BaseException as e: if idx != -1: - print("Exception happens at the `{}`th instance.".format(idx)) + logger.error("Exception happens at the `{}`th instance.".format(idx)) raise e finally: if show_progress_bar: @@ -98,6 +98,7 @@ def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: b if nums == total_len: break + f_rich_progress.destroy_task(main_pro) # pb_main.close() @@ -519,7 +520,7 @@ class DataSet: # 开启进程池,线程 main_thread.start() pool = mp.Pool(processes=num_proc) - pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id}) + pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) for proc_id, ds in enumerate(shard_data)] pool.close() pool.join() @@ -775,3 +776,35 @@ class DataSet: if self._collator is None: self._collator = Collator() return self._collator + + +if __name__ == '__main__': + # from fastNLP import DataSet + + # if __name__=='__main__': + # data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) + # data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) + + import multiprocess as mp + # from fastNLP.core.dataset.dataset import _apply_single, _progress_bar + from functools import partial + from threading import Thread + + shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), + DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] + parent, chid = mp.Pipe() + partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), + pipe=chid, show_progress_bar=False) + thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) + thread.start() + pool = mp.Pool(processes=6) + pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) + for proc_id, ds in enumerate(shard_data)] + pool.close() + pool.join() + thread.join() + results = [] + for async_result in pool_outs: + data = async_result.get() + results.extend(data) + print(results) diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index a2540ecf..ded60465 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -184,7 +184,7 @@ class TestDataSetMethods: ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) assert ds.field_arrays["y"].content[0] == 2 - res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") + res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") assert (isinstance(res, list) and len(res) > 0) == True assert res[0] == 4 @@ -375,6 +375,10 @@ class TestDataSetMethods: ds.add_seq_len('x') print(ds) + def test_apply_proc(self): + data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) + data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2) + class TestFieldArrayInit: """ From da25fbee3f954d936ce1c3cb6e311a9be51af4fa Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Thu, 5 May 2022 19:24:50 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9fdl?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/dataloaders/paddle_dataloader/fdl.py | 79 ++++++++++++++----- .../core/dataloaders/torch_dataloader/fdl.py | 18 +++-- fastNLP/core/dataloaders/utils.py | 2 +- .../dataloaders/torch_dataloader/test_fdl.py | 49 ++++++++++-- 4 files changed, 110 insertions(+), 38 deletions(-) diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 952759f7..b157dd68 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -63,13 +63,6 @@ class PaddleDataLoader(DataLoader): shuffle = False drop_last = False - super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, - return_list=return_list, batch_sampler=batch_sampler, - batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - collate_fn=None, num_workers=num_workers, - use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, - timeout=timeout, worker_init_fn=worker_init_fn, - persistent_workers=persistent_workers) if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, FDataSet): @@ -80,11 +73,15 @@ class PaddleDataLoader(DataLoader): else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") - elif isinstance(collate_fn, Callable): - if collate_fn is not default_collate_fn: - self._collate_fn = collate_fn - else: - self._collate_fn = default_collate_fn + + super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, + return_list=return_list, batch_sampler=batch_sampler, + batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + collate_fn=collate_fn, num_workers=num_workers, + use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) # if collate_fn is not None: # _collate_fn.add_collator(collate_fn) @@ -96,13 +93,13 @@ class PaddleDataLoader(DataLoader): # if len(self._collate_fn.get_collators()) == 0: # self._collate_fn.add_collator(default_collate_fn) # self._collate_fn = default_collate_fn - self.collate_fn = indice_collate_wrapper(self._collate_fn) + self.collate_fn = indice_collate_wrapper(self.collate_fn) for indices, data in super().__iter__(): self.cur_batch_indices = indices yield data - def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, - pad_fn:Callable=None) -> Collator: + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> Collator: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 @@ -121,12 +118,26 @@ class PaddleDataLoader(DataLoader): 形式,输出将被直接作为结果输出。 :return: 返回 Collator 自身 """ - if isinstance(self._collate_fn, Collator): - self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) - return self._collate_fn + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") + def _get_collator(self): + """ + 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None + + :return: + """ + collator = None + if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): + collator = self.collate_fn.__wrapped__ + elif isinstance(self.collate_fn, Collator): + collator = self.collate_fn + return collator + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 @@ -138,9 +149,10 @@ class PaddleDataLoader(DataLoader): __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 :return: 返回 Collator 自身 """ - if isinstance(self._collate_fn, Collator): - self._collate_fn.set_ignore(*field_names) - return self._collate_fn + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_ignore(*field_names) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") @@ -163,6 +175,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, worker_init_fn: Callable = None, persistent_workers=False, non_train_batch_size: int = 16) \ -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: + from fastNLP.io.data_bundle import DataBundle if isinstance(ds_or_db, Dataset): dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, @@ -170,6 +183,30 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) return dl + elif isinstance(ds_or_db, DataBundle): + dl_bundle = {} + for name, ds in ds_or_db.iter_datasets(): + if 'train' in name: + dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, + return_list=return_list, + batch_sampler=batch_sampler, batch_size=train_batch_size, + shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, + use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + else: + dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, + return_list=return_list, + batch_sampler=batch_sampler, batch_size=non_train_batch_size, + shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, + use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + return dl_bundle elif isinstance(ds_or_db, Sequence): ds_seq = [] for ds in ds_or_db: diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 923f6415..b827e1ab 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -14,7 +14,6 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler - from torch.utils.data._utils.collate import default_collate else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader @@ -73,12 +72,15 @@ class TorchDataLoader(DataLoader): :param prefetch_factor: :param persistent_workers: """ + if isinstance(dataset, DataSet) and collate_fn is None: + raise ValueError("When use FastNLP DataSet, collate_fn must be not None") + if not isinstance(dataset, _FDataSet): dataset = _FDataSet(dataset) if sampler is None and batch_sampler is None: sampler = RandomSampler(dataset, shuffle=shuffle) - shuffle=False + shuffle = False if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset @@ -107,8 +109,8 @@ class TorchDataLoader(DataLoader): self.cur_batch_indices = indices yield data - def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, - pad_fn:Callable=None) -> Collator: + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> Collator: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 @@ -174,12 +176,12 @@ class TorchDataLoader(DataLoader): return self.cur_batch_indices - def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 16, - shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + shuffle: bool = True, + sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, - num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, + num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, @@ -220,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping ) return dl - elif type(ds_or_db, DataBundle): + elif isinstance(ds_or_db, DataBundle): dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 0ee496c5..39ce5983 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -10,7 +10,7 @@ def indice_collate_wrapper(func): :param func: 需要修饰的函数 :return: """ - if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 + if hasattr(func, '__name__') and func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 return func def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 52fe48ff..8aa12ab6 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -3,6 +3,10 @@ import pytest from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataset import DataSet from fastNLP.io.data_bundle import DataBundle +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch @pytest.mark.torch @@ -22,20 +26,16 @@ class TestFdl: fdl = TorchDataLoader(ds, batch_size=3) fdl.set_pad("x", -1) for batch in fdl: - print(batch) - # fdl.set_pad_val("x", val=-2) - # for batch in fdl: - # print(batch) + assert batch['x'].shape == torch.Size([3, 4]) def test_get_batch_indices(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) for batch in fdl: - print(fdl.get_batch_indices()) + assert len(fdl.get_batch_indices()) == 3 def test_other_dataset(self): import numpy as np - class _DataSet: def __init__(self): @@ -55,8 +55,41 @@ class TestFdl: # dl.set_inputs('data', 'labels') # dl.set_pad_val('labels', val=None) for batch in dl: - print(batch) - print(dl.get_batch_indices()) + assert batch[0].shape == torch.Size([2, 5]) + assert batch[1].shape == torch.Size([2, 2, 3]) + + def test_default_collate_fn(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + with pytest.raises(ValueError): + fdl = TorchDataLoader(ds, batch_size=3, collate_fn=None) + import numpy as np + class _DataSet: + + def __init__(self): + pass + + def __getitem__(self, item): + return np.random.randn(5), [[1, 2], [2, 3, 4]] + + def __len__(self): + return 10 + + fdl = TorchDataLoader(_DataSet(), batch_size=3, collate_fn=None, drop_last=True) + for batch in fdl: + assert batch[0].shape == torch.Size([3, 5]) + + def test_my_collate_fn(self): + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + def collate_fn(batch): + res = {'x': [], 'y': []} + for ins in batch: + res['x'].append(ins['x']) + res['y'].append(ins['y']) + return res + fdl = TorchDataLoader(ds, collate_fn=collate_fn, batch_size=3, drop_last=True) + for batch in fdl: + assert batch['x'] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] + assert batch['y'] == [1, 0, 1] def test_prepare_torch_dataloader(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})