Browse Source

dataset多进程

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
58a21c2b63
2 changed files with 73 additions and 104 deletions
  1. +72
    -104
      fastNLP/core/dataset/dataset.py
  2. +1
    -0
      fastNLP/core/log/logger.py

+ 72
- 104
fastNLP/core/dataset/dataset.py View File

@@ -9,22 +9,18 @@ __all__ = [
import _pickle as pickle
from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from functools import partial
from types import LambdaType
import sys
import time

import numpy as np
from threading import Thread

try:
import multiprocessing as mp
except:
pass

from .field import FieldArray
from .instance import Instance
from fastNLP.core.utils.utils import pretty_table_printer
from fastNLP.core.utils.utils import pretty_table_printer, deprecated
from fastNLP.core.collators import Collator
from fastNLP.core.utils.rich_progress import f_rich_progress
from fastNLP.core.log import logger
from ..log import logger


class ApplyResultException(Exception):
@@ -35,14 +31,13 @@ class ApplyResultException(Exception):


def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True,
pipe=None, desc: str = None) -> list:
desc: str = None) -> list:
"""
对数据集进行处理封装函数,以便多进程使用

:param ds: 数据集
:param _apply_field: 需要处理数据集的field_name
:param func: 用户自定义的func
:param pipe: 管道
:param desc: 进度条的描述字符
:param show_progress_bar: 是否展示子进程进度条
:return:
@@ -60,8 +55,6 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s
results.append(func(ins[_apply_field]))
else:
results.append(func(ins))
if pipe is not None:
pipe.send([idx + 1])
if show_progress_bar:
f_rich_progress.update(pg_main, advance=1)

@@ -75,31 +68,36 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s
return results


def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None:
def _multi_proc(ds, _apply_field, func, counter, queue):
"""
多进程下显示主进程的进度条
对数据集进行处理封装函数,以便多进程使用

:param parent: 进程管道
:param total_len: 数据集总长度
:param desc: 进度条描述符
:param show_progress_bar: 是否展示进度条
:param ds: 数据集
:param _apply_field: 需要处理数据集的field_name
:param func: 用户自定义的func
:param counter: 计数器
:param queue: 多进程时,将结果输入到这个 queue 中
:return:
"""
desc = desc if desc else "Main"

main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar)
# pb_main = tqdm(total=total_len, desc=desc, position=0)
nums = 0
while True:
msg = parent.recv()[0]
if msg is not None:
f_rich_progress.update(main_pro, advance=1)
nums += 1

if nums == total_len:
break
f_rich_progress.destroy_task(main_pro)
# pb_main.close()
idx = -1
import contextlib
with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁
logger.set_stdout(stdout='raw')
results = []
try:
for idx, ins in enumerate(ds):
if _apply_field is not None:
res = func(ins[_apply_field])
else:
res = func(ins)
results.append(res)
with counter.get_lock():
counter.value += 1
except BaseException as e:
if idx != -1:
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
queue.put(pickle.dumps(results))


class DataSet:
@@ -114,7 +112,7 @@ class DataSet:
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。
"""
self.field_arrays = {}
self._collator = Collator(backend="numpy")
self._collator = Collator()
if data is not None:
if isinstance(data, Dict):
length_set = set()
@@ -127,7 +125,6 @@ class DataSet:
for ins in data:
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
self.append(ins)

else:
raise ValueError("data only be dict or list type.")

@@ -263,7 +260,7 @@ class DataSet:
try:
self.field_arrays[name].append(field)
except Exception as e:
print(f"Cannot append to field:{name}.")
logger.error(f"Cannot append to field:{name}.")
raise e

def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None:
@@ -469,9 +466,7 @@ class DataSet:

except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
print(e.msg)
print("Exception happens at the `{}`th instance.".format(idx + 1))
logger.error("Exception happens at the `{}`th instance.".format(idx + 1))
raise e

if modify_fields is True:
@@ -490,18 +485,19 @@ class DataSet:
:param show_progress_bar: 是否展示progress进度条,默认为展示
:param progress_desc: 进度条的描述字符,默认为'Main
"""
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>":
raise ("Lambda function does not support multiple processes, please set `num_proc=0`.")
if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'):
raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`")

if num_proc == 0:
if num_proc < 2:
results = _apply_single(ds=self, _apply_field=_apply_field, func=func,
desc=progress_desc, show_progress_bar=show_progress_bar)
else:
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2
results = []
if num_proc > len(self):
num_proc = len(self)
print(
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}."
)
import multiprocessing as mp
ctx = mp.get_context('fork')
num_proc = min(num_proc, len(self))
# 划分数据集
shard_len = len(self) // num_proc
num_left_sample = len(self) % num_proc
@@ -511,24 +507,32 @@ class DataSet:
end = shard_len + int(_i<num_left_sample) + start
shard_data.append(self[start:end])
start = end
# 配置管道,线程以实现 main progress 能够实时更新。
parent, child = mp.Pipe()
main_thread = Thread(target=_progress_bar, args=(parent, len(self), progress_desc,
show_progress_bar))
partial_single_map = partial(_apply_single, _apply_field=_apply_field, func=func,
pipe=child, show_progress_bar=False)
# 开启进程池,线程
main_thread.start()
pool = mp.Pool(processes=num_proc)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
main_thread.join()

for async_result in pool_outs:
data = async_result.get()
results.extend(data)
# 配置共享参数,线程以实现 main progress 能够实时更新。
counter = ctx.Value('i', 0, lock=True)
pool = []
queues = []
results = []
for i in range(num_proc):
queue = ctx.SimpleQueue()
proc = ctx.Process(target=_multi_proc, args=(shard_data[i], _apply_field, func, counter, queue))
proc.start()
pool.append(proc)
queues.append(queue)

total_len = len(self)
task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar)
last_count = -1
while counter.value < total_len or last_count == -1:
while counter.value == last_count:
time.sleep(0.1)
advance = counter.value - last_count
last_count = counter.value
f_rich_progress.update(task_id, advance=advance, refresh=True)

for idx, proc in enumerate(pool):
results.extend(pickle.loads(queues[idx].get()))
proc.join()
f_rich_progress.destroy_task(task_id)
return results

def apply_more(self, func: Callable = None, modify_fields: bool = True,
@@ -552,8 +556,7 @@ class DataSet:
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称
:return Dict[str:Field]: 返回一个字典
"""
# 返回 dict , 检查是否一直相同
assert callable(func), "The func you provide is not callable."
assert callable(func), "The func is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
assert num_proc >= 0, "num_proc must >= 0"
idx = -1
@@ -577,9 +580,7 @@ class DataSet:

except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
print(e.msg)
print("Exception happens at the `{}`th instance.".format(idx + 1))
logger.error("Exception happens at the `{}`th instance.".format(idx + 1))
raise e

if modify_fields is True:
@@ -665,8 +666,7 @@ class DataSet:
np.random.shuffle(all_indices)
split = int(ratio * len(self))
if split == 0:
error_msg = f'Dev DataSet has {split} instance after split.'
print(error_msg)
error_msg = f'Dev DataSet has `{split}` instance after split.'
raise IndexError(error_msg)
dev_indices = all_indices[:split]
train_indices = all_indices[split:]
@@ -776,35 +776,3 @@ class DataSet:
if self._collator is None:
self._collator = Collator()
return self._collator


if __name__ == '__main__':
# from fastNLP import DataSet

# if __name__=='__main__':
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True)

import multiprocess as mp
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar
from functools import partial
from threading import Thread

shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}),
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})]
parent, chid = mp.Pipe()
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x),
pipe=chid, show_progress_bar=False)
thread = Thread(target=_progress_bar, args=(parent, 400, 'main'))
thread.start()
pool = mp.Pool(processes=6)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
thread.join()
results = []
for async_result in pool_outs:
data = async_result.get()
results.extend(data)
print(results)

+ 1
- 0
fastNLP/core/log/logger.py View File

@@ -302,6 +302,7 @@ def _set_stdout_handler(_logger, stdout='raw', level='INFO'):
break
if stream_handler is not None:
_logger.removeHandler(stream_handler)
del stream_handler

# Stream Handler
if stdout == 'raw':


Loading…
Cancel
Save