|
|
@@ -9,22 +9,18 @@ __all__ = [ |
|
|
|
import _pickle as pickle |
|
|
|
from copy import deepcopy |
|
|
|
from typing import Optional, List, Callable, Union, Dict, Any, Mapping |
|
|
|
from functools import partial |
|
|
|
from types import LambdaType |
|
|
|
import sys |
|
|
|
import time |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from threading import Thread |
|
|
|
|
|
|
|
try: |
|
|
|
import multiprocessing as mp |
|
|
|
except: |
|
|
|
pass |
|
|
|
|
|
|
|
from .field import FieldArray |
|
|
|
from .instance import Instance |
|
|
|
from fastNLP.core.utils.utils import pretty_table_printer |
|
|
|
from fastNLP.core.utils.utils import pretty_table_printer, deprecated |
|
|
|
from fastNLP.core.collators import Collator |
|
|
|
from fastNLP.core.utils.rich_progress import f_rich_progress |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from ..log import logger |
|
|
|
|
|
|
|
|
|
|
|
class ApplyResultException(Exception): |
|
|
@@ -35,14 +31,13 @@ class ApplyResultException(Exception): |
|
|
|
|
|
|
|
|
|
|
|
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, |
|
|
|
pipe=None, desc: str = None) -> list: |
|
|
|
desc: str = None) -> list: |
|
|
|
""" |
|
|
|
对数据集进行处理封装函数,以便多进程使用 |
|
|
|
|
|
|
|
:param ds: 数据集 |
|
|
|
:param _apply_field: 需要处理数据集的field_name |
|
|
|
:param func: 用户自定义的func |
|
|
|
:param pipe: 管道 |
|
|
|
:param desc: 进度条的描述字符 |
|
|
|
:param show_progress_bar: 是否展示子进程进度条 |
|
|
|
:return: |
|
|
@@ -60,8 +55,6 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s |
|
|
|
results.append(func(ins[_apply_field])) |
|
|
|
else: |
|
|
|
results.append(func(ins)) |
|
|
|
if pipe is not None: |
|
|
|
pipe.send([idx + 1]) |
|
|
|
if show_progress_bar: |
|
|
|
f_rich_progress.update(pg_main, advance=1) |
|
|
|
|
|
|
@@ -75,31 +68,36 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s |
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None: |
|
|
|
def _multi_proc(ds, _apply_field, func, counter, queue): |
|
|
|
""" |
|
|
|
多进程下显示主进程的进度条 |
|
|
|
对数据集进行处理封装函数,以便多进程使用 |
|
|
|
|
|
|
|
:param parent: 进程管道 |
|
|
|
:param total_len: 数据集总长度 |
|
|
|
:param desc: 进度条描述符 |
|
|
|
:param show_progress_bar: 是否展示进度条 |
|
|
|
:param ds: 数据集 |
|
|
|
:param _apply_field: 需要处理数据集的field_name |
|
|
|
:param func: 用户自定义的func |
|
|
|
:param counter: 计数器 |
|
|
|
:param queue: 多进程时,将结果输入到这个 queue 中 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
desc = desc if desc else "Main" |
|
|
|
|
|
|
|
main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar) |
|
|
|
# pb_main = tqdm(total=total_len, desc=desc, position=0) |
|
|
|
nums = 0 |
|
|
|
while True: |
|
|
|
msg = parent.recv()[0] |
|
|
|
if msg is not None: |
|
|
|
f_rich_progress.update(main_pro, advance=1) |
|
|
|
nums += 1 |
|
|
|
|
|
|
|
if nums == total_len: |
|
|
|
break |
|
|
|
f_rich_progress.destroy_task(main_pro) |
|
|
|
# pb_main.close() |
|
|
|
idx = -1 |
|
|
|
import contextlib |
|
|
|
with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁 |
|
|
|
logger.set_stdout(stdout='raw') |
|
|
|
results = [] |
|
|
|
try: |
|
|
|
for idx, ins in enumerate(ds): |
|
|
|
if _apply_field is not None: |
|
|
|
res = func(ins[_apply_field]) |
|
|
|
else: |
|
|
|
res = func(ins) |
|
|
|
results.append(res) |
|
|
|
with counter.get_lock(): |
|
|
|
counter.value += 1 |
|
|
|
except BaseException as e: |
|
|
|
if idx != -1: |
|
|
|
logger.error("Exception happens at the `{}`th instance.".format(idx)) |
|
|
|
raise e |
|
|
|
queue.put(pickle.dumps(results)) |
|
|
|
|
|
|
|
|
|
|
|
class DataSet: |
|
|
@@ -114,7 +112,7 @@ class DataSet: |
|
|
|
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 |
|
|
|
""" |
|
|
|
self.field_arrays = {} |
|
|
|
self._collator = Collator(backend="numpy") |
|
|
|
self._collator = Collator() |
|
|
|
if data is not None: |
|
|
|
if isinstance(data, Dict): |
|
|
|
length_set = set() |
|
|
@@ -127,7 +125,6 @@ class DataSet: |
|
|
|
for ins in data: |
|
|
|
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) |
|
|
|
self.append(ins) |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError("data only be dict or list type.") |
|
|
|
|
|
|
@@ -263,7 +260,7 @@ class DataSet: |
|
|
|
try: |
|
|
|
self.field_arrays[name].append(field) |
|
|
|
except Exception as e: |
|
|
|
print(f"Cannot append to field:{name}.") |
|
|
|
logger.error(f"Cannot append to field:{name}.") |
|
|
|
raise e |
|
|
|
|
|
|
|
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: |
|
|
@@ -469,9 +466,7 @@ class DataSet: |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
if idx != -1: |
|
|
|
if isinstance(e, ApplyResultException): |
|
|
|
print(e.msg) |
|
|
|
print("Exception happens at the `{}`th instance.".format(idx + 1)) |
|
|
|
logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) |
|
|
|
raise e |
|
|
|
|
|
|
|
if modify_fields is True: |
|
|
@@ -490,18 +485,19 @@ class DataSet: |
|
|
|
:param show_progress_bar: 是否展示progress进度条,默认为展示 |
|
|
|
:param progress_desc: 进度条的描述字符,默认为'Main |
|
|
|
""" |
|
|
|
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>": |
|
|
|
raise ("Lambda function does not support multiple processes, please set `num_proc=0`.") |
|
|
|
if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'): |
|
|
|
raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`") |
|
|
|
|
|
|
|
if num_proc == 0: |
|
|
|
if num_proc < 2: |
|
|
|
results = _apply_single(ds=self, _apply_field=_apply_field, func=func, |
|
|
|
desc=progress_desc, show_progress_bar=show_progress_bar) |
|
|
|
else: |
|
|
|
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 |
|
|
|
results = [] |
|
|
|
if num_proc > len(self): |
|
|
|
num_proc = len(self) |
|
|
|
print( |
|
|
|
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}." |
|
|
|
) |
|
|
|
import multiprocessing as mp |
|
|
|
ctx = mp.get_context('fork') |
|
|
|
num_proc = min(num_proc, len(self)) |
|
|
|
# 划分数据集 |
|
|
|
shard_len = len(self) // num_proc |
|
|
|
num_left_sample = len(self) % num_proc |
|
|
@@ -511,24 +507,32 @@ class DataSet: |
|
|
|
end = shard_len + int(_i<num_left_sample) + start |
|
|
|
shard_data.append(self[start:end]) |
|
|
|
start = end |
|
|
|
# 配置管道,线程以实现 main progress 能够实时更新。 |
|
|
|
parent, child = mp.Pipe() |
|
|
|
main_thread = Thread(target=_progress_bar, args=(parent, len(self), progress_desc, |
|
|
|
show_progress_bar)) |
|
|
|
partial_single_map = partial(_apply_single, _apply_field=_apply_field, func=func, |
|
|
|
pipe=child, show_progress_bar=False) |
|
|
|
# 开启进程池,线程 |
|
|
|
main_thread.start() |
|
|
|
pool = mp.Pool(processes=num_proc) |
|
|
|
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) |
|
|
|
for proc_id, ds in enumerate(shard_data)] |
|
|
|
pool.close() |
|
|
|
pool.join() |
|
|
|
main_thread.join() |
|
|
|
|
|
|
|
for async_result in pool_outs: |
|
|
|
data = async_result.get() |
|
|
|
results.extend(data) |
|
|
|
# 配置共享参数,线程以实现 main progress 能够实时更新。 |
|
|
|
counter = ctx.Value('i', 0, lock=True) |
|
|
|
pool = [] |
|
|
|
queues = [] |
|
|
|
results = [] |
|
|
|
for i in range(num_proc): |
|
|
|
queue = ctx.SimpleQueue() |
|
|
|
proc = ctx.Process(target=_multi_proc, args=(shard_data[i], _apply_field, func, counter, queue)) |
|
|
|
proc.start() |
|
|
|
pool.append(proc) |
|
|
|
queues.append(queue) |
|
|
|
|
|
|
|
total_len = len(self) |
|
|
|
task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar) |
|
|
|
last_count = -1 |
|
|
|
while counter.value < total_len or last_count == -1: |
|
|
|
while counter.value == last_count: |
|
|
|
time.sleep(0.1) |
|
|
|
advance = counter.value - last_count |
|
|
|
last_count = counter.value |
|
|
|
f_rich_progress.update(task_id, advance=advance, refresh=True) |
|
|
|
|
|
|
|
for idx, proc in enumerate(pool): |
|
|
|
results.extend(pickle.loads(queues[idx].get())) |
|
|
|
proc.join() |
|
|
|
f_rich_progress.destroy_task(task_id) |
|
|
|
return results |
|
|
|
|
|
|
|
def apply_more(self, func: Callable = None, modify_fields: bool = True, |
|
|
@@ -552,8 +556,7 @@ class DataSet: |
|
|
|
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 |
|
|
|
:return Dict[str:Field]: 返回一个字典 |
|
|
|
""" |
|
|
|
# 返回 dict , 检查是否一直相同 |
|
|
|
assert callable(func), "The func you provide is not callable." |
|
|
|
assert callable(func), "The func is not callable." |
|
|
|
assert len(self) != 0, "Null DataSet cannot use apply()." |
|
|
|
assert num_proc >= 0, "num_proc must >= 0" |
|
|
|
idx = -1 |
|
|
@@ -577,9 +580,7 @@ class DataSet: |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
if idx != -1: |
|
|
|
if isinstance(e, ApplyResultException): |
|
|
|
print(e.msg) |
|
|
|
print("Exception happens at the `{}`th instance.".format(idx + 1)) |
|
|
|
logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) |
|
|
|
raise e |
|
|
|
|
|
|
|
if modify_fields is True: |
|
|
@@ -665,8 +666,7 @@ class DataSet: |
|
|
|
np.random.shuffle(all_indices) |
|
|
|
split = int(ratio * len(self)) |
|
|
|
if split == 0: |
|
|
|
error_msg = f'Dev DataSet has {split} instance after split.' |
|
|
|
print(error_msg) |
|
|
|
error_msg = f'Dev DataSet has `{split}` instance after split.' |
|
|
|
raise IndexError(error_msg) |
|
|
|
dev_indices = all_indices[:split] |
|
|
|
train_indices = all_indices[split:] |
|
|
@@ -776,35 +776,3 @@ class DataSet: |
|
|
|
if self._collator is None: |
|
|
|
self._collator = Collator() |
|
|
|
return self._collator |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
# from fastNLP import DataSet |
|
|
|
|
|
|
|
# if __name__=='__main__': |
|
|
|
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) |
|
|
|
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) |
|
|
|
|
|
|
|
import multiprocess as mp |
|
|
|
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar |
|
|
|
from functools import partial |
|
|
|
from threading import Thread |
|
|
|
|
|
|
|
shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), |
|
|
|
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] |
|
|
|
parent, chid = mp.Pipe() |
|
|
|
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), |
|
|
|
pipe=chid, show_progress_bar=False) |
|
|
|
thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) |
|
|
|
thread.start() |
|
|
|
pool = mp.Pool(processes=6) |
|
|
|
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) |
|
|
|
for proc_id, ds in enumerate(shard_data)] |
|
|
|
pool.close() |
|
|
|
pool.join() |
|
|
|
thread.join() |
|
|
|
results = [] |
|
|
|
for async_result in pool_outs: |
|
|
|
data = async_result.get() |
|
|
|
results.extend(data) |
|
|
|
print(results) |