Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
a2956b697e
40 changed files with 1311 additions and 1456 deletions
  1. +5
    -1
      fastNLP/core/callbacks/__init__.py
  2. +2
    -138
      fastNLP/core/callbacks/callback.py
  3. +4
    -3
      fastNLP/core/callbacks/checkpoint_callback.py
  4. +1
    -1
      fastNLP/core/callbacks/early_stop_callback.py
  5. +189
    -0
      fastNLP/core/callbacks/has_monitor_callback.py
  6. +2
    -5
      fastNLP/core/callbacks/load_best_model_callback.py
  7. +1
    -1
      fastNLP/core/callbacks/progress_callback.py
  8. +8
    -0
      fastNLP/core/callbacks/torch_callbacks/__init__.py
  9. +52
    -0
      fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py
  10. +58
    -0
      fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
  11. +2
    -2
      fastNLP/core/controllers/trainer.py
  12. +1
    -1
      fastNLP/core/drivers/driver.py
  13. +5
    -8
      fastNLP/core/drivers/jittor_driver/mpi.py
  14. +19
    -47
      fastNLP/core/drivers/jittor_driver/single_device.py
  15. +376
    -0
      fastNLP/core/drivers/paddle_driver/dist_utils.py
  16. +86
    -56
      fastNLP/core/drivers/paddle_driver/fleet.py
  17. +19
    -32
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  18. +48
    -91
      fastNLP/core/drivers/paddle_driver/single_device.py
  19. +5
    -75
      fastNLP/core/drivers/paddle_driver/utils.py
  20. +1
    -1
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  21. +7
    -1
      fastNLP/core/drivers/torch_driver/single_device.py
  22. +24
    -8
      fastNLP/core/drivers/torch_driver/torch_driver.py
  23. +18
    -43
      fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py
  24. +0
    -6
      fastNLP/core/samplers/__init__.py
  25. +0
    -728
      fastNLP/core/samplers/sampler.py
  26. +31
    -7
      fastNLP/core/utils/utils.py
  27. +0
    -1
      fastNLP/envs/set_backend.py
  28. +1
    -2
      fastNLP/io/loader/conll.py
  29. +1
    -1
      fastNLP/modules/mix_modules/mix_module.py
  30. +0
    -0
      tests/core/callbacks/torch_callbacks/__init__.py
  31. +41
    -0
      tests/core/callbacks/torch_callbacks/test_torch_grad_clip_callback.py
  32. +34
    -0
      tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py
  33. +27
    -53
      tests/core/controllers/test_trainer_paddle.py
  34. +39
    -28
      tests/core/drivers/paddle_driver/test_fleet.py
  35. +133
    -82
      tests/core/drivers/paddle_driver/test_single_device.py
  36. +0
    -31
      tests/core/samplers/test_sampler.py
  37. +1
    -1
      tests/envs/test_set_backend.py
  38. +1
    -1
      tests/helpers/callbacks/helper_callbacks.py
  39. +68
    -0
      tests/helpers/callbacks/prepare_trainer_args_for_torch_test.py
  40. +1
    -1
      tests/helpers/models/paddle_model.py

+ 5
- 1
fastNLP/core/callbacks/__init__.py View File

@@ -11,7 +11,10 @@ __all__ = [
'RichCallback',
"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback"
"EarlyStopCallback",

"TorchWarmupCallback",
"TorchGradClipCallback"
]


@@ -23,4 +26,5 @@ from .progress_callback import choose_progress_callback, ProgressCallback, RichC
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback
from .torch_callbacks import *


+ 2
- 138
fastNLP/core/callbacks/callback.py View File

@@ -1,16 +1,12 @@
from typing import Union, Callable, Dict, Optional, Any
from abc import ABC

__all__ = [
'Callback',
]

from typing import Union, Callable, Dict, Optional, Any

from .callback_events import Events, EventsList, Filter
from .utils import _get_monitor_value
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils.utils import _check_valid_parameters_number


class Callback:
@@ -278,135 +274,3 @@ class _CallbackWrapper(Callback):
@property
def callback_name(self):
return self.fn.__name__


class CanItemDataType(ABC):
"""
检测可以进行传输的对象。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented


class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor

def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor

def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。

:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")

def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。

:param results:
:return: 如果为 None ,表明此次没有找到合适的monitor
"""
if len(results)==0:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")

self._real_monitor = use_monitor
return monitor_value

def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
"""
检测 monitor_value 是否是更好的

:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better

def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中,是否monitor_value1的结果更好。

:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better

@property
def monitor_name(self):
"""
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。

:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name

+ 4
- 3
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -10,9 +10,9 @@ from copy import deepcopy


import fastNLP
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir


@@ -217,7 +217,8 @@ class CheckpointCallback(HasMonitorCallback):
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
synchronize_mkdir(folder)
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建
synchronize_mkdir(folder)
_fn = getattr(trainer, self.save_fn_name)
_fn(
folder=folder,


+ 1
- 1
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -4,7 +4,7 @@ __all__ = [

from typing import Dict, Union, Callable

from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils.exceptions import EarlyStopException




+ 189
- 0
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -0,0 +1,189 @@
__all__ = [
'HasMonitorCallback',
'ExecuteOnceBetterMonitor'
]

from typing import Dict, Union, Any
from abc import ABC

from fastNLP.core.utils import apply_to_collection
from fastNLP.core.callbacks import Callback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger
from fastNLP.core.utils.utils import _check_valid_parameters_number


class CanItemDataType(ABC):
"""
检测可以进行传输的对象。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented



class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 是否时越大越好
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。
"""
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor

def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor

def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。

:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")

def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。

:param results:
:return: 如果为 None ,表明此次没有找到合适的monitor
"""
if len(results)==0:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")

self._real_monitor = use_monitor
return monitor_value

def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
"""
检测 monitor_value 是否是更好的

:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better

def is_better_results(self, results, keep_if_better=True):
"""
检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。

:param results: on_valid_ends() 接口中传入的 evaluation 结果。
:param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。
:return:
"""
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return False
return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better)

def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中,是否monitor_value1的结果更好。

:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better

@property
def monitor_name(self):
"""
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。

:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name


class ExecuteOnceBetterMonitor(HasMonitorCallback):
def __init__(self, monitor, larger_better, execute_fn):
"""
当监控的 monitor 结果更好的时候,调用 execute_fn 函数。

:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 是否时越大越好
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。
"""
super().__init__(monitor, larger_better, must_have_monitor=True)
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
self.execute_fn = execute_fn()

def on_validate_end(self, trainer, results):
if self.is_better_results(results):
self.execute_fn()

+ 2
- 5
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -4,7 +4,7 @@ __all__ = [

import os
from typing import Optional, Callable, Union
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from io import BytesIO
import shutil

@@ -80,10 +80,7 @@ class LoadBestModelCallback(HasMonitorCallback):
self.get_monitor_value(sanity_check_res)

def on_validate_end(self, trainer, results):
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.is_better_results(results, keep_if_better=True):
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_save_fn=self.model_save_fn)


+ 1
- 1
fastNLP/core/callbacks/progress_callback.py View File

@@ -8,7 +8,7 @@ __all__ = [
'RichCallback'
]

from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger


+ 8
- 0
fastNLP/core/callbacks/torch_callbacks/__init__.py View File

@@ -0,0 +1,8 @@
__all__ = [
'TorchWarmupCallback',
'TorchGradClipCallback'
]


from .torch_lr_sched_callback import TorchWarmupCallback
from .torch_grad_clip_callback import TorchGradClipCallback

+ 52
- 0
fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py View File

@@ -0,0 +1,52 @@
__all__ = [
'TorchGradClipCallback'
]
from ..callback import Callback


class TorchGradClipCallback(Callback):
def __init__(self, clip_value=1, clip_type='norm', parameters=None):
r"""
在每次 optimizer update 之前将 parameter 进行 clip

:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
:param str clip_type: 支持'norm', 'value'
两种::

1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]

2 'value', 将gradient限制在[-clip_value, clip_value],
小于-clip_value的gradient被赋值为-clip_value;
大于clip_value的gradient被赋值为clip_value.
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。
"""
super().__init__()

from torch import nn
if clip_type == 'norm':
self.clip_fun = nn.utils.clip_grad_norm_
elif clip_type == 'value':
self.clip_fun = nn.utils.clip_grad_value_
else:
raise ValueError("Only supports `norm` or `value` right now.")
if parameters is not None:
self.parameters = list(parameters)
else:
self.parameters = None
self.clip_value = clip_value

def on_after_trainer_initialized(self, trainer, driver):
assert 'torch' in driver.__class__.__name__.lower(), f"Callback:{self.__class__.__name__} only supports torch " \
f"related drivers for now."
parameters = []
for optimizer in trainer.driver.optimizers:
for param_group in optimizer.param_groups:
parameters.extend(param_group['params'])
self.parameters = parameters
assert len(self.parameters), "There is no parameters need to be clipped."

def on_before_optimizers_step(self, trainer, optimizers):
for optimizer in trainer.driver.optimizers:
trainer.driver.grad_scaler.unscale_(optimizer)
self.clip_fun(self.parameters, self.clip_value)

+ 58
- 0
fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py View File

@@ -0,0 +1,58 @@
__all__ = [
'TorchWarmupCallback'
]
import math

from ..callback import Callback


class TorchWarmupCallback(Callback):
def __init__(self, warmup=0.1, schedule='constant'):
r"""
调整 learning rate 的 callback 。仅在实际发生参数更新的情况下

:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float,
如0.1, 则前10%的step是按照schedule策略调整learning rate。
:param str schedule: 以哪种方式调整。
linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0;
constant前warmup的step上升到指定learning rate,后面的step保持learning rate.
"""
super().__init__()
self.warmup = max(warmup, 0.)

self.initial_lrs = [] # 存放param_group的learning rate
if schedule == 'constant':
self.get_lr = self._get_constant_lr
elif schedule == 'linear':
self.get_lr = self._get_linear_lr
else:
raise RuntimeError("Only support 'linear', 'constant'.")

def _get_constant_lr(self, progress):
if progress <self.warmup:
return progress /self.warmup
return 1

def _get_linear_lr(self, progress):
if progress <self.warmup:
return progress /self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0.)

def on_train_begin(self, trainer):
self.t_steps = trainer.total_batches
if self.warmup >1:
self.warmup = self.warmup / self.t_steps
self.t_steps = max(2, self.t_steps) # 不能小于2
# 防止 t_steps 不能整除 accumulation_steps
self.t_steps = math.ceil(self.t_steps/trainer.accumulation_steps) * trainer.accumulation_steps
# 获取param_group的初始learning rate
for optimizer in trainer.driver.optimizers:
for group in optimizer.param_groups:
self.initial_lrs.append(group['lr'])

def on_before_optimizers_step(self, trainer, optimizers):
# 这里需要加 accumulation_steps 是防止 lr 从 0 开始
progress = (trainer.global_forward_batches + trainer.accumulation_steps) / self.t_steps
for optimizer in trainer.driver.optimizers:
for lr, group in zip(self.initial_lrs, optimizer.param_groups):
group['lr'] = lr * self.get_lr(progress)

+ 2
- 2
fastNLP/core/controllers/trainer.py View File

@@ -219,10 +219,10 @@ class Trainer(TrainerEventTrigger):

""" 设置内部的 Evaluator """
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.")
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.")

if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.")

self.evaluator = None
self.monitor = monitor


+ 1
- 1
fastNLP/core/drivers/driver.py View File

@@ -129,7 +129,7 @@ class Driver(ABC):
@property
def optimizers(self) -> List:
r"""
如下所示,driver 返回的 optimizers 一定是一个 List,如果用户直接向 Trainer 传入一个单独的 optimzer,我们会使用一个 List 将其
如下所示,driver 返回的 optimizers 一定是一个 List,如果用户直接向 Trainer 传入一个单独的 optimizer,我们会使用一个 List 将其
包裹;

:return: List[optimizer0, optimizer1, optimizer2, ...]


+ 5
- 8
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -1,5 +1,5 @@
import os
from typing import Optional, Union
from typing import Optional, Union, Callable, Dict, Tuple

from .jittor_driver import JittorDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -61,14 +61,11 @@ class JittorMPIDriver(JittorDriver):
return self._data_device
return self.model_device

def train_step(self, batch):
return self._train_step(batch)

def validate_step(self, batch):
return self._validate_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
pass

def test_step(self, batch):
return self._test_step(batch)
def get_model_call_fn(self, fn: str) -> Tuple:
pass

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):


+ 19
- 47
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -1,9 +1,11 @@
from typing import Dict, Union
from typing import Dict, Union, Tuple, Callable, Optional

from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger

if _NEED_IMPORT_JITTOR:
import jittor
@@ -27,42 +29,6 @@ class JittorSingleDriver(JittorDriver):
self.global_rank = 0
self.world_size = 1

if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
model = self.unwrap_model()
self._train_signature_fn = model.execute

if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.execute

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.execute

def train_step(self, batch) -> Dict:
if isinstance(batch, Dict):
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)

def step(self):
"""
jittor optimizers 的step函数可以传入参数loss
@@ -80,18 +46,24 @@ class JittorSingleDriver(JittorDriver):
for optimizer in self.optimizers:
optimizer.zero_grad()

def validate_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch):

if isinstance(batch, Dict):
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def unwrap_model(self):
return self.model


+ 376
- 0
fastNLP/core/drivers/paddle_driver/dist_utils.py View File

@@ -0,0 +1,376 @@
import io
import pickle
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
from typing import Any, List

from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
if _TORCH_GREATER_EQUAL_1_8:
try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
pass


from fastNLP.core.utils import apply_to_collection


def _validate_output_list_for_rank(my_rank, dst, gather_list):
if dst == my_rank:
if not gather_list:
raise ValueError(
"Argument ``gather_list`` must be specified on destination rank."
)
elif gather_list:
raise ValueError(
"Argument ``gather_list`` must NOT be specified "
"on non-destination ranks."
)


def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
"""
从其它 rank gather 东西到 dst rank 。

Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered.

Args:
obj (Any): Input object. Must be picklable.
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
should be correctly sized as the size of the group for this
collective and will contain the output. Must be ``None`` on non-dst
ranks. (default is ``None``)
dst (int, optional): Destination rank. (default is 0)
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.

Returns:
None. On the ``dst`` rank, ``object_gather_list`` will contain the
output of the collective.

.. note:: Note that this API differs slightly from the gather collective
since it does not provide an async_op handle and thus will be a blocking
call.

.. note:: Note that this API is not supported when using the NCCL backend.

.. warning::
:func:`gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
gather_objects[dist.get_rank()],
output if dist.get_rank() == 0 else None,
dst=0
)
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]
"""
if group is None:
group = DEFAULT_TORCH_GROUP

if dist.distributed_c10d._rank_not_in_group(group):
return

# Ensure object_gather_list is specified appopriately.
my_rank = dist.get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
# 防止 unpickle 的时候出现在了发送的 gpu 上。
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
input_tensor, local_size = _object_to_tensor(obj)
group_backend = dist.get_backend(group)
current_device = torch.device("cpu")
is_nccl_backend = group_backend == dist.Backend.NCCL
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes. An all-gather is needed here despite this being a
# gather, since each rank needs to broadcast a tensor of the same (maximal)
# size.
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
# Avoid populating output tensors if the result won't be gathered on this rank.
if my_rank == dst:
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
# All ranks call gather with equal-sized tensors.
dist.gather(
input_tensor,
gather_list=output_tensors if my_rank == dst else None,
dst=dst,
group=group,
)
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)


def _object_to_tensor(obj, device=None):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
if device is not None:
byte_tensor = byte_tensor.to(device)
local_size = local_size.to(device)
return byte_tensor, local_size


def _tensor_to_object(tensor, tensor_size):
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()


def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
# src rank send to all other ranks
size = torch.LongTensor([0]).to(device)

if cur_rank == src:
world_size = dist.get_world_size(group=group)
tensor, size = _object_to_tensor(obj)
tensor = tensor.to(device)
size = size.to(device)

# 首先同步 obj 的 size 的信息;
dist.broadcast(size, src, group=group)
for subrank in range(world_size):
if subrank != src:
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag)
else:
dist.broadcast(size, src, group=group)
tensor = torch.ByteTensor([0] * size).to(device)
dist.recv(tensor=tensor, src=src, group=group, tag=tag)

return _tensor_to_object(tensor.cpu(), size)

def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
"""
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。

example:
obj = {
'a': [1, 1],
'b': [[1, 2], [1, 2]],
'c': {
'd': [1, 2]
}
}
->
[
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
]

:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行
序列化之后进行传输。
:param device: 当前该参数无意义。
:param group:
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
"""
if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
dist.all_gather(objs, obj, group=group)
else:
objs = [None for _ in range(dist.get_world_size(group))]
# 防止 unpickle 的时候弄到发送的 gpu 上了
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
dist.all_gather_object(objs, obj, group=group)
else:
objs = all_gather_object(objs, obj, group=group)
return objs


def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
"""
将 src 上的 obj 对象广播到其它 rank 上。

:param obj:
:param src:
:param device:
:param group:
:return:
"""
if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
if cur_rank == src:
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
if cur_rank!=src:
get_obj = [None]
dist.broadcast_object_list(get_obj, src=src, group=group)
return get_obj[0]
else:
dist.broadcast_object_list([obj], src=src, group=group)
return obj
if device is None:
device = torch.cuda.current_device()

if cur_rank == src:
tensor, size = _object_to_tensor(obj, device=device)
else:
size = torch.LongTensor([0]).to(device)

dist.broadcast(size, src=src, group=group)
if cur_rank != src:
tensor = torch.empty(
size.int().item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=device
)
dist.broadcast(tensor, src=src, group=group)

return _tensor_to_object(tensor, tensor_size=size.item())


def _check_for_nccl_backend(group):
pg = group or dist.distributed_c10d._get_default_group()
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, _ProcessGroupWrapper):
pg = pg.wrapped_pg

return (
dist.is_nccl_available() and
isinstance(pg, dist.ProcessGroupNCCL)
)


def all_gather_object(object_list, obj, group=None):
"""
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。

Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.

Args:
object_list (list[Any]): Output list. It should be correctly sized as the
size of the group for this collective and will contain the output.
object (Any): Pickable Python object to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.

Returns:
None. If the calling rank is part of this group, the output of the
collective will be populated into the input ``object_list``. If the
calling rank is not part of the group, the passed in ``object_list`` will
be unmodified.

.. note:: Note that this API differs slightly from the :func:`all_gather`
collective since it does not provide an ``async_op`` handle and thus
will be a blocking call.

.. note:: For NCCL-based processed groups, internal tensor representations
of objects must be moved to the GPU device before communication takes
place. In this case, the device used is given by
``torch.cuda.current_device()`` and it is the user's responsiblity to
ensure that this is set so that each rank has an individual GPU, via
``torch.cuda.set_device()``.

.. warning::
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]
"""
if dist.distributed_c10d._rank_not_in_group(group):
return
if _TORCH_GREATER_EQUAL_1_8:
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
else:
current_device = torch.cuda.current_device()

input_tensor, local_size = _object_to_tensor(obj, device=current_device)

# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=current_device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)
return object_list

+ 86
- 56
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -1,13 +1,12 @@
import os
import shutil
from functools import partial
from typing import List, Union, Optional, Dict
from typing import List, Union, Optional, Dict, Tuple, Callable

from .paddle_driver import PaddleDriver
from .fleet_launcher import FleetLauncher
from .utils import (
_FleetWrappingModel,
ForwardState,
_MODE_PARAMETER,
get_device_from_visible,
reset_seed,
replace_sampler,
@@ -47,8 +46,7 @@ if _NEED_IMPORT_PADDLE:
__all__ = [
"PaddleFleetDriver",
]
# if os.path.exists(self.gloo_rendezvous_dir):
# shutil.rmtree(self.gloo_rendezvous_dir)

class PaddleFleetDriver(PaddleDriver):
def __init__(
self,
@@ -104,34 +102,6 @@ class PaddleFleetDriver(PaddleDriver):
# 我们就直接将 model_device 置为 None;
self._model_device = None

def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)

model = model._layers
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
if self._data_device is not None:
@@ -150,8 +120,6 @@ class PaddleFleetDriver(PaddleDriver):

self.world_size = None
self.global_rank = 0
self._configured = False # 防止重复调用 configure_ddp() 函数使用
self._has_setup = False # 防止重复调用 setup() 函数

self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {})
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
@@ -173,6 +141,9 @@ class PaddleFleetDriver(PaddleDriver):
os.makedirs(name=self.output_from_new_proc, exist_ok=True)
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)

self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹;

def setup(self):
"""
在主进程拉起其它子进程,将主进程作为rank 0
@@ -268,17 +239,17 @@ class PaddleFleetDriver(PaddleDriver):
dist.barrier()

def configure_fleet(self):
if not self._configured and not isinstance(self.model, DataParallel):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
self.model = DataParallel(
_FleetWrappingModel(self.model),
**self._fleet_kwargs
)
self._has_fleetwrapped = True

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True
def on_exception(self):
if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir)
super().on_exception()

@property
def world_size(self) -> int:
@@ -310,14 +281,39 @@ class PaddleFleetDriver(PaddleDriver):
return self._data_device
return self.model_device

def train_step(self, batch):
return self._train_step(batch)

def validate_step(self, batch):
return self._validate_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if self._has_fleetwrapped:
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
wo_auto_param_call=self.wo_auto_param_call)
else:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
model = self.unwrap_model()
if self._has_fleetwrapped:
if hasattr(model, fn):
fn = getattr(model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return model, model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")
else:
if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}:
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
"`DistributedDataParallel` model, which means that we will only call model.forward "
"function when we are in forward propagation.")

def test_step(self, batch):
return self._test_step(batch)
return self.model, model.forward

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
@@ -406,14 +402,6 @@ class PaddleFleetDriver(PaddleDriver):
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def is_global_zero(self):
return self.global_rank == 0

@@ -450,3 +438,45 @@ class PaddleFleetDriver(PaddleDriver):
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
f"not {type(each_optimizer)}.")

def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
"""
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。

:param obj: obj,可能是 Tensor 或 嵌套类型的数据
:param int src: source 的 global rank 。
:param int dst: target 的 global rank,可以是多个目标 rank
:param group: 所属的 group
:param kwargs:
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
"""
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)

def all_gather(self, obj, group) -> List:
"""
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
pickle 进行序列化,接收到之后再反序列化。

example:
obj = {
'a': [1, 1],
'b': [[1, 2], [1, 2]],
'c': {
'd': [1, 2]
}
}
->
[
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
]

:param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。
:param group:
:return:
"""
return
return fastnlp_paddle_all_gather(obj, group=group)

+ 19
- 32
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -71,6 +71,14 @@ class PaddleDriver(Driver):
for optimizer in self.optimizers:
optimizer.clear_grad()

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r"""
@@ -115,28 +123,6 @@ class PaddleDriver(Driver):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
f"not {type(each_optimizer)}.")

def check_evaluator_mode(self, mode: str):
r"""
因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么
我们应当提醒用户这一行为;
"""
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"):
logger.warning(
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for"
"'evaluate_step'.")

else:
if not hasattr(model, "test_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.")

@staticmethod
def tensor_to_numeric(tensor, reduce=None):
r"""
@@ -258,20 +244,21 @@ class PaddleDriver(Driver):
if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None)
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
try:
sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
pass
assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report."
if isinstance(sampler, ReproducibleSampler):
# 如果是 sampler 的话,需要计算出实际的 sample 数目
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states["sampler_states"] = sampler_states

# 2. 保存模型的状态;
if should_save_model:


+ 48
- 91
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -1,5 +1,5 @@
import os
from typing import Optional, Dict, Union
from typing import Optional, Dict, Union, Callable, Tuple

from .paddle_driver import PaddleDriver
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible
@@ -11,16 +11,19 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproducibleSampler,
RandomSampler,
re_instantiate_sampler,
)
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
import paddle
from paddle import DataParallel
from paddle.fluid.reader import _DatasetKind

__all__ = [
@@ -28,109 +31,57 @@ __all__ = [
]

class PaddleSingleDriver(PaddleDriver):
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs):
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs):
if isinstance(model, DataParallel):
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`")

cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None)
if cuda_visible_devices == "":
device = "cpu"
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
"use `cpu` instead of `gpu` device.")

super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs)

if device is None:
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")

if device != "cpu":
if isinstance(device, int):
device_id = device
else:
device_id = get_paddle_device_id(device)
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
self.model_device = get_paddle_gpu_str(device)

self.local_rank = 0
self.global_rank = 0
self.world_size = 1

if isinstance(model, paddle.DataParallel):
# 注意这里的 unwrap_model 调用的是具体子类的方法;
model = self.unwrap_model()
if hasattr(model, "train_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `train_step` method, which we can not call actually, we will "
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "evaluate_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `evaluate_step` method, which we can not call actually, we "
"will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `test_step` method, which we can not call actually, we will "
"call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
# 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的;
model = self.unwrap_model()
self._train_signature_fn = model.forward

if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.forward

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.forward

def setup(self):
device = self.model_device
if device != "cpu":
device_id = get_paddle_device_id(device)
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
device = get_device_from_visible(device, output_type=str)
device = get_device_from_visible(device, output_type=str)
paddle.device.set_device(device)
self.model.to(device)

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._train_step(batch)

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def validate_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._validate_step(batch)

def test_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def move_data_to_device(self, batch: 'paddle.Tensor'):
r"""
@@ -164,12 +115,18 @@ class PaddleSingleDriver(PaddleDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
if isinstance(args.sampler, paddle.io.RandomSampler):
# 如果本来就是随机的,直接替换
sampler = RandomSampler(args.sampler.data_source)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader



+ 5
- 75
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -11,7 +11,6 @@ from typing import Dict, Optional, Union

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
from fastNLP.core.samplers import RandomSampler
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger

@@ -87,8 +86,6 @@ class ForwardState(IntEnum):
TEST = 2
PREDICT = 3

_MODE_PARAMETER = "forward_state"

class _FleetWrappingModel(Layer):
"""
参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和
@@ -98,83 +95,16 @@ class _FleetWrappingModel(Layer):
super(_FleetWrappingModel, self).__init__()
self.model = model

if isinstance(model, paddle.DataParallel):
model = model._layers
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward

if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(model, "train_step"):
self._train_step = model.train_step
self._train_signature_fn = None
else:
self._train_step = model
self._train_signature_fn = model.forward

if hasattr(model, "evaluate_step"):
self._validate_step = model.validate_step
self._validate_signature_fn = None
elif hasattr(model, "test_step"):
self._validate_step = model.test_step
self._validate_signature_fn = None
else:
self._validate_step = model
self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
self._test_step = model.test_step
self._test_signature_fn = None
elif hasattr(model, "evaluate_step"):
self._test_step = model.validate_step
self._test_signature_fn = None
else:
self._test_step = model
self._test_signature_fn = model.forward

def forward(self, batch, **kwargs) -> Dict:

forward_state = kwargs.pop(_MODE_PARAMETER)
fn = kwargs.pop("fastnlp_fn")
signature_fn = kwargs.pop("fastnlp_signature_fn")
wo_auto_param_call = kwargs.pop("wo_auto_param_call")

if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.")
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
raise NotImplementedError("You should direct a concrete evaluate_fn.")
return fn(batch)

class DummyGradScaler:
"""


+ 1
- 1
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
# world_size 和 rank
if FASTNLP_BACKEND_LAUNCH in os.environ:
if device is not None:
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
"up your script. And we will directly get the local device via "
"`os.environ['LOCAL_RANK']`.")
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)


+ 7
- 1
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -37,7 +37,12 @@ class TorchSingleDriver(TorchDriver):
super(TorchSingleDriver, self).__init__(model, fp16=fp16, **kwargs)

if device is None:
raise ValueError("Parameter `device` can not be None in `TorchSingleDriver`.")
logger.debug("device is not set, fastNLP will try to automatically get it.")
try:
device = next(model.parameters()).device
assert isinstance(device, torch.device)
except:
raise ValueError("fastNLP cannot get device automatically, please set device explicitly.")

self.model_device = device

@@ -70,6 +75,7 @@ class TorchSingleDriver(TorchDriver):

return self.model, model.forward
else:
# TODO 这种直接调用模型某个接口的方法无法触发hook,也许需要做一个warning,如果用户有钩子,提醒他train_step无法触发。
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):


+ 24
- 8
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -25,7 +25,7 @@ __all__ = [

from .utils import optimizer_state_to_device
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler
from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
@@ -224,6 +224,11 @@ class TorchDriver(Driver):
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;

# 4. 保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict

logger.debug("Save optimizer state dict")
states["optimizers_state_dict"] = optimizers_state_dict
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
@@ -232,7 +237,7 @@ class TorchDriver(Driver):
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))

# 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"]
optimizers_state_dict = states.pop("optimizers_state_dict")
for i in range(len(self.optimizers)):
optimizer: torch.optim.Optimizer = self.optimizers[i]
optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"])
@@ -244,26 +249,37 @@ class TorchDriver(Driver):
res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu')
if only_state_dict:
model.load_state_dict(res)
logger.debug("Load model state dict.")
logger.debug("Load model state dict...")
else:
model.load_state_dict(res.state_dict())
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
logger.debug("Load model...")

# 3. 加载fp16的状态
if 'grad_scaler_state_dict' in states:
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
if not isinstance(self.grad_scaler, DummyGradScaler):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
sampler.load_state_dict(states.pop('sampler_states'))
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)

# 4. 修改 trainer_state.batch_idx_in_epoch


+ 18
- 43
fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py View File

@@ -1,6 +1,7 @@
from typing import Optional, Dict, Union, Callable
from typing import Optional, Dict, Union, Callable, Tuple

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
from fastNLP.core.utils.utils import _get_fun_msg


if _NEED_IMPORT_PADDLE:
@@ -48,33 +49,6 @@ class TorchPaddleDriver(Driver):
elif self._data_device is not None:
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")

if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
self._train_signature_fn = self.model.forward

if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.forward
else:
self._validate_step = self.model
self._validate_signature_fn = self.model.forward

if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.forward
else:
self._test_step = self.model
self._test_signature_fn = self.model.forward

def setup(self):
if self.model_device is not None:
paddle.device.set_device(self.model_device.replace("cuda", "gpu"))
@@ -103,12 +77,6 @@ class TorchPaddleDriver(Driver):
f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, "
f"not {type(each_optimizer)}.")

def train_step(self, batch) -> Dict:
if isinstance(batch, Dict):
return auto_param_call(self._train_step, batch)
else:
return self._train_step(batch)

def step(self):
for optimizer in self.optimizers:
optimizer.step()
@@ -125,17 +93,24 @@ class TorchPaddleDriver(Driver):
else:
raise ValueError("Unknown optimizers type.")

def validate_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._validate_step, batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._test_step, batch)
return fn(batch)

def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

def predict_step(self, batch):
if isinstance(batch, Dict):


+ 0
- 6
fastNLP/core/samplers/__init__.py View File

@@ -1,9 +1,4 @@
__all__ = [
'BucketSampler',
'SortedSampler',
'ConstTokenNumSampler',
'ConstantTokenNumSampler',

'MixSampler',
'DopedSampler',
'MixSequentialSampler',
@@ -26,7 +21,6 @@ __all__ = [
"re_instantiate_sampler"
]

from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler


+ 0
- 728
fastNLP/core/samplers/sampler.py View File

@@ -1,728 +0,0 @@
r"""
sampler 子类实现了 fastNLP 所需的各种采样器。
"""

__all__ = [
"BucketSampler",
"SortedSampler",
'ConstTokenNumSampler',
"ConstantTokenNumSampler",
]

from itertools import chain
from typing import List, Iterable

import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
from torch.utils.data import Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Sampler

# class DopedSampler(Sampler):
# """
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。
# """
#
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None,
# ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, drop_last: bool = False) -> None:
# if batch_size <= 0:
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(batch_size))
# if not isinstance(drop_last, bool):
# raise ValueError("drop_last should be a boolean value, but got "
# "drop_last={}".format(drop_last))
# self.batch_size = batch_size
# self.drop_last = drop_last
# self.ds_ratio = ds_ratio
# if sampler is None:
# if isinstance(dataset, List):
# self.sampler = [SequentialSampler(ds) for ds in dataset]
# elif isinstance(dataset, Dict):
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
#
# elif isinstance(sampler, List):
# if len(sampler) != len(dataset):
# raise ValueError("the length of sampler != the length of sampler")
# self.sampler = sampler
# else:
# self.sampler = sampler
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None:
# self.ds_ratio = ds_ratio
# elif isinstance(ds_ratio, List):
# if not all(item >= 0 for item in ds_ratio):
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(ds_ratio))
# self.ds_ratio = ds_ratio
# else:
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None")
#
# def __iter__(self):
# samplers, index = [], 0
# if isinstance(self.sampler, List):
# for idx, sampler in enumerate(self.sampler):
# samplers.append((iter(sampler), self.batch_size, index, 0, idx))
# index += len(sampler)
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# samplers.append((iter(sampler), self.batch_size, index, 0, name))
# index += len(sampler)
#
# def __len__(self):
# lens = 0
# max_len, ds_len = 0, 0
# if self.ds_ratio == 'truncate_to_least':
# if isinstance(self.sampler, List):
# max_len = min(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = min(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
#
# elif self.ds_ratio == 'pad_to_most':
# if isinstance(self.sampler, List):
# max_len = max(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = max(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
#
# if self.ds_ratio is None:
# if isinstance(self.sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif self.ds_ratio == 'truncate_to_least' or self.ds_ratio == 'pad_to_most':
# for i in range(ds_len):
# if self.drop_last:
# lens += max_len // self.batch_size
# else:
# lens += (max_len + self.batch_size - 1) // self.batch_size
# return lens
#
# def demo(self):
# indexes = np.array([0]*self.batch_size + [1]*self.batch_size + [2]*self.batch_size)
# shift = np.array([0]*self.batch_size + [len(ds1)]*self.batch_size + [len(ds1)+len(ds2)]*self.batch_size)
# buffer = np.zeros(self.batch_size*self.num_ds, dtype=int)
# select_sampler = np.random.randint(0, self.batch_size*self.num_ds, num_sample=self.batch_size)
# select_indices = buffer[select_sampler] + shift[select_sampler]
# num_1 = (indexes[select_sampler]==0).sum()
#


# class MixSequentialSampler(Sampler):
# """
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表顺序采样并返回index,只有处理了上一个dataset才会处理下一个。
# """
#
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
# sampler: Union[List[Sampler], Dict[str, Sampler], None] = None,
# drop_last: bool = False) -> None:
# """
#
# :param dataset: 实现了__getitem__和__len__的数据容器列表
# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset
# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象
# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size
# """
# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable,
# if isinstance(dataset, Dict) and isinstance(sampler, List):
# raise ValueError(f"{sampler} must be dict")
#
# # 判断batch_size是否大于等于0
# if batch_size <= 0:
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(batch_size))
#
# if not isinstance(drop_last, bool):
# raise ValueError("drop_last should be a boolean value, but got "
# "drop_last={}".format(drop_last))
# self.batch_size = batch_size
# self.drop_last = drop_last
# if sampler is None:
# if isinstance(dataset, List):
# self.sampler = [SequentialSampler(ds) for ds in dataset]
# elif isinstance(dataset, Dict):
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
# elif isinstance(sampler, List):
# if len(sampler) != len(dataset):
# raise ValueError("the length of sampler != the length of sampler")
# self.sampler = sampler
#
# def __iter__(self) -> Iterable[List[int]]:
# """
# 按照dataset的顺序采样,打包成一个batch后返回
# :return:
# """
# index = 0
# batch = []
# if isinstance(self. sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# for idx in sampler:
# batch.append(idx + index)
# if len(batch) == self.batch_size:
# yield batch
# batch = []
# if len(batch) > 0 and not self.drop_last:
# yield batch
# batch = []
# index += len(sampler)
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# for idx in sampler:
# batch.append(idx + index)
# if len(batch) == self.batch_size:
# yield batch
# batch = []
# if len(batch) > 0 and not self.drop_last:
# yield batch
# batch = []
# index += len(sampler)
#
# def __len__(self) -> int:
# lens = 0
# if isinstance(self.sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif isinstance(self.sampler, Dict):
# for _, sampler in self.sampler.items():
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# return lens


# class PollingSampler(Sampler):
# """
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表轮流采样并返回index,处理了上个dataset的一个batch后会处理下一个。
# """
#
# def __init__(self, dataset: Union[List, Dict], batch_size: int = 16,
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None,
# drop_last: bool = False, ds_ratio="pad_to_most") -> None:
# """
#
# :param dataset: 实现了__getitem__和__len__的数据容器列表
# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset
# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象
# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size
# :param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时,
# 以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样
# """
# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable,
# if isinstance(dataset, Dict) and isinstance(sampler, List):
# raise ValueError(f"{sampler} must be dict")
# if isinstance(dataset, List) and isinstance(sampler, Dict):
# raise ValueError(f"{sampler} must be list")
# # 判断batch_size是否大于等于0
# if batch_size <= 0:
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(batch_size))
#
# if not isinstance(drop_last, bool):
# raise ValueError("drop_last should be a boolean value, but got "
# "drop_last={}".format(drop_last))
#
# self.batch_size = batch_size
# self.drop_last = drop_last
# if sampler is None:
# if isinstance(dataset, List):
# self.sampler = [SequentialSampler(ds) for ds in dataset]
# elif isinstance(dataset, Dict):
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
#
# elif isinstance(sampler, List):
# if len(sampler) != len(dataset):
# raise ValueError("the length of sampler != the length of sampler")
# self.sampler = sampler
# else:
# self.sampler = sampler
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None:
# self.ds_ratio = ds_ratio
# else:
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None")
#
# def __iter__(self) -> Iterable[List[int]]:
# # index是数据集下标基址, pointer指向数据集列表的某个数据集
# index, pointer, samplers, flag = 0, 0, [], False
#
# if isinstance(self.sampler, List):
# for idx, sampler in enumerate(self.sampler):
# samplers.append((iter(sampler), self.batch_size, index, 0, idx))
# index += len(sampler)
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# samplers.append((iter(sampler), self.batch_size, index, 0, name))
# index += len(sampler)
# if self.ds_ratio == 'pad_to_most':
# if isinstance(self.sampler, List):
# limit_len = max(len(ds) for ds in self.sampler)
# else:
# limit_len = max(len(ds) for _, ds in self.sampler.items())
# elif self.ds_ratio == 'truncate_to_least':
# if isinstance(self.sampler, List):
# limit_len = min(len(ds) for ds in self.sampler)
# else:
# limit_len = min(len(ds) for _, ds in self.sampler.items())
# else:
# limit_len = 0
# # 最后一个批次的大小
# last_batch_size = limit_len % self.batch_size
#
# while True:
# # 全部采样完,退出
# if len(samplers) == 0:
# break
# batch, flag = [], False
# # sampler_len代表已经取出来的数据个数
# sampler, batch_size, index, sampler_len, name = samplers.pop(0)
# for _ in range(batch_size):
# try:
# batch.append(index + next(sampler))
# sampler_len += 1
# except StopIteration:
# flag = True
# # ds_ratio为None,第一种情况,删除掉采样完的数据即可。
# if self.ds_ratio == 'pad_to_most' and sampler_len < limit_len:
# # 重置sampler,并取足一个batch数据
# sampler = iter(self.sampler[name])
# # 由于batch_size一定小于等于ds的长度,故能够取足一个batch_size的数据
# for _ in range(batch_size-len(batch)):
# batch.append(next(sampler) + index)
# sampler_len += 1
# break
#
# # ds_ratio不为None情况
# # 两种情况会触发一下逻辑:1.truncate_to_least时,最短的数据集最后一个batch大小不等于batch_size时,
# # 其他较长的数据集的最后一个batch长度会较长;2. pad_to_most,最长的数据集最后一个batch不等于batch_size时,较短数据集最后一个
# # batch长度会较长
# if limit_len != 0 and limit_len < sampler_len:
# batch = batch[:last_batch_size]
# # ds_ratio为任意情况下, 没有取完所有数据,则添加到队列尾部
# elif (limit_len == 0 and flag == False) or limit_len > sampler_len:
# samplers.append((sampler, batch_size, index, sampler_len, name))
# if len(batch) == batch_size:
# yield batch
# elif len(batch) > 0 and not self.drop_last:
# yield batch
#
# def __len__(self) -> int:
# lens = 0
# max_len, ds_len = 0, 0
# if self.ds_ratio == 'truncate_to_least':
# if isinstance(self.sampler, List):
# max_len = min(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = min(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
#
# elif self.ds_ratio == 'pad_to_most':
# if isinstance(self.sampler, List):
# max_len = max(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = max(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
# if self.ds_ratio is None:
# if isinstance(self.sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# else:
# for i in range(ds_len):
# if self.drop_last:
# lens += max_len // self.batch_size
# else:
# lens += (max_len + self.batch_size - 1) // self.batch_size
# return lens


class BucketSampler(Sampler):
r"""
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素
"""

def __init__(self, dataset, num_buckets=10, batch_size=None, seq_len_field_name='seq_len', drop_last=False) -> None:
r"""
:param int num_buckets: bucket的数量
:param int batch_size: batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非
Trainer/Tester场景使用,需要显示传递该值
:param str seq_len_field_name: 对应序列长度的 `field` 的名字
"""
self.dataset = dataset
self.num_buckets = num_buckets
self.batch_size = batch_size
self.seq_len_field_name = seq_len_field_name

def set_batch_size(self, batch_size) -> None:
r"""

:param int batch_size: 每个batch的大小
:return:
"""
self.batch_size = batch_size

def __iter__(self):
if self.batch_size is None:
raise RuntimeError("batch_size is None.")
seq_lens = self.dataset.get_all_fields()[self.seq_len_field_name].content
total_sample_num = len(seq_lens)

bucket_indexes = []
assert total_sample_num >= self.num_buckets, "The number of samples is smaller than the number of buckets."
num_sample_per_bucket = total_sample_num // self.num_buckets
for i in range(self.num_buckets):
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])
bucket_indexes[-1][1] = total_sample_num

sorted_seq_lens = list(sorted([(idx, seq_len) for
idx, seq_len in zip(range(total_sample_num), seq_lens)],
key=lambda x: x[1]))

batchs = []

left_init_indexes = []
for b_idx in range(self.num_buckets):
start_idx = bucket_indexes[b_idx][0]
end_idx = bucket_indexes[b_idx][1]
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
num_batch_per_bucket = len(left_init_indexes) // self.batch_size
np.random.shuffle(left_init_indexes)
for i in range(num_batch_per_bucket):
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:]
if (left_init_indexes) != 0:
batchs.append(left_init_indexes)
np.random.shuffle(batchs)

return chain(*batchs)


class ConstTokenNumSampler(Sampler):
"""
尽量保证每个batch的输入token数量是接近的。

"""

def __init__(self, dataset, seq_len_field_name: List[int], max_token: int = 4096, max_sentence: int = -1,
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None:
"""

:param dataset:
:param List[int] seq_len_field_name: 哪个field指示的sample的长度
:param int max_token: 每个batch的最大的token数量
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
"""
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1
self.dataset = dataset
self.seq_len_field_name = seq_len_field_name
self.num_bucket = num_bucket
self.max_token = max_token
self._max_sentence = max_sentence
self.need_be_multiple_of = need_be_multiple_of

assert len(self.dataset) > self.num_bucket, "The number of samples should be larger than buckets."
seq_len = self.dataset.get_field(self.seq_len_field_name)
self.seq_len = seq_len
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
seq_len_indice.sort(key=lambda x: x[0])
indice_in_buckets = []
if self.num_bucket > 0:
sample_per_bucket = len(seq_len_indice) // self.num_bucket
i = 0
while len(indice_in_buckets) < len(seq_len_indice):
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket])
i += 1
else:
indice_in_buckets = [seq_len_indice]
self.indice_in_buckets = indice_in_buckets
self.get_new_order()

@property
def max_sentence(self):
if self._max_sentence < 1:
return 100000000
return self._max_sentence

@max_sentence.setter
def max_sentence(self, max_sentence):
self._max_sentence = max_sentence

def get_new_order(self) -> None:
np.random.shuffle(self.indice_in_buckets)
for bucket in self.indice_in_buckets:
np.random.shuffle(bucket)
indices = list(chain(*self.indice_in_buckets))
batches = []
cur_max_len = 0
batch = []
for length, i in indices:
max_len = max(length, cur_max_len)
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
cur_max_len = length
if left_sample != 0:
add_samples = add_samples[:-left_sample]
batch = batch[-left_sample:]
cur_max_len = max(cur_max_len, max(batch))
else:
batch = []
if len(add_samples) == 0:
raise RuntimeError(
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
batches.append(add_samples)
else:
cur_max_len = max_len
batch.append(i)
if batch:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
if left_sample != 0:
add_samples = add_samples[:-left_sample].copy()
if add_samples:
batches.append(add_samples)
np.random.shuffle(batches)
self.batches = batches

def __iter__(self) -> Iterable[int]:
for batch in self.batches:
yield batch
self.get_new_order()

def __len__(self):
return len(self.batches)


class ConstantTokenNumSampler:
"""
尽量保证每个batch的输入token数量是接近的。

"""

def __init__(self, seq_len, max_token: List[int] = 4096, max_sentence: int = -1,
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None:
"""

:param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入
:param int max_token: 每个batch的最大的token数量
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
"""
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1
assert len(seq_len) > num_bucket, "The number of samples should be larger than buckets."
self.seq_len = seq_len
self.max_token = max_token
self._max_sentence = max_sentence
self.need_be_multiple_of = need_be_multiple_of
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
seq_len_indice.sort(key=lambda x: x[0])
indice_in_buckets = []
if num_bucket > 0:
sample_per_bucket = len(seq_len_indice) // num_bucket
i = 0
while len(indice_in_buckets) < len(seq_len_indice):
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket])
i += 1
else:
indice_in_buckets = [seq_len_indice]
self.indice_in_buckets = indice_in_buckets
self.get_new_order()

@property
def max_sentence(self):
if self._max_sentence < 1:
return 100000000
return self._max_sentence

@max_sentence.setter
def max_sentence(self, max_sentence):
self._max_sentence = max_sentence

def get_new_order(self) -> None:
np.random.shuffle(self.indice_in_buckets)
for bucket in self.indice_in_buckets:
np.random.shuffle(bucket)
indices = list(chain(*self.indice_in_buckets))
batches = []
cur_max_len = 0
batch = []
for length, i in indices:
max_len = max(length, cur_max_len)
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
cur_max_len = length
if left_sample != 0:
add_samples = add_samples[:-left_sample]
batch = batch[-left_sample:]
cur_max_len = max(cur_max_len, max(batch))
else:
batch = []
if len(add_samples) == 0:
raise RuntimeError(
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
batches.append(add_samples)
else:
cur_max_len = max_len
batch.append(i)
if batch:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
if left_sample != 0:
add_samples = add_samples[:-left_sample].copy()
if add_samples:
batches.append(add_samples)
np.random.shuffle(batches)
self.batches = batches

def __iter__(self) -> Iterable[int]:
for batch in self.batches:
yield batch
self.get_new_order()

def __len__(self):
return len(self.batches)


class SortedSampler(Sampler):
r"""
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding)
"""

def __init__(self, dataset, seq_len_field_name: str = 'seq_len', descending: bool = True) -> None:
"""

:param str seq_len_field_name: 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是
数字,则使用该field的长度进行排序
:param bool descending: 是否降序排列
"""
self.dataset = dataset
self.seq_len_field_name = seq_len_field_name
self.descending = descending

def __iter__(self) -> Iterable[int]:
seq_lens = self.dataset.get_field(self.seq_len_field_name).content
try:
seq_lens = list(map(len, seq_lens))
except:
pass

orders = np.argsort(seq_lens).tolist() # 从小到大的顺序
if self.descending:
orders = orders[::-1]
for order in orders:
yield order


def simple_sort_bucketing(lengths):
r"""

:param lengths: list of int, the lengths of all examples.
:return data: 2-level list
::

[
[index_11, index_12, ...], # bucket 1
[index_21, index_22, ...], # bucket 2
...
]

"""
lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)]
sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1])
# TODO: need to return buckets
return [idx for idx, _ in sorted_lengths]


def k_means_1d(x, k, max_iter=100):
r"""Perform k-means on 1-D data.

:param x: list of int, representing points in 1-D.
:param k: the number of clusters required.
:param max_iter: maximum iteration
:return centroids: numpy array, centroids of the k clusters
assignment: numpy array, 1-D, the bucket id assigned to each example.
"""
sorted_x = sorted(list(set(x)))
x = np.array(x)
if len(sorted_x) < k:
raise ValueError("too few buckets")
gap = len(sorted_x) / k

centroids = np.array([sorted_x[int(x * gap)] for x in range(k)])
assign = None

for i in range(max_iter):
# Cluster Assignment step
assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x])
# Move centroids step
new_centroids = np.array([x[assign == k].mean() for k in range(k)])
if (new_centroids == centroids).all():
centroids = new_centroids
break
centroids = new_centroids
return np.array(centroids), assign


def k_means_bucketing(lengths, buckets):
r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths.

:param lengths: list of int, the length of all samples.
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
threshold for each bucket (This is usually None.).
:return data: 2-level list
::

[
[index_11, index_12, ...], # bucket 1
[index_21, index_22, ...], # bucket 2
...
]

"""
bucket_data = [[] for _ in buckets]
num_buckets = len(buckets)
_, assignments = k_means_1d(lengths, num_buckets)

for idx, bucket_id in enumerate(assignments):
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
bucket_data[bucket_id].append(idx)
return bucket_data

+ 31
- 7
fastNLP/core/utils/utils.py View File

@@ -203,7 +203,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
:return:
"""
if fn_name is not None:
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}."
assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`."

parameters = list(inspect.signature(fn).parameters.values())
if inspect.ismethod(fn):
@@ -606,16 +606,38 @@ def seq_len_to_mask(seq_len, max_len=None):
return mask


def wait_to_success(fn, no=False):
def wait_filepath(path, exist=True):
"""
等待当 path 的存在状态为 {exist} 时返回

:param path: 待检测的 path
:param exist: 为 True 时表明检测这个 path 存在就返回; 为 False 表明检测到这个 path 不存在 返回。
:return:
"""
if isinstance(path, str):
path = Path(path)
assert isinstance(path, Path)
count = 0
while True:
sleep(0.01)
if (no and not fn()) or (not no and fn()):
if path.exists() == exist:
break
count += 1
if count % 1000 == 0:
msg = 'create' if exist else 'delete'
logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...")



# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
def synchronize_safe_rm(path: Optional[Union[str, Path]]):
"""
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。

:param path:
:return:
"""
if path is None:
return
if isinstance(path, str):
@@ -624,7 +646,7 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]):
return
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
_recursive_rm(path)
wait_to_success(path.exists, no=True)
wait_filepath(path, exist=False)


def _recursive_rm(path: Path):
@@ -643,6 +665,8 @@ def _recursive_rm(path: Path):
def synchronize_mkdir(path: Optional[Union[str, Path]]):
"""
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数;
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。

"""
if path is None:
return
@@ -652,7 +676,7 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]):
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
path.mkdir(parents=True, exist_ok=True)

wait_to_success(path.exists)
wait_filepath(path, exist=True)


def get_class_that_defined_method(method):


+ 0
- 1
fastNLP/envs/set_backend.py View File

@@ -5,7 +5,6 @@
import os
import json
import sys
import subprocess
from collections import defaultdict




+ 1
- 2
fastNLP/io/loader/conll.py View File

@@ -50,8 +50,6 @@ class ConllLoader(Loader):

ConllLoader返回的DataSet的field由传入的headers确定。

数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。

"""
def __init__(self, headers, sep=None, indexes=None, dropna=True):
@@ -93,6 +91,7 @@ class ConllLoader(Loader):
class Conll2003Loader(ConllLoader):
r"""
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。

Example::



+ 1
- 1
fastNLP/modules/mix_modules/mix_module.py View File

@@ -85,7 +85,7 @@ class MixModule:
def test_step(self, batch):
raise NotImplementedError

def validate_step(self, batch):
def evaluate_step(self, batch):
raise NotImplementedError

def train(self):


+ 0
- 0
tests/core/callbacks/torch_callbacks/__init__.py View File


+ 41
- 0
tests/core/callbacks/torch_callbacks/test_torch_grad_clip_callback.py View File

@@ -0,0 +1,41 @@
import pytest
import numpy as np

from fastNLP.core.callbacks import TorchGradClipCallback, Callback
from fastNLP import Trainer
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch

from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args


class CheckClipCallback(Callback):
def __init__(self, parameters, clip_type, clip_value):
self.parameters = parameters
self.clip_type = clip_type
self.clip_value = clip_value

def on_after_optimizers_step(self, trainer, optimizers):
for param in self.parameters:
if self.clip_type == 'value':
assert param.grad.max().item()<=self.clip_value
else:
assert np.linalg.norm(param.grad.cpu().view(-1).numpy())<=self.clip_value


@pytest.mark.parametrize('accumulation_steps', [1, 3, 5])
@pytest.mark.parametrize('fp16', [True, False])
@pytest.mark.parametrize('clip_type', ['norm', 'value'])
@pytest.mark.parametrize('clip_value', [1, 2])
def test_torch_grad_clip_callback(accumulation_steps, fp16, clip_type, clip_value):
if not torch.cuda.is_available() and fp16:
pytest.skip("No cuda, cannot test fp16.")
device = 'cuda' if fp16 else 'cpu'
kwargs = get_trainer_args(lr=1, device=device)
callbacks = []
callbacks.append(TorchGradClipCallback(clip_value=clip_value, clip_type=clip_type))
callbacks.append(CheckClipCallback(kwargs['model'].parameters(), clip_type, clip_value))
trainer = Trainer(**kwargs, callbacks=callbacks, fp16=fp16)
trainer.run()

+ 34
- 0
tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py View File

@@ -0,0 +1,34 @@
import pytest
import numpy as np

from fastNLP.core.callbacks import TorchWarmupCallback, Callback
from fastNLP import Trainer

from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args


class RecordLrCallback(Callback):
def __init__(self):
self.lrs = []

def on_after_optimizers_step(self, trainer, optimizers):
self.lrs.append(trainer.driver.optimizers[0].param_groups[0]['lr'])


@pytest.mark.parametrize('warmup', [5, 0.1])
@pytest.mark.parametrize('schedule', ['constant', 'linear'])
@pytest.mark.parametrize('accumulation_steps', [1, 3, 4])
def test_torch_warmup_callback(warmup, schedule, accumulation_steps):
kwargs = get_trainer_args(lr=0.1, bsz=4)
callback = TorchWarmupCallback(warmup, schedule)
r_callback = RecordLrCallback()
kwargs['callbacks'] = [callback, r_callback]
trainer = Trainer(**kwargs, accumulation_steps=accumulation_steps)
trainer.run()

if schedule == 'linear':
assert kwargs['optimizers'].param_groups[0]['lr'] <= 0.01
elif schedule == 'constant':
assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr'])

assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1

+ 27
- 53
tests/core/controllers/test_trainer_paddle.py View File

@@ -1,13 +1,11 @@
import pytest
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
from typing import Any
from dataclasses import dataclass

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK

from paddle.optimizer import Adam
from paddle.io import DataLoader
@@ -19,40 +17,18 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordM
from tests.helpers.utils import magic_argv_env_context

@dataclass
class MNISTTrainPaddleConfig:
class TrainPaddleConfig:
num_labels: int = 10
feature_dimension: int = 784
feature_dimension: int = 10

batch_size: int = 32
batch_size: int = 2
shuffle: bool = True
validate_every = -5
evaluate_every = 2

driver: str = "paddle"
device = "gpu"

@dataclass
class MNISTTrainFleetConfig:
num_labels: int = 10
feature_dimension: int = 784

batch_size: int = 32
shuffle: bool = True
validate_every = -5

@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
validate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None

@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
RichCallback(5)]])
@magic_argv_env_context
def test_trainer_paddle(
driver,
@@ -60,38 +36,36 @@ def test_trainer_paddle(
callbacks,
n_epochs=2,
):
trainer_params = TrainerParameters()

trainer_params.model = PaddleNormalModel_Classification_1(
num_labels=MNISTTrainPaddleConfig.num_labels,
feature_dimension=MNISTTrainPaddleConfig.feature_dimension
model = PaddleNormalModel_Classification_1(
num_labels=TrainPaddleConfig.num_labels,
feature_dimension=TrainPaddleConfig.feature_dimension
)
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(6400, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
dataset=PaddleRandomMaxDataset(20, 10),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(1000, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
dataset=PaddleRandomMaxDataset(20, 10),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)
trainer_params.train_dataloader = train_dataloader
trainer_params.validate_dataloaders = val_dataloader
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
train_dataloader = train_dataloader
evaluate_dataloaders = val_dataloader
evaluate_every = TrainPaddleConfig.evaluate_every
metrics = {"acc": Accuracy(backend="paddle")}
trainer = Trainer(
model=trainer_params.model,
model=model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,
optimizers=optimizers,
train_dataloader=train_dataloader,
evaluate_dataloaders=evaluate_dataloaders,
evaluate_every=evaluate_every,
input_mapping=None,
output_mapping=None,
metrics=metrics,

n_epochs=n_epochs,
callbacks=callbacks,


+ 39
- 28
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -117,12 +117,13 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)

assert not (replaced_loader is dataloader)
@@ -133,12 +134,13 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
sampler = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)

assert not (replaced_loader is dataloader)
@@ -171,14 +173,15 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
"""
dataloader = DataLoader(
self.dataset,
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle),
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
@@ -195,12 +198,13 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
@@ -222,11 +226,12 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert replaced_loader is dataloader
@@ -238,14 +243,15 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

@@ -258,13 +264,14 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -276,16 +283,17 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
@@ -293,7 +301,7 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
dist.barrier()

"""
@@ -302,13 +310,14 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -320,18 +329,19 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -349,11 +359,12 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)


+ 133
- 82
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,4 +1,5 @@
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest
from pathlib import Path
@@ -56,34 +57,57 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
num_consumed_batches = 2

# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)

sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)

assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)


finally:
synchronize_safe_rm(path)

@@ -104,21 +128,36 @@ def test_save_and_load_with_randomsampler(only_state_dict):
dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2

# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)

sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])

# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
replaced_loader = load_states["dataloader"]

assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
@@ -129,60 +168,51 @@ def test_save_and_load_with_randomsampler(only_state_dict):

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

assert paddle.equal_all(res1["pred"], res2["pred"])

# 4. 检查 batch_idx
# TODO
finally:
synchronize_safe_rm(path)

def test_save_and_load_state_dict(prepare_test_save_load):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
"""
try:
path = "dict"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path)
driver2.load_model(path)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
finally:
synchronize_safe_rm(path)

def test_save_and_load_whole_model(prepare_test_save_load):
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
测试 save_model 和 load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict=False)
if only_state_dict:
driver1.save_model(path, only_state_dict)
else:
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

class TestSingleDeviceFunction:
"""
@@ -199,13 +229,7 @@ class TestSingleDeviceFunction:
测试能否运行
"""
res = self.driver.unwrap_model()

def test_check_evaluator_mode(self):
"""
这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素
"""
self.driver.check_evaluator_mode("validate")
self.driver.check_evaluator_mode("test")
assert res is self.driver.model

def test_is_distributed(self):
assert self.driver.is_distributed() == False
@@ -237,44 +261,55 @@ class TestSetDistReproDataloder:

assert replaced_loader is dataloader

def test_set_dist_repro_dataloader_with_reproducible_true(self):
@pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
@@ -284,16 +319,21 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
batch_sampler=RandomBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

@@ -303,15 +343,16 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -323,11 +364,11 @@ class TestSetDistReproDataloder:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
@@ -346,9 +387,6 @@ class TestSetDistReproDataloder:
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

import time
time.sleep(5)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
@@ -357,16 +395,29 @@ class TestSetDistReproDataloder:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.load_state_dict(sampler_states)
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
)
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
replaced_loader.batch_sampler.sampler.set_epoch(0)
for idx, batch in enumerate(replaced_loader):
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)


+ 0
- 31
tests/core/samplers/test_sampler.py View File

@@ -1,31 +0,0 @@
import unittest
import random
from fastNLP.core.samplers import SequentialSampler, RandomSampler, BucketSampler
from fastNLP.core.dataset import DataSet
from array import array
import torch

from fastNLP.core.samplers.sampler import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset


class SamplerTest(unittest.TestCase):

def test_sequentialsampler(self):
ds = DataSet({'x': [1, 2, 3, 4] * 10})
sqspl = SequentialSampler(ds)
for idx, inst in enumerate(sqspl):
self.assertEqual(idx, inst)

def test_randomsampler(self):
ds = DataSet({'x': [1, 2, 3, 4] * 10})
rdspl = RandomSampler(ds)
ans = [ds[i] for i in rdspl]
self.assertEqual(len(ans), len(ds))

def test_bucketsampler(self):
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
sampler = BucketSampler(data_set, num_buckets=3, batch_size=16, seq_len_field_name="seq_len")



+ 1
- 1
tests/envs/test_set_backend.py View File

@@ -1,6 +1,6 @@
import os

from fastNLP.envs.set_env import dump_fastnlp_backend
from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing
from fastNLP.core import synchronize_safe_rm



+ 1
- 1
tests/helpers/callbacks/helper_callbacks.py View File

@@ -72,7 +72,7 @@ class RecordTrainerEventTriggerCallback(Callback):
print("on_train_end")

def on_train_epoch_begin(self, trainer):
if trainer.current_epoch_idx >= 1:
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")


+ 68
- 0
tests/helpers/callbacks/prepare_trainer_args_for_torch_test.py View File

@@ -0,0 +1,68 @@

"""
这个文件主要用于提供测试 callback 时的 Trainer 的参数,可以直接使用进行对Trainer进行初始化。只需要再额外传入相应的callback就可以运行

"""

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.metrics import Accuracy


if _NEED_IMPORT_TORCH:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

class DataSet:
def __init__(self, num_samples=1000, num_features=10):
g = torch.Generator()
g.manual_seed(1000)
self.data = torch.randn(num_samples, num_features, generator=g)
self.y = self.data.argmax(dim=-1)

def __getitem__(self, item):
return {'x': self.data[item], 'target': self.y[item]}

def __len__(self):
return len(self.data)


class Model(nn.Module):
def __init__(self, num_features=5):
super().__init__()
self.mlps = nn.Sequential(
nn.Linear(num_features, 20),
nn.ReLU(),
nn.Linear(20, 20),
nn.Dropout(p=0.3),
nn.ReLU(),
nn.Linear(20, num_features)
)

def forward(self, x, target):
y = self.mlps(x)
if self.training:
return {'loss': F.cross_entropy(y, target)}
return {'pred': y}


def get_trainer_args(num_features=5, num_samples=20, bsz=4, lr=0.1, n_epochs=5, device=None):
ds = DataSet(num_samples=num_samples, num_features=num_features)
dl = DataLoader(ds, batch_size=bsz)
model = Model(num_features=num_features)

optimizer = torch.optim.SGD(model.parameters(), lr=lr)

kwargs = {
'model': model,
'driver': 'torch',
'device': device,
'optimizers': optimizer,
'train_dataloader': dl,
'evaluate_dataloaders': dl,
'metrics': {'acc': Accuracy()},
'n_epochs': n_epochs
}

return kwargs

+ 1
- 1
tests/helpers/models/paddle_model.py View File

@@ -26,7 +26,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer):
x = self(x)
return {"loss": self.loss_fn(x, y)}

def validate_step(self, x, y):
def evaluate_step(self, x, y):

x = self(x)
return {"pred": x, "target": y.reshape((-1,))}


Loading…
Cancel
Save