From 7e40d984041ad3700d1598863906017887517e91 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 25 Apr 2022 16:47:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Etrain=5Finput=5Fmapping=20?= =?UTF-8?q?=E5=92=8C=20evaluate=5Finput=5Fmapping=20=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 50 ++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 8a888c2e..9a3c30d5 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -103,10 +103,12 @@ class Trainer(TrainerEventTrigger): value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; 注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时); + 如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。 :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; + 如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。 :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; @@ -133,6 +135,10 @@ class Trainer(TrainerEventTrigger): progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 + train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 + train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 + evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 + evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 """ self.model = model self.marker = marker @@ -147,8 +153,18 @@ class Trainer(TrainerEventTrigger): self.evaluate_dataloaders = evaluate_dataloaders self.optimizers = optimizers self.fp16 = fp16 - self.input_mapping = input_mapping - self.output_mapping = output_mapping + + train_input_mapping = kwargs.get('train_input_mapping', None) + train_output_mapping = kwargs.get('train_output_mapping', None) + evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None) + evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None) + + train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \ + _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, + evaluate_input_mapping, evaluate_output_mapping) + + self.input_mapping = train_input_mapping + self.output_mapping = train_output_mapping self.evaluate_fn = evaluate_fn self.batch_step_fn = batch_step_fn @@ -185,8 +201,8 @@ class Trainer(TrainerEventTrigger): callbacks=callbacks, metrics=metrics, evaluate_every=evaluate_every, - input_mapping=input_mapping, - output_mapping=output_mapping, + input_mapping=evaluate_input_mapping, + output_mapping=evaluate_output_mapping, model_wo_auto_param_call=model_wo_auto_param_call, accumulation_steps=accumulation_steps, fp16=fp16, @@ -854,6 +870,32 @@ class Trainer(TrainerEventTrigger): self._evaluate_dataloaders = evaluate_dataloaders +def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, + evaluate_input_mapping, evaluate_output_mapping): + if train_input_mapping is not None and input_mapping is not None: + raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.") + + if evaluate_input_mapping is not None and input_mapping is not None: + raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.") + + if train_output_mapping is not None and output_mapping is not None: + raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.") + + if evaluate_output_mapping is not None and output_mapping is not None: + raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.") + + if train_input_mapping is None: + train_input_mapping = input_mapping + if evaluate_input_mapping is None: + evaluate_input_mapping = input_mapping + + if train_output_mapping is None: + train_output_mapping = output_mapping + if evaluate_output_mapping is None: + evaluate_output_mapping = output_mapping + + return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping +