From 97a7532a71241709688839e394ebc62e171b121a Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Thu, 5 May 2022 13:28:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9dl,=20dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/__init__.py | 4 +- .../core/dataloaders/torch_dataloader/fdl.py | 6 +-- fastNLP/core/dataset/dataset.py | 43 ++++++++++++++++--- tests/core/dataset/test_dataset.py | 6 ++- 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index e9dc51b4..40dd7b1c 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -5,12 +5,10 @@ __all__ = [ 'JittorDataLoader', 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', - 'prepare_torch_dataloader', - 'indice_collate_wrapper' + 'prepare_torch_dataloader' ] from .mix_dataloader import MixDataLoader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader -from .utils import indice_collate_wrapper diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index ff2d1e65..1616fb85 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -8,7 +8,7 @@ from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, Li from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper -from fastNLP.io.data_bundle import DataBundle +# from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler @@ -164,7 +164,7 @@ class TorchDataLoader(DataLoader): return self.cur_batch_indices -def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], +def prepare_torch_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, @@ -197,7 +197,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler :param non_train_batch_size: """ - + from fastNLP.io.data_bundle import DataBundle if isinstance(ds_or_db, DataSet): dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 3b9f027e..fa330854 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -15,16 +15,16 @@ import numpy as np from threading import Thread try: - import multiprocess as mp - from multiprocess import RLock + import multiprocessing as mp except: pass from .field import FieldArray from .instance import Instance -from fastNLP.core.utils.utils import pretty_table_printer, deprecated +from fastNLP.core.utils.utils import pretty_table_printer from fastNLP.core.collators import Collator from fastNLP.core.utils.rich_progress import f_rich_progress +from fastNLP.core.log import logger class ApplyResultException(Exception): @@ -67,7 +67,7 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s except BaseException as e: if idx != -1: - print("Exception happens at the `{}`th instance.".format(idx)) + logger.error("Exception happens at the `{}`th instance.".format(idx)) raise e finally: if show_progress_bar: @@ -98,6 +98,7 @@ def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: b if nums == total_len: break + f_rich_progress.destroy_task(main_pro) # pb_main.close() @@ -519,7 +520,7 @@ class DataSet: # 开启进程池,线程 main_thread.start() pool = mp.Pool(processes=num_proc) - pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id}) + pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) for proc_id, ds in enumerate(shard_data)] pool.close() pool.join() @@ -775,3 +776,35 @@ class DataSet: if self._collator is None: self._collator = Collator() return self._collator + + +if __name__ == '__main__': + # from fastNLP import DataSet + + # if __name__=='__main__': + # data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) + # data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) + + import multiprocess as mp + # from fastNLP.core.dataset.dataset import _apply_single, _progress_bar + from functools import partial + from threading import Thread + + shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), + DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] + parent, chid = mp.Pipe() + partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), + pipe=chid, show_progress_bar=False) + thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) + thread.start() + pool = mp.Pool(processes=6) + pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) + for proc_id, ds in enumerate(shard_data)] + pool.close() + pool.join() + thread.join() + results = [] + for async_result in pool_outs: + data = async_result.get() + results.extend(data) + print(results) diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index a2540ecf..ded60465 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -184,7 +184,7 @@ class TestDataSetMethods: ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) assert ds.field_arrays["y"].content[0] == 2 - res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") + res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") assert (isinstance(res, list) and len(res) > 0) == True assert res[0] == 4 @@ -375,6 +375,10 @@ class TestDataSetMethods: ds.add_seq_len('x') print(ds) + def test_apply_proc(self): + data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) + data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2) + class TestFieldArrayInit: """