Browse Source

修改dl, dataset

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
97a7532a71
4 changed files with 47 additions and 12 deletions
  1. +1
    -3
      fastNLP/core/dataloaders/__init__.py
  2. +3
    -3
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  3. +38
    -5
      fastNLP/core/dataset/dataset.py
  4. +5
    -1
      tests/core/dataset/test_dataset.py

+ 1
- 3
fastNLP/core/dataloaders/__init__.py View File

@@ -5,12 +5,10 @@ __all__ = [
'JittorDataLoader',
'prepare_jittor_dataloader',
'prepare_paddle_dataloader',
'prepare_torch_dataloader',
'indice_collate_wrapper'
'prepare_torch_dataloader'
]

from .mix_dataloader import MixDataLoader
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
from .utils import indice_collate_wrapper

+ 3
- 3
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -8,7 +8,7 @@ from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, Li
from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle
# from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler

@@ -164,7 +164,7 @@ class TorchDataLoader(DataLoader):
return self.cur_batch_indices


def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
def prepare_torch_dataloader(ds_or_db,
batch_size: int = 16,
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
@@ -197,7 +197,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler
:param non_train_batch_size:
"""
from fastNLP.io.data_bundle import DataBundle
if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,


+ 38
- 5
fastNLP/core/dataset/dataset.py View File

@@ -15,16 +15,16 @@ import numpy as np
from threading import Thread

try:
import multiprocess as mp
from multiprocess import RLock
import multiprocessing as mp
except:
pass

from .field import FieldArray
from .instance import Instance
from fastNLP.core.utils.utils import pretty_table_printer, deprecated
from fastNLP.core.utils.utils import pretty_table_printer
from fastNLP.core.collators import Collator
from fastNLP.core.utils.rich_progress import f_rich_progress
from fastNLP.core.log import logger


class ApplyResultException(Exception):
@@ -67,7 +67,7 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s

except BaseException as e:
if idx != -1:
print("Exception happens at the `{}`th instance.".format(idx))
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
finally:
if show_progress_bar:
@@ -98,6 +98,7 @@ def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: b

if nums == total_len:
break
f_rich_progress.destroy_task(main_pro)
# pb_main.close()


@@ -519,7 +520,7 @@ class DataSet:
# 开启进程池,线程
main_thread.start()
pool = mp.Pool(processes=num_proc)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id})
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
@@ -775,3 +776,35 @@ class DataSet:
if self._collator is None:
self._collator = Collator()
return self._collator


if __name__ == '__main__':
# from fastNLP import DataSet

# if __name__=='__main__':
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True)

import multiprocess as mp
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar
from functools import partial
from threading import Thread

shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}),
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})]
parent, chid = mp.Pipe()
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x),
pipe=chid, show_progress_bar=False)
thread = Thread(target=_progress_bar, args=(parent, 400, 'main'))
thread.start()
pool = mp.Pool(processes=6)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
thread.join()
results = []
for async_result in pool_outs:
data = async_result.get()
results.extend(data)
print(results)

+ 5
- 1
tests/core/dataset/test_dataset.py View File

@@ -184,7 +184,7 @@ class TestDataSetMethods:
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
assert ds.field_arrays["y"].content[0] == 2

res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len")
assert (isinstance(res, list) and len(res) > 0) == True
assert res[0] == 4

@@ -375,6 +375,10 @@ class TestDataSetMethods:
ds.add_seq_len('x')
print(ds)

def test_apply_proc(self):
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2)


class TestFieldArrayInit:
"""


Loading…
Cancel
Save