Browse Source

提交core/utils/

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
2a3c764d9c
9 changed files with 1414 additions and 0 deletions
  1. +43
    -0
      fastNLP/core/utils/__init__.py
  2. +310
    -0
      fastNLP/core/utils/cache_results.py
  3. +4
    -0
      fastNLP/core/utils/dummy_class.py
  4. +51
    -0
      fastNLP/core/utils/jittor_utils.py
  5. +89
    -0
      fastNLP/core/utils/paddle_utils.py
  6. +214
    -0
      fastNLP/core/utils/rich_progress.py
  7. +49
    -0
      fastNLP/core/utils/torch_paddle_utils.py
  8. +63
    -0
      fastNLP/core/utils/torch_utils.py
  9. +591
    -0
      fastNLP/core/utils/utils.py

+ 43
- 0
fastNLP/core/utils/__init__.py View File

@@ -0,0 +1,43 @@
__all__ = [
'cache_results',
'is_jittor_dataset',
'jittor_collate_wraps',
'paddle_to',
'paddle_move_data_to_device',
'get_paddle_device_id',
'get_paddle_gpu_str',
'is_in_paddle_dist',
'is_in_fnlp_paddle_dist',
'is_in_paddle_launch_dist',
'f_rich_progress',
'torch_paddle_move_data_to_device',
'torch_move_data_to_device',
'get_fn_arg_names',
'check_fn_not_empty_params',
'auto_param_call',
'check_user_specific_params',
'dataclass_to_dict',
'match_and_substitute_params',
'apply_to_collection',
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'synchronize_safe_rm',
'synchronize_mkdir'
]

from .cache_results import cache_results
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist
from .rich_progress import f_rich_progress
from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir



+ 310
- 0
fastNLP/core/utils/cache_results.py View File

@@ -0,0 +1,310 @@
from datetime import datetime
import hashlib
import _pickle
import functools
import os
from typing import Callable, List, Any, Optional
import inspect
import ast
from collections import deque

__all__ = [
'cache_results'
]

from fastNLP.core.log.logger import logger
from fastNLP.core.log.highlighter import ColorHighlighter


class FuncCallVisitor(ast.NodeVisitor):
# credit to https://gist.github.com/jargnar/0946ab1d985e2b4ab776
def __init__(self):
self._name = deque()

@property
def name(self):
return '.'.join(self._name)

@name.deleter
def name(self):
self._name.clear()

def visit_Name(self, node):
self._name.appendleft(node.id)

def visit_Attribute(self, node):
try:
self._name.appendleft(node.attr)
self._name.appendleft(node.value.id)
except AttributeError:
self.generic_visit(node)


def get_func_calls(tree):
func_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
callvisitor = FuncCallVisitor()
callvisitor.visit(node.func)
func_calls.append(callvisitor.name)
if isinstance(node, ast.FunctionDef):
if not (node is tree):
func_calls.extend(get_func_calls(node))

return func_calls


def truncate_start_blanks(source:str)->str:
"""
将source中的每一行按照第一行的indent删掉多余的空格

:param source:
:return:
"""
lines = source.split('\n')
num_blank = 0
# get the top blank line
for line in lines:
if line:
num_blank = len(line) - len(line.lstrip())
new_lines = []
for line in lines:
i = -1
for i in range(min(len(line), num_blank)):
if line[i] == ' ':
continue
else:
break
line = line[i:]
new_lines.append(line)
return '\n'.join(new_lines)


def _get_func_and_its_called_func_source_code(func) -> List[str]:
"""
给定一个func,返回在这个函数里面用到的所有函数的源码。

:param callable func:
:return:
"""
last_frame = inspect.currentframe().f_back.f_back.f_back
last_frame_f_local = last_frame.f_locals
last_frame_loc = {}
if 'loc' in last_frame_f_local:
last_frame_loc = last_frame_f_local['loc']
func_calls = list(set(get_func_calls(ast.parse(truncate_start_blanks(inspect.getsource(func))))))
func_calls.sort()
sources = []
for _func_name in func_calls:
try:
if _func_name == 'cache_results': # ignore the decorator
continue
if '.' in _func_name:
_funcs = _func_name.split('.')
else:
_funcs = [_func_name]
if _funcs[0] in last_frame_f_local or _funcs[0] in last_frame_loc:
tmp = _funcs.pop(0)
variable = last_frame_f_local.get(tmp, last_frame_loc.get(tmp))
while len(_funcs) or variable is not None:
if hasattr(variable, '__class__') and not inspect.isbuiltin(variable.__class__):
try:
sources.append(inspect.getsource(variable.__class__))
except TypeError:
pass
if callable(variable) or inspect.isclass(variable):
sources.append(inspect.getsource(variable))
if len(_funcs):
tmp = _funcs.pop(0)
if hasattr(variable, tmp):
variable = getattr(variable, tmp)
else:
break
else:
variable = None
except:
# some failure
pass
del last_frame #
sources.append(inspect.getsource(func))
return sources


def _prepare_cache_filepath(filepath:str):
r"""
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径

:param filepath: str.
:return: None, if not, this function will raise error
"""
_cache_filepath = os.path.abspath(filepath)
if os.path.isdir(_cache_filepath):
raise RuntimeError("The cache_file_path must be a file, not a directory.")
cache_dir = os.path.dirname(_cache_filepath)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)


class Hasher:
def __init__(self):
self.m = hashlib.sha1()

def update(self, value: Any) -> None:
if isinstance(value, str):
value = [value]
for x in value:
self.m.update(x.encode('utf8'))

def hexdigest(self) -> str:
return self.m.hexdigest()


def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = None):
if fn_kwargs is None:
fn_kwargs = {}
hasher = Hasher()
try:
sources = _get_func_and_its_called_func_source_code(fn)
hasher.update(sources)
except:
return "can't be hashed"
for key in sorted(fn_kwargs):
hasher.update(key)
try:
hasher.update(fn_kwargs[key])
except:
pass
return hasher.hexdigest()


def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
r"""
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::

import time
import numpy as np
from fastNLP import cache_results

@cache_results('cache.pkl')
def process_data():
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
time.sleep(1)
return np.random.randint(10, size=(5,))

start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)

start_time = time.time()
print("res =",process_data())
print(time.time() - start_time)

# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
# Save cache to cache.pkl.
# res = [5 4 9 1 8]
# 1.0042750835418701
# Read cache from cache.pkl.
# res = [5 4 9 1 8]
# 0.0040721893310546875

可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::

# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'

上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::

process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache

:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
函数调用的时候传入_cache_fp这个参数。
:param bool _refresh: 是否重新生成cache。
:param int _verbose: 是否打印cache的信息。
:param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值
与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然
该修改对结果有影响,但无法做出warning。

:return:
"""

def wrapper_(func):
signature = inspect.signature(func)
for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))

@functools.wraps(func)
def wrapper(*args, **kwargs):
fn_param = kwargs.copy()
if args:
params = [p.name for p in inspect.signature(func).parameters.values()]
fn_param.update(zip(params, args))
if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp')
assert isinstance(cache_filepath, str), "_cache_fp can only be str."
else:
cache_filepath = _cache_fp
if '_refresh' in kwargs:
refresh = kwargs.pop('_refresh')
assert isinstance(refresh, bool), "_refresh can only be bool."
else:
refresh = _refresh
if '_verbose' in kwargs:
verbose = kwargs.pop('_verbose')
assert isinstance(verbose, int), "_verbose can only be integer."
else:
verbose = _verbose

if '_check_hash' in kwargs:
check_hash = kwargs.pop('_check_hash')
else:
check_hash = _check_hash

refresh_flag = True
new_hash_code = None
if check_hash:
new_hash_code = cal_fn_hash_code(func, fn_param)

if cache_filepath is not None and refresh is False:
# load data
if os.path.exists(cache_filepath):
cache_filepath = os.path.abspath(cache_filepath)
with open(cache_filepath, 'rb') as f:
results = _pickle.load(f)
old_hash_code = results['hash']
save_time = results['save_time']
results = results['results']
if verbose == 1:
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time))
if check_hash and old_hash_code != new_hash_code:
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The "
f"difference may caused by the sourcecode change of the functions by this function.",
extra={'highlighter': ColorHighlighter('red')})
refresh_flag = False

if refresh_flag:
if new_hash_code is None:
new_hash_code = cal_fn_hash_code(func, fn_param)
results = func(*args, **kwargs)
if cache_filepath is not None:
if results is None:
raise RuntimeError("The return value is None. Cannot save None results.")
cache_filepath = os.path.abspath(cache_filepath)
_prepare_cache_filepath(cache_filepath)
_dict = {
'results': results,
'hash': new_hash_code,
'save_time': datetime.now(),
}
with open(cache_filepath, 'wb') as f:
_pickle.dump(_dict, f)
logger.info("Save cache to {}.".format(cache_filepath))

return results

return wrapper

return wrapper_

+ 4
- 0
fastNLP/core/utils/dummy_class.py View File

@@ -0,0 +1,4 @@


class DummyClass:
pass

+ 51
- 0
fastNLP/core/utils/jittor_utils.py View File

@@ -0,0 +1,51 @@
__all__ = [
'is_jittor_dataset',
'jittor_collate_wraps'
]

from collections.abc import Mapping, Callable
from functools import wraps

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
import jittor as jt

from fastNLP.core.dataset import Instance



def is_jittor_dataset(dataset) -> bool:
try:
if isinstance(dataset, jt.dataset.Dataset):
return True
else:
return False
except BaseException:
return False


def jittor_collate_wraps(func, auto_collator: Callable):
"""
对jittor的collate_fn进行wrap封装, 如果数据集为mapping类型,那么采用auto_collator,否则还是采用jittor自带的collate_batch

:param func:
:param auto_collator:
:return:
"""
@wraps(func)
def wrapper(batch):
if isinstance(batch[0], Instance):
if auto_collator is not None:
result = auto_collator(batch)
else:
raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!")
elif isinstance(batch[0], Mapping):
if auto_collator is not None:
result = auto_collator(batch)
else:
result = func(batch)
else:
result = func(batch)
return result

return wrapper

+ 89
- 0
fastNLP/core/utils/paddle_utils.py View File

@@ -0,0 +1,89 @@
__all__ = [
"paddle_to",
"paddle_move_data_to_device",
"get_paddle_gpu_str",
"get_paddle_device_id",
"is_in_paddle_dist",
"is_in_fnlp_paddle_dist",
"is_in_paddle_launch_dist",
]

import os
from typing import Any, Optional, Union

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK

if _NEED_IMPORT_PADDLE:
import paddle

from .utils import apply_to_collection


def paddle_to(data, device: Union[str, int]):

if device == "cpu":
return data.cpu()
else:
return data.cuda(get_paddle_device_id(device))

def get_paddle_gpu_str(device: Union[str, int]):
"""
获得 `gpu:x` 类型的设备名
"""
if isinstance(device, str):
return device.replace("cuda", "gpu")
return f"gpu:{device}"

def get_paddle_device_id(device: Union[str, int]):
"""
获得 gpu 的设备id,注意不要传入 `cpu` 。
"""
if isinstance(device, int):
return device

if device == "cpu":
raise ValueError("Cannot get device id from `cpu`.")

return paddle.device._convert_to_place(device).get_device_id()

def paddle_move_data_to_device(batch: Any, device: Optional[str] = None,
data_device: Optional[str] = None) -> Any:
r"""
将数据集合传输到给定设备。只有paddle.Tensor对象会被传输到设备中,其余保持不变

:param batch:
:param device: `cpu`, `gpu` or `gpu:x`
:param data_device:
:return: 相同的集合,但所有包含的张量都驻留在新设备上;
"""
if device is None:
if data_device is not None:
device = data_device
else:
return batch

def batch_to(data: Any) -> Any:
return paddle_to(data, device)

return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to)

def is_in_paddle_dist():
"""
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断
"""
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ)

def is_in_fnlp_paddle_dist():
"""
判断是否处于 FastNLP 拉起的分布式进程中
"""
return FASTNLP_DISTRIBUTED_CHECK in os.environ

def is_in_paddle_launch_dist():
"""
判断是否处于 launch 启动的分布式进程中
"""
return 'PADDLE_RANK_IN_NODE' in os.environ and \
'FLAGS_selected_gpus' in os.environ and \
FASTNLP_DISTRIBUTED_CHECK not in os.environ

+ 214
- 0
fastNLP/core/utils/rich_progress.py View File

@@ -0,0 +1,214 @@
"""
该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能
不冲突

"""
import sys
from typing import Any, Union, Optional

from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn

__all__ = [
'f_rich_progress'
]

from fastNLP.envs import get_global_rank


class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


# 如果不打印的时候,使得整个 progress 没有任何意义
class DummyFRichProgress:
def __getattr__(self, item):
return DummyFRichProgress()

def __call__(self, *args, **kwargs):
# 防止用户通过 DummyFRichProgress.console.print() 这种调用
return None


class FRichProgress(Progress, metaclass=Singleton):
"""
fastNLP 使用的 progress bar ,新增了 new_progress 函数,通过此函数即可定制 fastNLP 中所有 progress 的样式。

"""

def new_progess(self, *columns: Union[str, ProgressColumn],
console: Optional[Console] = None,
auto_refresh: bool = True,
refresh_per_second: float = 10,
speed_estimate_period: float = 30.0,
transient: bool = True,
redirect_stdout: bool = True,
redirect_stderr: bool = True,
get_time: Optional[GetTimeCallable] = None,
disable: bool = False,
expand: bool = False):
"""
重新初始化一个rich bar。如果columns不传入,则继续使用之前的column内容。

:param progress:
:return:
"""
for task_id in self.task_ids: # 首先移除已有的
self.remove_task(task_id)

assert (
refresh_per_second is None or refresh_per_second > 0
), "refresh_per_second must be > 0"

# stop previous columns
self.stop()

# do not change these variables
# self._lock = RLock()
# self._tasks: Dict[TaskID, Task] = {}
# self._task_index: TaskID = TaskID(0)

if len(columns) != 0:
self.columns = columns

self.speed_estimate_period = speed_estimate_period

self.disable = disable
self.expand = expand

self.live = Live(
console=console or get_console(),
auto_refresh=auto_refresh,
refresh_per_second=refresh_per_second,
transient=transient,
redirect_stdout=redirect_stdout,
redirect_stderr=redirect_stderr,
get_renderable=self.get_renderable,
)
self.get_time = get_time or self.console.get_time
self.print = self.console.print
self.log = self.console.log

# start new
self.start()
return self

def set_transient(self, transient: bool = True):
"""
设置是否在bar运行结束之后不关闭

:param transient:
:return:
"""
self.new_progess(transient=transient)

def set_disable(self, flag: bool = True):
"""
设置当前 progress bar 的状态,如果为 True ,则不会显示进度条了。

:param flag:
:return:
"""
self.disable = flag

def add_task(
self,
description: str,
start: bool = True,
total: float = 100.0,
completed: int = 0,
visible: bool = True,
**fields: Any,
) -> TaskID:
if self.live._started is False:
self.start()
post_desc = fields.pop('post_desc', '')
return super().add_task(description=description,
start=start,
total=total,
completed=completed,
visible=visible,
post_desc=post_desc,
**fields)

def stop_task(self, task_id: TaskID) -> None:
if task_id in self._tasks:
super().stop_task(task_id)

def remove_task(self, task_id: TaskID) -> None:
if task_id in self._tasks:
super().remove_task(task_id)

def destroy_task(self, task_id: TaskID):
if task_id in self._tasks:
super().stop_task(task_id)
super().remove_task(task_id)


if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess(
"[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(),
TimeElapsedColumn(),
"/",
TimeRemainingColumn(),
TextColumn("{task.fields[post_desc]}", justify="right"),
transient=True,
disable=False,
speed_estimate_period=10
)
else:
f_rich_progress = DummyFRichProgress()


if __name__ == '__main__':
f = DummyFRichProgress()
f.console.print('xxx')
f.console.print.print('xxx')
# 测试创建
import time

n_steps = 10

task_id = f_rich_progress.add_task(description='test', total=n_steps)
for i in range(n_steps):
f_rich_progress.update(task_id, description=f'test:{i}', advance=1, refresh=True)
print(f"test:{i}")
time.sleep(0.3)
f_rich_progress.remove_task(task_id)

# 测试一下 inner/outer
n_steps = 5
f_rich_progress.start()
outer_task_id = f_rich_progress.add_task(description='Outer:', total=n_steps)
inner_task_id = f_rich_progress.add_task(description='Inner:', total=n_steps)
for i in range(n_steps):
f_rich_progress.reset(inner_task_id, total=n_steps)
f_rich_progress.update(outer_task_id, description=f'Outer:{i}', advance=1, refresh=True)
for j in range(n_steps):
f_rich_progress.update(inner_task_id, description=f'Inner:{j}', advance=1, refresh=True,
post_desc='Loss: 0.334332323')
print(f"Outer:{i}, Inner:{j}")
time.sleep(0.3)

# 测试一下修改bar
f_rich_progress = FRichProgress().new_progess(
BarColumn(),
"[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%",
TimeElapsedColumn(),
transient=True)
n_steps = 10
task_id = f_rich_progress.add_task(description='test', total=n_steps)
for i in range(n_steps):
f_rich_progress.update(task_id, description=f'test:{i}', advance=1)
print(f"test:{i}")
time.sleep(0.3)
f_rich_progress.remove_task(task_id)
f_rich_progress.stop()

+ 49
- 0
fastNLP/core/utils/torch_paddle_utils.py View File

@@ -0,0 +1,49 @@
from typing import Any, Optional

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH

if _NEED_IMPORT_PADDLE:
import paddle

if _NEED_IMPORT_TORCH:
import torch

__all__ = [
"torch_paddle_move_data_to_device",
]

from .utils import apply_to_collection
from .paddle_utils import paddle_to


def torch_paddle_move_data_to_device(batch: Any, device: Optional[str] = None, non_blocking: Optional[bool] = True,
data_device: Optional[str] = None) -> Any:
r"""
将数据集合传输到给定设备。只有paddle.Tensor和torch.Tensor对象会被传输到设备中,其余保持不变

:param batch:
:param device:
:param non_blocking:
:param data_device:
:return: 相同的集合,但所有包含的张量都驻留在新设备上;
"""

if device is None:
if data_device is not None:
device = data_device
else:
return batch

torch_device = device.replace("gpu", "cuda")
paddle_device = device.replace("cuda", "gpu")

def batch_to(data: Any) -> Any:
if isinstance(data, torch.Tensor):
data = data.to(torch_device, non_blocking=non_blocking)
elif isinstance(data, paddle.Tensor):
data = paddle_to(data, paddle_device)
return data

return apply_to_collection(batch, dtype=(paddle.Tensor, torch.Tensor), function=batch_to)

+ 63
- 0
fastNLP/core/utils/torch_utils.py View File

@@ -0,0 +1,63 @@
from abc import ABC
from typing import Any, Union, Optional
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch

__all__ = [
'torch_move_data_to_device'
]

from .utils import apply_to_collection


class TorchTransferableDataType(ABC):
"""
A custom type for data that can be moved to a torch device via `.to(...)`.
Example:
>>> isinstance(dict, TorchTransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TorchTransferableDataType)
True
>>> class CustomObject:
... def __init__(self):
... self.x = torch.rand(2, 2)
... def to(self, device):
... self.x = self.x.to(device)
... return self
>>> isinstance(CustomObject(), TorchTransferableDataType)
True
"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is TorchTransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented


def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None,
non_blocking: Optional[bool] = True) -> Any:
r"""
将数据集合传输到给定设备。任何定义方法 “to(device)” 的对象都将被移动并且集合中的所有其他对象将保持不变;

:param batch: 应当迁移的数据;
:param device: 数据应当迁移到的设备;当该参数的值为 None 时,表示迁移数据的操作由用户自己完成,我们不需要经管;
:param non_blocking: pytorch 的迁移数据方法 `to` 的参数;
:return: 相同的集合,但所有包含的张量都驻留在新设备上;
"""
if device is None:
return batch

def batch_to(data: Any) -> Any:
kwargs = dict(non_blocking=non_blocking) if isinstance(data, torch.Tensor) else {}
data_output = data.to(device, **kwargs)
if data_output is not None:
return data_output
# user wrongly implemented the `TransferableDataType` and forgot to return `self`.
return data

dtype = TorchTransferableDataType
return apply_to_collection(batch, dtype=dtype, function=batch_to)

+ 591
- 0
fastNLP/core/utils/utils.py View File

@@ -0,0 +1,591 @@
import inspect
from inspect import Parameter
import dataclasses
import warnings
from dataclasses import is_dataclass
from copy import deepcopy
from collections import defaultdict, OrderedDict
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional
from typing import Tuple, Optional
from time import sleep

try:
from typing import Literal, Final
except ImportError:
from typing_extensions import Literal, Final
import os
from contextlib import contextmanager
from functools import wraps
from prettytable import PrettyTable
import numpy as np
from pathlib import Path

from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_GLOBAL_RANK



__all__ = [
'get_fn_arg_names',
'check_fn_not_empty_params',
'auto_param_call',
'check_user_specific_params',
'dataclass_to_dict',
'match_and_substitute_params',
'apply_to_collection',
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'synchronize_safe_rm',
'synchronize_mkdir'
]


def get_fn_arg_names(fn: Callable) -> List[str]:
r"""
返回一个函数的所有参数的名字;

:param fn: 需要查询的函数;

:return: 一个列表,其中的元素则是查询函数的参数的字符串名字;
"""
return list(inspect.signature(fn).parameters)


def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool:
r"""
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数;
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可;

:param fn: 传入的用以代替 Loop 中 'step' 函数的函数;
:param param_num: 检测的函数的应当的没有默认值的参数的个数;

:return: bool,表示传入的 `batch_step_fn` 是否正确;
"""

if fn is None:
return True
if not callable(fn):
return False
else:
params = inspect.signature(fn).parameters
not_default_params = {}
for _name, _param in params.items():
if _param.default == Parameter.empty:
not_default_params[_name] = _param
return len(not_default_params) == param_num


def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
r"""
1.该函数用来提供给用户根据字符串匹配从而实现自动计算;
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来;
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性;
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同;

:param fn: 用来进行实际计算的函数,其参数可以包含有默认值;
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数;
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取
参数值后,再传给 `fn` 进行实际的运算;
:param mapping: 一个字典,用来更改其前面的字典的键值;

:return: 返回 `fn` 运行的结果;

Examples:
>>> # 1
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred);
>>> batch = {"x": 20, "y": 1}
>>> output = {"pred": 0}
>>> acc = auto_param_call(loss_fn, batch, output)

>>> # 2
>>> def test_fn(x, y, a, b=10):
>>> return x + y + a + b
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240
"""
if signature_fn is not None:
if not callable(signature_fn):
raise ValueError(f"Parameter `signature_fn` should be `Callable`.")
_need_params = OrderedDict(inspect.signature(signature_fn).parameters)
else:
_need_params = OrderedDict(inspect.signature(fn).parameters)
_kwargs = None
for _name, _param in _need_params.items():
if _param.kind == Parameter.VAR_POSITIONAL:
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.")
if _param.kind == Parameter.VAR_KEYWORD:
_kwargs = (_name, _param)

if _kwargs is not None:
_need_params.pop(_kwargs[0])

_default_params = {}
for _name, _param in _need_params.items():
if _param.default != Parameter.empty:
_default_params[_name] = _param.default

if mapping is not None:
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}."

_has_params = {}
duplicate_names = []
for arg in args:
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type."
for _name, _value in arg.items():
if mapping is not None and _name in mapping:
_name = mapping[_name]

if _name not in _has_params:
if _kwargs is not None or _name in _need_params:
_has_params[_name] = _value
# 同一参数对象在两个输入的资源中都出现,造成混淆;
elif _name in _need_params and not (_has_params[_name] is _value):
duplicate_names.append(_name)
if duplicate_names:
raise ValueError(f"The following key present in several inputs:{duplicate_names}")

# 将具有默认值但是没有被输入修改过的参数值传进去;
for _name, _value in _default_params.items():
if _name not in _has_params:
_has_params[_name] = _value

if len(_has_params)<len(_need_params):
miss_params = list(set(_need_params.keys()) - set(_has_params.keys()))
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.")

return fn(**_has_params)


def check_user_specific_params(user_params: Dict, fn: Callable):
"""
该函数使用用户的输入来对指定函数的参数进行赋值;
主要用于一些用户无法直接调用函数的情况;
该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误;

:param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值;
:param fn: 会被调用的函数;
:return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值;
"""

fn_arg_names = get_fn_arg_names(fn)
for arg_name, arg_value in user_params.items():
if arg_name not in fn_arg_names:
logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.")
return user_params


def dataclass_to_dict(data: "dataclass") -> Dict:
if not is_dataclass(data):
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.")
_dict = dict()
for _key in data.__dataclass_fields__:
_dict[_key] = getattr(data, _key)
return _dict


def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any:
r"""
用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能;
该函数应用于 `input_mapping` 和 `output_mapping`;
对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用;
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用;

转换的逻辑按优先级依次为:
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`;
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`];
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key];
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换;
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。

:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。
:param data: 需要被转换的对象;
:return: 返回转换好的结果;
"""
if mapping is None:
return data
if callable(mapping):
# 注意我们在 `Trainer.extract_loss_from_outputs` 函数里会检查 outputs 的输出,outputs 的类型目前只支持 `Dict` 和 `dataclass`;
return mapping(data)

if not isinstance(mapping, Dict):
raise ValueError(
f"Parameter `mapping` should be of type `Dict` or `Callable`, not `{type(mapping)}`. This is caused"
f"by your `input_mapping` or `output_mapping` parameter in your `Trainer` or `Evaluator`.")
if not isinstance(data, Dict) and not is_dataclass(data) and not isinstance(data, Sequence):
raise ValueError("Parameter `data` should be type `Dict` or `dataclass` when the other parameter `mapping` is "
"type `Dict`.")

# 如果 `data` 是一个 dataclass,那么先将其转换为一个 `Dict`;
if is_dataclass(data):
data = dataclass_to_dict(data)
# 如果 `data` 是一个 List,那么我们同样先将其转换为一个 `Dict`,为 {"_0": list[0], "_1": list[1], ...};
elif isinstance(data, Sequence):
data = {"_" + str(i): data[i] for i in range(len(data))}

_new_data = {}
for _name, _value in data.items():
if _name in mapping:
_new_data[mapping[_name]] = _value
else:
_new_data[_name] = _value
return _new_data


def _is_namedtuple(obj: object) -> bool:
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")


def _is_dataclass_instance(obj: object) -> bool:
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)


def apply_to_collection(
data: Any,
dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
include_none: bool = True,
**kwargs: Any,
) -> Any:
"""将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。

this function credit to: https://github.com/PyTorchLightning/pytorch-lightning
Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
wrong_dtype: the given function won't be applied if this type is specified and the given collections
is of the ``wrong_dtype`` even if it is of type ``dtype``
include_none: Whether to include an element if the output of ``function`` is ``None``.
**kwargs: keyword arguments (will be forwarded to calls of ``function``)

Returns:
The resulting collection
"""
# Breaking condition
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
return function(data, *args, **kwargs)

elem_type = type(data)

# Recursively apply to collection items
if isinstance(data, Mapping):
out = []
for k, v in data.items():
v = apply_to_collection(
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append((k, v))
if isinstance(data, defaultdict):
return elem_type(data.default_factory, OrderedDict(out))
return elem_type(OrderedDict(out))

is_namedtuple = _is_namedtuple(data)
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
if is_namedtuple or is_sequence:
out = []
for d in data:
v = apply_to_collection(
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs
)
if include_none or v is not None:
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)

if _is_dataclass_instance(data):
# make a deepcopy of the data,
# but do not deepcopy mapped fields since the computation would
# be wasted on values that likely get immediately overwritten
fields = {}
memo = {}
for field in dataclasses.fields(data):
field_value = getattr(data, field.name)
fields[field.name] = (field_value, field.init)
memo[id(field_value)] = field_value
result = deepcopy(data, memo=memo)
# apply function to each field
for field_name, (field_value, field_init) in fields.items():
if field_init:
v = apply_to_collection(
field_value,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs,
)
if not field_init or (not include_none and v is None): # retain old value
v = getattr(data, field_name)
setattr(result, field_name, v)
return result

# data is neither of dtype, nor a collection
return data


@contextmanager
def nullcontext():
r"""
用来实现一个什么 dummy 的 context 上下文环境;
"""
yield


def sub_column(string: str, c: int, c_size: int, title: str) -> str:
r"""
:param string: 要被截断的字符串
:param c: 命令行列数
:param c_size: instance或dataset field数
:param title: 列名
:return: 对一个过长的列进行截断的结果
"""
avg = max(int(c / c_size / 2), len(title))
string = str(string)
res = ""
counter = 0
for char in string:
if ord(char) > 255:
counter += 2
else:
counter += 1
res += char
if counter > avg:
res = res + "..."
break
return res


def _is_iterable(value):
# 检查是否是iterable的, duck typing
try:
iter(value)
return True
except BaseException as e:
return False


def pretty_table_printer(dataset_or_ins) -> PrettyTable:
r"""
:param dataset_or_ins: 传入一个dataSet或者instance
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
+-----------+-----------+-----------------+
| field_1 | field_2 | field_3 |
+-----------+-----------+-----------------+
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
+-----------+-----------+-----------------+
:return: 以 pretty table的形式返回根据terminal大小进行自动截断
"""
x = PrettyTable()
try:
sz = os.get_terminal_size()
column = sz.columns
row = sz.lines
except OSError:
column = 144
row = 11

if type(dataset_or_ins).__name__ == "DataSet":
x.field_names = list(dataset_or_ins.field_arrays.keys())
c_size = len(x.field_names)
for ins in dataset_or_ins:
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names])
row -= 1
if row < 0:
x.add_row(["..." for _ in range(c_size)])
break
elif type(dataset_or_ins).__name__ == "Instance":
x.field_names = list(dataset_or_ins.fields.keys())
c_size = len(x.field_names)
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names])

else:
raise Exception("only accept DataSet and Instance")
x.align = "l"

return x


class Option(dict):
r"""a dict can treat keys as attributes"""

def __getattr__(self, item):
try:
return self.__getitem__(item)
except KeyError:
raise AttributeError(item)

def __setattr__(self, key, value):
if key.startswith('__') and key.endswith('__'):
raise AttributeError(key)
self.__setitem__(key, value)

def __delattr__(self, item):
try:
self.pop(item)
except KeyError:
raise AttributeError(item)

def __getstate__(self):
return self

def __setstate__(self, state):
self.update(state)


def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。

:param func: 需要修饰的函数
:return:
"""

def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)

return wrapper


_emitted_deprecation_warnings = set()


def deprecated(help_message: Optional[str] = None):
"""Decorator to mark a function as deprecated.

Args:
help_message (`Optional[str]`): An optional message to guide the user on how to
switch to non-deprecated usage of the library.
"""

def decorator(deprecated_function: Callable):
global _emitted_deprecation_warnings
warning_msg = (
(
f"{deprecated_function.__name__} is deprecated and will be removed "
"in the next major version of datasets."
)
+ f" {help_message}"
if help_message
else ""
)

@wraps(deprecated_function)
def wrapper(*args, **kwargs):
func_hash = hash(deprecated_function)
if func_hash not in _emitted_deprecation_warnings:
warnings.warn(warning_msg, category=FutureWarning, stacklevel=2)
_emitted_deprecation_warnings.add(func_hash)
return deprecated_function(*args, **kwargs)

wrapper._decorator_name_ = "deprecated"
return wrapper

return decorator


def seq_len_to_mask(seq_len, max_len=None):
r"""

将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。
转变 1-d seq_len到2-d mask.

.. code-block::

>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.size())
torch.Size([14, 15])
>>> seq_len = np.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.shape)
(14, 15)
>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len, max_len=100)
>>>print(mask.size())
torch.Size([14, 100])

:param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8
"""
if isinstance(seq_len, np.ndarray):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
max_len = int(max_len) if max_len else int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)

else:
raise TypeError("Only support 1-d numpy.ndarray.")

return mask


def wait_to_success(fn, no=False):
while True:
sleep(0.01)
if (no and not fn()) or (not no and fn()):
break


# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
def synchronize_safe_rm(path: Optional[Union[str, Path]]):
if path is None:
return
if isinstance(path, str):
path = Path(path)
if not path.exists():
return
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
_recursive_rm(path)
wait_to_success(path.exists, no=True)


def _recursive_rm(path: Path):
if path.is_file() or path.is_symlink():
if path.exists():
try:
path.unlink()
except Exception:
pass
return
for sub_path in list(path.iterdir()):
_recursive_rm(sub_path)
path.rmdir()


def synchronize_mkdir(path: Optional[Union[str, Path]]):
"""
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数;
"""
if path is None:
return
if isinstance(path, str):
path = Path(path)

if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
path.mkdir(parents=True, exist_ok=True)

wait_to_success(path.exists)




Loading…
Cancel
Save