From 58a21c2b63873d5c9a8c054680428de8a16495a8 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 5 May 2022 23:57:51 +0800 Subject: [PATCH 1/2] =?UTF-8?q?dataset=E5=A4=9A=E8=BF=9B=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset/dataset.py | 176 +++++++++++++------------------- fastNLP/core/log/logger.py | 1 + 2 files changed, 73 insertions(+), 104 deletions(-) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index fa330854..98f23286 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -9,22 +9,18 @@ __all__ = [ import _pickle as pickle from copy import deepcopy from typing import Optional, List, Callable, Union, Dict, Any, Mapping -from functools import partial +from types import LambdaType +import sys +import time import numpy as np -from threading import Thread - -try: - import multiprocessing as mp -except: - pass from .field import FieldArray from .instance import Instance -from fastNLP.core.utils.utils import pretty_table_printer +from fastNLP.core.utils.utils import pretty_table_printer, deprecated from fastNLP.core.collators import Collator from fastNLP.core.utils.rich_progress import f_rich_progress -from fastNLP.core.log import logger +from ..log import logger class ApplyResultException(Exception): @@ -35,14 +31,13 @@ class ApplyResultException(Exception): def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, - pipe=None, desc: str = None) -> list: + desc: str = None) -> list: """ 对数据集进行处理封装函数,以便多进程使用 :param ds: 数据集 :param _apply_field: 需要处理数据集的field_name :param func: 用户自定义的func - :param pipe: 管道 :param desc: 进度条的描述字符 :param show_progress_bar: 是否展示子进程进度条 :return: @@ -60,8 +55,6 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s results.append(func(ins[_apply_field])) else: results.append(func(ins)) - if pipe is not None: - pipe.send([idx + 1]) if show_progress_bar: f_rich_progress.update(pg_main, advance=1) @@ -75,31 +68,36 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s return results -def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None: +def _multi_proc(ds, _apply_field, func, counter, queue): """ - 多进程下显示主进程的进度条 + 对数据集进行处理封装函数,以便多进程使用 - :param parent: 进程管道 - :param total_len: 数据集总长度 - :param desc: 进度条描述符 - :param show_progress_bar: 是否展示进度条 + :param ds: 数据集 + :param _apply_field: 需要处理数据集的field_name + :param func: 用户自定义的func + :param counter: 计数器 + :param queue: 多进程时,将结果输入到这个 queue 中 :return: """ - desc = desc if desc else "Main" - - main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar) - # pb_main = tqdm(total=total_len, desc=desc, position=0) - nums = 0 - while True: - msg = parent.recv()[0] - if msg is not None: - f_rich_progress.update(main_pro, advance=1) - nums += 1 - - if nums == total_len: - break - f_rich_progress.destroy_task(main_pro) - # pb_main.close() + idx = -1 + import contextlib + with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁 + logger.set_stdout(stdout='raw') + results = [] + try: + for idx, ins in enumerate(ds): + if _apply_field is not None: + res = func(ins[_apply_field]) + else: + res = func(ins) + results.append(res) + with counter.get_lock(): + counter.value += 1 + except BaseException as e: + if idx != -1: + logger.error("Exception happens at the `{}`th instance.".format(idx)) + raise e + queue.put(pickle.dumps(results)) class DataSet: @@ -114,7 +112,7 @@ class DataSet: 每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 """ self.field_arrays = {} - self._collator = Collator(backend="numpy") + self._collator = Collator() if data is not None: if isinstance(data, Dict): length_set = set() @@ -127,7 +125,6 @@ class DataSet: for ins in data: assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) self.append(ins) - else: raise ValueError("data only be dict or list type.") @@ -263,7 +260,7 @@ class DataSet: try: self.field_arrays[name].append(field) except Exception as e: - print(f"Cannot append to field:{name}.") + logger.error(f"Cannot append to field:{name}.") raise e def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: @@ -469,9 +466,7 @@ class DataSet: except Exception as e: if idx != -1: - if isinstance(e, ApplyResultException): - print(e.msg) - print("Exception happens at the `{}`th instance.".format(idx + 1)) + logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) raise e if modify_fields is True: @@ -490,18 +485,19 @@ class DataSet: :param show_progress_bar: 是否展示progress进度条,默认为展示 :param progress_desc: 进度条的描述字符,默认为'Main """ + if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "": + raise ("Lambda function does not support multiple processes, please set `num_proc=0`.") + if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'): + raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`") - if num_proc == 0: + if num_proc < 2: results = _apply_single(ds=self, _apply_field=_apply_field, func=func, desc=progress_desc, show_progress_bar=show_progress_bar) else: # TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 - results = [] - if num_proc > len(self): - num_proc = len(self) - print( - f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}." - ) + import multiprocessing as mp + ctx = mp.get_context('fork') + num_proc = min(num_proc, len(self)) # 划分数据集 shard_len = len(self) // num_proc num_left_sample = len(self) % num_proc @@ -511,24 +507,32 @@ class DataSet: end = shard_len + int(_i= 0, "num_proc must >= 0" idx = -1 @@ -577,9 +580,7 @@ class DataSet: except Exception as e: if idx != -1: - if isinstance(e, ApplyResultException): - print(e.msg) - print("Exception happens at the `{}`th instance.".format(idx + 1)) + logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) raise e if modify_fields is True: @@ -665,8 +666,7 @@ class DataSet: np.random.shuffle(all_indices) split = int(ratio * len(self)) if split == 0: - error_msg = f'Dev DataSet has {split} instance after split.' - print(error_msg) + error_msg = f'Dev DataSet has `{split}` instance after split.' raise IndexError(error_msg) dev_indices = all_indices[:split] train_indices = all_indices[split:] @@ -776,35 +776,3 @@ class DataSet: if self._collator is None: self._collator = Collator() return self._collator - - -if __name__ == '__main__': - # from fastNLP import DataSet - - # if __name__=='__main__': - # data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) - # data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) - - import multiprocess as mp - # from fastNLP.core.dataset.dataset import _apply_single, _progress_bar - from functools import partial - from threading import Thread - - shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), - DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] - parent, chid = mp.Pipe() - partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), - pipe=chid, show_progress_bar=False) - thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) - thread.start() - pool = mp.Pool(processes=6) - pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) - for proc_id, ds in enumerate(shard_data)] - pool.close() - pool.join() - thread.join() - results = [] - for async_result in pool_outs: - data = async_result.get() - results.extend(data) - print(results) diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 179755e2..6610b30c 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -302,6 +302,7 @@ def _set_stdout_handler(_logger, stdout='raw', level='INFO'): break if stream_handler is not None: _logger.removeHandler(stream_handler) + del stream_handler # Stream Handler if stdout == 'raw': From 6f402b9cddb3eda5f740aad97a94a04c58a70564 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 6 May 2022 00:50:37 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BC=98=E5=8C=96=E9=83=A8=E5=88=86warning?= =?UTF-8?q?=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/has_monitor_callback.py | 6 +++--- fastNLP/core/callbacks/topk_saver.py | 2 +- fastNLP/core/controllers/evaluator.py | 2 +- fastNLP/core/controllers/trainer.py | 4 ++-- .../drivers/paddle_driver/initialize_paddle_driver.py | 4 ++-- fastNLP/core/drivers/paddle_driver/paddle_driver.py | 8 ++++---- fastNLP/core/drivers/paddle_driver/utils.py | 2 +- fastNLP/core/drivers/torch_driver/torch_driver.py | 4 ++-- fastNLP/core/drivers/torch_driver/utils.py | 4 ++-- fastNLP/core/log/logger.py | 7 +++++-- fastNLP/core/utils/cache_results.py | 3 ++- fastNLP/core/utils/utils.py | 2 +- 12 files changed, 26 insertions(+), 22 deletions(-) diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index c5c5edde..8e5eb0aa 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -78,11 +78,11 @@ class MonitorUtility: return monitor_value # 第一次运行 if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: - logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " - f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") + logger.rank_zero_warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as " + f"{list(results.keys())}), we use the `{use_monitor}` as the monitor.", once=True) # 检测到此次和上次不同。 elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: - logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " + logger.rank_zero_warning(f"Change of monitor detected for `{self.__class__.__name__}`. " f"The expected monitor is:`{self.monitor}`, last used monitor is:" f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " f"customized monitor function when the evaluation results are varying between validation.") diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 25e66cb9..bd630836 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -32,7 +32,7 @@ class Saver: :param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 """ if folder is None: - logger.warning( + logger.rank_zero_warning( "Parameter `folder` is None, and we will use the current work directory to find and load your model.") folder = Path.cwd() folder = Path(folder) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 4dba8a4c..70c7fbd0 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -221,7 +221,7 @@ class Evaluator: @evaluate_batch_loop.setter def evaluate_batch_loop(self, loop: Loop): if self.evaluate_batch_step_fn is not None: - logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " + logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " "when the `evaluate_batch_loop` is also customized.") self._evaluate_batch_loop = loop diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index f720fe5b..8fd3c65e 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -305,7 +305,7 @@ class Trainer(TrainerEventTrigger): else: if self.driver.is_distributed(): if catch_KeyboardInterrupt: - logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " + logger.rank_zero_warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " "driver. And we are gonna to set it to False.") catch_KeyboardInterrupt = False @@ -535,7 +535,7 @@ class Trainer(TrainerEventTrigger): _not_called_callback_fns.append(each_callback_fn) if check_mode: - logger.warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " + logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " f"callback_fns: {_not_called_callback_fns}, but it seems that" "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index 46f51b9c..66da8cf1 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -69,7 +69,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ if not isinstance(device, List): return PaddleSingleDriver(model, device, **kwargs) else: - logger.warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" + logger.rank_zero_warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" "`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" "`driver` as `PaddleFleetDriver`.") return PaddleFleetDriver(model, device, **kwargs) @@ -77,7 +77,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ if not isinstance(device, List): if device == "cpu": raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") - logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" + logger.rank_zero_warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " "choose `paddle` driver.") return PaddleFleetDriver(model, [device], **kwargs) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 48ff9de1..795cb7bf 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -72,7 +72,7 @@ class PaddleDriver(Driver): :param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 """ if set_to_none: - logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") + logger.rank_zero_warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") for optimizer in self.optimizers: optimizer.clear_grad() @@ -256,7 +256,7 @@ class PaddleDriver(Driver): if dataloader_args.batch_size is not None: num_consumed_batches = num_consumed_batches * dataloader_args.batch_size else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " "it may cause missing some samples when reload.") num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] @@ -266,7 +266,7 @@ class PaddleDriver(Driver): sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ * num_consumed_batches else: - logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " "it may cause missing some samples when reload.") else: raise RuntimeError( @@ -329,7 +329,7 @@ class PaddleDriver(Driver): self.grad_scaler.load_state_dict(grad_scaler_state_dict) logger.debug("Load grad_scaler state dict...") elif not isinstance(self.grad_scaler, DummyGradScaler): - logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " f"the training process may be unstable.") # 4. 恢复 sampler 的状态; diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 60d243e7..6362193e 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -51,7 +51,7 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): - logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for " + logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for " "you.") # rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 8c332251..382ac2c1 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -197,7 +197,7 @@ class TorchDriver(Driver): if dataloader_args.batch_size is not None: num_consumed_batches = num_consumed_batches * dataloader_args.batch_size else: # 有可能 batch_size 为 None,就只有损失精度了 - logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " "it may cause missing some samples when reload.") num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] @@ -207,7 +207,7 @@ class TorchDriver(Driver): sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ * num_consumed_batches else: - logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " "it may cause missing some samples when reload.") states['sampler_states'] = sampler_states diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 941e4445..d756cf77 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -60,7 +60,7 @@ def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): - logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") + logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") seed = _select_seed_randomly(min_seed_value, max_seed_value) @@ -162,7 +162,7 @@ def _build_fp16_env(dummy=False): if not torch.cuda.is_available(): raise RuntimeError("No cuda") if torch.cuda.get_device_capability(0)[0] < 7: - logger.warning( + logger.rank_zero_warning( "NOTE: your device does NOT support faster training with fp16, " "please switch to FP32 which is likely to be faster" ) diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 6610b30c..eea54f36 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -124,18 +124,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): self._log(WARNING, msg, args, **kwargs) self._warning_msgs.add(msg) - def rank_zero_warning(self, msg, *args, **kwargs): + def rank_zero_warning(self, msg, *args, once=False, **kwargs): """ 只在 rank 0 上 warning 。 :param msg: :param args: + :param once: 是否只 warning 一次 :param kwargs: :return: """ if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': + if once and msg in self._warning_msgs: + return if self.isEnabledFor(WARNING): - # kwargs = self._add_rank_info(kwargs) + kwargs = self._add_rank_info(kwargs) self._log(WARNING, msg, args, **kwargs) def warn(self, msg, *args, **kwargs): diff --git a/fastNLP/core/utils/cache_results.py b/fastNLP/core/utils/cache_results.py index f8d34bc9..cde4a51e 100644 --- a/fastNLP/core/utils/cache_results.py +++ b/fastNLP/core/utils/cache_results.py @@ -15,6 +15,7 @@ __all__ = [ from fastNLP.core.log.logger import logger from fastNLP.core.log.highlighter import ColorHighlighter +from .utils import _get_fun_msg class FuncCallVisitor(ast.NodeVisitor): @@ -306,7 +307,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec if verbose == 1: logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) if check_hash and old_hash_code != new_hash_code: - logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " + logger.warning(f"The function {_get_fun_msg(func)} is different from its last cache (Save on {save_time}). The " f"difference may caused by the sourcecode change.", extra={'highlighter': ColorHighlighter('red')}) refresh_flag = False diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 93f38e2a..c894131d 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -239,7 +239,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable): fn_arg_names = get_fn_arg_names(fn) for arg_name, arg_value in user_params.items(): if arg_name not in fn_arg_names: - logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") + logger.rank_zero_warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") return user_params