@@ -5,12 +5,10 @@ __all__ = [ | |||
'JittorDataLoader', | |||
'prepare_jittor_dataloader', | |||
'prepare_paddle_dataloader', | |||
'prepare_torch_dataloader', | |||
'indice_collate_wrapper' | |||
'prepare_torch_dataloader' | |||
] | |||
from .mix_dataloader import MixDataLoader | |||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | |||
from .utils import indice_collate_wrapper |
@@ -8,7 +8,7 @@ from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, Li | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
from fastNLP.io.data_bundle import DataBundle | |||
# from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||
@@ -164,7 +164,7 @@ class TorchDataLoader(DataLoader): | |||
return self.cur_batch_indices | |||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||
def prepare_torch_dataloader(ds_or_db, | |||
batch_size: int = 16, | |||
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
@@ -197,7 +197,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||
:param non_train_batch_size: | |||
""" | |||
from fastNLP.io.data_bundle import DataBundle | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
@@ -15,16 +15,16 @@ import numpy as np | |||
from threading import Thread | |||
try: | |||
import multiprocess as mp | |||
from multiprocess import RLock | |||
import multiprocessing as mp | |||
except: | |||
pass | |||
from .field import FieldArray | |||
from .instance import Instance | |||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||
from fastNLP.core.utils.utils import pretty_table_printer | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.rich_progress import f_rich_progress | |||
from fastNLP.core.log import logger | |||
class ApplyResultException(Exception): | |||
@@ -67,7 +67,7 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||
except BaseException as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
finally: | |||
if show_progress_bar: | |||
@@ -98,6 +98,7 @@ def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: b | |||
if nums == total_len: | |||
break | |||
f_rich_progress.destroy_task(main_pro) | |||
# pb_main.close() | |||
@@ -519,7 +520,7 @@ class DataSet: | |||
# 开启进程池,线程 | |||
main_thread.start() | |||
pool = mp.Pool(processes=num_proc) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id}) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
for proc_id, ds in enumerate(shard_data)] | |||
pool.close() | |||
pool.join() | |||
@@ -775,3 +776,35 @@ class DataSet: | |||
if self._collator is None: | |||
self._collator = Collator() | |||
return self._collator | |||
if __name__ == '__main__': | |||
# from fastNLP import DataSet | |||
# if __name__=='__main__': | |||
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) | |||
import multiprocess as mp | |||
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar | |||
from functools import partial | |||
from threading import Thread | |||
shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), | |||
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] | |||
parent, chid = mp.Pipe() | |||
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), | |||
pipe=chid, show_progress_bar=False) | |||
thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) | |||
thread.start() | |||
pool = mp.Pool(processes=6) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
for proc_id, ds in enumerate(shard_data)] | |||
pool.close() | |||
pool.join() | |||
thread.join() | |||
results = [] | |||
for async_result in pool_outs: | |||
data = async_result.get() | |||
results.extend(data) | |||
print(results) |
@@ -184,7 +184,7 @@ class TestDataSetMethods: | |||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | |||
assert ds.field_arrays["y"].content[0] == 2 | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") | |||
assert (isinstance(res, list) and len(res) > 0) == True | |||
assert res[0] == 4 | |||
@@ -375,6 +375,10 @@ class TestDataSetMethods: | |||
ds.add_seq_len('x') | |||
print(ds) | |||
def test_apply_proc(self): | |||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2) | |||
class TestFieldArrayInit: | |||
""" | |||