@@ -0,0 +1,43 @@ | |||||
__all__ = [ | |||||
'cache_results', | |||||
'is_jittor_dataset', | |||||
'jittor_collate_wraps', | |||||
'paddle_to', | |||||
'paddle_move_data_to_device', | |||||
'get_paddle_device_id', | |||||
'get_paddle_gpu_str', | |||||
'is_in_paddle_dist', | |||||
'is_in_fnlp_paddle_dist', | |||||
'is_in_paddle_launch_dist', | |||||
'f_rich_progress', | |||||
'torch_paddle_move_data_to_device', | |||||
'torch_move_data_to_device', | |||||
'get_fn_arg_names', | |||||
'check_fn_not_empty_params', | |||||
'auto_param_call', | |||||
'check_user_specific_params', | |||||
'dataclass_to_dict', | |||||
'match_and_substitute_params', | |||||
'apply_to_collection', | |||||
'nullcontext', | |||||
'pretty_table_printer', | |||||
'Option', | |||||
'indice_collate_wrapper', | |||||
'deprecated', | |||||
'seq_len_to_mask', | |||||
'synchronize_safe_rm', | |||||
'synchronize_mkdir' | |||||
] | |||||
from .cache_results import cache_results | |||||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||||
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | |||||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||||
from .rich_progress import f_rich_progress | |||||
from .torch_paddle_utils import torch_paddle_move_data_to_device | |||||
from .torch_utils import torch_move_data_to_device | |||||
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||||
@@ -0,0 +1,310 @@ | |||||
from datetime import datetime | |||||
import hashlib | |||||
import _pickle | |||||
import functools | |||||
import os | |||||
from typing import Callable, List, Any, Optional | |||||
import inspect | |||||
import ast | |||||
from collections import deque | |||||
__all__ = [ | |||||
'cache_results' | |||||
] | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.log.highlighter import ColorHighlighter | |||||
class FuncCallVisitor(ast.NodeVisitor): | |||||
# credit to https://gist.github.com/jargnar/0946ab1d985e2b4ab776 | |||||
def __init__(self): | |||||
self._name = deque() | |||||
@property | |||||
def name(self): | |||||
return '.'.join(self._name) | |||||
@name.deleter | |||||
def name(self): | |||||
self._name.clear() | |||||
def visit_Name(self, node): | |||||
self._name.appendleft(node.id) | |||||
def visit_Attribute(self, node): | |||||
try: | |||||
self._name.appendleft(node.attr) | |||||
self._name.appendleft(node.value.id) | |||||
except AttributeError: | |||||
self.generic_visit(node) | |||||
def get_func_calls(tree): | |||||
func_calls = [] | |||||
for node in ast.walk(tree): | |||||
if isinstance(node, ast.Call): | |||||
callvisitor = FuncCallVisitor() | |||||
callvisitor.visit(node.func) | |||||
func_calls.append(callvisitor.name) | |||||
if isinstance(node, ast.FunctionDef): | |||||
if not (node is tree): | |||||
func_calls.extend(get_func_calls(node)) | |||||
return func_calls | |||||
def truncate_start_blanks(source:str)->str: | |||||
""" | |||||
将source中的每一行按照第一行的indent删掉多余的空格 | |||||
:param source: | |||||
:return: | |||||
""" | |||||
lines = source.split('\n') | |||||
num_blank = 0 | |||||
# get the top blank line | |||||
for line in lines: | |||||
if line: | |||||
num_blank = len(line) - len(line.lstrip()) | |||||
new_lines = [] | |||||
for line in lines: | |||||
i = -1 | |||||
for i in range(min(len(line), num_blank)): | |||||
if line[i] == ' ': | |||||
continue | |||||
else: | |||||
break | |||||
line = line[i:] | |||||
new_lines.append(line) | |||||
return '\n'.join(new_lines) | |||||
def _get_func_and_its_called_func_source_code(func) -> List[str]: | |||||
""" | |||||
给定一个func,返回在这个函数里面用到的所有函数的源码。 | |||||
:param callable func: | |||||
:return: | |||||
""" | |||||
last_frame = inspect.currentframe().f_back.f_back.f_back | |||||
last_frame_f_local = last_frame.f_locals | |||||
last_frame_loc = {} | |||||
if 'loc' in last_frame_f_local: | |||||
last_frame_loc = last_frame_f_local['loc'] | |||||
func_calls = list(set(get_func_calls(ast.parse(truncate_start_blanks(inspect.getsource(func)))))) | |||||
func_calls.sort() | |||||
sources = [] | |||||
for _func_name in func_calls: | |||||
try: | |||||
if _func_name == 'cache_results': # ignore the decorator | |||||
continue | |||||
if '.' in _func_name: | |||||
_funcs = _func_name.split('.') | |||||
else: | |||||
_funcs = [_func_name] | |||||
if _funcs[0] in last_frame_f_local or _funcs[0] in last_frame_loc: | |||||
tmp = _funcs.pop(0) | |||||
variable = last_frame_f_local.get(tmp, last_frame_loc.get(tmp)) | |||||
while len(_funcs) or variable is not None: | |||||
if hasattr(variable, '__class__') and not inspect.isbuiltin(variable.__class__): | |||||
try: | |||||
sources.append(inspect.getsource(variable.__class__)) | |||||
except TypeError: | |||||
pass | |||||
if callable(variable) or inspect.isclass(variable): | |||||
sources.append(inspect.getsource(variable)) | |||||
if len(_funcs): | |||||
tmp = _funcs.pop(0) | |||||
if hasattr(variable, tmp): | |||||
variable = getattr(variable, tmp) | |||||
else: | |||||
break | |||||
else: | |||||
variable = None | |||||
except: | |||||
# some failure | |||||
pass | |||||
del last_frame # | |||||
sources.append(inspect.getsource(func)) | |||||
return sources | |||||
def _prepare_cache_filepath(filepath:str): | |||||
r""" | |||||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||||
:param filepath: str. | |||||
:return: None, if not, this function will raise error | |||||
""" | |||||
_cache_filepath = os.path.abspath(filepath) | |||||
if os.path.isdir(_cache_filepath): | |||||
raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||||
cache_dir = os.path.dirname(_cache_filepath) | |||||
if not os.path.exists(cache_dir): | |||||
os.makedirs(cache_dir, exist_ok=True) | |||||
class Hasher: | |||||
def __init__(self): | |||||
self.m = hashlib.sha1() | |||||
def update(self, value: Any) -> None: | |||||
if isinstance(value, str): | |||||
value = [value] | |||||
for x in value: | |||||
self.m.update(x.encode('utf8')) | |||||
def hexdigest(self) -> str: | |||||
return self.m.hexdigest() | |||||
def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = None): | |||||
if fn_kwargs is None: | |||||
fn_kwargs = {} | |||||
hasher = Hasher() | |||||
try: | |||||
sources = _get_func_and_its_called_func_source_code(fn) | |||||
hasher.update(sources) | |||||
except: | |||||
return "can't be hashed" | |||||
for key in sorted(fn_kwargs): | |||||
hasher.update(key) | |||||
try: | |||||
hasher.update(fn_kwargs[key]) | |||||
except: | |||||
pass | |||||
return hasher.hexdigest() | |||||
def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||||
r""" | |||||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | |||||
import time | |||||
import numpy as np | |||||
from fastNLP import cache_results | |||||
@cache_results('cache.pkl') | |||||
def process_data(): | |||||
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | |||||
time.sleep(1) | |||||
return np.random.randint(10, size=(5,)) | |||||
start_time = time.time() | |||||
print("res =",process_data()) | |||||
print(time.time() - start_time) | |||||
start_time = time.time() | |||||
print("res =",process_data()) | |||||
print(time.time() - start_time) | |||||
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | |||||
# Save cache to cache.pkl. | |||||
# res = [5 4 9 1 8] | |||||
# 1.0042750835418701 | |||||
# Read cache from cache.pkl. | |||||
# res = [5 4 9 1 8] | |||||
# 0.0040721893310546875 | |||||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: | |||||
# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | |||||
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' | |||||
上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | |||||
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | |||||
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: | |||||
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 | |||||
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | |||||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | |||||
函数调用的时候传入_cache_fp这个参数。 | |||||
:param bool _refresh: 是否重新生成cache。 | |||||
:param int _verbose: 是否打印cache的信息。 | |||||
:param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 | |||||
与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 | |||||
该修改对结果有影响,但无法做出warning。 | |||||
:return: | |||||
""" | |||||
def wrapper_(func): | |||||
signature = inspect.signature(func) | |||||
for key, _ in signature.parameters.items(): | |||||
if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'): | |||||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||||
@functools.wraps(func) | |||||
def wrapper(*args, **kwargs): | |||||
fn_param = kwargs.copy() | |||||
if args: | |||||
params = [p.name for p in inspect.signature(func).parameters.values()] | |||||
fn_param.update(zip(params, args)) | |||||
if '_cache_fp' in kwargs: | |||||
cache_filepath = kwargs.pop('_cache_fp') | |||||
assert isinstance(cache_filepath, str), "_cache_fp can only be str." | |||||
else: | |||||
cache_filepath = _cache_fp | |||||
if '_refresh' in kwargs: | |||||
refresh = kwargs.pop('_refresh') | |||||
assert isinstance(refresh, bool), "_refresh can only be bool." | |||||
else: | |||||
refresh = _refresh | |||||
if '_verbose' in kwargs: | |||||
verbose = kwargs.pop('_verbose') | |||||
assert isinstance(verbose, int), "_verbose can only be integer." | |||||
else: | |||||
verbose = _verbose | |||||
if '_check_hash' in kwargs: | |||||
check_hash = kwargs.pop('_check_hash') | |||||
else: | |||||
check_hash = _check_hash | |||||
refresh_flag = True | |||||
new_hash_code = None | |||||
if check_hash: | |||||
new_hash_code = cal_fn_hash_code(func, fn_param) | |||||
if cache_filepath is not None and refresh is False: | |||||
# load data | |||||
if os.path.exists(cache_filepath): | |||||
cache_filepath = os.path.abspath(cache_filepath) | |||||
with open(cache_filepath, 'rb') as f: | |||||
results = _pickle.load(f) | |||||
old_hash_code = results['hash'] | |||||
save_time = results['save_time'] | |||||
results = results['results'] | |||||
if verbose == 1: | |||||
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | |||||
if check_hash and old_hash_code != new_hash_code: | |||||
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | |||||
f"difference may caused by the sourcecode change of the functions by this function.", | |||||
extra={'highlighter': ColorHighlighter('red')}) | |||||
refresh_flag = False | |||||
if refresh_flag: | |||||
if new_hash_code is None: | |||||
new_hash_code = cal_fn_hash_code(func, fn_param) | |||||
results = func(*args, **kwargs) | |||||
if cache_filepath is not None: | |||||
if results is None: | |||||
raise RuntimeError("The return value is None. Cannot save None results.") | |||||
cache_filepath = os.path.abspath(cache_filepath) | |||||
_prepare_cache_filepath(cache_filepath) | |||||
_dict = { | |||||
'results': results, | |||||
'hash': new_hash_code, | |||||
'save_time': datetime.now(), | |||||
} | |||||
with open(cache_filepath, 'wb') as f: | |||||
_pickle.dump(_dict, f) | |||||
logger.info("Save cache to {}.".format(cache_filepath)) | |||||
return results | |||||
return wrapper | |||||
return wrapper_ |
@@ -0,0 +1,4 @@ | |||||
class DummyClass: | |||||
pass |
@@ -0,0 +1,51 @@ | |||||
__all__ = [ | |||||
'is_jittor_dataset', | |||||
'jittor_collate_wraps' | |||||
] | |||||
from collections.abc import Mapping, Callable | |||||
from functools import wraps | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt | |||||
from fastNLP.core.dataset import Instance | |||||
def is_jittor_dataset(dataset) -> bool: | |||||
try: | |||||
if isinstance(dataset, jt.dataset.Dataset): | |||||
return True | |||||
else: | |||||
return False | |||||
except BaseException: | |||||
return False | |||||
def jittor_collate_wraps(func, auto_collator: Callable): | |||||
""" | |||||
对jittor的collate_fn进行wrap封装, 如果数据集为mapping类型,那么采用auto_collator,否则还是采用jittor自带的collate_batch | |||||
:param func: | |||||
:param auto_collator: | |||||
:return: | |||||
""" | |||||
@wraps(func) | |||||
def wrapper(batch): | |||||
if isinstance(batch[0], Instance): | |||||
if auto_collator is not None: | |||||
result = auto_collator(batch) | |||||
else: | |||||
raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!") | |||||
elif isinstance(batch[0], Mapping): | |||||
if auto_collator is not None: | |||||
result = auto_collator(batch) | |||||
else: | |||||
result = func(batch) | |||||
else: | |||||
result = func(batch) | |||||
return result | |||||
return wrapper |
@@ -0,0 +1,89 @@ | |||||
__all__ = [ | |||||
"paddle_to", | |||||
"paddle_move_data_to_device", | |||||
"get_paddle_gpu_str", | |||||
"get_paddle_device_id", | |||||
"is_in_paddle_dist", | |||||
"is_in_fnlp_paddle_dist", | |||||
"is_in_paddle_launch_dist", | |||||
] | |||||
import os | |||||
from typing import Any, Optional, Union | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
if _NEED_IMPORT_PADDLE: | |||||
import paddle | |||||
from .utils import apply_to_collection | |||||
def paddle_to(data, device: Union[str, int]): | |||||
if device == "cpu": | |||||
return data.cpu() | |||||
else: | |||||
return data.cuda(get_paddle_device_id(device)) | |||||
def get_paddle_gpu_str(device: Union[str, int]): | |||||
""" | |||||
获得 `gpu:x` 类型的设备名 | |||||
""" | |||||
if isinstance(device, str): | |||||
return device.replace("cuda", "gpu") | |||||
return f"gpu:{device}" | |||||
def get_paddle_device_id(device: Union[str, int]): | |||||
""" | |||||
获得 gpu 的设备id,注意不要传入 `cpu` 。 | |||||
""" | |||||
if isinstance(device, int): | |||||
return device | |||||
if device == "cpu": | |||||
raise ValueError("Cannot get device id from `cpu`.") | |||||
return paddle.device._convert_to_place(device).get_device_id() | |||||
def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | |||||
data_device: Optional[str] = None) -> Any: | |||||
r""" | |||||
将数据集合传输到给定设备。只有paddle.Tensor对象会被传输到设备中,其余保持不变 | |||||
:param batch: | |||||
:param device: `cpu`, `gpu` or `gpu:x` | |||||
:param data_device: | |||||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||||
""" | |||||
if device is None: | |||||
if data_device is not None: | |||||
device = data_device | |||||
else: | |||||
return batch | |||||
def batch_to(data: Any) -> Any: | |||||
return paddle_to(data, device) | |||||
return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | |||||
def is_in_paddle_dist(): | |||||
""" | |||||
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | |||||
""" | |||||
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | |||||
def is_in_fnlp_paddle_dist(): | |||||
""" | |||||
判断是否处于 FastNLP 拉起的分布式进程中 | |||||
""" | |||||
return FASTNLP_DISTRIBUTED_CHECK in os.environ | |||||
def is_in_paddle_launch_dist(): | |||||
""" | |||||
判断是否处于 launch 启动的分布式进程中 | |||||
""" | |||||
return 'PADDLE_RANK_IN_NODE' in os.environ and \ | |||||
'FLAGS_selected_gpus' in os.environ and \ | |||||
FASTNLP_DISTRIBUTED_CHECK not in os.environ |
@@ -0,0 +1,214 @@ | |||||
""" | |||||
该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能 | |||||
不冲突 | |||||
""" | |||||
import sys | |||||
from typing import Any, Union, Optional | |||||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||||
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | |||||
__all__ = [ | |||||
'f_rich_progress' | |||||
] | |||||
from fastNLP.envs import get_global_rank | |||||
class Singleton(type): | |||||
_instances = {} | |||||
def __call__(cls, *args, **kwargs): | |||||
if cls not in cls._instances: | |||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||||
return cls._instances[cls] | |||||
# 如果不打印的时候,使得整个 progress 没有任何意义 | |||||
class DummyFRichProgress: | |||||
def __getattr__(self, item): | |||||
return DummyFRichProgress() | |||||
def __call__(self, *args, **kwargs): | |||||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | |||||
return None | |||||
class FRichProgress(Progress, metaclass=Singleton): | |||||
""" | |||||
fastNLP 使用的 progress bar ,新增了 new_progress 函数,通过此函数即可定制 fastNLP 中所有 progress 的样式。 | |||||
""" | |||||
def new_progess(self, *columns: Union[str, ProgressColumn], | |||||
console: Optional[Console] = None, | |||||
auto_refresh: bool = True, | |||||
refresh_per_second: float = 10, | |||||
speed_estimate_period: float = 30.0, | |||||
transient: bool = True, | |||||
redirect_stdout: bool = True, | |||||
redirect_stderr: bool = True, | |||||
get_time: Optional[GetTimeCallable] = None, | |||||
disable: bool = False, | |||||
expand: bool = False): | |||||
""" | |||||
重新初始化一个rich bar。如果columns不传入,则继续使用之前的column内容。 | |||||
:param progress: | |||||
:return: | |||||
""" | |||||
for task_id in self.task_ids: # 首先移除已有的 | |||||
self.remove_task(task_id) | |||||
assert ( | |||||
refresh_per_second is None or refresh_per_second > 0 | |||||
), "refresh_per_second must be > 0" | |||||
# stop previous columns | |||||
self.stop() | |||||
# do not change these variables | |||||
# self._lock = RLock() | |||||
# self._tasks: Dict[TaskID, Task] = {} | |||||
# self._task_index: TaskID = TaskID(0) | |||||
if len(columns) != 0: | |||||
self.columns = columns | |||||
self.speed_estimate_period = speed_estimate_period | |||||
self.disable = disable | |||||
self.expand = expand | |||||
self.live = Live( | |||||
console=console or get_console(), | |||||
auto_refresh=auto_refresh, | |||||
refresh_per_second=refresh_per_second, | |||||
transient=transient, | |||||
redirect_stdout=redirect_stdout, | |||||
redirect_stderr=redirect_stderr, | |||||
get_renderable=self.get_renderable, | |||||
) | |||||
self.get_time = get_time or self.console.get_time | |||||
self.print = self.console.print | |||||
self.log = self.console.log | |||||
# start new | |||||
self.start() | |||||
return self | |||||
def set_transient(self, transient: bool = True): | |||||
""" | |||||
设置是否在bar运行结束之后不关闭 | |||||
:param transient: | |||||
:return: | |||||
""" | |||||
self.new_progess(transient=transient) | |||||
def set_disable(self, flag: bool = True): | |||||
""" | |||||
设置当前 progress bar 的状态,如果为 True ,则不会显示进度条了。 | |||||
:param flag: | |||||
:return: | |||||
""" | |||||
self.disable = flag | |||||
def add_task( | |||||
self, | |||||
description: str, | |||||
start: bool = True, | |||||
total: float = 100.0, | |||||
completed: int = 0, | |||||
visible: bool = True, | |||||
**fields: Any, | |||||
) -> TaskID: | |||||
if self.live._started is False: | |||||
self.start() | |||||
post_desc = fields.pop('post_desc', '') | |||||
return super().add_task(description=description, | |||||
start=start, | |||||
total=total, | |||||
completed=completed, | |||||
visible=visible, | |||||
post_desc=post_desc, | |||||
**fields) | |||||
def stop_task(self, task_id: TaskID) -> None: | |||||
if task_id in self._tasks: | |||||
super().stop_task(task_id) | |||||
def remove_task(self, task_id: TaskID) -> None: | |||||
if task_id in self._tasks: | |||||
super().remove_task(task_id) | |||||
def destroy_task(self, task_id: TaskID): | |||||
if task_id in self._tasks: | |||||
super().stop_task(task_id) | |||||
super().remove_task(task_id) | |||||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||||
f_rich_progress = FRichProgress().new_progess( | |||||
"[progress.description]{task.description}", | |||||
"[progress.percentage]{task.percentage:>3.0f}%", | |||||
BarColumn(), | |||||
TimeElapsedColumn(), | |||||
"/", | |||||
TimeRemainingColumn(), | |||||
TextColumn("{task.fields[post_desc]}", justify="right"), | |||||
transient=True, | |||||
disable=False, | |||||
speed_estimate_period=10 | |||||
) | |||||
else: | |||||
f_rich_progress = DummyFRichProgress() | |||||
if __name__ == '__main__': | |||||
f = DummyFRichProgress() | |||||
f.console.print('xxx') | |||||
f.console.print.print('xxx') | |||||
# 测试创建 | |||||
import time | |||||
n_steps = 10 | |||||
task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||||
for i in range(n_steps): | |||||
f_rich_progress.update(task_id, description=f'test:{i}', advance=1, refresh=True) | |||||
print(f"test:{i}") | |||||
time.sleep(0.3) | |||||
f_rich_progress.remove_task(task_id) | |||||
# 测试一下 inner/outer | |||||
n_steps = 5 | |||||
f_rich_progress.start() | |||||
outer_task_id = f_rich_progress.add_task(description='Outer:', total=n_steps) | |||||
inner_task_id = f_rich_progress.add_task(description='Inner:', total=n_steps) | |||||
for i in range(n_steps): | |||||
f_rich_progress.reset(inner_task_id, total=n_steps) | |||||
f_rich_progress.update(outer_task_id, description=f'Outer:{i}', advance=1, refresh=True) | |||||
for j in range(n_steps): | |||||
f_rich_progress.update(inner_task_id, description=f'Inner:{j}', advance=1, refresh=True, | |||||
post_desc='Loss: 0.334332323') | |||||
print(f"Outer:{i}, Inner:{j}") | |||||
time.sleep(0.3) | |||||
# 测试一下修改bar | |||||
f_rich_progress = FRichProgress().new_progess( | |||||
BarColumn(), | |||||
"[progress.description]{task.description}", | |||||
"[progress.percentage]{task.percentage:>3.0f}%", | |||||
TimeElapsedColumn(), | |||||
transient=True) | |||||
n_steps = 10 | |||||
task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||||
for i in range(n_steps): | |||||
f_rich_progress.update(task_id, description=f'test:{i}', advance=1) | |||||
print(f"test:{i}") | |||||
time.sleep(0.3) | |||||
f_rich_progress.remove_task(task_id) | |||||
f_rich_progress.stop() |
@@ -0,0 +1,49 @@ | |||||
from typing import Any, Optional | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_PADDLE: | |||||
import paddle | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
__all__ = [ | |||||
"torch_paddle_move_data_to_device", | |||||
] | |||||
from .utils import apply_to_collection | |||||
from .paddle_utils import paddle_to | |||||
def torch_paddle_move_data_to_device(batch: Any, device: Optional[str] = None, non_blocking: Optional[bool] = True, | |||||
data_device: Optional[str] = None) -> Any: | |||||
r""" | |||||
将数据集合传输到给定设备。只有paddle.Tensor和torch.Tensor对象会被传输到设备中,其余保持不变 | |||||
:param batch: | |||||
:param device: | |||||
:param non_blocking: | |||||
:param data_device: | |||||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||||
""" | |||||
if device is None: | |||||
if data_device is not None: | |||||
device = data_device | |||||
else: | |||||
return batch | |||||
torch_device = device.replace("gpu", "cuda") | |||||
paddle_device = device.replace("cuda", "gpu") | |||||
def batch_to(data: Any) -> Any: | |||||
if isinstance(data, torch.Tensor): | |||||
data = data.to(torch_device, non_blocking=non_blocking) | |||||
elif isinstance(data, paddle.Tensor): | |||||
data = paddle_to(data, paddle_device) | |||||
return data | |||||
return apply_to_collection(batch, dtype=(paddle.Tensor, torch.Tensor), function=batch_to) |
@@ -0,0 +1,63 @@ | |||||
from abc import ABC | |||||
from typing import Any, Union, Optional | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
__all__ = [ | |||||
'torch_move_data_to_device' | |||||
] | |||||
from .utils import apply_to_collection | |||||
class TorchTransferableDataType(ABC): | |||||
""" | |||||
A custom type for data that can be moved to a torch device via `.to(...)`. | |||||
Example: | |||||
>>> isinstance(dict, TorchTransferableDataType) | |||||
False | |||||
>>> isinstance(torch.rand(2, 3), TorchTransferableDataType) | |||||
True | |||||
>>> class CustomObject: | |||||
... def __init__(self): | |||||
... self.x = torch.rand(2, 2) | |||||
... def to(self, device): | |||||
... self.x = self.x.to(device) | |||||
... return self | |||||
>>> isinstance(CustomObject(), TorchTransferableDataType) | |||||
True | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is TorchTransferableDataType: | |||||
to = getattr(subclass, "to", None) | |||||
return callable(to) | |||||
return NotImplemented | |||||
def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, | |||||
non_blocking: Optional[bool] = True) -> Any: | |||||
r""" | |||||
将数据集合传输到给定设备。任何定义方法 “to(device)” 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||||
:param batch: 应当迁移的数据; | |||||
:param device: 数据应当迁移到的设备;当该参数的值为 None 时,表示迁移数据的操作由用户自己完成,我们不需要经管; | |||||
:param non_blocking: pytorch 的迁移数据方法 `to` 的参数; | |||||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||||
""" | |||||
if device is None: | |||||
return batch | |||||
def batch_to(data: Any) -> Any: | |||||
kwargs = dict(non_blocking=non_blocking) if isinstance(data, torch.Tensor) else {} | |||||
data_output = data.to(device, **kwargs) | |||||
if data_output is not None: | |||||
return data_output | |||||
# user wrongly implemented the `TransferableDataType` and forgot to return `self`. | |||||
return data | |||||
dtype = TorchTransferableDataType | |||||
return apply_to_collection(batch, dtype=dtype, function=batch_to) |
@@ -0,0 +1,591 @@ | |||||
import inspect | |||||
from inspect import Parameter | |||||
import dataclasses | |||||
import warnings | |||||
from dataclasses import is_dataclass | |||||
from copy import deepcopy | |||||
from collections import defaultdict, OrderedDict | |||||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional | |||||
from typing import Tuple, Optional | |||||
from time import sleep | |||||
try: | |||||
from typing import Literal, Final | |||||
except ImportError: | |||||
from typing_extensions import Literal, Final | |||||
import os | |||||
from contextlib import contextmanager | |||||
from functools import wraps | |||||
from prettytable import PrettyTable | |||||
import numpy as np | |||||
from pathlib import Path | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||||
__all__ = [ | |||||
'get_fn_arg_names', | |||||
'check_fn_not_empty_params', | |||||
'auto_param_call', | |||||
'check_user_specific_params', | |||||
'dataclass_to_dict', | |||||
'match_and_substitute_params', | |||||
'apply_to_collection', | |||||
'nullcontext', | |||||
'pretty_table_printer', | |||||
'Option', | |||||
'indice_collate_wrapper', | |||||
'deprecated', | |||||
'seq_len_to_mask', | |||||
'synchronize_safe_rm', | |||||
'synchronize_mkdir' | |||||
] | |||||
def get_fn_arg_names(fn: Callable) -> List[str]: | |||||
r""" | |||||
返回一个函数的所有参数的名字; | |||||
:param fn: 需要查询的函数; | |||||
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | |||||
""" | |||||
return list(inspect.signature(fn).parameters) | |||||
def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||||
r""" | |||||
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||||
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||||
:param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||||
:param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||||
:return: bool,表示传入的 `batch_step_fn` 是否正确; | |||||
""" | |||||
if fn is None: | |||||
return True | |||||
if not callable(fn): | |||||
return False | |||||
else: | |||||
params = inspect.signature(fn).parameters | |||||
not_default_params = {} | |||||
for _name, _param in params.items(): | |||||
if _param.default == Parameter.empty: | |||||
not_default_params[_name] = _param | |||||
return len(not_default_params) == param_num | |||||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||||
r""" | |||||
1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||||
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; | |||||
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | |||||
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; | |||||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||||
参数值后,再传给 `fn` 进行实际的运算; | |||||
:param mapping: 一个字典,用来更改其前面的字典的键值; | |||||
:return: 返回 `fn` 运行的结果; | |||||
Examples: | |||||
>>> # 1 | |||||
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); | |||||
>>> batch = {"x": 20, "y": 1} | |||||
>>> output = {"pred": 0} | |||||
>>> acc = auto_param_call(loss_fn, batch, output) | |||||
>>> # 2 | |||||
>>> def test_fn(x, y, a, b=10): | |||||
>>> return x + y + a + b | |||||
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 | |||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | |||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | |||||
""" | |||||
if signature_fn is not None: | |||||
if not callable(signature_fn): | |||||
raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | |||||
_need_params = OrderedDict(inspect.signature(signature_fn).parameters) | |||||
else: | |||||
_need_params = OrderedDict(inspect.signature(fn).parameters) | |||||
_kwargs = None | |||||
for _name, _param in _need_params.items(): | |||||
if _param.kind == Parameter.VAR_POSITIONAL: | |||||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||||
if _param.kind == Parameter.VAR_KEYWORD: | |||||
_kwargs = (_name, _param) | |||||
if _kwargs is not None: | |||||
_need_params.pop(_kwargs[0]) | |||||
_default_params = {} | |||||
for _name, _param in _need_params.items(): | |||||
if _param.default != Parameter.empty: | |||||
_default_params[_name] = _param.default | |||||
if mapping is not None: | |||||
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||||
_has_params = {} | |||||
duplicate_names = [] | |||||
for arg in args: | |||||
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||||
for _name, _value in arg.items(): | |||||
if mapping is not None and _name in mapping: | |||||
_name = mapping[_name] | |||||
if _name not in _has_params: | |||||
if _kwargs is not None or _name in _need_params: | |||||
_has_params[_name] = _value | |||||
# 同一参数对象在两个输入的资源中都出现,造成混淆; | |||||
elif _name in _need_params and not (_has_params[_name] is _value): | |||||
duplicate_names.append(_name) | |||||
if duplicate_names: | |||||
raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||||
# 将具有默认值但是没有被输入修改过的参数值传进去; | |||||
for _name, _value in _default_params.items(): | |||||
if _name not in _has_params: | |||||
_has_params[_name] = _value | |||||
if len(_has_params)<len(_need_params): | |||||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | |||||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||||
return fn(**_has_params) | |||||
def check_user_specific_params(user_params: Dict, fn: Callable): | |||||
""" | |||||
该函数使用用户的输入来对指定函数的参数进行赋值; | |||||
主要用于一些用户无法直接调用函数的情况; | |||||
该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误; | |||||
:param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值; | |||||
:param fn: 会被调用的函数; | |||||
:return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值; | |||||
""" | |||||
fn_arg_names = get_fn_arg_names(fn) | |||||
for arg_name, arg_value in user_params.items(): | |||||
if arg_name not in fn_arg_names: | |||||
logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||||
return user_params | |||||
def dataclass_to_dict(data: "dataclass") -> Dict: | |||||
if not is_dataclass(data): | |||||
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | |||||
_dict = dict() | |||||
for _key in data.__dataclass_fields__: | |||||
_dict[_key] = getattr(data, _key) | |||||
return _dict | |||||
def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | |||||
r""" | |||||
用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; | |||||
该函数应用于 `input_mapping` 和 `output_mapping`; | |||||
对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; | |||||
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; | |||||
转换的逻辑按优先级依次为: | |||||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; | |||||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; | |||||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; | |||||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; | |||||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 | |||||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 | |||||
:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 | |||||
:param data: 需要被转换的对象; | |||||
:return: 返回转换好的结果; | |||||
""" | |||||
if mapping is None: | |||||
return data | |||||
if callable(mapping): | |||||
# 注意我们在 `Trainer.extract_loss_from_outputs` 函数里会检查 outputs 的输出,outputs 的类型目前只支持 `Dict` 和 `dataclass`; | |||||
return mapping(data) | |||||
if not isinstance(mapping, Dict): | |||||
raise ValueError( | |||||
f"Parameter `mapping` should be of type `Dict` or `Callable`, not `{type(mapping)}`. This is caused" | |||||
f"by your `input_mapping` or `output_mapping` parameter in your `Trainer` or `Evaluator`.") | |||||
if not isinstance(data, Dict) and not is_dataclass(data) and not isinstance(data, Sequence): | |||||
raise ValueError("Parameter `data` should be type `Dict` or `dataclass` when the other parameter `mapping` is " | |||||
"type `Dict`.") | |||||
# 如果 `data` 是一个 dataclass,那么先将其转换为一个 `Dict`; | |||||
if is_dataclass(data): | |||||
data = dataclass_to_dict(data) | |||||
# 如果 `data` 是一个 List,那么我们同样先将其转换为一个 `Dict`,为 {"_0": list[0], "_1": list[1], ...}; | |||||
elif isinstance(data, Sequence): | |||||
data = {"_" + str(i): data[i] for i in range(len(data))} | |||||
_new_data = {} | |||||
for _name, _value in data.items(): | |||||
if _name in mapping: | |||||
_new_data[mapping[_name]] = _value | |||||
else: | |||||
_new_data[_name] = _value | |||||
return _new_data | |||||
def _is_namedtuple(obj: object) -> bool: | |||||
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 | |||||
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") | |||||
def _is_dataclass_instance(obj: object) -> bool: | |||||
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions | |||||
return dataclasses.is_dataclass(obj) and not isinstance(obj, type) | |||||
def apply_to_collection( | |||||
data: Any, | |||||
dtype: Union[type, Any, Tuple[Union[type, Any]]], | |||||
function: Callable, | |||||
*args: Any, | |||||
wrong_dtype: Optional[Union[type, Tuple[type]]] = None, | |||||
include_none: bool = True, | |||||
**kwargs: Any, | |||||
) -> Any: | |||||
"""将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 | |||||
this function credit to: https://github.com/PyTorchLightning/pytorch-lightning | |||||
Args: | |||||
data: the collection to apply the function to | |||||
dtype: the given function will be applied to all elements of this dtype | |||||
function: the function to apply | |||||
*args: positional arguments (will be forwarded to calls of ``function``) | |||||
wrong_dtype: the given function won't be applied if this type is specified and the given collections | |||||
is of the ``wrong_dtype`` even if it is of type ``dtype`` | |||||
include_none: Whether to include an element if the output of ``function`` is ``None``. | |||||
**kwargs: keyword arguments (will be forwarded to calls of ``function``) | |||||
Returns: | |||||
The resulting collection | |||||
""" | |||||
# Breaking condition | |||||
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): | |||||
return function(data, *args, **kwargs) | |||||
elem_type = type(data) | |||||
# Recursively apply to collection items | |||||
if isinstance(data, Mapping): | |||||
out = [] | |||||
for k, v in data.items(): | |||||
v = apply_to_collection( | |||||
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||||
) | |||||
if include_none or v is not None: | |||||
out.append((k, v)) | |||||
if isinstance(data, defaultdict): | |||||
return elem_type(data.default_factory, OrderedDict(out)) | |||||
return elem_type(OrderedDict(out)) | |||||
is_namedtuple = _is_namedtuple(data) | |||||
is_sequence = isinstance(data, Sequence) and not isinstance(data, str) | |||||
if is_namedtuple or is_sequence: | |||||
out = [] | |||||
for d in data: | |||||
v = apply_to_collection( | |||||
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||||
) | |||||
if include_none or v is not None: | |||||
out.append(v) | |||||
return elem_type(*out) if is_namedtuple else elem_type(out) | |||||
if _is_dataclass_instance(data): | |||||
# make a deepcopy of the data, | |||||
# but do not deepcopy mapped fields since the computation would | |||||
# be wasted on values that likely get immediately overwritten | |||||
fields = {} | |||||
memo = {} | |||||
for field in dataclasses.fields(data): | |||||
field_value = getattr(data, field.name) | |||||
fields[field.name] = (field_value, field.init) | |||||
memo[id(field_value)] = field_value | |||||
result = deepcopy(data, memo=memo) | |||||
# apply function to each field | |||||
for field_name, (field_value, field_init) in fields.items(): | |||||
if field_init: | |||||
v = apply_to_collection( | |||||
field_value, | |||||
dtype, | |||||
function, | |||||
*args, | |||||
wrong_dtype=wrong_dtype, | |||||
include_none=include_none, | |||||
**kwargs, | |||||
) | |||||
if not field_init or (not include_none and v is None): # retain old value | |||||
v = getattr(data, field_name) | |||||
setattr(result, field_name, v) | |||||
return result | |||||
# data is neither of dtype, nor a collection | |||||
return data | |||||
@contextmanager | |||||
def nullcontext(): | |||||
r""" | |||||
用来实现一个什么 dummy 的 context 上下文环境; | |||||
""" | |||||
yield | |||||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||||
r""" | |||||
:param string: 要被截断的字符串 | |||||
:param c: 命令行列数 | |||||
:param c_size: instance或dataset field数 | |||||
:param title: 列名 | |||||
:return: 对一个过长的列进行截断的结果 | |||||
""" | |||||
avg = max(int(c / c_size / 2), len(title)) | |||||
string = str(string) | |||||
res = "" | |||||
counter = 0 | |||||
for char in string: | |||||
if ord(char) > 255: | |||||
counter += 2 | |||||
else: | |||||
counter += 1 | |||||
res += char | |||||
if counter > avg: | |||||
res = res + "..." | |||||
break | |||||
return res | |||||
def _is_iterable(value): | |||||
# 检查是否是iterable的, duck typing | |||||
try: | |||||
iter(value) | |||||
return True | |||||
except BaseException as e: | |||||
return False | |||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||||
r""" | |||||
:param dataset_or_ins: 传入一个dataSet或者instance | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||||
+-----------+-----------+-----------------+ | |||||
| field_1 | field_2 | field_3 | | |||||
+-----------+-----------+-----------------+ | |||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||||
+-----------+-----------+-----------------+ | |||||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||||
""" | |||||
x = PrettyTable() | |||||
try: | |||||
sz = os.get_terminal_size() | |||||
column = sz.columns | |||||
row = sz.lines | |||||
except OSError: | |||||
column = 144 | |||||
row = 11 | |||||
if type(dataset_or_ins).__name__ == "DataSet": | |||||
x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||||
c_size = len(x.field_names) | |||||
for ins in dataset_or_ins: | |||||
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||||
row -= 1 | |||||
if row < 0: | |||||
x.add_row(["..." for _ in range(c_size)]) | |||||
break | |||||
elif type(dataset_or_ins).__name__ == "Instance": | |||||
x.field_names = list(dataset_or_ins.fields.keys()) | |||||
c_size = len(x.field_names) | |||||
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||||
else: | |||||
raise Exception("only accept DataSet and Instance") | |||||
x.align = "l" | |||||
return x | |||||
class Option(dict): | |||||
r"""a dict can treat keys as attributes""" | |||||
def __getattr__(self, item): | |||||
try: | |||||
return self.__getitem__(item) | |||||
except KeyError: | |||||
raise AttributeError(item) | |||||
def __setattr__(self, key, value): | |||||
if key.startswith('__') and key.endswith('__'): | |||||
raise AttributeError(key) | |||||
self.__setitem__(key, value) | |||||
def __delattr__(self, item): | |||||
try: | |||||
self.pop(item) | |||||
except KeyError: | |||||
raise AttributeError(item) | |||||
def __getstate__(self): | |||||
return self | |||||
def __setstate__(self, state): | |||||
self.update(state) | |||||
def indice_collate_wrapper(func): | |||||
""" | |||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||||
:param func: 需要修饰的函数 | |||||
:return: | |||||
""" | |||||
def wrapper(tuple_data): | |||||
indice, ins_list = [], [] | |||||
for idx, ins in tuple_data: | |||||
indice.append(idx) | |||||
ins_list.append(ins) | |||||
return indice, func(ins_list) | |||||
return wrapper | |||||
_emitted_deprecation_warnings = set() | |||||
def deprecated(help_message: Optional[str] = None): | |||||
"""Decorator to mark a function as deprecated. | |||||
Args: | |||||
help_message (`Optional[str]`): An optional message to guide the user on how to | |||||
switch to non-deprecated usage of the library. | |||||
""" | |||||
def decorator(deprecated_function: Callable): | |||||
global _emitted_deprecation_warnings | |||||
warning_msg = ( | |||||
( | |||||
f"{deprecated_function.__name__} is deprecated and will be removed " | |||||
"in the next major version of datasets." | |||||
) | |||||
+ f" {help_message}" | |||||
if help_message | |||||
else "" | |||||
) | |||||
@wraps(deprecated_function) | |||||
def wrapper(*args, **kwargs): | |||||
func_hash = hash(deprecated_function) | |||||
if func_hash not in _emitted_deprecation_warnings: | |||||
warnings.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||||
_emitted_deprecation_warnings.add(func_hash) | |||||
return deprecated_function(*args, **kwargs) | |||||
wrapper._decorator_name_ = "deprecated" | |||||
return wrapper | |||||
return decorator | |||||
def seq_len_to_mask(seq_len, max_len=None): | |||||
r""" | |||||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||||
转变 1-d seq_len到2-d mask. | |||||
.. code-block:: | |||||
>>> seq_len = torch.arange(2, 16) | |||||
>>> mask = seq_len_to_mask(seq_len) | |||||
>>> print(mask.size()) | |||||
torch.Size([14, 15]) | |||||
>>> seq_len = np.arange(2, 16) | |||||
>>> mask = seq_len_to_mask(seq_len) | |||||
>>> print(mask.shape) | |||||
(14, 15) | |||||
>>> seq_len = torch.arange(2, 16) | |||||
>>> mask = seq_len_to_mask(seq_len, max_len=100) | |||||
>>>print(mask.size()) | |||||
torch.Size([14, 100]) | |||||
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,) | |||||
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 | |||||
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 | |||||
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 | |||||
""" | |||||
if isinstance(seq_len, np.ndarray): | |||||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||||
max_len = int(max_len) if max_len else int(seq_len.max()) | |||||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||||
else: | |||||
raise TypeError("Only support 1-d numpy.ndarray.") | |||||
return mask | |||||
def wait_to_success(fn, no=False): | |||||
while True: | |||||
sleep(0.01) | |||||
if (no and not fn()) or (not no and fn()): | |||||
break | |||||
# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||||
# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||||
def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if not path.exists(): | |||||
return | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
_recursive_rm(path) | |||||
wait_to_success(path.exists, no=True) | |||||
def _recursive_rm(path: Path): | |||||
if path.is_file() or path.is_symlink(): | |||||
if path.exists(): | |||||
try: | |||||
path.unlink() | |||||
except Exception: | |||||
pass | |||||
return | |||||
for sub_path in list(path.iterdir()): | |||||
_recursive_rm(sub_path) | |||||
path.rmdir() | |||||
def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||||
""" | |||||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | |||||
""" | |||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
path.mkdir(parents=True, exist_ok=True) | |||||
wait_to_success(path.exists) | |||||