|
|
@@ -103,10 +103,12 @@ class Trainer(TrainerEventTrigger): |
|
|
|
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 |
|
|
|
类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; |
|
|
|
注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时); |
|
|
|
如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。 |
|
|
|
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 |
|
|
|
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, |
|
|
|
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; |
|
|
|
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; |
|
|
|
如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。 |
|
|
|
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; |
|
|
|
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 |
|
|
|
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; |
|
|
@@ -133,6 +135,10 @@ class Trainer(TrainerEventTrigger): |
|
|
|
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, |
|
|
|
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 |
|
|
|
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 |
|
|
|
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 |
|
|
|
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 |
|
|
|
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 |
|
|
|
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 |
|
|
|
""" |
|
|
|
self.model = model |
|
|
|
self.marker = marker |
|
|
@@ -147,8 +153,18 @@ class Trainer(TrainerEventTrigger): |
|
|
|
self.evaluate_dataloaders = evaluate_dataloaders |
|
|
|
self.optimizers = optimizers |
|
|
|
self.fp16 = fp16 |
|
|
|
self.input_mapping = input_mapping |
|
|
|
self.output_mapping = output_mapping |
|
|
|
|
|
|
|
train_input_mapping = kwargs.get('train_input_mapping', None) |
|
|
|
train_output_mapping = kwargs.get('train_output_mapping', None) |
|
|
|
evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None) |
|
|
|
evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None) |
|
|
|
|
|
|
|
train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \ |
|
|
|
_get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, |
|
|
|
evaluate_input_mapping, evaluate_output_mapping) |
|
|
|
|
|
|
|
self.input_mapping = train_input_mapping |
|
|
|
self.output_mapping = train_output_mapping |
|
|
|
self.evaluate_fn = evaluate_fn |
|
|
|
|
|
|
|
self.batch_step_fn = batch_step_fn |
|
|
@@ -185,8 +201,8 @@ class Trainer(TrainerEventTrigger): |
|
|
|
callbacks=callbacks, |
|
|
|
metrics=metrics, |
|
|
|
evaluate_every=evaluate_every, |
|
|
|
input_mapping=input_mapping, |
|
|
|
output_mapping=output_mapping, |
|
|
|
input_mapping=evaluate_input_mapping, |
|
|
|
output_mapping=evaluate_output_mapping, |
|
|
|
model_wo_auto_param_call=model_wo_auto_param_call, |
|
|
|
accumulation_steps=accumulation_steps, |
|
|
|
fp16=fp16, |
|
|
@@ -854,6 +870,32 @@ class Trainer(TrainerEventTrigger): |
|
|
|
self._evaluate_dataloaders = evaluate_dataloaders |
|
|
|
|
|
|
|
|
|
|
|
def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, |
|
|
|
evaluate_input_mapping, evaluate_output_mapping): |
|
|
|
if train_input_mapping is not None and input_mapping is not None: |
|
|
|
raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.") |
|
|
|
|
|
|
|
if evaluate_input_mapping is not None and input_mapping is not None: |
|
|
|
raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.") |
|
|
|
|
|
|
|
if train_output_mapping is not None and output_mapping is not None: |
|
|
|
raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.") |
|
|
|
|
|
|
|
if evaluate_output_mapping is not None and output_mapping is not None: |
|
|
|
raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.") |
|
|
|
|
|
|
|
if train_input_mapping is None: |
|
|
|
train_input_mapping = input_mapping |
|
|
|
if evaluate_input_mapping is None: |
|
|
|
evaluate_input_mapping = input_mapping |
|
|
|
|
|
|
|
if train_output_mapping is None: |
|
|
|
train_output_mapping = output_mapping |
|
|
|
if evaluate_output_mapping is None: |
|
|
|
evaluate_output_mapping = output_mapping |
|
|
|
|
|
|
|
return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|