Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
c2575ab357
2 changed files with 3 additions and 3 deletions
  1. +2
    -2
      fastNLP/core/controllers/trainer.py
  2. +1
    -1
      tests/core/controllers/test_trainer_w_evaluator_torch.py

+ 2
- 2
fastNLP/core/controllers/trainer.py View File

@@ -105,8 +105,8 @@ class Trainer(TrainerEventTrigger):
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
True,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有


+ 1
- 1
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -143,7 +143,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
accumulation_steps,
n_epochs=6,
):
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.1, larger_better=True)]
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,


Loading…
Cancel
Save