@@ -0,0 +1,43 @@ | |||
__all__ = [ | |||
'cache_results', | |||
'is_jittor_dataset', | |||
'jittor_collate_wraps', | |||
'paddle_to', | |||
'paddle_move_data_to_device', | |||
'get_paddle_device_id', | |||
'get_paddle_gpu_str', | |||
'is_in_paddle_dist', | |||
'is_in_fnlp_paddle_dist', | |||
'is_in_paddle_launch_dist', | |||
'f_rich_progress', | |||
'torch_paddle_move_data_to_device', | |||
'torch_move_data_to_device', | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
'match_and_substitute_params', | |||
'apply_to_collection', | |||
'nullcontext', | |||
'pretty_table_printer', | |||
'Option', | |||
'indice_collate_wrapper', | |||
'deprecated', | |||
'seq_len_to_mask', | |||
'synchronize_safe_rm', | |||
'synchronize_mkdir' | |||
] | |||
from .cache_results import cache_results | |||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | |||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||
from .rich_progress import f_rich_progress | |||
from .torch_paddle_utils import torch_paddle_move_data_to_device | |||
from .torch_utils import torch_move_data_to_device | |||
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||
@@ -0,0 +1,310 @@ | |||
from datetime import datetime | |||
import hashlib | |||
import _pickle | |||
import functools | |||
import os | |||
from typing import Callable, List, Any, Optional | |||
import inspect | |||
import ast | |||
from collections import deque | |||
__all__ = [ | |||
'cache_results' | |||
] | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.core.log.highlighter import ColorHighlighter | |||
class FuncCallVisitor(ast.NodeVisitor): | |||
# credit to https://gist.github.com/jargnar/0946ab1d985e2b4ab776 | |||
def __init__(self): | |||
self._name = deque() | |||
@property | |||
def name(self): | |||
return '.'.join(self._name) | |||
@name.deleter | |||
def name(self): | |||
self._name.clear() | |||
def visit_Name(self, node): | |||
self._name.appendleft(node.id) | |||
def visit_Attribute(self, node): | |||
try: | |||
self._name.appendleft(node.attr) | |||
self._name.appendleft(node.value.id) | |||
except AttributeError: | |||
self.generic_visit(node) | |||
def get_func_calls(tree): | |||
func_calls = [] | |||
for node in ast.walk(tree): | |||
if isinstance(node, ast.Call): | |||
callvisitor = FuncCallVisitor() | |||
callvisitor.visit(node.func) | |||
func_calls.append(callvisitor.name) | |||
if isinstance(node, ast.FunctionDef): | |||
if not (node is tree): | |||
func_calls.extend(get_func_calls(node)) | |||
return func_calls | |||
def truncate_start_blanks(source:str)->str: | |||
""" | |||
将source中的每一行按照第一行的indent删掉多余的空格 | |||
:param source: | |||
:return: | |||
""" | |||
lines = source.split('\n') | |||
num_blank = 0 | |||
# get the top blank line | |||
for line in lines: | |||
if line: | |||
num_blank = len(line) - len(line.lstrip()) | |||
new_lines = [] | |||
for line in lines: | |||
i = -1 | |||
for i in range(min(len(line), num_blank)): | |||
if line[i] == ' ': | |||
continue | |||
else: | |||
break | |||
line = line[i:] | |||
new_lines.append(line) | |||
return '\n'.join(new_lines) | |||
def _get_func_and_its_called_func_source_code(func) -> List[str]: | |||
""" | |||
给定一个func,返回在这个函数里面用到的所有函数的源码。 | |||
:param callable func: | |||
:return: | |||
""" | |||
last_frame = inspect.currentframe().f_back.f_back.f_back | |||
last_frame_f_local = last_frame.f_locals | |||
last_frame_loc = {} | |||
if 'loc' in last_frame_f_local: | |||
last_frame_loc = last_frame_f_local['loc'] | |||
func_calls = list(set(get_func_calls(ast.parse(truncate_start_blanks(inspect.getsource(func)))))) | |||
func_calls.sort() | |||
sources = [] | |||
for _func_name in func_calls: | |||
try: | |||
if _func_name == 'cache_results': # ignore the decorator | |||
continue | |||
if '.' in _func_name: | |||
_funcs = _func_name.split('.') | |||
else: | |||
_funcs = [_func_name] | |||
if _funcs[0] in last_frame_f_local or _funcs[0] in last_frame_loc: | |||
tmp = _funcs.pop(0) | |||
variable = last_frame_f_local.get(tmp, last_frame_loc.get(tmp)) | |||
while len(_funcs) or variable is not None: | |||
if hasattr(variable, '__class__') and not inspect.isbuiltin(variable.__class__): | |||
try: | |||
sources.append(inspect.getsource(variable.__class__)) | |||
except TypeError: | |||
pass | |||
if callable(variable) or inspect.isclass(variable): | |||
sources.append(inspect.getsource(variable)) | |||
if len(_funcs): | |||
tmp = _funcs.pop(0) | |||
if hasattr(variable, tmp): | |||
variable = getattr(variable, tmp) | |||
else: | |||
break | |||
else: | |||
variable = None | |||
except: | |||
# some failure | |||
pass | |||
del last_frame # | |||
sources.append(inspect.getsource(func)) | |||
return sources | |||
def _prepare_cache_filepath(filepath:str): | |||
r""" | |||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||
:param filepath: str. | |||
:return: None, if not, this function will raise error | |||
""" | |||
_cache_filepath = os.path.abspath(filepath) | |||
if os.path.isdir(_cache_filepath): | |||
raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||
cache_dir = os.path.dirname(_cache_filepath) | |||
if not os.path.exists(cache_dir): | |||
os.makedirs(cache_dir, exist_ok=True) | |||
class Hasher: | |||
def __init__(self): | |||
self.m = hashlib.sha1() | |||
def update(self, value: Any) -> None: | |||
if isinstance(value, str): | |||
value = [value] | |||
for x in value: | |||
self.m.update(x.encode('utf8')) | |||
def hexdigest(self) -> str: | |||
return self.m.hexdigest() | |||
def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = None): | |||
if fn_kwargs is None: | |||
fn_kwargs = {} | |||
hasher = Hasher() | |||
try: | |||
sources = _get_func_and_its_called_func_source_code(fn) | |||
hasher.update(sources) | |||
except: | |||
return "can't be hashed" | |||
for key in sorted(fn_kwargs): | |||
hasher.update(key) | |||
try: | |||
hasher.update(fn_kwargs[key]) | |||
except: | |||
pass | |||
return hasher.hexdigest() | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||
r""" | |||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | |||
import time | |||
import numpy as np | |||
from fastNLP import cache_results | |||
@cache_results('cache.pkl') | |||
def process_data(): | |||
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | |||
time.sleep(1) | |||
return np.random.randint(10, size=(5,)) | |||
start_time = time.time() | |||
print("res =",process_data()) | |||
print(time.time() - start_time) | |||
start_time = time.time() | |||
print("res =",process_data()) | |||
print(time.time() - start_time) | |||
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | |||
# Save cache to cache.pkl. | |||
# res = [5 4 9 1 8] | |||
# 1.0042750835418701 | |||
# Read cache from cache.pkl. | |||
# res = [5 4 9 1 8] | |||
# 0.0040721893310546875 | |||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: | |||
# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | |||
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' | |||
上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | |||
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | |||
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: | |||
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 | |||
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | |||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | |||
函数调用的时候传入_cache_fp这个参数。 | |||
:param bool _refresh: 是否重新生成cache。 | |||
:param int _verbose: 是否打印cache的信息。 | |||
:param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 | |||
与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 | |||
该修改对结果有影响,但无法做出warning。 | |||
:return: | |||
""" | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'): | |||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
fn_param = kwargs.copy() | |||
if args: | |||
params = [p.name for p in inspect.signature(func).parameters.values()] | |||
fn_param.update(zip(params, args)) | |||
if '_cache_fp' in kwargs: | |||
cache_filepath = kwargs.pop('_cache_fp') | |||
assert isinstance(cache_filepath, str), "_cache_fp can only be str." | |||
else: | |||
cache_filepath = _cache_fp | |||
if '_refresh' in kwargs: | |||
refresh = kwargs.pop('_refresh') | |||
assert isinstance(refresh, bool), "_refresh can only be bool." | |||
else: | |||
refresh = _refresh | |||
if '_verbose' in kwargs: | |||
verbose = kwargs.pop('_verbose') | |||
assert isinstance(verbose, int), "_verbose can only be integer." | |||
else: | |||
verbose = _verbose | |||
if '_check_hash' in kwargs: | |||
check_hash = kwargs.pop('_check_hash') | |||
else: | |||
check_hash = _check_hash | |||
refresh_flag = True | |||
new_hash_code = None | |||
if check_hash: | |||
new_hash_code = cal_fn_hash_code(func, fn_param) | |||
if cache_filepath is not None and refresh is False: | |||
# load data | |||
if os.path.exists(cache_filepath): | |||
cache_filepath = os.path.abspath(cache_filepath) | |||
with open(cache_filepath, 'rb') as f: | |||
results = _pickle.load(f) | |||
old_hash_code = results['hash'] | |||
save_time = results['save_time'] | |||
results = results['results'] | |||
if verbose == 1: | |||
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | |||
if check_hash and old_hash_code != new_hash_code: | |||
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | |||
f"difference may caused by the sourcecode change of the functions by this function.", | |||
extra={'highlighter': ColorHighlighter('red')}) | |||
refresh_flag = False | |||
if refresh_flag: | |||
if new_hash_code is None: | |||
new_hash_code = cal_fn_hash_code(func, fn_param) | |||
results = func(*args, **kwargs) | |||
if cache_filepath is not None: | |||
if results is None: | |||
raise RuntimeError("The return value is None. Cannot save None results.") | |||
cache_filepath = os.path.abspath(cache_filepath) | |||
_prepare_cache_filepath(cache_filepath) | |||
_dict = { | |||
'results': results, | |||
'hash': new_hash_code, | |||
'save_time': datetime.now(), | |||
} | |||
with open(cache_filepath, 'wb') as f: | |||
_pickle.dump(_dict, f) | |||
logger.info("Save cache to {}.".format(cache_filepath)) | |||
return results | |||
return wrapper | |||
return wrapper_ |
@@ -0,0 +1,4 @@ | |||
class DummyClass: | |||
pass |
@@ -0,0 +1,51 @@ | |||
__all__ = [ | |||
'is_jittor_dataset', | |||
'jittor_collate_wraps' | |||
] | |||
from collections.abc import Mapping, Callable | |||
from functools import wraps | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from fastNLP.core.dataset import Instance | |||
def is_jittor_dataset(dataset) -> bool: | |||
try: | |||
if isinstance(dataset, jt.dataset.Dataset): | |||
return True | |||
else: | |||
return False | |||
except BaseException: | |||
return False | |||
def jittor_collate_wraps(func, auto_collator: Callable): | |||
""" | |||
对jittor的collate_fn进行wrap封装, 如果数据集为mapping类型,那么采用auto_collator,否则还是采用jittor自带的collate_batch | |||
:param func: | |||
:param auto_collator: | |||
:return: | |||
""" | |||
@wraps(func) | |||
def wrapper(batch): | |||
if isinstance(batch[0], Instance): | |||
if auto_collator is not None: | |||
result = auto_collator(batch) | |||
else: | |||
raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!") | |||
elif isinstance(batch[0], Mapping): | |||
if auto_collator is not None: | |||
result = auto_collator(batch) | |||
else: | |||
result = func(batch) | |||
else: | |||
result = func(batch) | |||
return result | |||
return wrapper |
@@ -0,0 +1,89 @@ | |||
__all__ = [ | |||
"paddle_to", | |||
"paddle_move_data_to_device", | |||
"get_paddle_gpu_str", | |||
"get_paddle_device_id", | |||
"is_in_paddle_dist", | |||
"is_in_fnlp_paddle_dist", | |||
"is_in_paddle_launch_dist", | |||
] | |||
import os | |||
from typing import Any, Optional, Union | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from .utils import apply_to_collection | |||
def paddle_to(data, device: Union[str, int]): | |||
if device == "cpu": | |||
return data.cpu() | |||
else: | |||
return data.cuda(get_paddle_device_id(device)) | |||
def get_paddle_gpu_str(device: Union[str, int]): | |||
""" | |||
获得 `gpu:x` 类型的设备名 | |||
""" | |||
if isinstance(device, str): | |||
return device.replace("cuda", "gpu") | |||
return f"gpu:{device}" | |||
def get_paddle_device_id(device: Union[str, int]): | |||
""" | |||
获得 gpu 的设备id,注意不要传入 `cpu` 。 | |||
""" | |||
if isinstance(device, int): | |||
return device | |||
if device == "cpu": | |||
raise ValueError("Cannot get device id from `cpu`.") | |||
return paddle.device._convert_to_place(device).get_device_id() | |||
def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | |||
data_device: Optional[str] = None) -> Any: | |||
r""" | |||
将数据集合传输到给定设备。只有paddle.Tensor对象会被传输到设备中,其余保持不变 | |||
:param batch: | |||
:param device: `cpu`, `gpu` or `gpu:x` | |||
:param data_device: | |||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
""" | |||
if device is None: | |||
if data_device is not None: | |||
device = data_device | |||
else: | |||
return batch | |||
def batch_to(data: Any) -> Any: | |||
return paddle_to(data, device) | |||
return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | |||
def is_in_paddle_dist(): | |||
""" | |||
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | |||
""" | |||
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | |||
def is_in_fnlp_paddle_dist(): | |||
""" | |||
判断是否处于 FastNLP 拉起的分布式进程中 | |||
""" | |||
return FASTNLP_DISTRIBUTED_CHECK in os.environ | |||
def is_in_paddle_launch_dist(): | |||
""" | |||
判断是否处于 launch 启动的分布式进程中 | |||
""" | |||
return 'PADDLE_RANK_IN_NODE' in os.environ and \ | |||
'FLAGS_selected_gpus' in os.environ and \ | |||
FASTNLP_DISTRIBUTED_CHECK not in os.environ |
@@ -0,0 +1,214 @@ | |||
""" | |||
该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能 | |||
不冲突 | |||
""" | |||
import sys | |||
from typing import Any, Union, Optional | |||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | |||
__all__ = [ | |||
'f_rich_progress' | |||
] | |||
from fastNLP.envs import get_global_rank | |||
class Singleton(type): | |||
_instances = {} | |||
def __call__(cls, *args, **kwargs): | |||
if cls not in cls._instances: | |||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||
return cls._instances[cls] | |||
# 如果不打印的时候,使得整个 progress 没有任何意义 | |||
class DummyFRichProgress: | |||
def __getattr__(self, item): | |||
return DummyFRichProgress() | |||
def __call__(self, *args, **kwargs): | |||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | |||
return None | |||
class FRichProgress(Progress, metaclass=Singleton): | |||
""" | |||
fastNLP 使用的 progress bar ,新增了 new_progress 函数,通过此函数即可定制 fastNLP 中所有 progress 的样式。 | |||
""" | |||
def new_progess(self, *columns: Union[str, ProgressColumn], | |||
console: Optional[Console] = None, | |||
auto_refresh: bool = True, | |||
refresh_per_second: float = 10, | |||
speed_estimate_period: float = 30.0, | |||
transient: bool = True, | |||
redirect_stdout: bool = True, | |||
redirect_stderr: bool = True, | |||
get_time: Optional[GetTimeCallable] = None, | |||
disable: bool = False, | |||
expand: bool = False): | |||
""" | |||
重新初始化一个rich bar。如果columns不传入,则继续使用之前的column内容。 | |||
:param progress: | |||
:return: | |||
""" | |||
for task_id in self.task_ids: # 首先移除已有的 | |||
self.remove_task(task_id) | |||
assert ( | |||
refresh_per_second is None or refresh_per_second > 0 | |||
), "refresh_per_second must be > 0" | |||
# stop previous columns | |||
self.stop() | |||
# do not change these variables | |||
# self._lock = RLock() | |||
# self._tasks: Dict[TaskID, Task] = {} | |||
# self._task_index: TaskID = TaskID(0) | |||
if len(columns) != 0: | |||
self.columns = columns | |||
self.speed_estimate_period = speed_estimate_period | |||
self.disable = disable | |||
self.expand = expand | |||
self.live = Live( | |||
console=console or get_console(), | |||
auto_refresh=auto_refresh, | |||
refresh_per_second=refresh_per_second, | |||
transient=transient, | |||
redirect_stdout=redirect_stdout, | |||
redirect_stderr=redirect_stderr, | |||
get_renderable=self.get_renderable, | |||
) | |||
self.get_time = get_time or self.console.get_time | |||
self.print = self.console.print | |||
self.log = self.console.log | |||
# start new | |||
self.start() | |||
return self | |||
def set_transient(self, transient: bool = True): | |||
""" | |||
设置是否在bar运行结束之后不关闭 | |||
:param transient: | |||
:return: | |||
""" | |||
self.new_progess(transient=transient) | |||
def set_disable(self, flag: bool = True): | |||
""" | |||
设置当前 progress bar 的状态,如果为 True ,则不会显示进度条了。 | |||
:param flag: | |||
:return: | |||
""" | |||
self.disable = flag | |||
def add_task( | |||
self, | |||
description: str, | |||
start: bool = True, | |||
total: float = 100.0, | |||
completed: int = 0, | |||
visible: bool = True, | |||
**fields: Any, | |||
) -> TaskID: | |||
if self.live._started is False: | |||
self.start() | |||
post_desc = fields.pop('post_desc', '') | |||
return super().add_task(description=description, | |||
start=start, | |||
total=total, | |||
completed=completed, | |||
visible=visible, | |||
post_desc=post_desc, | |||
**fields) | |||
def stop_task(self, task_id: TaskID) -> None: | |||
if task_id in self._tasks: | |||
super().stop_task(task_id) | |||
def remove_task(self, task_id: TaskID) -> None: | |||
if task_id in self._tasks: | |||
super().remove_task(task_id) | |||
def destroy_task(self, task_id: TaskID): | |||
if task_id in self._tasks: | |||
super().stop_task(task_id) | |||
super().remove_task(task_id) | |||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
f_rich_progress = FRichProgress().new_progess( | |||
"[progress.description]{task.description}", | |||
"[progress.percentage]{task.percentage:>3.0f}%", | |||
BarColumn(), | |||
TimeElapsedColumn(), | |||
"/", | |||
TimeRemainingColumn(), | |||
TextColumn("{task.fields[post_desc]}", justify="right"), | |||
transient=True, | |||
disable=False, | |||
speed_estimate_period=10 | |||
) | |||
else: | |||
f_rich_progress = DummyFRichProgress() | |||
if __name__ == '__main__': | |||
f = DummyFRichProgress() | |||
f.console.print('xxx') | |||
f.console.print.print('xxx') | |||
# 测试创建 | |||
import time | |||
n_steps = 10 | |||
task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||
for i in range(n_steps): | |||
f_rich_progress.update(task_id, description=f'test:{i}', advance=1, refresh=True) | |||
print(f"test:{i}") | |||
time.sleep(0.3) | |||
f_rich_progress.remove_task(task_id) | |||
# 测试一下 inner/outer | |||
n_steps = 5 | |||
f_rich_progress.start() | |||
outer_task_id = f_rich_progress.add_task(description='Outer:', total=n_steps) | |||
inner_task_id = f_rich_progress.add_task(description='Inner:', total=n_steps) | |||
for i in range(n_steps): | |||
f_rich_progress.reset(inner_task_id, total=n_steps) | |||
f_rich_progress.update(outer_task_id, description=f'Outer:{i}', advance=1, refresh=True) | |||
for j in range(n_steps): | |||
f_rich_progress.update(inner_task_id, description=f'Inner:{j}', advance=1, refresh=True, | |||
post_desc='Loss: 0.334332323') | |||
print(f"Outer:{i}, Inner:{j}") | |||
time.sleep(0.3) | |||
# 测试一下修改bar | |||
f_rich_progress = FRichProgress().new_progess( | |||
BarColumn(), | |||
"[progress.description]{task.description}", | |||
"[progress.percentage]{task.percentage:>3.0f}%", | |||
TimeElapsedColumn(), | |||
transient=True) | |||
n_steps = 10 | |||
task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||
for i in range(n_steps): | |||
f_rich_progress.update(task_id, description=f'test:{i}', advance=1) | |||
print(f"test:{i}") | |||
time.sleep(0.3) | |||
f_rich_progress.remove_task(task_id) | |||
f_rich_progress.stop() |
@@ -0,0 +1,49 @@ | |||
from typing import Any, Optional | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
__all__ = [ | |||
"torch_paddle_move_data_to_device", | |||
] | |||
from .utils import apply_to_collection | |||
from .paddle_utils import paddle_to | |||
def torch_paddle_move_data_to_device(batch: Any, device: Optional[str] = None, non_blocking: Optional[bool] = True, | |||
data_device: Optional[str] = None) -> Any: | |||
r""" | |||
将数据集合传输到给定设备。只有paddle.Tensor和torch.Tensor对象会被传输到设备中,其余保持不变 | |||
:param batch: | |||
:param device: | |||
:param non_blocking: | |||
:param data_device: | |||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
""" | |||
if device is None: | |||
if data_device is not None: | |||
device = data_device | |||
else: | |||
return batch | |||
torch_device = device.replace("gpu", "cuda") | |||
paddle_device = device.replace("cuda", "gpu") | |||
def batch_to(data: Any) -> Any: | |||
if isinstance(data, torch.Tensor): | |||
data = data.to(torch_device, non_blocking=non_blocking) | |||
elif isinstance(data, paddle.Tensor): | |||
data = paddle_to(data, paddle_device) | |||
return data | |||
return apply_to_collection(batch, dtype=(paddle.Tensor, torch.Tensor), function=batch_to) |
@@ -0,0 +1,63 @@ | |||
from abc import ABC | |||
from typing import Any, Union, Optional | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
__all__ = [ | |||
'torch_move_data_to_device' | |||
] | |||
from .utils import apply_to_collection | |||
class TorchTransferableDataType(ABC): | |||
""" | |||
A custom type for data that can be moved to a torch device via `.to(...)`. | |||
Example: | |||
>>> isinstance(dict, TorchTransferableDataType) | |||
False | |||
>>> isinstance(torch.rand(2, 3), TorchTransferableDataType) | |||
True | |||
>>> class CustomObject: | |||
... def __init__(self): | |||
... self.x = torch.rand(2, 2) | |||
... def to(self, device): | |||
... self.x = self.x.to(device) | |||
... return self | |||
>>> isinstance(CustomObject(), TorchTransferableDataType) | |||
True | |||
""" | |||
@classmethod | |||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
if cls is TorchTransferableDataType: | |||
to = getattr(subclass, "to", None) | |||
return callable(to) | |||
return NotImplemented | |||
def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, | |||
non_blocking: Optional[bool] = True) -> Any: | |||
r""" | |||
将数据集合传输到给定设备。任何定义方法 “to(device)” 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||
:param batch: 应当迁移的数据; | |||
:param device: 数据应当迁移到的设备;当该参数的值为 None 时,表示迁移数据的操作由用户自己完成,我们不需要经管; | |||
:param non_blocking: pytorch 的迁移数据方法 `to` 的参数; | |||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
""" | |||
if device is None: | |||
return batch | |||
def batch_to(data: Any) -> Any: | |||
kwargs = dict(non_blocking=non_blocking) if isinstance(data, torch.Tensor) else {} | |||
data_output = data.to(device, **kwargs) | |||
if data_output is not None: | |||
return data_output | |||
# user wrongly implemented the `TransferableDataType` and forgot to return `self`. | |||
return data | |||
dtype = TorchTransferableDataType | |||
return apply_to_collection(batch, dtype=dtype, function=batch_to) |
@@ -0,0 +1,591 @@ | |||
import inspect | |||
from inspect import Parameter | |||
import dataclasses | |||
import warnings | |||
from dataclasses import is_dataclass | |||
from copy import deepcopy | |||
from collections import defaultdict, OrderedDict | |||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional | |||
from typing import Tuple, Optional | |||
from time import sleep | |||
try: | |||
from typing import Literal, Final | |||
except ImportError: | |||
from typing_extensions import Literal, Final | |||
import os | |||
from contextlib import contextmanager | |||
from functools import wraps | |||
from prettytable import PrettyTable | |||
import numpy as np | |||
from pathlib import Path | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||
__all__ = [ | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
'match_and_substitute_params', | |||
'apply_to_collection', | |||
'nullcontext', | |||
'pretty_table_printer', | |||
'Option', | |||
'indice_collate_wrapper', | |||
'deprecated', | |||
'seq_len_to_mask', | |||
'synchronize_safe_rm', | |||
'synchronize_mkdir' | |||
] | |||
def get_fn_arg_names(fn: Callable) -> List[str]: | |||
r""" | |||
返回一个函数的所有参数的名字; | |||
:param fn: 需要查询的函数; | |||
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | |||
""" | |||
return list(inspect.signature(fn).parameters) | |||
def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||
r""" | |||
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||
:param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||
:param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||
:return: bool,表示传入的 `batch_step_fn` 是否正确; | |||
""" | |||
if fn is None: | |||
return True | |||
if not callable(fn): | |||
return False | |||
else: | |||
params = inspect.signature(fn).parameters | |||
not_default_params = {} | |||
for _name, _param in params.items(): | |||
if _param.default == Parameter.empty: | |||
not_default_params[_name] = _param | |||
return len(not_default_params) == param_num | |||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
r""" | |||
1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; | |||
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | |||
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; | |||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||
参数值后,再传给 `fn` 进行实际的运算; | |||
:param mapping: 一个字典,用来更改其前面的字典的键值; | |||
:return: 返回 `fn` 运行的结果; | |||
Examples: | |||
>>> # 1 | |||
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); | |||
>>> batch = {"x": 20, "y": 1} | |||
>>> output = {"pred": 0} | |||
>>> acc = auto_param_call(loss_fn, batch, output) | |||
>>> # 2 | |||
>>> def test_fn(x, y, a, b=10): | |||
>>> return x + y + a + b | |||
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | |||
""" | |||
if signature_fn is not None: | |||
if not callable(signature_fn): | |||
raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | |||
_need_params = OrderedDict(inspect.signature(signature_fn).parameters) | |||
else: | |||
_need_params = OrderedDict(inspect.signature(fn).parameters) | |||
_kwargs = None | |||
for _name, _param in _need_params.items(): | |||
if _param.kind == Parameter.VAR_POSITIONAL: | |||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||
if _param.kind == Parameter.VAR_KEYWORD: | |||
_kwargs = (_name, _param) | |||
if _kwargs is not None: | |||
_need_params.pop(_kwargs[0]) | |||
_default_params = {} | |||
for _name, _param in _need_params.items(): | |||
if _param.default != Parameter.empty: | |||
_default_params[_name] = _param.default | |||
if mapping is not None: | |||
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
_has_params = {} | |||
duplicate_names = [] | |||
for arg in args: | |||
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||
for _name, _value in arg.items(): | |||
if mapping is not None and _name in mapping: | |||
_name = mapping[_name] | |||
if _name not in _has_params: | |||
if _kwargs is not None or _name in _need_params: | |||
_has_params[_name] = _value | |||
# 同一参数对象在两个输入的资源中都出现,造成混淆; | |||
elif _name in _need_params and not (_has_params[_name] is _value): | |||
duplicate_names.append(_name) | |||
if duplicate_names: | |||
raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||
# 将具有默认值但是没有被输入修改过的参数值传进去; | |||
for _name, _value in _default_params.items(): | |||
if _name not in _has_params: | |||
_has_params[_name] = _value | |||
if len(_has_params)<len(_need_params): | |||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | |||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||
return fn(**_has_params) | |||
def check_user_specific_params(user_params: Dict, fn: Callable): | |||
""" | |||
该函数使用用户的输入来对指定函数的参数进行赋值; | |||
主要用于一些用户无法直接调用函数的情况; | |||
该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误; | |||
:param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值; | |||
:param fn: 会被调用的函数; | |||
:return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值; | |||
""" | |||
fn_arg_names = get_fn_arg_names(fn) | |||
for arg_name, arg_value in user_params.items(): | |||
if arg_name not in fn_arg_names: | |||
logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||
return user_params | |||
def dataclass_to_dict(data: "dataclass") -> Dict: | |||
if not is_dataclass(data): | |||
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | |||
_dict = dict() | |||
for _key in data.__dataclass_fields__: | |||
_dict[_key] = getattr(data, _key) | |||
return _dict | |||
def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | |||
r""" | |||
用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; | |||
该函数应用于 `input_mapping` 和 `output_mapping`; | |||
对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; | |||
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; | |||
转换的逻辑按优先级依次为: | |||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; | |||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; | |||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; | |||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; | |||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 | |||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 | |||
:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 | |||
:param data: 需要被转换的对象; | |||
:return: 返回转换好的结果; | |||
""" | |||
if mapping is None: | |||
return data | |||
if callable(mapping): | |||
# 注意我们在 `Trainer.extract_loss_from_outputs` 函数里会检查 outputs 的输出,outputs 的类型目前只支持 `Dict` 和 `dataclass`; | |||
return mapping(data) | |||
if not isinstance(mapping, Dict): | |||
raise ValueError( | |||
f"Parameter `mapping` should be of type `Dict` or `Callable`, not `{type(mapping)}`. This is caused" | |||
f"by your `input_mapping` or `output_mapping` parameter in your `Trainer` or `Evaluator`.") | |||
if not isinstance(data, Dict) and not is_dataclass(data) and not isinstance(data, Sequence): | |||
raise ValueError("Parameter `data` should be type `Dict` or `dataclass` when the other parameter `mapping` is " | |||
"type `Dict`.") | |||
# 如果 `data` 是一个 dataclass,那么先将其转换为一个 `Dict`; | |||
if is_dataclass(data): | |||
data = dataclass_to_dict(data) | |||
# 如果 `data` 是一个 List,那么我们同样先将其转换为一个 `Dict`,为 {"_0": list[0], "_1": list[1], ...}; | |||
elif isinstance(data, Sequence): | |||
data = {"_" + str(i): data[i] for i in range(len(data))} | |||
_new_data = {} | |||
for _name, _value in data.items(): | |||
if _name in mapping: | |||
_new_data[mapping[_name]] = _value | |||
else: | |||
_new_data[_name] = _value | |||
return _new_data | |||
def _is_namedtuple(obj: object) -> bool: | |||
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 | |||
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") | |||
def _is_dataclass_instance(obj: object) -> bool: | |||
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions | |||
return dataclasses.is_dataclass(obj) and not isinstance(obj, type) | |||
def apply_to_collection( | |||
data: Any, | |||
dtype: Union[type, Any, Tuple[Union[type, Any]]], | |||
function: Callable, | |||
*args: Any, | |||
wrong_dtype: Optional[Union[type, Tuple[type]]] = None, | |||
include_none: bool = True, | |||
**kwargs: Any, | |||
) -> Any: | |||
"""将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 | |||
this function credit to: https://github.com/PyTorchLightning/pytorch-lightning | |||
Args: | |||
data: the collection to apply the function to | |||
dtype: the given function will be applied to all elements of this dtype | |||
function: the function to apply | |||
*args: positional arguments (will be forwarded to calls of ``function``) | |||
wrong_dtype: the given function won't be applied if this type is specified and the given collections | |||
is of the ``wrong_dtype`` even if it is of type ``dtype`` | |||
include_none: Whether to include an element if the output of ``function`` is ``None``. | |||
**kwargs: keyword arguments (will be forwarded to calls of ``function``) | |||
Returns: | |||
The resulting collection | |||
""" | |||
# Breaking condition | |||
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): | |||
return function(data, *args, **kwargs) | |||
elem_type = type(data) | |||
# Recursively apply to collection items | |||
if isinstance(data, Mapping): | |||
out = [] | |||
for k, v in data.items(): | |||
v = apply_to_collection( | |||
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||
) | |||
if include_none or v is not None: | |||
out.append((k, v)) | |||
if isinstance(data, defaultdict): | |||
return elem_type(data.default_factory, OrderedDict(out)) | |||
return elem_type(OrderedDict(out)) | |||
is_namedtuple = _is_namedtuple(data) | |||
is_sequence = isinstance(data, Sequence) and not isinstance(data, str) | |||
if is_namedtuple or is_sequence: | |||
out = [] | |||
for d in data: | |||
v = apply_to_collection( | |||
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||
) | |||
if include_none or v is not None: | |||
out.append(v) | |||
return elem_type(*out) if is_namedtuple else elem_type(out) | |||
if _is_dataclass_instance(data): | |||
# make a deepcopy of the data, | |||
# but do not deepcopy mapped fields since the computation would | |||
# be wasted on values that likely get immediately overwritten | |||
fields = {} | |||
memo = {} | |||
for field in dataclasses.fields(data): | |||
field_value = getattr(data, field.name) | |||
fields[field.name] = (field_value, field.init) | |||
memo[id(field_value)] = field_value | |||
result = deepcopy(data, memo=memo) | |||
# apply function to each field | |||
for field_name, (field_value, field_init) in fields.items(): | |||
if field_init: | |||
v = apply_to_collection( | |||
field_value, | |||
dtype, | |||
function, | |||
*args, | |||
wrong_dtype=wrong_dtype, | |||
include_none=include_none, | |||
**kwargs, | |||
) | |||
if not field_init or (not include_none and v is None): # retain old value | |||
v = getattr(data, field_name) | |||
setattr(result, field_name, v) | |||
return result | |||
# data is neither of dtype, nor a collection | |||
return data | |||
@contextmanager | |||
def nullcontext(): | |||
r""" | |||
用来实现一个什么 dummy 的 context 上下文环境; | |||
""" | |||
yield | |||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
r""" | |||
:param string: 要被截断的字符串 | |||
:param c: 命令行列数 | |||
:param c_size: instance或dataset field数 | |||
:param title: 列名 | |||
:return: 对一个过长的列进行截断的结果 | |||
""" | |||
avg = max(int(c / c_size / 2), len(title)) | |||
string = str(string) | |||
res = "" | |||
counter = 0 | |||
for char in string: | |||
if ord(char) > 255: | |||
counter += 2 | |||
else: | |||
counter += 1 | |||
res += char | |||
if counter > avg: | |||
res = res + "..." | |||
break | |||
return res | |||
def _is_iterable(value): | |||
# 检查是否是iterable的, duck typing | |||
try: | |||
iter(value) | |||
return True | |||
except BaseException as e: | |||
return False | |||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
r""" | |||
:param dataset_or_ins: 传入一个dataSet或者instance | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||
+-----------+-----------+-----------------+ | |||
| field_1 | field_2 | field_3 | | |||
+-----------+-----------+-----------------+ | |||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||
+-----------+-----------+-----------------+ | |||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||
""" | |||
x = PrettyTable() | |||
try: | |||
sz = os.get_terminal_size() | |||
column = sz.columns | |||
row = sz.lines | |||
except OSError: | |||
column = 144 | |||
row = 11 | |||
if type(dataset_or_ins).__name__ == "DataSet": | |||
x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||
c_size = len(x.field_names) | |||
for ins in dataset_or_ins: | |||
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||
row -= 1 | |||
if row < 0: | |||
x.add_row(["..." for _ in range(c_size)]) | |||
break | |||
elif type(dataset_or_ins).__name__ == "Instance": | |||
x.field_names = list(dataset_or_ins.fields.keys()) | |||
c_size = len(x.field_names) | |||
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||
else: | |||
raise Exception("only accept DataSet and Instance") | |||
x.align = "l" | |||
return x | |||
class Option(dict): | |||
r"""a dict can treat keys as attributes""" | |||
def __getattr__(self, item): | |||
try: | |||
return self.__getitem__(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __setattr__(self, key, value): | |||
if key.startswith('__') and key.endswith('__'): | |||
raise AttributeError(key) | |||
self.__setitem__(key, value) | |||
def __delattr__(self, item): | |||
try: | |||
self.pop(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __getstate__(self): | |||
return self | |||
def __setstate__(self, state): | |||
self.update(state) | |||
def indice_collate_wrapper(func): | |||
""" | |||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||
:param func: 需要修饰的函数 | |||
:return: | |||
""" | |||
def wrapper(tuple_data): | |||
indice, ins_list = [], [] | |||
for idx, ins in tuple_data: | |||
indice.append(idx) | |||
ins_list.append(ins) | |||
return indice, func(ins_list) | |||
return wrapper | |||
_emitted_deprecation_warnings = set() | |||
def deprecated(help_message: Optional[str] = None): | |||
"""Decorator to mark a function as deprecated. | |||
Args: | |||
help_message (`Optional[str]`): An optional message to guide the user on how to | |||
switch to non-deprecated usage of the library. | |||
""" | |||
def decorator(deprecated_function: Callable): | |||
global _emitted_deprecation_warnings | |||
warning_msg = ( | |||
( | |||
f"{deprecated_function.__name__} is deprecated and will be removed " | |||
"in the next major version of datasets." | |||
) | |||
+ f" {help_message}" | |||
if help_message | |||
else "" | |||
) | |||
@wraps(deprecated_function) | |||
def wrapper(*args, **kwargs): | |||
func_hash = hash(deprecated_function) | |||
if func_hash not in _emitted_deprecation_warnings: | |||
warnings.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||
_emitted_deprecation_warnings.add(func_hash) | |||
return deprecated_function(*args, **kwargs) | |||
wrapper._decorator_name_ = "deprecated" | |||
return wrapper | |||
return decorator | |||
def seq_len_to_mask(seq_len, max_len=None): | |||
r""" | |||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||
转变 1-d seq_len到2-d mask. | |||
.. code-block:: | |||
>>> seq_len = torch.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len) | |||
>>> print(mask.size()) | |||
torch.Size([14, 15]) | |||
>>> seq_len = np.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len) | |||
>>> print(mask.shape) | |||
(14, 15) | |||
>>> seq_len = torch.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len, max_len=100) | |||
>>>print(mask.size()) | |||
torch.Size([14, 100]) | |||
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,) | |||
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 | |||
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 | |||
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 | |||
""" | |||
if isinstance(seq_len, np.ndarray): | |||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
max_len = int(max_len) if max_len else int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
else: | |||
raise TypeError("Only support 1-d numpy.ndarray.") | |||
return mask | |||
def wait_to_success(fn, no=False): | |||
while True: | |||
sleep(0.01) | |||
if (no and not fn()) or (not no and fn()): | |||
break | |||
# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||
# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||
def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
if not path.exists(): | |||
return | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
_recursive_rm(path) | |||
wait_to_success(path.exists, no=True) | |||
def _recursive_rm(path: Path): | |||
if path.is_file() or path.is_symlink(): | |||
if path.exists(): | |||
try: | |||
path.unlink() | |||
except Exception: | |||
pass | |||
return | |||
for sub_path in list(path.iterdir()): | |||
_recursive_rm(sub_path) | |||
path.rmdir() | |||
def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||
""" | |||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | |||
""" | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
path.mkdir(parents=True, exist_ok=True) | |||
wait_to_success(path.exists) | |||