diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py new file mode 100644 index 00000000..1d1c9d16 --- /dev/null +++ b/fastNLP/core/utils/__init__.py @@ -0,0 +1,43 @@ +__all__ = [ + 'cache_results', + 'is_jittor_dataset', + 'jittor_collate_wraps', + 'paddle_to', + 'paddle_move_data_to_device', + 'get_paddle_device_id', + 'get_paddle_gpu_str', + 'is_in_paddle_dist', + 'is_in_fnlp_paddle_dist', + 'is_in_paddle_launch_dist', + 'f_rich_progress', + 'torch_paddle_move_data_to_device', + 'torch_move_data_to_device', + 'get_fn_arg_names', + 'check_fn_not_empty_params', + 'auto_param_call', + 'check_user_specific_params', + 'dataclass_to_dict', + 'match_and_substitute_params', + 'apply_to_collection', + 'nullcontext', + 'pretty_table_printer', + 'Option', + 'indice_collate_wrapper', + 'deprecated', + 'seq_len_to_mask', + 'synchronize_safe_rm', + 'synchronize_mkdir' +] + +from .cache_results import cache_results +from .jittor_utils import is_jittor_dataset, jittor_collate_wraps +from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ + is_in_fnlp_paddle_dist, is_in_paddle_launch_dist +from .rich_progress import f_rich_progress +from .torch_paddle_utils import torch_paddle_move_data_to_device +from .torch_utils import torch_move_data_to_device +from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ + dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ + indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir + + diff --git a/fastNLP/core/utils/cache_results.py b/fastNLP/core/utils/cache_results.py new file mode 100644 index 00000000..ff253f3e --- /dev/null +++ b/fastNLP/core/utils/cache_results.py @@ -0,0 +1,310 @@ +from datetime import datetime +import hashlib +import _pickle +import functools +import os +from typing import Callable, List, Any, Optional +import inspect +import ast +from collections import deque + +__all__ = [ + 'cache_results' +] + +from fastNLP.core.log.logger import logger +from fastNLP.core.log.highlighter import ColorHighlighter + + +class FuncCallVisitor(ast.NodeVisitor): + # credit to https://gist.github.com/jargnar/0946ab1d985e2b4ab776 + def __init__(self): + self._name = deque() + + @property + def name(self): + return '.'.join(self._name) + + @name.deleter + def name(self): + self._name.clear() + + def visit_Name(self, node): + self._name.appendleft(node.id) + + def visit_Attribute(self, node): + try: + self._name.appendleft(node.attr) + self._name.appendleft(node.value.id) + except AttributeError: + self.generic_visit(node) + + +def get_func_calls(tree): + func_calls = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call): + callvisitor = FuncCallVisitor() + callvisitor.visit(node.func) + func_calls.append(callvisitor.name) + if isinstance(node, ast.FunctionDef): + if not (node is tree): + func_calls.extend(get_func_calls(node)) + + return func_calls + + +def truncate_start_blanks(source:str)->str: + """ + 将source中的每一行按照第一行的indent删掉多余的空格 + + :param source: + :return: + """ + lines = source.split('\n') + num_blank = 0 + # get the top blank line + for line in lines: + if line: + num_blank = len(line) - len(line.lstrip()) + new_lines = [] + for line in lines: + i = -1 + for i in range(min(len(line), num_blank)): + if line[i] == ' ': + continue + else: + break + line = line[i:] + new_lines.append(line) + return '\n'.join(new_lines) + + +def _get_func_and_its_called_func_source_code(func) -> List[str]: + """ + 给定一个func,返回在这个函数里面用到的所有函数的源码。 + + :param callable func: + :return: + """ + last_frame = inspect.currentframe().f_back.f_back.f_back + last_frame_f_local = last_frame.f_locals + last_frame_loc = {} + if 'loc' in last_frame_f_local: + last_frame_loc = last_frame_f_local['loc'] + func_calls = list(set(get_func_calls(ast.parse(truncate_start_blanks(inspect.getsource(func)))))) + func_calls.sort() + sources = [] + for _func_name in func_calls: + try: + if _func_name == 'cache_results': # ignore the decorator + continue + if '.' in _func_name: + _funcs = _func_name.split('.') + else: + _funcs = [_func_name] + if _funcs[0] in last_frame_f_local or _funcs[0] in last_frame_loc: + tmp = _funcs.pop(0) + variable = last_frame_f_local.get(tmp, last_frame_loc.get(tmp)) + while len(_funcs) or variable is not None: + if hasattr(variable, '__class__') and not inspect.isbuiltin(variable.__class__): + try: + sources.append(inspect.getsource(variable.__class__)) + except TypeError: + pass + if callable(variable) or inspect.isclass(variable): + sources.append(inspect.getsource(variable)) + if len(_funcs): + tmp = _funcs.pop(0) + if hasattr(variable, tmp): + variable = getattr(variable, tmp) + else: + break + else: + variable = None + except: + # some failure + pass + del last_frame # + sources.append(inspect.getsource(func)) + return sources + + +def _prepare_cache_filepath(filepath:str): + r""" + 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 + + :param filepath: str. + :return: None, if not, this function will raise error + """ + _cache_filepath = os.path.abspath(filepath) + if os.path.isdir(_cache_filepath): + raise RuntimeError("The cache_file_path must be a file, not a directory.") + cache_dir = os.path.dirname(_cache_filepath) + if not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + + +class Hasher: + def __init__(self): + self.m = hashlib.sha1() + + def update(self, value: Any) -> None: + if isinstance(value, str): + value = [value] + for x in value: + self.m.update(x.encode('utf8')) + + def hexdigest(self) -> str: + return self.m.hexdigest() + + +def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = None): + if fn_kwargs is None: + fn_kwargs = {} + hasher = Hasher() + try: + sources = _get_func_and_its_called_func_source_code(fn) + hasher.update(sources) + except: + return "can't be hashed" + for key in sorted(fn_kwargs): + hasher.update(key) + try: + hasher.update(fn_kwargs[key]) + except: + pass + return hasher.hexdigest() + + +def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): + r""" + cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: + + import time + import numpy as np + from fastNLP import cache_results + + @cache_results('cache.pkl') + def process_data(): + # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 + time.sleep(1) + return np.random.randint(10, size=(5,)) + + start_time = time.time() + print("res =",process_data()) + print(time.time() - start_time) + + start_time = time.time() + print("res =",process_data()) + print(time.time() - start_time) + + # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 + # Save cache to cache.pkl. + # res = [5 4 9 1 8] + # 1.0042750835418701 + # Read cache from cache.pkl. + # res = [5 4 9 1 8] + # 0.0040721893310546875 + + 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: + + # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 + process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' + + 上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 + 'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 + 上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: + + process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 + # _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache + + :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 + 函数调用的时候传入_cache_fp这个参数。 + :param bool _refresh: 是否重新生成cache。 + :param int _verbose: 是否打印cache的信息。 + :param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 + 与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 + 该修改对结果有影响,但无法做出warning。 + + :return: + """ + + def wrapper_(func): + signature = inspect.signature(func) + for key, _ in signature.parameters.items(): + if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'): + raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + fn_param = kwargs.copy() + if args: + params = [p.name for p in inspect.signature(func).parameters.values()] + fn_param.update(zip(params, args)) + if '_cache_fp' in kwargs: + cache_filepath = kwargs.pop('_cache_fp') + assert isinstance(cache_filepath, str), "_cache_fp can only be str." + else: + cache_filepath = _cache_fp + if '_refresh' in kwargs: + refresh = kwargs.pop('_refresh') + assert isinstance(refresh, bool), "_refresh can only be bool." + else: + refresh = _refresh + if '_verbose' in kwargs: + verbose = kwargs.pop('_verbose') + assert isinstance(verbose, int), "_verbose can only be integer." + else: + verbose = _verbose + + if '_check_hash' in kwargs: + check_hash = kwargs.pop('_check_hash') + else: + check_hash = _check_hash + + refresh_flag = True + new_hash_code = None + if check_hash: + new_hash_code = cal_fn_hash_code(func, fn_param) + + if cache_filepath is not None and refresh is False: + # load data + if os.path.exists(cache_filepath): + cache_filepath = os.path.abspath(cache_filepath) + with open(cache_filepath, 'rb') as f: + results = _pickle.load(f) + old_hash_code = results['hash'] + save_time = results['save_time'] + results = results['results'] + if verbose == 1: + logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) + if check_hash and old_hash_code != new_hash_code: + logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " + f"difference may caused by the sourcecode change of the functions by this function.", + extra={'highlighter': ColorHighlighter('red')}) + refresh_flag = False + + if refresh_flag: + if new_hash_code is None: + new_hash_code = cal_fn_hash_code(func, fn_param) + results = func(*args, **kwargs) + if cache_filepath is not None: + if results is None: + raise RuntimeError("The return value is None. Cannot save None results.") + cache_filepath = os.path.abspath(cache_filepath) + _prepare_cache_filepath(cache_filepath) + _dict = { + 'results': results, + 'hash': new_hash_code, + 'save_time': datetime.now(), + } + with open(cache_filepath, 'wb') as f: + _pickle.dump(_dict, f) + logger.info("Save cache to {}.".format(cache_filepath)) + + return results + + return wrapper + + return wrapper_ \ No newline at end of file diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py new file mode 100644 index 00000000..2e97c3e4 --- /dev/null +++ b/fastNLP/core/utils/dummy_class.py @@ -0,0 +1,4 @@ + + +class DummyClass: + pass \ No newline at end of file diff --git a/fastNLP/core/utils/jittor_utils.py b/fastNLP/core/utils/jittor_utils.py new file mode 100644 index 00000000..3784f991 --- /dev/null +++ b/fastNLP/core/utils/jittor_utils.py @@ -0,0 +1,51 @@ +__all__ = [ + 'is_jittor_dataset', + 'jittor_collate_wraps' +] + +from collections.abc import Mapping, Callable +from functools import wraps + +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +if _NEED_IMPORT_JITTOR: + import jittor as jt + +from fastNLP.core.dataset import Instance + + + +def is_jittor_dataset(dataset) -> bool: + try: + if isinstance(dataset, jt.dataset.Dataset): + return True + else: + return False + except BaseException: + return False + + +def jittor_collate_wraps(func, auto_collator: Callable): + """ + 对jittor的collate_fn进行wrap封装, 如果数据集为mapping类型,那么采用auto_collator,否则还是采用jittor自带的collate_batch + + :param func: + :param auto_collator: + :return: + """ + @wraps(func) + def wrapper(batch): + if isinstance(batch[0], Instance): + if auto_collator is not None: + result = auto_collator(batch) + else: + raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!") + elif isinstance(batch[0], Mapping): + if auto_collator is not None: + result = auto_collator(batch) + else: + result = func(batch) + else: + result = func(batch) + return result + + return wrapper diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py new file mode 100644 index 00000000..8af6efc9 --- /dev/null +++ b/fastNLP/core/utils/paddle_utils.py @@ -0,0 +1,89 @@ +__all__ = [ + "paddle_to", + "paddle_move_data_to_device", + "get_paddle_gpu_str", + "get_paddle_device_id", + "is_in_paddle_dist", + "is_in_fnlp_paddle_dist", + "is_in_paddle_launch_dist", +] + +import os +from typing import Any, Optional, Union + +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK + +if _NEED_IMPORT_PADDLE: + import paddle + +from .utils import apply_to_collection + + +def paddle_to(data, device: Union[str, int]): + + if device == "cpu": + return data.cpu() + else: + return data.cuda(get_paddle_device_id(device)) + +def get_paddle_gpu_str(device: Union[str, int]): + """ + 获得 `gpu:x` 类型的设备名 + """ + if isinstance(device, str): + return device.replace("cuda", "gpu") + return f"gpu:{device}" + +def get_paddle_device_id(device: Union[str, int]): + """ + 获得 gpu 的设备id,注意不要传入 `cpu` 。 + """ + if isinstance(device, int): + return device + + if device == "cpu": + raise ValueError("Cannot get device id from `cpu`.") + + return paddle.device._convert_to_place(device).get_device_id() + +def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, + data_device: Optional[str] = None) -> Any: + r""" + 将数据集合传输到给定设备。只有paddle.Tensor对象会被传输到设备中,其余保持不变 + + :param batch: + :param device: `cpu`, `gpu` or `gpu:x` + :param data_device: + :return: 相同的集合,但所有包含的张量都驻留在新设备上; + """ + if device is None: + if data_device is not None: + device = data_device + else: + return batch + + def batch_to(data: Any) -> Any: + return paddle_to(data, device) + + return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) + +def is_in_paddle_dist(): + """ + 判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 + """ + return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) + +def is_in_fnlp_paddle_dist(): + """ + 判断是否处于 FastNLP 拉起的分布式进程中 + """ + return FASTNLP_DISTRIBUTED_CHECK in os.environ + +def is_in_paddle_launch_dist(): + """ + 判断是否处于 launch 启动的分布式进程中 + """ + return 'PADDLE_RANK_IN_NODE' in os.environ and \ + 'FLAGS_selected_gpus' in os.environ and \ + FASTNLP_DISTRIBUTED_CHECK not in os.environ \ No newline at end of file diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py new file mode 100644 index 00000000..20330d02 --- /dev/null +++ b/fastNLP/core/utils/rich_progress.py @@ -0,0 +1,214 @@ +""" +该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能 + 不冲突 + +""" +import sys +from typing import Any, Union, Optional + +from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live +from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn + +__all__ = [ + 'f_rich_progress' +] + +from fastNLP.envs import get_global_rank + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +# 如果不打印的时候,使得整个 progress 没有任何意义 +class DummyFRichProgress: + def __getattr__(self, item): + return DummyFRichProgress() + + def __call__(self, *args, **kwargs): + # 防止用户通过 DummyFRichProgress.console.print() 这种调用 + return None + + +class FRichProgress(Progress, metaclass=Singleton): + """ + fastNLP 使用的 progress bar ,新增了 new_progress 函数,通过此函数即可定制 fastNLP 中所有 progress 的样式。 + + """ + + def new_progess(self, *columns: Union[str, ProgressColumn], + console: Optional[Console] = None, + auto_refresh: bool = True, + refresh_per_second: float = 10, + speed_estimate_period: float = 30.0, + transient: bool = True, + redirect_stdout: bool = True, + redirect_stderr: bool = True, + get_time: Optional[GetTimeCallable] = None, + disable: bool = False, + expand: bool = False): + """ + 重新初始化一个rich bar。如果columns不传入,则继续使用之前的column内容。 + + :param progress: + :return: + """ + for task_id in self.task_ids: # 首先移除已有的 + self.remove_task(task_id) + + assert ( + refresh_per_second is None or refresh_per_second > 0 + ), "refresh_per_second must be > 0" + + # stop previous columns + self.stop() + + # do not change these variables + # self._lock = RLock() + # self._tasks: Dict[TaskID, Task] = {} + # self._task_index: TaskID = TaskID(0) + + if len(columns) != 0: + self.columns = columns + + self.speed_estimate_period = speed_estimate_period + + self.disable = disable + self.expand = expand + + self.live = Live( + console=console or get_console(), + auto_refresh=auto_refresh, + refresh_per_second=refresh_per_second, + transient=transient, + redirect_stdout=redirect_stdout, + redirect_stderr=redirect_stderr, + get_renderable=self.get_renderable, + ) + self.get_time = get_time or self.console.get_time + self.print = self.console.print + self.log = self.console.log + + # start new + self.start() + return self + + def set_transient(self, transient: bool = True): + """ + 设置是否在bar运行结束之后不关闭 + + :param transient: + :return: + """ + self.new_progess(transient=transient) + + def set_disable(self, flag: bool = True): + """ + 设置当前 progress bar 的状态,如果为 True ,则不会显示进度条了。 + + :param flag: + :return: + """ + self.disable = flag + + def add_task( + self, + description: str, + start: bool = True, + total: float = 100.0, + completed: int = 0, + visible: bool = True, + **fields: Any, + ) -> TaskID: + if self.live._started is False: + self.start() + post_desc = fields.pop('post_desc', '') + return super().add_task(description=description, + start=start, + total=total, + completed=completed, + visible=visible, + post_desc=post_desc, + **fields) + + def stop_task(self, task_id: TaskID) -> None: + if task_id in self._tasks: + super().stop_task(task_id) + + def remove_task(self, task_id: TaskID) -> None: + if task_id in self._tasks: + super().remove_task(task_id) + + def destroy_task(self, task_id: TaskID): + if task_id in self._tasks: + super().stop_task(task_id) + super().remove_task(task_id) + + +if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: + f_rich_progress = FRichProgress().new_progess( + "[progress.description]{task.description}", + "[progress.percentage]{task.percentage:>3.0f}%", + BarColumn(), + TimeElapsedColumn(), + "/", + TimeRemainingColumn(), + TextColumn("{task.fields[post_desc]}", justify="right"), + transient=True, + disable=False, + speed_estimate_period=10 + ) +else: + f_rich_progress = DummyFRichProgress() + + +if __name__ == '__main__': + f = DummyFRichProgress() + f.console.print('xxx') + f.console.print.print('xxx') + # 测试创建 + import time + + n_steps = 10 + + task_id = f_rich_progress.add_task(description='test', total=n_steps) + for i in range(n_steps): + f_rich_progress.update(task_id, description=f'test:{i}', advance=1, refresh=True) + print(f"test:{i}") + time.sleep(0.3) + f_rich_progress.remove_task(task_id) + + # 测试一下 inner/outer + n_steps = 5 + f_rich_progress.start() + outer_task_id = f_rich_progress.add_task(description='Outer:', total=n_steps) + inner_task_id = f_rich_progress.add_task(description='Inner:', total=n_steps) + for i in range(n_steps): + f_rich_progress.reset(inner_task_id, total=n_steps) + f_rich_progress.update(outer_task_id, description=f'Outer:{i}', advance=1, refresh=True) + for j in range(n_steps): + f_rich_progress.update(inner_task_id, description=f'Inner:{j}', advance=1, refresh=True, + post_desc='Loss: 0.334332323') + print(f"Outer:{i}, Inner:{j}") + time.sleep(0.3) + + # 测试一下修改bar + f_rich_progress = FRichProgress().new_progess( + BarColumn(), + "[progress.description]{task.description}", + "[progress.percentage]{task.percentage:>3.0f}%", + TimeElapsedColumn(), + transient=True) + n_steps = 10 + task_id = f_rich_progress.add_task(description='test', total=n_steps) + for i in range(n_steps): + f_rich_progress.update(task_id, description=f'test:{i}', advance=1) + print(f"test:{i}") + time.sleep(0.3) + f_rich_progress.remove_task(task_id) + f_rich_progress.stop() diff --git a/fastNLP/core/utils/torch_paddle_utils.py b/fastNLP/core/utils/torch_paddle_utils.py new file mode 100644 index 00000000..9201548d --- /dev/null +++ b/fastNLP/core/utils/torch_paddle_utils.py @@ -0,0 +1,49 @@ +from typing import Any, Optional + +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH + +if _NEED_IMPORT_PADDLE: + import paddle + +if _NEED_IMPORT_TORCH: + import torch + +__all__ = [ + "torch_paddle_move_data_to_device", +] + +from .utils import apply_to_collection +from .paddle_utils import paddle_to + + +def torch_paddle_move_data_to_device(batch: Any, device: Optional[str] = None, non_blocking: Optional[bool] = True, + data_device: Optional[str] = None) -> Any: + + r""" + 将数据集合传输到给定设备。只有paddle.Tensor和torch.Tensor对象会被传输到设备中,其余保持不变 + + :param batch: + :param device: + :param non_blocking: + :param data_device: + :return: 相同的集合,但所有包含的张量都驻留在新设备上; + """ + + if device is None: + if data_device is not None: + device = data_device + else: + return batch + + torch_device = device.replace("gpu", "cuda") + paddle_device = device.replace("cuda", "gpu") + + def batch_to(data: Any) -> Any: + if isinstance(data, torch.Tensor): + data = data.to(torch_device, non_blocking=non_blocking) + elif isinstance(data, paddle.Tensor): + data = paddle_to(data, paddle_device) + + return data + + return apply_to_collection(batch, dtype=(paddle.Tensor, torch.Tensor), function=batch_to) \ No newline at end of file diff --git a/fastNLP/core/utils/torch_utils.py b/fastNLP/core/utils/torch_utils.py new file mode 100644 index 00000000..9dea93dd --- /dev/null +++ b/fastNLP/core/utils/torch_utils.py @@ -0,0 +1,63 @@ +from abc import ABC +from typing import Any, Union, Optional +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + +__all__ = [ + 'torch_move_data_to_device' +] + +from .utils import apply_to_collection + + +class TorchTransferableDataType(ABC): + """ + A custom type for data that can be moved to a torch device via `.to(...)`. + Example: + >>> isinstance(dict, TorchTransferableDataType) + False + >>> isinstance(torch.rand(2, 3), TorchTransferableDataType) + True + >>> class CustomObject: + ... def __init__(self): + ... self.x = torch.rand(2, 2) + ... def to(self, device): + ... self.x = self.x.to(device) + ... return self + >>> isinstance(CustomObject(), TorchTransferableDataType) + True + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is TorchTransferableDataType: + to = getattr(subclass, "to", None) + return callable(to) + return NotImplemented + + +def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, + non_blocking: Optional[bool] = True) -> Any: + r""" + 将数据集合传输到给定设备。任何定义方法 “to(device)” 的对象都将被移动并且集合中的所有其他对象将保持不变; + + :param batch: 应当迁移的数据; + :param device: 数据应当迁移到的设备;当该参数的值为 None 时,表示迁移数据的操作由用户自己完成,我们不需要经管; + :param non_blocking: pytorch 的迁移数据方法 `to` 的参数; + :return: 相同的集合,但所有包含的张量都驻留在新设备上; + """ + if device is None: + return batch + + def batch_to(data: Any) -> Any: + kwargs = dict(non_blocking=non_blocking) if isinstance(data, torch.Tensor) else {} + data_output = data.to(device, **kwargs) + if data_output is not None: + return data_output + # user wrongly implemented the `TransferableDataType` and forgot to return `self`. + return data + + dtype = TorchTransferableDataType + return apply_to_collection(batch, dtype=dtype, function=batch_to) diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py new file mode 100644 index 00000000..66159f24 --- /dev/null +++ b/fastNLP/core/utils/utils.py @@ -0,0 +1,591 @@ +import inspect +from inspect import Parameter +import dataclasses +import warnings +from dataclasses import is_dataclass +from copy import deepcopy +from collections import defaultdict, OrderedDict +from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional +from typing import Tuple, Optional +from time import sleep + +try: + from typing import Literal, Final +except ImportError: + from typing_extensions import Literal, Final +import os +from contextlib import contextmanager +from functools import wraps +from prettytable import PrettyTable +import numpy as np +from pathlib import Path + +from fastNLP.core.log import logger +from fastNLP.envs import FASTNLP_GLOBAL_RANK + + + +__all__ = [ + 'get_fn_arg_names', + 'check_fn_not_empty_params', + 'auto_param_call', + 'check_user_specific_params', + 'dataclass_to_dict', + 'match_and_substitute_params', + 'apply_to_collection', + 'nullcontext', + 'pretty_table_printer', + 'Option', + 'indice_collate_wrapper', + 'deprecated', + 'seq_len_to_mask', + 'synchronize_safe_rm', + 'synchronize_mkdir' +] + + +def get_fn_arg_names(fn: Callable) -> List[str]: + r""" + 返回一个函数的所有参数的名字; + + :param fn: 需要查询的函数; + + :return: 一个列表,其中的元素则是查询函数的参数的字符串名字; + """ + return list(inspect.signature(fn).parameters) + + +def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: + r""" + 检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; + 用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; + + :param fn: 传入的用以代替 Loop 中 'step' 函数的函数; + :param param_num: 检测的函数的应当的没有默认值的参数的个数; + + :return: bool,表示传入的 `batch_step_fn` 是否正确; + """ + + if fn is None: + return True + if not callable(fn): + return False + else: + params = inspect.signature(fn).parameters + not_default_params = {} + for _name, _param in params.items(): + if _param.default == Parameter.empty: + not_default_params[_name] = _param + return len(not_default_params) == param_num + + +def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, + mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: + r""" + 1.该函数用来提供给用户根据字符串匹配从而实现自动计算; + 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; + 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; + 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; + 4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; + + :param fn: 用来进行实际计算的函数,其参数可以包含有默认值; + :param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; + :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 + 参数值后,再传给 `fn` 进行实际的运算; + :param mapping: 一个字典,用来更改其前面的字典的键值; + + :return: 返回 `fn` 运行的结果; + + Examples: + >>> # 1 + >>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); + >>> batch = {"x": 20, "y": 1} + >>> output = {"pred": 0} + >>> acc = auto_param_call(loss_fn, batch, output) + + >>> # 2 + >>> def test_fn(x, y, a, b=10): + >>> return x + y + a + b + >>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 + >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 + >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 + """ + if signature_fn is not None: + if not callable(signature_fn): + raise ValueError(f"Parameter `signature_fn` should be `Callable`.") + _need_params = OrderedDict(inspect.signature(signature_fn).parameters) + else: + _need_params = OrderedDict(inspect.signature(fn).parameters) + _kwargs = None + for _name, _param in _need_params.items(): + if _param.kind == Parameter.VAR_POSITIONAL: + raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") + if _param.kind == Parameter.VAR_KEYWORD: + _kwargs = (_name, _param) + + if _kwargs is not None: + _need_params.pop(_kwargs[0]) + + _default_params = {} + for _name, _param in _need_params.items(): + if _param.default != Parameter.empty: + _default_params[_name] = _param.default + + if mapping is not None: + assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." + + _has_params = {} + duplicate_names = [] + for arg in args: + assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." + for _name, _value in arg.items(): + if mapping is not None and _name in mapping: + _name = mapping[_name] + + if _name not in _has_params: + if _kwargs is not None or _name in _need_params: + _has_params[_name] = _value + # 同一参数对象在两个输入的资源中都出现,造成混淆; + elif _name in _need_params and not (_has_params[_name] is _value): + duplicate_names.append(_name) + if duplicate_names: + raise ValueError(f"The following key present in several inputs:{duplicate_names}") + + # 将具有默认值但是没有被输入修改过的参数值传进去; + for _name, _value in _default_params.items(): + if _name not in _has_params: + _has_params[_name] = _value + + if len(_has_params) Dict: + if not is_dataclass(data): + raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") + _dict = dict() + for _key in data.__dataclass_fields__: + _dict[_key] = getattr(data, _key) + return _dict + + +def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: + r""" + 用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; + 该函数应用于 `input_mapping` 和 `output_mapping`; + 对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; + 对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; + + 转换的逻辑按优先级依次为: + 1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; + 2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; + 如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; + 如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; + 如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 + mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 + + :param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 + :param data: 需要被转换的对象; + :return: 返回转换好的结果; + """ + if mapping is None: + return data + if callable(mapping): + # 注意我们在 `Trainer.extract_loss_from_outputs` 函数里会检查 outputs 的输出,outputs 的类型目前只支持 `Dict` 和 `dataclass`; + return mapping(data) + + if not isinstance(mapping, Dict): + raise ValueError( + f"Parameter `mapping` should be of type `Dict` or `Callable`, not `{type(mapping)}`. This is caused" + f"by your `input_mapping` or `output_mapping` parameter in your `Trainer` or `Evaluator`.") + if not isinstance(data, Dict) and not is_dataclass(data) and not isinstance(data, Sequence): + raise ValueError("Parameter `data` should be type `Dict` or `dataclass` when the other parameter `mapping` is " + "type `Dict`.") + + # 如果 `data` 是一个 dataclass,那么先将其转换为一个 `Dict`; + if is_dataclass(data): + data = dataclass_to_dict(data) + # 如果 `data` 是一个 List,那么我们同样先将其转换为一个 `Dict`,为 {"_0": list[0], "_1": list[1], ...}; + elif isinstance(data, Sequence): + data = {"_" + str(i): data[i] for i in range(len(data))} + + _new_data = {} + for _name, _value in data.items(): + if _name in mapping: + _new_data[mapping[_name]] = _value + else: + _new_data[_name] = _value + return _new_data + + +def _is_namedtuple(obj: object) -> bool: + # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 + return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + + +def _is_dataclass_instance(obj: object) -> bool: + # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions + return dataclasses.is_dataclass(obj) and not isinstance(obj, type) + + +def apply_to_collection( + data: Any, + dtype: Union[type, Any, Tuple[Union[type, Any]]], + function: Callable, + *args: Any, + wrong_dtype: Optional[Union[type, Tuple[type]]] = None, + include_none: bool = True, + **kwargs: Any, +) -> Any: + """将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 + + this function credit to: https://github.com/PyTorchLightning/pytorch-lightning + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections + is of the ``wrong_dtype`` even if it is of type ``dtype`` + include_none: Whether to include an element if the output of ``function`` is ``None``. + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + The resulting collection + """ + # Breaking condition + if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): + return function(data, *args, **kwargs) + + elem_type = type(data) + + # Recursively apply to collection items + if isinstance(data, Mapping): + out = [] + for k, v in data.items(): + v = apply_to_collection( + v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs + ) + if include_none or v is not None: + out.append((k, v)) + if isinstance(data, defaultdict): + return elem_type(data.default_factory, OrderedDict(out)) + return elem_type(OrderedDict(out)) + + is_namedtuple = _is_namedtuple(data) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple or is_sequence: + out = [] + for d in data: + v = apply_to_collection( + d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs + ) + if include_none or v is not None: + out.append(v) + return elem_type(*out) if is_namedtuple else elem_type(out) + + if _is_dataclass_instance(data): + # make a deepcopy of the data, + # but do not deepcopy mapped fields since the computation would + # be wasted on values that likely get immediately overwritten + fields = {} + memo = {} + for field in dataclasses.fields(data): + field_value = getattr(data, field.name) + fields[field.name] = (field_value, field.init) + memo[id(field_value)] = field_value + result = deepcopy(data, memo=memo) + # apply function to each field + for field_name, (field_value, field_init) in fields.items(): + if field_init: + v = apply_to_collection( + field_value, + dtype, + function, + *args, + wrong_dtype=wrong_dtype, + include_none=include_none, + **kwargs, + ) + if not field_init or (not include_none and v is None): # retain old value + v = getattr(data, field_name) + setattr(result, field_name, v) + return result + + # data is neither of dtype, nor a collection + return data + + +@contextmanager +def nullcontext(): + r""" + 用来实现一个什么 dummy 的 context 上下文环境; + """ + yield + + +def sub_column(string: str, c: int, c_size: int, title: str) -> str: + r""" + :param string: 要被截断的字符串 + :param c: 命令行列数 + :param c_size: instance或dataset field数 + :param title: 列名 + :return: 对一个过长的列进行截断的结果 + """ + avg = max(int(c / c_size / 2), len(title)) + string = str(string) + res = "" + counter = 0 + for char in string: + if ord(char) > 255: + counter += 2 + else: + counter += 1 + res += char + if counter > avg: + res = res + "..." + break + return res + + +def _is_iterable(value): + # 检查是否是iterable的, duck typing + try: + iter(value) + return True + except BaseException as e: + return False + + +def pretty_table_printer(dataset_or_ins) -> PrettyTable: + r""" + :param dataset_or_ins: 传入一个dataSet或者instance + ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) + +-----------+-----------+-----------------+ + | field_1 | field_2 | field_3 | + +-----------+-----------+-----------------+ + | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | + +-----------+-----------+-----------------+ + :return: 以 pretty table的形式返回根据terminal大小进行自动截断 + """ + x = PrettyTable() + try: + sz = os.get_terminal_size() + column = sz.columns + row = sz.lines + except OSError: + column = 144 + row = 11 + + if type(dataset_or_ins).__name__ == "DataSet": + x.field_names = list(dataset_or_ins.field_arrays.keys()) + c_size = len(x.field_names) + for ins in dataset_or_ins: + x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) + row -= 1 + if row < 0: + x.add_row(["..." for _ in range(c_size)]) + break + elif type(dataset_or_ins).__name__ == "Instance": + x.field_names = list(dataset_or_ins.fields.keys()) + c_size = len(x.field_names) + x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) + + else: + raise Exception("only accept DataSet and Instance") + x.align = "l" + + return x + + +class Option(dict): + r"""a dict can treat keys as attributes""" + + def __getattr__(self, item): + try: + return self.__getitem__(item) + except KeyError: + raise AttributeError(item) + + def __setattr__(self, key, value): + if key.startswith('__') and key.endswith('__'): + raise AttributeError(key) + self.__setitem__(key, value) + + def __delattr__(self, item): + try: + self.pop(item) + except KeyError: + raise AttributeError(item) + + def __getstate__(self): + return self + + def __setstate__(self, state): + self.update(state) + + +def indice_collate_wrapper(func): + """ + 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 + + :param func: 需要修饰的函数 + :return: + """ + + def wrapper(tuple_data): + indice, ins_list = [], [] + for idx, ins in tuple_data: + indice.append(idx) + ins_list.append(ins) + return indice, func(ins_list) + + return wrapper + + +_emitted_deprecation_warnings = set() + + +def deprecated(help_message: Optional[str] = None): + """Decorator to mark a function as deprecated. + + Args: + help_message (`Optional[str]`): An optional message to guide the user on how to + switch to non-deprecated usage of the library. + """ + + def decorator(deprecated_function: Callable): + global _emitted_deprecation_warnings + warning_msg = ( + ( + f"{deprecated_function.__name__} is deprecated and will be removed " + "in the next major version of datasets." + ) + + f" {help_message}" + if help_message + else "" + ) + + @wraps(deprecated_function) + def wrapper(*args, **kwargs): + func_hash = hash(deprecated_function) + if func_hash not in _emitted_deprecation_warnings: + warnings.warn(warning_msg, category=FutureWarning, stacklevel=2) + _emitted_deprecation_warnings.add(func_hash) + return deprecated_function(*args, **kwargs) + + wrapper._decorator_name_ = "deprecated" + return wrapper + + return decorator + + +def seq_len_to_mask(seq_len, max_len=None): + r""" + + 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 + 转变 1-d seq_len到2-d mask. + + .. code-block:: + + >>> seq_len = torch.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len) + >>> print(mask.size()) + torch.Size([14, 15]) + >>> seq_len = np.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len) + >>> print(mask.shape) + (14, 15) + >>> seq_len = torch.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len, max_len=100) + >>>print(mask.size()) + torch.Size([14, 100]) + + :param np.ndarray,torch.LongTensor seq_len: shape将是(B,) + :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 + 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 + :return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 + """ + if isinstance(seq_len, np.ndarray): + assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." + max_len = int(max_len) if max_len else int(seq_len.max()) + broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) + mask = broad_cast_seq_len < seq_len.reshape(-1, 1) + + else: + raise TypeError("Only support 1-d numpy.ndarray.") + + return mask + + +def wait_to_success(fn, no=False): + while True: + sleep(0.01) + if (no and not fn()) or (not no and fn()): + break + + +# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 +# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; +def synchronize_safe_rm(path: Optional[Union[str, Path]]): + if path is None: + return + if isinstance(path, str): + path = Path(path) + if not path.exists(): + return + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: + _recursive_rm(path) + wait_to_success(path.exists, no=True) + + +def _recursive_rm(path: Path): + if path.is_file() or path.is_symlink(): + if path.exists(): + try: + path.unlink() + except Exception: + pass + return + for sub_path in list(path.iterdir()): + _recursive_rm(sub_path) + path.rmdir() + + +def synchronize_mkdir(path: Optional[Union[str, Path]]): + """ + 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; + """ + if path is None: + return + if isinstance(path, str): + path = Path(path) + + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: + path.mkdir(parents=True, exist_ok=True) + + wait_to_success(path.exists) + + +