@@ -78,11 +78,11 @@ class MonitorUtility: | |||
return monitor_value | |||
# 第一次运行 | |||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||
logger.rank_zero_warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as " | |||
f"{list(results.keys())}), we use the `{use_monitor}` as the monitor.", once=True) | |||
# 检测到此次和上次不同。 | |||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
logger.rank_zero_warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||
f"customized monitor function when the evaluation results are varying between validation.") | |||
@@ -33,7 +33,7 @@ class Saver: | |||
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | |||
""" | |||
if folder is None: | |||
logger.warning( | |||
logger.rank_zero_warning( | |||
"Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||
folder = Path.cwd() | |||
folder = Path(folder) | |||
@@ -221,7 +221,7 @@ class Evaluator: | |||
@evaluate_batch_loop.setter | |||
def evaluate_batch_loop(self, loop: Loop): | |||
if self.evaluate_batch_step_fn is not None: | |||
logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||
logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||
"when the `evaluate_batch_loop` is also customized.") | |||
self._evaluate_batch_loop = loop | |||
@@ -305,7 +305,7 @@ class Trainer(TrainerEventTrigger): | |||
else: | |||
if self.driver.is_distributed(): | |||
if catch_KeyboardInterrupt: | |||
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||
logger.rank_zero_warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||
"driver. And we are gonna to set it to False.") | |||
catch_KeyboardInterrupt = False | |||
@@ -535,7 +535,7 @@ class Trainer(TrainerEventTrigger): | |||
_not_called_callback_fns.append(each_callback_fn) | |||
if check_mode: | |||
logger.warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||
f"callback_fns: {_not_called_callback_fns}, but it seems that" | |||
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||
# 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass | |||
@@ -9,22 +9,18 @@ __all__ = [ | |||
import _pickle as pickle | |||
from copy import deepcopy | |||
from typing import Optional, List, Callable, Union, Dict, Any, Mapping | |||
from functools import partial | |||
from types import LambdaType | |||
import sys | |||
import time | |||
import numpy as np | |||
from threading import Thread | |||
try: | |||
import multiprocessing as mp | |||
except: | |||
pass | |||
from .field import FieldArray | |||
from .instance import Instance | |||
from fastNLP.core.utils.utils import pretty_table_printer | |||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.rich_progress import f_rich_progress | |||
from fastNLP.core.log import logger | |||
from ..log import logger | |||
class ApplyResultException(Exception): | |||
@@ -35,14 +31,13 @@ class ApplyResultException(Exception): | |||
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, | |||
pipe=None, desc: str = None) -> list: | |||
desc: str = None) -> list: | |||
""" | |||
对数据集进行处理封装函数,以便多进程使用 | |||
:param ds: 数据集 | |||
:param _apply_field: 需要处理数据集的field_name | |||
:param func: 用户自定义的func | |||
:param pipe: 管道 | |||
:param desc: 进度条的描述字符 | |||
:param show_progress_bar: 是否展示子进程进度条 | |||
:return: | |||
@@ -60,8 +55,6 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||
results.append(func(ins[_apply_field])) | |||
else: | |||
results.append(func(ins)) | |||
if pipe is not None: | |||
pipe.send([idx + 1]) | |||
if show_progress_bar: | |||
f_rich_progress.update(pg_main, advance=1) | |||
@@ -75,31 +68,36 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||
return results | |||
def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None: | |||
def _multi_proc(ds, _apply_field, func, counter, queue): | |||
""" | |||
多进程下显示主进程的进度条 | |||
对数据集进行处理封装函数,以便多进程使用 | |||
:param parent: 进程管道 | |||
:param total_len: 数据集总长度 | |||
:param desc: 进度条描述符 | |||
:param show_progress_bar: 是否展示进度条 | |||
:param ds: 数据集 | |||
:param _apply_field: 需要处理数据集的field_name | |||
:param func: 用户自定义的func | |||
:param counter: 计数器 | |||
:param queue: 多进程时,将结果输入到这个 queue 中 | |||
:return: | |||
""" | |||
desc = desc if desc else "Main" | |||
main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar) | |||
# pb_main = tqdm(total=total_len, desc=desc, position=0) | |||
nums = 0 | |||
while True: | |||
msg = parent.recv()[0] | |||
if msg is not None: | |||
f_rich_progress.update(main_pro, advance=1) | |||
nums += 1 | |||
if nums == total_len: | |||
break | |||
f_rich_progress.destroy_task(main_pro) | |||
# pb_main.close() | |||
idx = -1 | |||
import contextlib | |||
with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁 | |||
logger.set_stdout(stdout='raw') | |||
results = [] | |||
try: | |||
for idx, ins in enumerate(ds): | |||
if _apply_field is not None: | |||
res = func(ins[_apply_field]) | |||
else: | |||
res = func(ins) | |||
results.append(res) | |||
with counter.get_lock(): | |||
counter.value += 1 | |||
except BaseException as e: | |||
if idx != -1: | |||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
queue.put(pickle.dumps(results)) | |||
class DataSet: | |||
@@ -114,7 +112,7 @@ class DataSet: | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
""" | |||
self.field_arrays = {} | |||
self._collator = Collator(backend="numpy") | |||
self._collator = Collator() | |||
if data is not None: | |||
if isinstance(data, Dict): | |||
length_set = set() | |||
@@ -127,7 +125,6 @@ class DataSet: | |||
for ins in data: | |||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||
self.append(ins) | |||
else: | |||
raise ValueError("data only be dict or list type.") | |||
@@ -263,7 +260,7 @@ class DataSet: | |||
try: | |||
self.field_arrays[name].append(field) | |||
except Exception as e: | |||
print(f"Cannot append to field:{name}.") | |||
logger.error(f"Cannot append to field:{name}.") | |||
raise e | |||
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | |||
@@ -469,9 +466,7 @@ class DataSet: | |||
except Exception as e: | |||
if idx != -1: | |||
if isinstance(e, ApplyResultException): | |||
print(e.msg) | |||
print("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
raise e | |||
if modify_fields is True: | |||
@@ -490,18 +485,19 @@ class DataSet: | |||
:param show_progress_bar: 是否展示progress进度条,默认为展示 | |||
:param progress_desc: 进度条的描述字符,默认为'Main | |||
""" | |||
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>": | |||
raise ("Lambda function does not support multiple processes, please set `num_proc=0`.") | |||
if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'): | |||
raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`") | |||
if num_proc == 0: | |||
if num_proc < 2: | |||
results = _apply_single(ds=self, _apply_field=_apply_field, func=func, | |||
desc=progress_desc, show_progress_bar=show_progress_bar) | |||
else: | |||
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 | |||
results = [] | |||
if num_proc > len(self): | |||
num_proc = len(self) | |||
print( | |||
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}." | |||
) | |||
import multiprocessing as mp | |||
ctx = mp.get_context('fork') | |||
num_proc = min(num_proc, len(self)) | |||
# 划分数据集 | |||
shard_len = len(self) // num_proc | |||
num_left_sample = len(self) % num_proc | |||
@@ -511,24 +507,32 @@ class DataSet: | |||
end = shard_len + int(_i<num_left_sample) + start | |||
shard_data.append(self[start:end]) | |||
start = end | |||
# 配置管道,线程以实现 main progress 能够实时更新。 | |||
parent, child = mp.Pipe() | |||
main_thread = Thread(target=_progress_bar, args=(parent, len(self), progress_desc, | |||
show_progress_bar)) | |||
partial_single_map = partial(_apply_single, _apply_field=_apply_field, func=func, | |||
pipe=child, show_progress_bar=False) | |||
# 开启进程池,线程 | |||
main_thread.start() | |||
pool = mp.Pool(processes=num_proc) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
for proc_id, ds in enumerate(shard_data)] | |||
pool.close() | |||
pool.join() | |||
main_thread.join() | |||
for async_result in pool_outs: | |||
data = async_result.get() | |||
results.extend(data) | |||
# 配置共享参数,线程以实现 main progress 能够实时更新。 | |||
counter = ctx.Value('i', 0, lock=True) | |||
pool = [] | |||
queues = [] | |||
results = [] | |||
for i in range(num_proc): | |||
queue = ctx.SimpleQueue() | |||
proc = ctx.Process(target=_multi_proc, args=(shard_data[i], _apply_field, func, counter, queue)) | |||
proc.start() | |||
pool.append(proc) | |||
queues.append(queue) | |||
total_len = len(self) | |||
task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar) | |||
last_count = -1 | |||
while counter.value < total_len or last_count == -1: | |||
while counter.value == last_count: | |||
time.sleep(0.1) | |||
advance = counter.value - last_count | |||
last_count = counter.value | |||
f_rich_progress.update(task_id, advance=advance, refresh=True) | |||
for idx, proc in enumerate(pool): | |||
results.extend(pickle.loads(queues[idx].get())) | |||
proc.join() | |||
f_rich_progress.destroy_task(task_id) | |||
return results | |||
def apply_more(self, func: Callable = None, modify_fields: bool = True, | |||
@@ -552,8 +556,7 @@ class DataSet: | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
# 返回 dict , 检查是否一直相同 | |||
assert callable(func), "The func you provide is not callable." | |||
assert callable(func), "The func is not callable." | |||
assert len(self) != 0, "Null DataSet cannot use apply()." | |||
assert num_proc >= 0, "num_proc must >= 0" | |||
idx = -1 | |||
@@ -577,9 +580,7 @@ class DataSet: | |||
except Exception as e: | |||
if idx != -1: | |||
if isinstance(e, ApplyResultException): | |||
print(e.msg) | |||
print("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
raise e | |||
if modify_fields is True: | |||
@@ -665,8 +666,7 @@ class DataSet: | |||
np.random.shuffle(all_indices) | |||
split = int(ratio * len(self)) | |||
if split == 0: | |||
error_msg = f'Dev DataSet has {split} instance after split.' | |||
print(error_msg) | |||
error_msg = f'Dev DataSet has `{split}` instance after split.' | |||
raise IndexError(error_msg) | |||
dev_indices = all_indices[:split] | |||
train_indices = all_indices[split:] | |||
@@ -776,35 +776,3 @@ class DataSet: | |||
if self._collator is None: | |||
self._collator = Collator() | |||
return self._collator | |||
if __name__ == '__main__': | |||
# from fastNLP import DataSet | |||
# if __name__=='__main__': | |||
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) | |||
import multiprocess as mp | |||
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar | |||
from functools import partial | |||
from threading import Thread | |||
shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), | |||
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] | |||
parent, chid = mp.Pipe() | |||
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), | |||
pipe=chid, show_progress_bar=False) | |||
thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) | |||
thread.start() | |||
pool = mp.Pool(processes=6) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
for proc_id, ds in enumerate(shard_data)] | |||
pool.close() | |||
pool.join() | |||
thread.join() | |||
results = [] | |||
for async_result in pool_outs: | |||
data = async_result.get() | |||
results.extend(data) | |||
print(results) |
@@ -69,7 +69,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
if not isinstance(device, List): | |||
return PaddleSingleDriver(model, device, **kwargs) | |||
else: | |||
logger.warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||
logger.rank_zero_warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||
"`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" | |||
"`driver` as `PaddleFleetDriver`.") | |||
return PaddleFleetDriver(model, device, **kwargs) | |||
@@ -77,7 +77,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
if not isinstance(device, List): | |||
if device == "cpu": | |||
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") | |||
logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||
logger.rank_zero_warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | |||
"choose `paddle` driver.") | |||
return PaddleFleetDriver(model, [device], **kwargs) | |||
@@ -72,7 +72,7 @@ class PaddleDriver(Driver): | |||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | |||
""" | |||
if set_to_none: | |||
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||
logger.rank_zero_warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||
for optimizer in self.optimizers: | |||
optimizer.clear_grad() | |||
@@ -233,7 +233,7 @@ class PaddleDriver(Driver): | |||
if dataloader_args.batch_size is not None: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
@@ -243,7 +243,7 @@ class PaddleDriver(Driver): | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
else: | |||
raise RuntimeError( | |||
@@ -306,7 +306,7 @@ class PaddleDriver(Driver): | |||
self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||
logger.debug("Load grad_scaler state dict...") | |||
elif not isinstance(self.grad_scaler, DummyGradScaler): | |||
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||
f"the training process may be unstable.") | |||
# 4. 恢复 sampler 的状态; | |||
@@ -51,7 +51,7 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||
seed = int(seed) | |||
if not (min_seed_value <= seed <= max_seed_value): | |||
logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for " | |||
logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for " | |||
"you.") | |||
# rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") | |||
@@ -197,7 +197,7 @@ class TorchDriver(Driver): | |||
if dataloader_args.batch_size is not None: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
@@ -207,7 +207,7 @@ class TorchDriver(Driver): | |||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
* num_consumed_batches | |||
else: | |||
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
"it may cause missing some samples when reload.") | |||
states['sampler_states'] = sampler_states | |||
@@ -60,7 +60,7 @@ def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||
seed = int(seed) | |||
if not (min_seed_value <= seed <= max_seed_value): | |||
logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") | |||
logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") | |||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||
@@ -162,7 +162,7 @@ def _build_fp16_env(dummy=False): | |||
if not torch.cuda.is_available(): | |||
raise RuntimeError("No cuda") | |||
if torch.cuda.get_device_capability(0)[0] < 7: | |||
logger.warning( | |||
logger.rank_zero_warning( | |||
"NOTE: your device does NOT support faster training with fp16, " | |||
"please switch to FP32 which is likely to be faster" | |||
) | |||
@@ -124,18 +124,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
self._log(WARNING, msg, args, **kwargs) | |||
self._warning_msgs.add(msg) | |||
def rank_zero_warning(self, msg, *args, **kwargs): | |||
def rank_zero_warning(self, msg, *args, once=False, **kwargs): | |||
""" | |||
只在 rank 0 上 warning 。 | |||
:param msg: | |||
:param args: | |||
:param once: 是否只 warning 一次 | |||
:param kwargs: | |||
:return: | |||
""" | |||
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | |||
if once and msg in self._warning_msgs: | |||
return | |||
if self.isEnabledFor(WARNING): | |||
# kwargs = self._add_rank_info(kwargs) | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
def warn(self, msg, *args, **kwargs): | |||
@@ -302,6 +305,7 @@ def _set_stdout_handler(_logger, stdout='raw', level='INFO'): | |||
break | |||
if stream_handler is not None: | |||
_logger.removeHandler(stream_handler) | |||
del stream_handler | |||
# Stream Handler | |||
if stdout == 'raw': | |||
@@ -15,6 +15,7 @@ __all__ = [ | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.core.log.highlighter import ColorHighlighter | |||
from .utils import _get_fun_msg | |||
class FuncCallVisitor(ast.NodeVisitor): | |||
@@ -306,7 +307,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec | |||
if verbose == 1: | |||
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | |||
if check_hash and old_hash_code != new_hash_code: | |||
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | |||
logger.warning(f"The function {_get_fun_msg(func)} is different from its last cache (Save on {save_time}). The " | |||
f"difference may caused by the sourcecode change.", | |||
extra={'highlighter': ColorHighlighter('red')}) | |||
refresh_flag = False | |||
@@ -239,7 +239,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable): | |||
fn_arg_names = get_fn_arg_names(fn) | |||
for arg_name, arg_value in user_params.items(): | |||
if arg_name not in fn_arg_names: | |||
logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||
logger.rank_zero_warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||
return user_params | |||