@@ -5,12 +5,10 @@ __all__ = [ | |||||
'JittorDataLoader', | 'JittorDataLoader', | ||||
'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'prepare_torch_dataloader', | |||||
'indice_collate_wrapper' | |||||
'prepare_torch_dataloader' | |||||
] | ] | ||||
from .mix_dataloader import MixDataLoader | from .mix_dataloader import MixDataLoader | ||||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | ||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | ||||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | ||||
from .utils import indice_collate_wrapper |
@@ -8,7 +8,7 @@ from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, Li | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
from fastNLP.io.data_bundle import DataBundle | |||||
# from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | ||||
@@ -164,7 +164,7 @@ class TorchDataLoader(DataLoader): | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||||
def prepare_torch_dataloader(ds_or_db, | |||||
batch_size: int = 16, | batch_size: int = 16, | ||||
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
@@ -197,7 +197,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | ||||
:param non_train_batch_size: | :param non_train_batch_size: | ||||
""" | """ | ||||
from fastNLP.io.data_bundle import DataBundle | |||||
if isinstance(ds_or_db, DataSet): | if isinstance(ds_or_db, DataSet): | ||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | ||||
@@ -15,16 +15,16 @@ import numpy as np | |||||
from threading import Thread | from threading import Thread | ||||
try: | try: | ||||
import multiprocess as mp | |||||
from multiprocess import RLock | |||||
import multiprocessing as mp | |||||
except: | except: | ||||
pass | pass | ||||
from .field import FieldArray | from .field import FieldArray | ||||
from .instance import Instance | from .instance import Instance | ||||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||||
from fastNLP.core.utils.utils import pretty_table_printer | |||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.utils.rich_progress import f_rich_progress | from fastNLP.core.utils.rich_progress import f_rich_progress | ||||
from fastNLP.core.log import logger | |||||
class ApplyResultException(Exception): | class ApplyResultException(Exception): | ||||
@@ -67,7 +67,7 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||||
except BaseException as e: | except BaseException as e: | ||||
if idx != -1: | if idx != -1: | ||||
print("Exception happens at the `{}`th instance.".format(idx)) | |||||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||||
raise e | raise e | ||||
finally: | finally: | ||||
if show_progress_bar: | if show_progress_bar: | ||||
@@ -98,6 +98,7 @@ def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: b | |||||
if nums == total_len: | if nums == total_len: | ||||
break | break | ||||
f_rich_progress.destroy_task(main_pro) | |||||
# pb_main.close() | # pb_main.close() | ||||
@@ -519,7 +520,7 @@ class DataSet: | |||||
# 开启进程池,线程 | # 开启进程池,线程 | ||||
main_thread.start() | main_thread.start() | ||||
pool = mp.Pool(processes=num_proc) | pool = mp.Pool(processes=num_proc) | ||||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id}) | |||||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||||
for proc_id, ds in enumerate(shard_data)] | for proc_id, ds in enumerate(shard_data)] | ||||
pool.close() | pool.close() | ||||
pool.join() | pool.join() | ||||
@@ -775,3 +776,35 @@ class DataSet: | |||||
if self._collator is None: | if self._collator is None: | ||||
self._collator = Collator() | self._collator = Collator() | ||||
return self._collator | return self._collator | ||||
if __name__ == '__main__': | |||||
# from fastNLP import DataSet | |||||
# if __name__=='__main__': | |||||
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||||
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) | |||||
import multiprocess as mp | |||||
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar | |||||
from functools import partial | |||||
from threading import Thread | |||||
shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), | |||||
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] | |||||
parent, chid = mp.Pipe() | |||||
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), | |||||
pipe=chid, show_progress_bar=False) | |||||
thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) | |||||
thread.start() | |||||
pool = mp.Pool(processes=6) | |||||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||||
for proc_id, ds in enumerate(shard_data)] | |||||
pool.close() | |||||
pool.join() | |||||
thread.join() | |||||
results = [] | |||||
for async_result in pool_outs: | |||||
data = async_result.get() | |||||
results.extend(data) | |||||
print(results) |
@@ -184,7 +184,7 @@ class TestDataSetMethods: | |||||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | ||||
assert ds.field_arrays["y"].content[0] == 2 | assert ds.field_arrays["y"].content[0] == 2 | ||||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") | |||||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") | |||||
assert (isinstance(res, list) and len(res) > 0) == True | assert (isinstance(res, list) and len(res) > 0) == True | ||||
assert res[0] == 4 | assert res[0] == 4 | ||||
@@ -375,6 +375,10 @@ class TestDataSetMethods: | |||||
ds.add_seq_len('x') | ds.add_seq_len('x') | ||||
print(ds) | print(ds) | ||||
def test_apply_proc(self): | |||||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||||
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2) | |||||
class TestFieldArrayInit: | class TestFieldArrayInit: | ||||
""" | """ | ||||