Browse Source

修改dl, dataset

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
97a7532a71
4 changed files with 47 additions and 12 deletions
  1. +1
    -3
      fastNLP/core/dataloaders/__init__.py
  2. +3
    -3
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  3. +38
    -5
      fastNLP/core/dataset/dataset.py
  4. +5
    -1
      tests/core/dataset/test_dataset.py

+ 1
- 3
fastNLP/core/dataloaders/__init__.py View File

@@ -5,12 +5,10 @@ __all__ = [
'JittorDataLoader', 'JittorDataLoader',
'prepare_jittor_dataloader', 'prepare_jittor_dataloader',
'prepare_paddle_dataloader', 'prepare_paddle_dataloader',
'prepare_torch_dataloader',
'indice_collate_wrapper'
'prepare_torch_dataloader'
] ]


from .mix_dataloader import MixDataLoader from .mix_dataloader import MixDataLoader
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
from .utils import indice_collate_wrapper

+ 3
- 3
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -8,7 +8,7 @@ from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, Li
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle
# from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler


@@ -164,7 +164,7 @@ class TorchDataLoader(DataLoader):
return self.cur_batch_indices return self.cur_batch_indices




def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
def prepare_torch_dataloader(ds_or_db,
batch_size: int = 16, batch_size: int = 16,
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
@@ -197,7 +197,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler
:param non_train_batch_size: :param non_train_batch_size:
""" """
from fastNLP.io.data_bundle import DataBundle
if isinstance(ds_or_db, DataSet): if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,


+ 38
- 5
fastNLP/core/dataset/dataset.py View File

@@ -15,16 +15,16 @@ import numpy as np
from threading import Thread from threading import Thread


try: try:
import multiprocess as mp
from multiprocess import RLock
import multiprocessing as mp
except: except:
pass pass


from .field import FieldArray from .field import FieldArray
from .instance import Instance from .instance import Instance
from fastNLP.core.utils.utils import pretty_table_printer, deprecated
from fastNLP.core.utils.utils import pretty_table_printer
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
from fastNLP.core.utils.rich_progress import f_rich_progress from fastNLP.core.utils.rich_progress import f_rich_progress
from fastNLP.core.log import logger




class ApplyResultException(Exception): class ApplyResultException(Exception):
@@ -67,7 +67,7 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s


except BaseException as e: except BaseException as e:
if idx != -1: if idx != -1:
print("Exception happens at the `{}`th instance.".format(idx))
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e raise e
finally: finally:
if show_progress_bar: if show_progress_bar:
@@ -98,6 +98,7 @@ def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: b


if nums == total_len: if nums == total_len:
break break
f_rich_progress.destroy_task(main_pro)
# pb_main.close() # pb_main.close()




@@ -519,7 +520,7 @@ class DataSet:
# 开启进程池,线程 # 开启进程池,线程
main_thread.start() main_thread.start()
pool = mp.Pool(processes=num_proc) pool = mp.Pool(processes=num_proc)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id})
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)] for proc_id, ds in enumerate(shard_data)]
pool.close() pool.close()
pool.join() pool.join()
@@ -775,3 +776,35 @@ class DataSet:
if self._collator is None: if self._collator is None:
self._collator = Collator() self._collator = Collator()
return self._collator return self._collator


if __name__ == '__main__':
# from fastNLP import DataSet

# if __name__=='__main__':
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True)

import multiprocess as mp
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar
from functools import partial
from threading import Thread

shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}),
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})]
parent, chid = mp.Pipe()
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x),
pipe=chid, show_progress_bar=False)
thread = Thread(target=_progress_bar, args=(parent, 400, 'main'))
thread.start()
pool = mp.Pool(processes=6)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
thread.join()
results = []
for async_result in pool_outs:
data = async_result.get()
results.extend(data)
print(results)

+ 5
- 1
tests/core/dataset/test_dataset.py View File

@@ -184,7 +184,7 @@ class TestDataSetMethods:
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
assert ds.field_arrays["y"].content[0] == 2 assert ds.field_arrays["y"].content[0] == 2


res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len")
assert (isinstance(res, list) and len(res) > 0) == True assert (isinstance(res, list) and len(res) > 0) == True
assert res[0] == 4 assert res[0] == 4


@@ -375,6 +375,10 @@ class TestDataSetMethods:
ds.add_seq_len('x') ds.add_seq_len('x')
print(ds) print(ds)


def test_apply_proc(self):
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2)



class TestFieldArrayInit: class TestFieldArrayInit:
""" """


Loading…
Cancel
Save