Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
1216ac6e90
61 changed files with 5012 additions and 317 deletions
  1. +32
    -4
      fastNLP/core/callbacks/callback.py
  2. +2
    -2
      fastNLP/core/callbacks/callback_events.py
  3. +3
    -14
      fastNLP/core/callbacks/callback_manager.py
  4. +1
    -1
      fastNLP/core/callbacks/checkpoint_callback.py
  5. +4
    -4
      fastNLP/core/callbacks/early_stop_callback.py
  6. +1
    -1
      fastNLP/core/callbacks/has_monitor_callback.py
  7. +6
    -17
      fastNLP/core/callbacks/load_best_model_callback.py
  8. +6
    -6
      fastNLP/core/callbacks/more_evaluate_callback.py
  9. +26
    -13
      fastNLP/core/callbacks/progress_callback.py
  10. +181
    -0
      fastNLP/core/collators/new_collator.py
  11. +0
    -0
      fastNLP/core/collators/padders/__init__.py
  12. +44
    -0
      fastNLP/core/collators/padders/exceptions.py
  13. +193
    -0
      fastNLP/core/collators/padders/get_padder.py
  14. +72
    -0
      fastNLP/core/collators/padders/numpy_padder.py
  15. +21
    -0
      fastNLP/core/collators/padders/padder.py
  16. +48
    -0
      fastNLP/core/collators/padders/raw_padder.py
  17. +157
    -0
      fastNLP/core/collators/padders/torch_padder.py
  18. +20
    -0
      fastNLP/core/collators/padders/torch_utils.py
  19. +173
    -0
      fastNLP/core/collators/padders/utils.py
  20. +103
    -0
      fastNLP/core/collators/utils.py
  21. +32
    -47
      fastNLP/core/controllers/evaluator.py
  22. +1
    -1
      fastNLP/core/controllers/loops/train_batch_loop.py
  23. +95
    -46
      fastNLP/core/controllers/trainer.py
  24. +8
    -8
      fastNLP/core/controllers/utils/utils.py
  25. +1
    -1
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  26. +210
    -25
      fastNLP/core/drivers/paddle_driver/fleet.py
  27. +25
    -7
      fastNLP/core/drivers/paddle_driver/fleet_launcher.py
  28. +7
    -10
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  29. +12
    -4
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  30. +56
    -0
      fastNLP/core/drivers/paddle_driver/single_device.py
  31. +6
    -18
      fastNLP/core/drivers/paddle_driver/utils.py
  32. +1
    -1
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  33. +13
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  34. +13
    -0
      fastNLP/core/log/logger.py
  35. +3
    -3
      fastNLP/core/samplers/reproducible_batch_sampler.py
  36. +14
    -1
      fastNLP/core/utils/paddle_utils.py
  37. +22
    -1
      fastNLP/core/utils/rich_progress.py
  38. +20
    -1
      fastNLP/core/utils/utils.py
  39. +2
    -2
      tests/core/callbacks/test_checkpoint_callback_torch.py
  40. +2
    -2
      tests/core/callbacks/test_load_best_model_callback_torch.py
  41. +2
    -2
      tests/core/callbacks/test_more_evaluate_callback.py
  42. +0
    -0
      tests/core/collators/__init__.py
  43. +0
    -0
      tests/core/collators/padders/__init__.py
  44. +139
    -0
      tests/core/collators/padders/test_get_padder.py
  45. +81
    -0
      tests/core/collators/padders/test_numpy_padder.py
  46. +29
    -0
      tests/core/collators/padders/test_raw_padder.py
  47. +105
    -0
      tests/core/collators/padders/test_torch_padder.py
  48. +90
    -0
      tests/core/collators/padders/test_utils.py
  49. +225
    -0
      tests/core/collators/test_new_collator.py
  50. +37
    -0
      tests/core/collators/test_utils.py
  51. +2
    -2
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  52. +2
    -2
      tests/core/drivers/paddle_driver/test_fleet.py
  53. +4
    -4
      tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py
  54. +21
    -25
      tests/core/drivers/paddle_driver/test_single_device.py
  55. +788
    -0
      tests/core/drivers/torch_driver/test_ddp.py
  56. +103
    -0
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py
  57. +697
    -0
      tests/core/drivers/torch_driver/test_single_device.py
  58. +36
    -35
      tests/core/drivers/torch_driver/test_utils.py
  59. +5
    -5
      tests/helpers/callbacks/helper_callbacks.py
  60. +1
    -1
      tests/helpers/datasets/torch_data.py
  61. +1009
    -0
      tutorials/fastnlp_tutorial_0.ipynb

+ 32
- 4
fastNLP/core/callbacks/callback.py View File

@@ -12,6 +12,34 @@ from fastNLP.core.callbacks.callback_events import _SingleEventState
class Callback: class Callback:
r""" r"""
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类;
callback 调用时机顺序大概如下
Trainer.__init__():
on_after_trainer_initialized()
Trainer.run():
if num_eval_sanity_batch>0:
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch
on_sanity_check_end()
try:
on_train_begin()
while cur_epoch_idx < n_epochs:
on_train_epoch_begin()
while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin()
on_fetch_data_end()
on_train_batch_begin()
on_before_backward()
on_after_backward()
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_train_batch_end()
on_train_epoch_end()
except BaseException:
self.on_exception()
finally:
on_train_end()
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将
""" """


def on_after_trainer_initialized(self, trainer, driver): def on_after_trainer_initialized(self, trainer, driver):
@@ -221,9 +249,9 @@ class Callback:
""" """
pass pass


def on_validate_begin(self, trainer):
def on_evaluate_begin(self, trainer):
""" """
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后
在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。


:param trainer: :param trainer:
@@ -231,9 +259,9 @@ class Callback:
""" """
pass pass


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
""" """
结束 validate 时调用,并把 validate 的结果传入。
结束 evaluate 时调用,并把 evaluate 的结果传入。


:param trainer: :param trainer:
:param results: Evaluate 的结果,一般是个 dict 。 :param results: Evaluate 的结果,一般是个 dict 。


+ 2
- 2
fastNLP/core/callbacks/callback_events.py View File

@@ -96,8 +96,8 @@ class Events(EventEnum):
on_after_optimizers_step = "on_after_optimizers_step" on_after_optimizers_step = "on_after_optimizers_step"
on_before_zero_grad = "on_before_zero_grad" on_before_zero_grad = "on_before_zero_grad"
on_after_zero_grad = "on_after_zero_grad" on_after_zero_grad = "on_after_zero_grad"
on_validate_begin = "on_validate_begin"
on_validate_end = "on_validate_end"
on_evaluate_begin = "on_evaluate_begin"
on_evaluate_end = "on_evaluate_end"




class EventsList: class EventsList:


+ 3
- 14
fastNLP/core/callbacks/callback_manager.py View File

@@ -8,7 +8,6 @@ __all__ = [


from .callback_events import Events from .callback_events import Events
from .callback import Callback from .callback import Callback
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.core.log import logger from fastNLP.core.log import logger




@@ -35,7 +34,7 @@ class CallbackManager:
class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback;
callback_fns: dict callback_fns: dict


def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'):
def __init__(self, callbacks: Optional[List[Callback]]):
r""" r"""
注意 callback 的调用顺序: 注意 callback 的调用顺序:
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数;
@@ -46,7 +45,6 @@ class CallbackManager:
""" """
self._need_reproducible_sampler = False self._need_reproducible_sampler = False


_has_progress_callback = False
_callbacks = [] _callbacks = []
if callbacks is not None: if callbacks is not None:
if isinstance(callbacks, Callback): if isinstance(callbacks, Callback):
@@ -57,16 +55,7 @@ class CallbackManager:
for _callback in callbacks: for _callback in callbacks:
if not isinstance(_callback, Callback): if not isinstance(_callback, Callback):
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
if isinstance(_callback, ProgressCallback):
_has_progress_callback = True
_callbacks += callbacks _callbacks += callbacks
if not _has_progress_callback:
# 添加 progress callback
progress_callback = choose_progress_callback(progress_bar=progress_bar)
if progress_callback is None:
logger.info("There is no progress bar, Trainer will not output training progress.")
else:
_callbacks.append(progress_callback)
self.callback_fns = defaultdict(list) self.callback_fns = defaultdict(list)
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可;
@@ -292,9 +281,9 @@ class CallbackManager:
pass pass


@_transfer @_transfer
def on_validate_begin(self, trainer):
def on_evaluate_begin(self, trainer):
pass pass


@_transfer @_transfer
def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
pass pass

+ 1
- 1
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -114,7 +114,7 @@ class CheckpointCallback(Callback):
if self.topk_saver.topk_queue and trainer.evaluator is None: if self.topk_saver.topk_queue and trainer.evaluator is None:
logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.")


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
# 如果发生了保存,则返回的 folder 不为 None # 如果发生了保存,则返回的 folder 不为 None
folder = self.topk_saver.save_topk(trainer, results) folder = self.topk_saver.save_topk(trainer, results)




+ 4
- 4
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -16,13 +16,13 @@ class EarlyStopCallback(HasMonitorCallback):
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
果(字典类型),返回一个 float 值作为 monitor 的结果。 果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。 :param larger_better: monitor 的值是否是越大越好。
:param patience: 多少次 validate 不没有提升就停止。
:param patience: 多少次 evaluate 不没有提升就停止。
""" """
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True)
self.wait = 0 self.wait = 0
self.patience = patience self.patience = patience


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
monitor_value = self.get_monitor_value(results) monitor_value = self.get_monitor_value(results)
if monitor_value is None: if monitor_value is None:
return return
@@ -32,13 +32,13 @@ class EarlyStopCallback(HasMonitorCallback):
self.wait += 1 self.wait += 1


def on_fetch_data_begin(self, trainer): def on_fetch_data_begin(self, trainer):
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。
# 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。
if self.wait >= self.patience: if self.wait >= self.patience:
raise EarlyStopException(f"After {self.wait} validations, no improvement for " raise EarlyStopException(f"After {self.wait} validations, no improvement for "
f"metric `{self._real_monitor}`") f"metric `{self._real_monitor}`")


def on_train_epoch_begin(self, trainer): def on_train_epoch_begin(self, trainer):
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。
# 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。
if self.wait >= self.patience: if self.wait >= self.patience:
raise EarlyStopException(f"After {self.wait} validations, no improvement for " raise EarlyStopException(f"After {self.wait} validations, no improvement for "
f"metric `{self._real_monitor}`(best value: {self.monitor_value})") f"metric `{self._real_monitor}`(best value: {self.monitor_value})")


+ 1
- 1
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -216,6 +216,6 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback):
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
self.execute_fn = execute_fn self.execute_fn = execute_fn


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if self.is_better_results(results): if self.is_better_results(results):
self.execute_fn() self.execute_fn()

+ 6
- 17
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -76,7 +76,7 @@ class LoadBestModelCallback(HasMonitorCallback):


super().on_after_trainer_initialized(trainer, driver) super().on_after_trainer_initialized(trainer, driver)


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True): if self.is_better_results(results, keep_if_better=True):
if self.real_save_folder: if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
@@ -95,25 +95,14 @@ class LoadBestModelCallback(HasMonitorCallback):
self.buffer.seek(0) self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)


if self.delete_after_after:
if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
# 只需要 rank 0 执行删除。
logger.info(f"Deleting {self.real_save_folder}...")
shutil.rmtree(self.real_save_folder)
try:
# 如果是 emtpy 的,就会被删除掉
os.rmdir(self.save_folder)
except:
pass
elif hasattr(self, 'buffer'):
self.buffer.close()
del self.buffer
self._delete_after_after(trainer)


def on_exception(self, trainer, exception):
def _delete_after_after(self, trainer):
trainer.driver.barrier()
if self.delete_after_after: if self.delete_after_after:
if self.real_save_folder: # 这里,谁处异常,谁删除
if self.real_save_folder:
logger.info(f"Deleting {self.real_save_folder}...") logger.info(f"Deleting {self.real_save_folder}...")
shutil.rmtree(self.real_save_folder)
shutil.rmtree(self.real_save_folder, ignore_errors=True)
try: try:
# 如果是 emtpy 的,就会被删除掉 # 如果是 emtpy 的,就会被删除掉
os.rmdir(self.save_folder) os.rmdir(self.save_folder)


+ 6
- 6
fastNLP/core/callbacks/more_evaluate_callback.py View File

@@ -31,8 +31,8 @@ class MoreEvaluateCallback(HasMonitorCallback):


:param dataloaders: 需要评估的数据 :param dataloaders: 需要评估的数据
:param metrics: 使用的 metrics 。 :param metrics: 使用的 metrics 。
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch evaluate 一次;(2) 为正整数则表示每隔几个 batch
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回
一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种
@@ -108,7 +108,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
'metrics': self.metrics, 'metrics': self.metrics,
'driver': self.kwargs.get('driver', trainer.driver), 'driver': self.kwargs.get('driver', trainer.driver),
'device': self.kwargs.get('device', trainer.device), 'device': self.kwargs.get('device', trainer.device),
'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn),
'evaluate_batch_step_fn': self.kwargs.get('evaluate_batch_step_fn', trainer.evaluate_batch_step_fn),
'evaluate_fn': self.evaluate_fn, 'evaluate_fn': self.evaluate_fn,
'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), 'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping),
'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), 'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping),
@@ -128,7 +128,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch)
self.topk_saver.get_monitor_value(results) self.topk_saver.get_monitor_value(results)


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True): if self.is_better_results(results, keep_if_better=True):
results = self.evaluator.run() results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results) self.topk_saver.save_topk(trainer, results)
@@ -137,8 +137,8 @@ class MoreEvaluateCallback(HasMonitorCallback):
if self.watch_monitor is not None: if self.watch_monitor is not None:
return return
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
validate_every = -self.evaluate_every
if trainer.cur_epoch_idx % validate_every == 0:
evaluate_every = -self.evaluate_every
if trainer.cur_epoch_idx % evaluate_every == 0:
results = self.evaluator.run() results = self.evaluator.run()
self.topk_saver.save_topk(trainer, results) self.topk_saver.save_topk(trainer, results)




+ 26
- 13
fastNLP/core/callbacks/progress_callback.py View File

@@ -1,6 +1,6 @@
import json import json
import sys import sys
from typing import Union


__all__ = [ __all__ = [
'choose_progress_callback', 'choose_progress_callback',
@@ -11,11 +11,22 @@ __all__ = [
from .has_monitor_callback import HasMonitorCallback from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils import f_rich_progress from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.utils.utils import is_notebook





def choose_progress_callback(progress_bar:str):
class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()

@property
def name(self): # progress bar的名称
return 'auto'


def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback:
if progress_bar == 'auto': if progress_bar == 'auto':
if (sys.stdin and sys.stdin.isatty()):
if not f_rich_progress.dummy_rich:
progress_bar = 'rich' progress_bar = 'rich'
else: else:
progress_bar = 'raw' progress_bar = 'raw'
@@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str):
return RichCallback() return RichCallback()
elif progress_bar == 'raw': elif progress_bar == 'raw':
return RawTextCallback() return RawTextCallback()
elif isinstance(progress_bar, ProgressCallback):
return progress_bar
else: else:
return None return None




class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()


class RichCallback(ProgressCallback): class RichCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
format_json=True): format_json=True):
@@ -92,7 +100,7 @@ class RichCallback(ProgressCallback):
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}',
advance=self.epoch_bar_update_advance, refresh=True) advance=self.epoch_bar_update_advance, refresh=True)


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if len(results)==0: if len(results)==0:
return return
rule_style = '' rule_style = ''
@@ -114,9 +122,6 @@ class RichCallback(ProgressCallback):
else: else:
self.progress_bar.print(results) self.progress_bar.print(results)


def on_exception(self, trainer, exception):
self.clear_tasks()

def clear_tasks(self): def clear_tasks(self):
for key, taskid in self.task2id.items(): for key, taskid in self.task2id.items():
self.progress_bar.destroy_task(taskid) self.progress_bar.destroy_task(taskid)
@@ -124,6 +129,10 @@ class RichCallback(ProgressCallback):
self.task2id = {} self.task2id = {}
self.loss = 0 self.loss = 0


@property
def name(self): # progress bar的名称
return 'rich'



class RawTextCallback(ProgressCallback): class RawTextCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
@@ -166,7 +175,7 @@ class RawTextCallback(ProgressCallback):
f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.'
logger.info(text) logger.info(text)


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if len(results)==0: if len(results)==0:
return return
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
@@ -184,3 +193,7 @@ class RawTextCallback(ProgressCallback):
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) logger.info(json.dumps(trainer.driver.tensor_to_numeric(results)))
else: else:
logger.info(results) logger.info(results)

@property
def name(self): # progress bar的名称
return 'raw'

+ 181
- 0
fastNLP/core/collators/new_collator.py View File

@@ -0,0 +1,181 @@
from typing import List, Union, Dict, Callable, Sequence, Mapping

from fastNLP.core.log import logger
from .padders.get_padder import get_padder

import re

from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \
pack_batch_sequence, NESTED_DICT_SEPARATOR

sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None]


class Collator:
def __init__(self, backend='torch'):
"""
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None],
若为 None ,则不进行 padding 。
"""
self.unpack_batch_func = None
self.pack_batch_func = None
self.ignore_fields = set()
self.padders = {}
self.input_fields = {}
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。
self.set_backend(backend)

def __call__(self, batch)->Union[List, Dict]:
"""
batch可能存在三种可能性
List[Dict], List[List], List[Sample]

第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。
第二步:使用每个 field 各自的 padder 进行 pad 。
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。

第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample
的类别。
第一次调用会根据当前 field 决定对应的 Padder 。

"""
if self.unpack_batch_func is None:
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型
if self.batch_data_type is None:
if isinstance(batch[0], Mapping):
self.batch_data_type = 'd'
elif isinstance(batch[0], Sequence): # 这里存在误判的风险
self.batch_data_type = 'l'
else:
self.batch_data_type = 's'
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
f"is {self.batch_data_type}")
if self.batch_data_type == 's':
self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整
self.pack_batch_func = lambda x:x['_single']
elif self.batch_data_type == 'l':
self.unpack_batch_func = unpack_batch_sequence
self.pack_batch_func = pack_batch_sequence
elif self.batch_data_type == 'd':
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value}
self.unpack_batch_func = unpack_batch_nested_mapping
self.pack_batch_func = pack_batch_nested_mapping
else:
self.unpack_batch_func = unpack_batch_mapping
self.pack_batch_func = lambda x:x

unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。

pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder
for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}

for field_name, setting in self.input_fields.items():
pad_fn = setting.get('pad_fn', None)
if callable(pad_fn):
padder = pad_fn
else:
batch_field = unpack_batch.get(field_name)
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
dtype=setting['dtype'], backend=setting['backend'],
field_name=field_name)
self.padders[field_name] = padder
if self.batch_data_type == 'l':
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序

for key, padder in self.padders.items():
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)

return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型

def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "Collator":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
self.padders.clear() # 重新生成

if self.batch_data_type is not None:
if self.batch_data_type == 's':
logger.debug("Set as single field mode.")
self.input_fields.clear()
elif self.batch_data_type == 'd':
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \
f"index, but other field is set as dict mode."
elif self.batch_data_type == 'l':
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
f"field name is {field_name}"

if field_name == '_single':
self.batch_data_type = 's'
elif sequence_idx_str.match(field_name):
self.batch_data_type = 'l'
else:
self.batch_data_type = 'd'

if field_name in self.ignore_fields:
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.")
if backend is None:
backend = self.backend
else:
assert backend in SUPPORTED_BACKENDS

self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn}

return self

def set_backend(self, backend:str):
"""
设置可以 pad 的 field 默认 pad 为什么类型的 tensor

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None],
若为 None ,则不进行 padding 。
:return:
"""
assert backend in SUPPORTED_BACKENDS
self.padders.clear()
self.backend = backend

def set_ignore(self, *field_names) -> "Collator":
"""
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
for field_name in field_names:
if field_name in self.input_fields:
self.input_fields.pop(field_name)
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.")
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。
self.ignore_fields.add(field_name)

return self



+ 0
- 0
fastNLP/core/collators/padders/__init__.py View File


+ 44
- 0
fastNLP/core/collators/padders/exceptions.py View File

@@ -0,0 +1,44 @@
__all__ = [
'InconsistencyError',
'EleDtypeUnsupportedError',
'EleDtypeDtypeConversionError',
'DtypeUnsupportedError',
"DtypeError"
]


class InconsistencyError(BaseException):
"""
当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。

"""
def __init__(self, msg, *args):
super(InconsistencyError, self).__init__(msg, *args)


class DtypeError(BaseException):
def __init__(self, msg, *args):
super(DtypeError, self).__init__(msg, *args)
self.msg = msg


class EleDtypeUnsupportedError(DtypeError):
"""
当 batch 中的 element 的类别本身无法 pad 的时候报错。
例如要求 str 类型的数据进行 padding 。

"""


class EleDtypeDtypeConversionError(DtypeError):
"""
当 batch 中的 element 的类别无法转换为 dtype 类型时报错。

"""


class DtypeUnsupportedError(DtypeError):
"""
当当前 backend 不支持这种类型的 dtype 时报错。

"""

+ 193
- 0
fastNLP/core/collators/padders/get_padder.py View File

@@ -0,0 +1,193 @@

from typing import Dict



from typing import Sequence, Any, Union, Dict
from abc import ABC

from fastNLP.core.log import logger


from .padder import Padder, NullPadder
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder
from .exceptions import *


def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder:
"""
根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。

:param batch_field: 将某 field 的内容组合成一个 batch 传入。
:param pad_val:
:param backend:
:param dtype:
:param field_name: 方便报错的。
:return:
"""
logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field))
if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
return NullPadder()
if backend is None:
logger.debug(f"The backend for field:{field_name} is None, not padding this field.")
return NullPadder()

# 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。
must_pad = False
if pad_val != 0 or dtype is not None:
must_pad = True

catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。

# 根据 catalog 来判定当前是否可以进行 pad 。
# 首先检查是否所有的 key 是一样长的,表明深度是一致的
depths = set(map(len, catalog.keys()))
num_depth = len(depths)
if num_depth != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
logger.debug(msg)
return NullPadder()

# 再检查所有的元素 shape 是否一致?
shape_lens = set([len(v[0]) for v in catalog.values()])
num_shape = len(shape_lens)
if num_shape != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
logger.debug(msg)
return NullPadder()

# 再检查所有的元素 type 是否一致
ele_dtypes = set([v[1] for v in catalog.values()])
num_eletypes = len(ele_dtypes)
if num_eletypes != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
logger.debug(msg)
return NullPadder()

depth = depths.pop()
shape_len = shape_lens.pop()
ele_dtype = ele_dtypes.pop()

# 需要由 padder 自己决定是否能够 pad 。
try:
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True]
if backend == 'raw':
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'numpy':
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw':
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'numpy':
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report."
if must_pad:
raise RuntimeError(msg)
logger.debug(msg)
return NullPadder()

except DtypeError as e:
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \
"information please set logger's level to DEBUG."
if must_pad:
raise type(e)(msg=msg)
logger.debug(msg)
return NullPadder()

except BaseException as e:
raise e

return NullPadder()


class HasShapeDtype(ABC):
"""
检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is HasShapeDtype:
if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'):
return True
return False
return NotImplemented


def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict:
"""
获取对象的中 element 的基本信息,用于判断是否可以 padding。

:param content:
:param tuple parent:
:param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。
例如: [1, 2, 3] -> {(0,): ((), <class 'int'>), (1,): ((), <class 'int'>), (2,): ((), <class 'int'>)}
例如: [1, [2, 3], 4] -> {(0,): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (1, 1): ((), <class 'int'>), (2,): ((), <class 'int'>)}
例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), <class 'int'>), (0, 1): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (2, 0): ((), <class 'int'>), (2, 1): ((), <class 'int'>)}
例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)]
-> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)}

:return:
"""
if catalog is None:
catalog = {}

if parent is None:
parent = ()

if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray
shape = content.shape
dtype = content.dtype
catalog[parent] = (shape, dtype)
elif isinstance(content, (tuple, list)):
for i, c in enumerate(content):
_get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog)
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别
return catalog




"""
from numbers import Number

issubclass(type(3), Number) # True
issubclass(type(3.1), Number) # True
issubclass(type('3'), Number) # False
issubclass(type(True), Number) # True
issubclass(type(np.zeros(3)[0]), Number) # True
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype)
"""




+ 72
- 0
fastNLP/core/collators/padders/numpy_padder.py View File

@@ -0,0 +1,72 @@
__all__ = [
'NumpyNumberPadder',
'NumpySequencePadder',
]

from numbers import Number
from abc import ABC
from typing import Any, Union
import numpy as np

from .padder import Padder
from .utils import get_padded_numpy_array, is_number_or_numpy_number
from .exceptions import *


def _get_dtype(ele_dtype, dtype, class_name):
if not is_number_or_numpy_number(ele_dtype):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers but get `{ele_dtype}`.")

if dtype is None:
dtype = ele_dtype
else:
if not is_number_or_numpy_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or numpy numbers but get `{dtype}`.")
dtype = dtype
return dtype


class NumpyNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
return np.array(batch_field, dtype=dtype)


class NumpySequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val)


class NumpyTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
"""
pad 类似于 [np.array([3, 4], np.array([1])] 的 field

:param ele_dtype:
:param pad_val:
:param dtype:
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
array = np.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
array[slices] = field
return array


+ 21
- 0
fastNLP/core/collators/padders/padder.py View File

@@ -0,0 +1,21 @@

class Padder:
def __init__(self, pad_val, dtype):
self.pad_val = pad_val
self.dtype = dtype

def __call__(self, batch_field):
return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
raise NotImplementedError()


class NullPadder(Padder):
def __init__(self, ele_dtype=None, pad_val=None, dtype=None):
super().__init__(pad_val=pad_val, dtype=dtype)

def __call__(self, batch_field):
# 直接返回,不调用 pad() 方法加快速度。
return batch_field

+ 48
- 0
fastNLP/core/collators/padders/raw_padder.py View File

@@ -0,0 +1,48 @@


from .padder import Padder
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array
from .exceptions import *


def _get_dtype(ele_dtype, dtype, class_name):
if is_number(ele_dtype):
if dtype is None:
dtype = ele_dtype
elif not is_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but "
f"get `{dtype}`.")
else:
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"but get `{ele_dtype}`.")
return dtype


class RawNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def __call__(self, batch_field):
return batch_field

@staticmethod
def pad(batch_field, pad_val, dtype):
raise NotImplementedError()


class RawSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
"""

:param batch_field:
:param pad_val:
:param dtype: 该参数无意义。
:return:
"""
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()

+ 157
- 0
fastNLP/core/collators/padders/torch_padder.py View File

@@ -0,0 +1,157 @@

from inspect import isclass
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
np.complex64: torch.complex64,
np.complex128: torch.complex128
}
number_to_torch_dtype_dict = {
float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64
int: torch.int64,
bool: torch.bool
}

from .padder import Padder
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
from .exceptions import *


def is_torch_tensor(dtype):
if not isclass(dtype) and isinstance(dtype, torch.dtype):
return True
return False


def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")

if dtype is not None:
if not (is_torch_tensor(dtype) or is_number(dtype)):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or torch.dtype but get `{dtype}`.")
dtype = number_to_torch_dtype_dict.get(dtype, dtype)
else:
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)):
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype)
dtype = ele_dtype
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type)
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_torch_dtype_dict.get(ele_dtype)

return dtype


class TorchNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
return torch.tensor(batch_field, dtype=dtype)


class TorchSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val)
return tensor


class TorchTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
"""
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的

:param ele_dtype:
:param pad_val:
:param dtype:
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
print(dtype)
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
if isinstance(field, np.ndarray):
field = torch.from_numpy(field)
tensor[slices] = field
return tensor


def fill_tensor(batch_field, padded_batch, dtype):
"""
将 batch_field 中的值填入到 tensor 中。

:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 tensor
:param dtype: 数据的类别

:return:
"""
if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype)
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype)
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
except:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype)
elif padded_batch.ndim == 1:
padded_batch[:] = torch.tensor(batch_field, dtype=dtype)
else:
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
"report.")
return padded_batch


def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0):
"""
例如:
[[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]])

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
/4d(多为图片)。
:param dtype: 目标类别是什么
:param pad_val: pad 的 value
:return:
"""
shapes = get_shape(batch_field)
tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

+ 20
- 0
fastNLP/core/collators/padders/torch_utils.py View File

@@ -0,0 +1,20 @@


from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch


def is_torch_tensor_dtype(dtype) -> bool:
"""
返回当前 dtype 是否是 torch 的 dtype 类型


:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果
:return:
"""
try:
return isinstance(dtype, torch.dtype)
except:
return False

+ 173
- 0
fastNLP/core/collators/padders/utils.py View File

@@ -0,0 +1,173 @@

from typing import Sequence, List
from numbers import Number
import re
from inspect import isclass

import numpy as np
np_str_obj_array_pattern = re.compile(r'[SaUO]')


def get_shape(batch_field:List, shape=None):
"""
给定 field 返回这个 field pad 完成之后的 shape 。
例如: [[1, 2, 3], [3]] -> [2, 3]
[[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3]

:param batch_field: list,第 0 维一般为 batch 维度。
:param shape: 无需传入。
:return:
"""
if shape is None:
shape = []
if isinstance(batch_field, Sequence):
num_ele = len(batch_field)
_shape = shape + [num_ele]
try:
shapes = []
if isinstance(batch_field[0], Sequence):
for _field in batch_field:
shapes.append(get_shape(_field, _shape))
max_shape = [max(_) for _ in zip(*shapes)]
return max_shape
except IndexError: # 空的shape
pass
return _shape # 说明是一个空的 sequence
else:
return shape


def fill_array(batch_field:List, padded_batch:np.ndarray):
"""
将 batch_field 中的值填入到 array 中。

:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 np.ndarray
:return:
"""
if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = content_i
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = content_ii
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
except:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = content_iii
elif padded_batch.ndim == 1:
padded_batch[:] = batch_field
else:
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
"report.")
return padded_batch


def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray:
"""
例如:
[[1,2], [3]] -> np.array([[1, 2], [3, 0]])

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
/4d(多为图片)。
:param dtype: 目标类别是什么
:param pad_val: pad 的 value
:return:
"""
shapes = get_shape(batch_field)
array = np.full(shapes, dtype=dtype, fill_value=pad_val)
array = fill_array(batch_field, array)
return array


def get_padded_nest_list(batch_field: List, pad_val=0) -> List:
"""
例如:
[[1,2], [3]] -> [[1, 2], [3, 0]]

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
/4d(多为图片)。
:param pad_val: pad 的 value
:return:
"""

array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist()
return array


def is_number_or_numpy_number(dtype):
"""
判断 dtype 是否是数字类型,或者 numpy 的数字类型。
is_number_or_numpy_number(type(3)) # True
is_number_or_numpy_number(type(3.1)) # True
is_number_or_numpy_number(type('3')) # False
is_number_or_numpy_number(type(True)) # True
is_number_or_numpy_number(type(np.zeros(3)[0])) # True
is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True
is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True
is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False
is_number_or_numpy_number(np.array([1, [2]]).dtype) # False

:param dtype:
:return:
"""
if is_number(dtype):
return True
else:
if isclass(dtype):
return is_numpy_generic_class(dtype)
elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
return True
return False


def is_numpy_number_dtype(dtype):
if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
return True
return False


def is_numpy_generic_class(dtype):
"""
形如 np.int64,或者 np.zeros(1).dtype.type 的值

:param dtype:
:return:
"""
if isclass(dtype) and issubclass(dtype, np.generic):
return True
return False


def is_number(dtype):
try:
if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \
and not is_numpy_number_dtype(dtype):
return True
except:
return False



if __name__ == '__main__':
# a = [[[1]], [1, 2, 3], [3]]
# a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
# b = get_padded_nest_list(a)
# print(type(b[0]))
# print(b)
# import torch
print(is_number_or_numpy_number(type(3))) # True
print(is_number_or_numpy_number(type(3.1))) # True
print(is_number_or_numpy_number(type('3'))) # False
print(is_number_or_numpy_number(type(True))) # True
print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True
print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True
print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True
print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False
print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False


+ 103
- 0
fastNLP/core/collators/utils.py View File

@@ -0,0 +1,103 @@
from collections import defaultdict
from functools import reduce
from typing import Sequence, Mapping, Dict

NESTED_DICT_SEPARATOR = '@@'


def unpack_batch_mapping(batch:Sequence[Mapping])->Dict:
"""
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]}

:param batch:
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
dict_batch[key].append(value)
return dict_batch


def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict:
"""
将 nested 的 dict 中的内容展开到一个 flat dict 中

:param batch:
:param _parent: 内部使用
:return:
"""
dict_batch = defaultdict(list)
if _parent != '':
_parent += NESTED_DICT_SEPARATOR
for sample in batch:
for key, value in sample.items():
if isinstance(value, Mapping):
_dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key)
for key, value in _dict_batch.items():
dict_batch[key].append(value)
else:
dict_batch[_parent + key].append(value)
return dict_batch


def _unpack_batch_nested_mapping(value, _parent)->Dict:
_dict = {}
_parent += NESTED_DICT_SEPARATOR
for k, v in value.items():
if isinstance(v, Mapping):
__dict = _unpack_batch_nested_mapping(v, _parent=_parent + k)
_dict.update(__dict)
else:
_dict[_parent + k] = v
return _dict


def pack_batch_nested_mapping(batch:Mapping) -> Dict:
"""
需要恢复出 nested 的 dict 原来的样式

:param batch:
:return:
"""
dicts = []

for key, value in batch.items():
keys = key.split(NESTED_DICT_SEPARATOR)
d = {keys[-1]: value}
for key in keys[:-1:][::-1]:
d = {key: d}
dicts.append(d)
return reduce(_merge_dict, dicts)


def _merge_dict(a, b, path=None):
"merges b into a"
if path is None: path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
_merge_dict(a[key], b[key], path + [str(key)])
else:
raise Exception('Conflict at %s' % '.'.join(path + [str(key)]))
else:
a[key] = b[key]
return a


def unpack_batch_sequence(batch:Sequence[Sequence])->Dict:
"""
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]}

:param batch:
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for i, content in enumerate(sample):
dict_batch[f'_{i}'].append(content)
return dict_batch


def pack_batch_sequence(batch:Mapping)->Sequence:
return list(batch.values())

+ 32
- 47
fastNLP/core/controllers/evaluator.py View File

@@ -20,47 +20,31 @@ from fastNLP.core.log import logger




class Evaluator: class Evaluator:
"""
1. 我们目前不直接提供每一个 metric 对应一个或者特殊的多个 dataloader 的功能,默认就是所有 metric 处理所有 dataloader,如果用户有这种
需求,请使用多个 Tester 进行操作;
2. Trainer 的 validate dataloader 只允许传进去一个,而 Tester 则可以多个;因为 Trainer 涉及到保存 topk 模型的逻辑,而 Tester
则只需要给出评测的结果即可;

"""
driver: Driver driver: Driver
_evaluate_batch_loop: Loop _evaluate_batch_loop: Loop


def __init__(
self,
model,
dataloaders,
metrics: Optional[Union[Dict, Metric]] = None,
driver: Union[str, Driver] = 'torch',
device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None,
evaluate_fn: Optional[str] = None,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
fp16: bool = False,
verbose: int = 1,
**kwargs
):
def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None,
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None,
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False,
fp16: bool = False, verbose: int = 1, **kwargs):
""" """
用于对数据进行评测。


:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 :param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。
:param dataloaders: 待评测的数据集。
:param dataloaders: 待评测的数据集。如果为多个,请使用 dict 传入。
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的
metric ,torchmetrics,allennlpmetrics等。
metric ,torchmetrics,allennlpmetrics 等。
:param driver: 使用 driver 。 :param driver: 使用 driver 。
:param device: 使用的设备。 :param device: 使用的设备。
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的
batch_step_fn 函数。
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`,
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中。如果针对
model 和 metric 需要不同的 mapping,请考虑使用 evaluate_batch_step_fn 参数定制。
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
@@ -69,7 +53,8 @@ class Evaluator:
:param verbose: 是否打印 evaluate 的结果。 :param verbose: 是否打印 evaluate 的结果。
:param kwargs: :param kwargs:
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout
与 batch normalization 将会关闭。默认为True。
与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。
TODO 还没完成。 TODO 还没完成。
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,
@@ -96,9 +81,9 @@ class Evaluator:
self.device = device self.device = device
self.verbose = verbose self.verbose = verbose


if batch_step_fn is not None:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.batch_step_fn = batch_step_fn
if evaluate_batch_step_fn is not None:
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn')
self.evaluate_batch_step_fn = evaluate_batch_step_fn


self.input_mapping = input_mapping self.input_mapping = input_mapping
self.output_mapping = output_mapping self.output_mapping = output_mapping
@@ -106,14 +91,14 @@ class Evaluator:
if not isinstance(dataloaders, dict): if not isinstance(dataloaders, dict):
dataloaders = {None: dataloaders} dataloaders = {None: dataloaders}


self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn)
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)


self.driver.setup() self.driver.setup()
self.driver.barrier() self.driver.barrier()


self.separator = kwargs.get('separator', '#') self.separator = kwargs.get('separator', '#')
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed())
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
if use_dist_sampler: if use_dist_sampler:
self._dist_sampler = "unrepeatdist" self._dist_sampler = "unrepeatdist"
else: else:
@@ -134,7 +119,7 @@ class Evaluator:


self.progress_bar = kwargs.get('progress_bar', 'auto') self.progress_bar = kwargs.get('progress_bar', 'auto')
if self.progress_bar == 'auto': if self.progress_bar == 'auto':
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw'
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich'


self.driver.barrier() self.driver.barrier()


@@ -235,8 +220,8 @@ class Evaluator:


@evaluate_batch_loop.setter @evaluate_batch_loop.setter
def evaluate_batch_loop(self, loop: Loop): def evaluate_batch_loop(self, loop: Loop):
if self.batch_step_fn is not None:
logger.warning("`batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
if self.evaluate_batch_step_fn is not None:
logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
"when the `evaluate_batch_loop` is also customized.") "when the `evaluate_batch_loop` is also customized.")
self._evaluate_batch_loop = loop self._evaluate_batch_loop = loop


@@ -249,15 +234,15 @@ class Evaluator:
""" """
self.metrics_wrapper.reset() self.metrics_wrapper.reset()


def update(self, *args, **kwargs):
def update(self, batch, outputs):
""" """
调用所有metric的 update 方法,对当前 batch 的结果进行累积,会根据相应 metric 的参数列表进行匹配传参。
自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。


:param args:
:param kwargs:
:param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。
:param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。
:return: :return:
""" """
self.metrics_wrapper.update(*args, **kwargs)
self.metrics_wrapper.update(batch, outputs)


def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict:
""" """
@@ -271,7 +256,7 @@ class Evaluator:
@property @property
def metrics_wrapper(self): def metrics_wrapper(self):
""" """
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 batch_step_fn )中使用,同时也为了支持
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持
不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper
进行操作。 进行操作。


@@ -283,11 +268,11 @@ class Evaluator:


def evaluate_step(self, batch): def evaluate_step(self, batch):
""" """
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再
返回。 返回。


:param batch:
:return:
:param batch: {evaluate_fn} 函数支持的输入类型
:return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。
""" """
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn)
outputs = match_and_substitute_params(self.output_mapping, outputs) outputs = match_and_substitute_params(self.output_mapping, outputs)


+ 1
- 1
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -43,7 +43,7 @@ class TrainBatchLoop(Loop):


trainer.check_batch_step_fn() trainer.check_batch_step_fn()
trainer.on_train_batch_end() trainer.on_train_batch_end()
trainer.step_validate()
trainer.step_evaluate()
trainer.batch_idx_in_epoch = 0 trainer.batch_idx_in_epoch = 0


@staticmethod @staticmethod


+ 95
- 46
fastNLP/core/controllers/trainer.py View File

@@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.callbacks.progress_callback import choose_progress_callback
from fastNLP.core.drivers import Driver from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
@@ -82,10 +83,10 @@ class Trainer(TrainerEventTrigger):
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; :param n_epochs: 训练总共的 epoch 的数量,默认为 20;
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None; 为 None;
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和
`batch`;默认为 None;
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的
两个参数必须为 `evaluator` 和 `batch`;默认为 None;
:param batch_step_fn: 定制每次 train batch 执行的函数。该函数应接受两个参数为 `trainer` 和`batch`,不需要要返回值;可以
参考 fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop中的batch_step_fn函数。
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`,
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; :param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`;
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, 默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,
则使用模型默认的前向传播函数。 则使用模型默认的前向传播函数。
@@ -102,10 +103,12 @@ class Trainer(TrainerEventTrigger):
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; 类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;
注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时); 注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时);
如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型,
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`;
@@ -125,14 +128,17 @@ class Trainer(TrainerEventTrigger):
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过
callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。

progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
""" """
self.model = model self.model = model
self.marker = marker self.marker = marker
@@ -147,8 +153,18 @@ class Trainer(TrainerEventTrigger):
self.evaluate_dataloaders = evaluate_dataloaders self.evaluate_dataloaders = evaluate_dataloaders
self.optimizers = optimizers self.optimizers = optimizers
self.fp16 = fp16 self.fp16 = fp16
self.input_mapping = input_mapping
self.output_mapping = output_mapping

train_input_mapping = kwargs.get('train_input_mapping', None)
train_output_mapping = kwargs.get('train_output_mapping', None)
evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None)
evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None)

train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \
_get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
evaluate_input_mapping, evaluate_output_mapping)

self.input_mapping = train_input_mapping
self.output_mapping = train_output_mapping
self.evaluate_fn = evaluate_fn self.evaluate_fn = evaluate_fn


self.batch_step_fn = batch_step_fn self.batch_step_fn = batch_step_fn
@@ -185,8 +201,8 @@ class Trainer(TrainerEventTrigger):
callbacks=callbacks, callbacks=callbacks,
metrics=metrics, metrics=metrics,
evaluate_every=evaluate_every, evaluate_every=evaluate_every,
input_mapping=input_mapping,
output_mapping=output_mapping,
input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping,
model_wo_auto_param_call=model_wo_auto_param_call, model_wo_auto_param_call=model_wo_auto_param_call,
accumulation_steps=accumulation_steps, accumulation_steps=accumulation_steps,
fp16=fp16, fp16=fp16,
@@ -195,8 +211,20 @@ class Trainer(TrainerEventTrigger):
) )
self.driver.set_optimizers(optimizers=optimizers) self.driver.set_optimizers(optimizers=optimizers)


# 根据 progress_bar 参数选择 ProgressBarCallback
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto'))
if progress_bar_callback is not None:
if callbacks is None:
callbacks = []
elif not isinstance(callbacks, Sequence):
callbacks = [callbacks]

callbacks = list(callbacks) + [progress_bar_callback]
else:
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
"during training.")
# 初始化 callback manager; # 初始化 callback manager;
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto'))
self.callback_manager = CallbackManager(callbacks)
# 添加所有的函数式 callbacks; # 添加所有的函数式 callbacks;
self._fetch_matched_fn_callbacks() self._fetch_matched_fn_callbacks()
# 添加所有的类 callbacks; # 添加所有的类 callbacks;
@@ -237,21 +265,15 @@ class Trainer(TrainerEventTrigger):
self.larger_better = larger_better self.larger_better = larger_better
if metrics is not None and evaluate_dataloaders is not None: if metrics is not None and evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every) check_evaluate_every(evaluate_every)
self.evaluator = Evaluator(
model=model,
dataloaders=evaluate_dataloaders,
metrics=metrics,
driver=self.driver,
device=device,
batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn,
input_mapping=input_mapping,
output_mapping=output_mapping,
fp16=fp16,
verbose=0,
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None),
progress_bar=kwargs.get('progress_bar', 'auto')
)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=input_mapping,
output_mapping=output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
progress_bar=progress_bar)


if train_fn is not None and not isinstance(train_fn, str): if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
@@ -317,11 +339,11 @@ class Trainer(TrainerEventTrigger):
self.num_batches_per_epoch = len(self.dataloader) self.num_batches_per_epoch = len(self.dataloader)
self.total_batches = self.num_batches_per_epoch * self.n_epochs self.total_batches = self.num_batches_per_epoch * self.n_epochs
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch
self.on_train_begin()
self.driver.barrier()
self.driver.zero_grad(self.set_grad_to_none)


try: try:
self.on_train_begin()
self.driver.barrier()
self.driver.zero_grad(self.set_grad_to_none)
while self.cur_epoch_idx < self.n_epochs: while self.cur_epoch_idx < self.n_epochs:
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
@@ -334,10 +356,8 @@ class Trainer(TrainerEventTrigger):
self.cur_epoch_idx += 1 self.cur_epoch_idx += 1
self.on_train_epoch_end() self.on_train_epoch_end()
self.driver.barrier() self.driver.barrier()
self.epoch_validate()
self.epoch_evaluate()
self.driver.barrier() self.driver.barrier()
self.on_train_end()
self.driver.barrier()


except EarlyStopException as e: except EarlyStopException as e:
logger.info(f"Catch early stop exception: {e.msg}.") logger.info(f"Catch early stop exception: {e.msg}.")
@@ -351,17 +371,20 @@ class Trainer(TrainerEventTrigger):
self.driver.on_exception() self.driver.on_exception()
self.on_exception(e) self.on_exception(e)
raise e raise e
finally:
self.on_train_end()
self.driver.barrier()


def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None:
trainer.on_validate_begin()
_validate_res: dict = validate_fn()
trainer.on_validate_end(_validate_res)
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
trainer.on_evaluate_begin()
_evaluate_res: dict = evaluate_fn()
trainer.on_evaluate_end(_evaluate_res)


if self.evaluator is not None: if self.evaluator is not None:
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))
self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))


def step_validate(self):
def step_evaluate(self):
""" """
在每个 batch 结束后调用,根据设置执行 evaluate 。 在每个 batch 结束后调用,根据设置执行 evaluate 。


@@ -374,7 +397,7 @@ class Trainer(TrainerEventTrigger):
elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0:
self.run_evaluate() self.run_evaluate()


def epoch_validate(self):
def epoch_evaluate(self):
""" """
在每个 epoch 结束后调用,根据设置执行 evaluate 。 在每个 epoch 结束后调用,根据设置执行 evaluate 。


@@ -382,8 +405,8 @@ class Trainer(TrainerEventTrigger):
""" """
if self.evaluator is not None: if self.evaluator is not None:
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
validate_every = -self.evaluate_every
if self.cur_epoch_idx % validate_every == 0:
evaluate_every = -self.evaluate_every
if self.cur_epoch_idx % evaluate_every == 0:
self.run_evaluate() self.run_evaluate()


def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
@@ -576,7 +599,7 @@ class Trainer(TrainerEventTrigger):
if model_load_fn is not None: if model_load_fn is not None:
if not callable(model_load_fn): if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
rank_zero_call(model_load_fn)(folder)
model_load_fn(folder)
else: else:
if isinstance(folder, str): if isinstance(folder, str):
folder = Path(folder) folder = Path(folder)
@@ -653,7 +676,7 @@ class Trainer(TrainerEventTrigger):
if model_load_fn is not None: if model_load_fn is not None:
if not callable(model_load_fn): if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable`.") raise ValueError("Parameter `model_save_fn` should be `Callable`.")
rank_zero_call(model_load_fn)(folder)
model_load_fn(folder)
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
else: else:
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
@@ -839,6 +862,32 @@ class Trainer(TrainerEventTrigger):
self._evaluate_dataloaders = evaluate_dataloaders self._evaluate_dataloaders = evaluate_dataloaders




def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
evaluate_input_mapping, evaluate_output_mapping):
if train_input_mapping is not None and input_mapping is not None:
raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.")

if evaluate_input_mapping is not None and input_mapping is not None:
raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.")

if train_output_mapping is not None and output_mapping is not None:
raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.")

if evaluate_output_mapping is not None and output_mapping is not None:
raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.")

if train_input_mapping is None:
train_input_mapping = input_mapping
if evaluate_input_mapping is None:
evaluate_input_mapping = input_mapping

if train_output_mapping is None:
train_output_mapping = output_mapping
if evaluate_output_mapping is None:
evaluate_output_mapping = output_mapping

return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping









+ 8
- 8
fastNLP/core/controllers/utils/utils.py View File

@@ -81,12 +81,12 @@ class TrainerEventTrigger:
def on_after_zero_grad(self, optimizers): def on_after_zero_grad(self, optimizers):
self.callback_manager.on_after_zero_grad(self, optimizers) self.callback_manager.on_after_zero_grad(self, optimizers)


def on_validate_begin(self):
self.callback_manager.on_validate_begin(self)
def on_evaluate_begin(self):
self.callback_manager.on_evaluate_begin(self)


def on_validate_end(self, results):
def on_evaluate_end(self, results):
self.trainer_state.save_on_this_step = True self.trainer_state.save_on_this_step = True
self.callback_manager.on_validate_end(self, results)
self.callback_manager.on_evaluate_end(self, results)




class _TruncatedDataLoader: class _TruncatedDataLoader:
@@ -126,8 +126,8 @@ class _TruncatedDataLoader:
return getattr(self.dataloader, item) return getattr(self.dataloader, item)




def check_evaluate_every(validate_every):
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
def check_evaluate_every(evaluate_every):
if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0):
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):
_check_valid_parameters_number(validate_every, expected_params=['trainer'])
if callable(evaluate_every):
_check_valid_parameters_number(evaluate_every, expected_params=['trainer'])

+ 1
- 1
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -63,7 +63,7 @@ class JittorDriver(Driver):


def check_evaluator_mode(self, mode: str): def check_evaluator_mode(self, mode: str):
model = self.unwrap_model() model = self.unwrap_model()
if mode == "validate":
if mode == "evaluate":
if not hasattr(model, "evaluate_step"): if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"): if hasattr(model, "test_step"):
logger.warning_once( logger.warning_once(


+ 210
- 25
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -19,6 +19,7 @@ from fastNLP.core.utils import (
check_user_specific_params, check_user_specific_params,
paddle_move_data_to_device, paddle_move_data_to_device,
is_in_paddle_dist, is_in_paddle_dist,
rank_zero_rm
) )
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
RandomBatchSampler, RandomBatchSampler,
@@ -55,20 +56,134 @@ class PaddleFleetDriver(PaddleDriver):
fp16: bool = False, fp16: bool = False,
**kwargs **kwargs
): ):
"""
采用fleet接口进行并行paddle训练的driver
PaddleFleetDriver 目前考虑支持的三种启动方式:
1. 用户自己不进行 fleet 的任何操作,直接使用我们的 Trainer,并且只运行一个 main 脚本,这时是由我们自己使用 open_subprocesses 拉起
多个进程,然后由 Driver 自己进行初始化
2. 其它情况同 1,但是用户自己使用 python -m paddle.distributed.launch 拉起;
3. 用户自己在外面初始化 Fleet,并且通过 python -m paddle.distributed.launch 拉起;

注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动;

如果用户自己在外面初始化了 fleet,那么
parallel_device 为 None;
data_device 为 表示单卡的一个参数;
dist.is_initialized 为 true;
r"""
通过使用 PaddlePaddle 的 Fleet 框架启动多卡进程的 Driver。
需要注意的一点是,由于 PaddlePaddle 框架的特性,如果直接使用在 rank0 拉起其它进程的方法的话,如果不加以任何限制,PaddlePaddle会出现
第一次前向传播后卡住或占用所有显卡的现象;为了解决这一问题,我们在引入 FastNLP 时,会使用 `CUDA_VISIBLE_DEVICES` 将设备限制在卡0上,
而用户如果使用了这一环境变量,我们会将其储存在 `USER_CUDA_VISIBLE_DEVICES` 中,并且通过一定的手段实现了转换(详细的设置请参见:
`fastNLP/envs/set_backend.py`)。在拉起其它进程的时候,我们会如法炮制,将环境限制在对应的设备上。

`PaddleFleetDriver` 目前支持的三种启动方式:
1. 用户自己不进行分布式的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `FleetLauncher` 拉起多个进程,
然后 `PaddleFleetDriver` 自己通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 A)
2. 用户同样不在 Trainer 之外初始化分布式训练,但是用户自己使用 python -m paddle.distributed.launch 拉起来创建多个进程,这时我们仍旧
会通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 B)
3. 用户自己在外面初始化分布式,并且通过 python -m paddle.distributed.launch 拉起,这时无论是多个进程的拉起和通信组的建立
都由用户自己操作,我们只会在 driver.setup 的时候对 `PaddleFleetDriver` 设置一些必要的属性值;(情况 C)

注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动;因此我们不会在 `PaddleFleetDriver` 中保存
任何当前有多少台机器的信息;

Part 1:三种启动方式的具体分析:
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时,
`PaddleFleetDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示:
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DataParallel` 包裹的model),
因为 `Parallel` 的使用一定要求 fleet.init 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果
用户需要使用 2 张以上的显卡,那么其必然需要使用 paddle.distributed.launch 来启动,意味着就不是情况 A 了;
这时我们首先会调用 `FleetLauncher.launch` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程);
接着我们会调用 `fleet.init` 来初始化各个进程之间的通信组;
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0
才会去真正地运行 `FleetLauncher.launch`;进程 0 运行到 `fleet.init`,paddle 会阻塞进程 0 继续
向前运行,直到其它进程也运行到这里;
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DataParallel` 将模型包裹;
至此,paddle 分布式的环境配置过程全部完成;

-> 情况 B:注意这种情况我们直接限定了用户是通过 paddle.distributed.launch 拉起,并且没有自己建立分布式的通信组。这时在
`PaddleFleetDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效,
这时每个进程所使用的 gpu 是我们直接通过 `CUDA_VISIBLE_DEVICE` 来配置的;因此,如果用户想要实现使用特定 gpu
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现,我们会通过一定的手段将其保存起来);
剩下的操作和情况 A 类似;

-> 情况 C:注意这种情况我们限定了用户是通过 paddle.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DataParallel` 包裹等。
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时:
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `PaddleFleetDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的
启动方式来实现这一点的:
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在
使用 '情况 A' 来启动 `PaddleFleetDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `PaddleFleetDriver` 的初始化和 setup 过程中,
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。

Part 2:对应的代码细节:
1. 如何判断当前的各进程之间的通信组已经被建立(fleet 已经被初始化);
parallel_helper._is_parallel_ctx_initialized();
2. 如何判断不同的进程是否是由 `python -m paddle.distributed.launch` 拉起还是由我们的 `FleetLauncher.launch()`
函数拉起;
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'PADDLE_RANK_IN_NODE'、'PADDLE_TRAINER_ID'
以及没有 `FASTNLP_DISTRIBUTED_CHECK`,
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m paddle.distributed.launch`
来拉起多个进程;
3. 整体的处理判断流程:
___________________________________
|进入 PaddleFleetDriver 的 __init__ 函数|
———————————————————————————————————
___________________________________________________
| 判断不同的进程是否是由 paddle.distributed.launch 拉起 |
|(或者我们自己的 FleetLauncher 函数拉起) | -------------->
———————————————————————————————————————————————————  |
↓ 是由 paddle.distributed.launch 拉起 | 我们自己的 FleetLauncher 函数拉起多个进程
 _____________________________            | 
←←←←← | 检测用户是否自己初始化了 fleet |              |
↓ —————————————————————————————                  ↓
↓ ↓ 是 ________
↓ ______ | 情况 A |
↓ 否 |情况 C| —————————
↓ ———————
↓ ______
↓ -----------> |情况 B|
  ———————
4. 为了完成全部的建立分布式所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属:

情况 A | 情况 B | 情况 C
________________________________________________________________________________________________________
配置 fleet 所 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch
需要的环境变量 | | |
————————————————————————————————————————————————————————————————————————————————————————————————————————
开启多个进程 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch
————————————————————————————————————————————————————————————————————————————————————————————————————————
调用 fleet.init函数 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | 用户自己调用
————————————————————————————————————————————————————————————————————————————————————————————————————————
设置 PaddleFleetDriver | | |
的 world_size 和 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | PaddleFleetDriver.setup
global_rank 属性 | | |
————————————————————————————————————————————————————————————————————————————————————————————————————————

Part 3:其它的处理细节:
1. 环境变量;
fastNLP 的 `PaddleFleetDriver` 运行时所需要的环境变量分为两种,一种是 paddle fleet 运行所需要的环境变量;另一种是 fastNLP 自己
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了;
2. parallel_device, model_device 和 data_device 的关系;
parallel_device 为 `PaddleFleetDriver` 的参数,model_device 和 data_device 都为 driver 的属性;
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上;
model_device 永远都为单独的一个 torch.device;

情况 A | 情况 B | 情况 C
________________________________________________________________________________________________________
parallel_device | 由用户传入trainer的参数 | |
| device 决定,必须是一个list, | 为 CUDA_VISIBLE_DEVICES | 为 CUDA_VISIBLE_DEVICES
| 其中每一个对象都是 int | |
————————————————————————————————————————————————————————————————————————————————————————————————————————
model_device | parallel_device[local_rank] | parallel_device | None
————————————————————————————————————————————————————————————————————————————————————————————————————————
data_device | model_device | model_device | 由用户传入 trainer 的参数
| | | data_device 决定
————————————————————————————————————————————————————————————————————————————————————————————————————————

3. _DDPWrappingModel 的作用;
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DataParallel` 的forward 函数来帮助
我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DataParallel` 的 forward 方法,
然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 forward 函数,还是
`train_step`、`evaluate_step`、`test_step` 方法。

4. 当某一个进程出现 exception 后,`PaddleFleetDriver` 的处理;

不管是什么情况,`PaddleFleetDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后,
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉;
""" """
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs)


@@ -78,6 +193,7 @@ class PaddleFleetDriver(PaddleDriver):
"when your value of parameter `device` is `None` in your `Trainer` instance.") "when your value of parameter `device` is `None` in your `Trainer` instance.")
# 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 # 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的
# 这个参数会在 initialize_paddle_drvier 中设置。
self.is_pull_by_paddle_run = is_pull_by_paddle_run self.is_pull_by_paddle_run = is_pull_by_paddle_run
self.parallel_device = parallel_device self.parallel_device = parallel_device
# 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu # 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu
@@ -98,7 +214,7 @@ class PaddleFleetDriver(PaddleDriver):


self.outside_fleet = True self.outside_fleet = True
# 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中 # 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中
# 我们就直接将 model_device 置为 None;
# 我们就直接将 model_device 置为 None;
self._model_device = None self._model_device = None


# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上;
@@ -119,9 +235,12 @@ class PaddleFleetDriver(PaddleDriver):


self.world_size = None self.world_size = None
self.global_rank = 0 self.global_rank = 0
self.gloo_rendezvous_dir = None


# 分布式环境的其它参数设置
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {})
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy())
self.is_collective = self._fleet_kwargs.get("is_collective", True) self.is_collective = self._fleet_kwargs.get("is_collective", True)
if not self.is_collective: if not self.is_collective:
@@ -145,7 +264,10 @@ class PaddleFleetDriver(PaddleDriver):


def setup(self): def setup(self):
""" """
在主进程拉起其它子进程,将主进程作为rank 0
根据不同的情况进行不同的设置。
1、如果是通过 paddle.distributed.launch 方法启动时,则根据已经设置好的环境获取
分布式的属性。
2、否则,调用 FleetLauncher 类启动子进程
""" """
if self._has_setup: if self._has_setup:
return return
@@ -174,7 +296,7 @@ class PaddleFleetDriver(PaddleDriver):
# 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False # 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False
# parallel_device 是 list, # parallel_device 是 list,
if not parallel_helper._is_parallel_ctx_initialized(): if not parallel_helper._is_parallel_ctx_initialized():
# 没有初始化分布式环境,且是主进程
# 拉起子进程并设置相应的属性
self.init_fleet_and_set() self.init_fleet_and_set()
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver;
else: else:
@@ -216,12 +338,13 @@ class PaddleFleetDriver(PaddleDriver):
# 是 rank0 的话,则拉起其它子进程 # 是 rank0 的话,则拉起其它子进程
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc)
launcher.launch() launcher.launch()
self.gloo_rendezvous_dir = launcher.gloo_rendezvous_dir
# 设置参数和初始化分布式环境 # 设置参数和初始化分布式环境
fleet.init(self.role_maker, self.is_collective, self.strategy) fleet.init(self.role_maker, self.is_collective, self.strategy)
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) self.global_rank = int(os.getenv("PADDLE_TRAINER_ID"))
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))


# 正常情况下不会Assert出问题,但还是保险一下
# 正常情况下不会 Assert 出问题,但还是保险一下
assert self.global_rank is not None assert self.global_rank is not None
assert self.world_size is not None assert self.world_size is not None
assert self.world_size == len(self.parallel_device) assert self.world_size == len(self.parallel_device)
@@ -235,10 +358,19 @@ class PaddleFleetDriver(PaddleDriver):
self.global_rank = paddledist.get_rank() self.global_rank = paddledist.get_rank()


def barrier(self): def barrier(self):
r"""
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行;
仅在多分布式训练场景中有使用。

注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。
"""
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
paddledist.barrier() paddledist.barrier()


def configure_fleet(self): def configure_fleet(self):
"""
将模型用 DataParallel 和自定义的类型包裹起来
"""
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
self.model = DataParallel( self.model = DataParallel(
_FleetWrappingModel(self.model), _FleetWrappingModel(self.model),
@@ -247,8 +379,14 @@ class PaddleFleetDriver(PaddleDriver):
self._has_fleetwrapped = True self._has_fleetwrapped = True


def on_exception(self): def on_exception(self):
if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir)
"""
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉;

因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的
pid 的信息;
"""
rank_zero_rm(self.gloo_rendezvous_dir)
super().on_exception() super().on_exception()


@property @property
@@ -282,6 +420,17 @@ class PaddleFleetDriver(PaddleDriver):
return self.model_device return self.model_device


def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
"""
通过调用 `fn` 来实现训练时的前向传播过程;
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的
函数;

:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:param fn: 调用该函数进行一次计算。
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
"""
if self._has_fleetwrapped: if self._has_fleetwrapped:
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
wo_auto_param_call=self.wo_auto_param_call) wo_auto_param_call=self.wo_auto_param_call)
@@ -292,6 +441,27 @@ class PaddleFleetDriver(PaddleDriver):
return fn(batch) return fn(batch)


def get_model_call_fn(self, fn: str) -> Tuple: def get_model_call_fn(self, fn: str) -> Tuple:
"""
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数;
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用;

之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上;
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;

这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
函数,然后给出 warning;
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
可能需要额外标记最初传入 driver 的模型是哪种形式的;

:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法;
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入;
"""
model = self.unwrap_model() model = self.unwrap_model()
if self._has_fleetwrapped: if self._has_fleetwrapped:
if hasattr(model, fn): if hasattr(model, fn):
@@ -316,7 +486,25 @@ class PaddleFleetDriver(PaddleDriver):
return self.model, model.forward return self.model, model.forward


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
reproducible: bool = False):
r"""
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。

:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;

:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
可以可以加载。
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
"""
# 暂时不支持iterableDataset # 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now." "FastNLP does not support `IteratorDataset` now."
@@ -429,10 +617,7 @@ class PaddleFleetDriver(PaddleDriver):


@staticmethod @staticmethod
def _check_optimizer_legality(optimizers): def _check_optimizer_legality(optimizers):
"""
paddle存在设置分布式optimizers的函数,返回值为fleet.meta_optimizers.HybridParallelOptimizer
重写是为了防止单卡下也传入了分布式的优化器
"""
# paddle 存在设置分布式 optimizers 的函数,返回值为 fleet.meta_optimizers.HybridParallelOptimizer
DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer
for each_optimizer in optimizers: for each_optimizer in optimizers:
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)):


+ 25
- 7
fastNLP/core/drivers/paddle_driver/fleet_launcher.py View File

@@ -20,7 +20,7 @@ from .utils import (
# 记录各个进程信息 # 记录各个进程信息
class SubTrainer(object): class SubTrainer(object):
""" """
和fastnlp的Triainer没有关系,仅用于统计节点内不同训练的一些信息
用于统计节点内不同训练进程的信息, fastnlp Triainer 没有关系
""" """
def __init__(self, endpoint=None, rank=None): def __init__(self, endpoint=None, rank=None):
self.devices = [] self.devices = []
@@ -30,8 +30,8 @@ class SubTrainer(object):


class FleetLauncher: class FleetLauncher:
""" """
复原了 paddle 的 launch_collective 函数,将其集成到一个类里
仅支持单机多卡的启动
复原了 paddle 的 launch_collective 函数,将其简化后集成到一个类里
仅支持每个机器单卡的情况。
""" """
def __init__( def __init__(
self, self,
@@ -45,17 +45,26 @@ class FleetLauncher:
self.setup() self.setup()


def setup(self): def setup(self):

"""
进行初始化设置的函数,根据传入的设备找到分布式训练使用的端口号
"""
self.set_endpoints() self.set_endpoints()
self.sub_trainers = self.get_process_info() self.sub_trainers = self.get_process_info()


def launch(self) -> int:
def launch(self):
"""
用于启动分布式进程。
首先设置 PaddlePaddle 分布式训练需要设置的环境变量,然后建立新的子进程
"""
# 设置环境变量 # 设置环境变量
self.global_envs = self.get_global_env() self.global_envs = self.get_global_env()
self.open_subprocess() self.open_subprocess()
reset_seed() reset_seed()


def open_subprocess(self): def open_subprocess(self):
"""
从 sub_trainers 中获取各个 rank 的信息,并且使用 subprocess.Popen 建立新的子进程。
"""


if __main__.__spec__ is None: if __main__.__spec__ is None:
# Script called as `python a/b/c.py` # Script called as `python a/b/c.py`
@@ -77,6 +86,7 @@ class FleetLauncher:


current_env = copy.copy(self.global_envs) current_env = copy.copy(self.global_envs)
for idx, t in enumerate(self.sub_trainers): for idx, t in enumerate(self.sub_trainers):
# 根据不同的 rank 设置环境变量
proc_env = { proc_env = {
# global_rank # global_rank
"PADDLE_TRAINER_ID": f"{t.rank}", "PADDLE_TRAINER_ID": f"{t.rank}",
@@ -108,6 +118,14 @@ class FleetLauncher:
os.environ.update(current_env) os.environ.update(current_env)


def get_global_env(self): def get_global_env(self):
"""
设置分布式训练需要的全局变量,包括:
1、GLOO 相关的设置
2、`PADDLE_TRAINERS_NUM` :所有的进程数目
3、`PADDLE_TRAINER_ENDPOINTS` :使用的所有地址及其端口
4、`PADDLE_WORLD_DEVICE_IDS` :使用的所有设备
5、FASTNLP_DISTRIBUTED_CHECK:通过 fastNLP 建立子进程的标志,保存分布式训练使用的设备
"""


global_envs = copy.copy(os.environ.copy()) global_envs = copy.copy(os.environ.copy())
self.gloo_rendezvous_dir = tempfile.mkdtemp() self.gloo_rendezvous_dir = tempfile.mkdtemp()
@@ -137,7 +155,7 @@ class FleetLauncher:


def set_endpoints(self): def set_endpoints(self):
""" """
Reference to `get_cluster_from_args`
寻找用户设置的端口或是空闲端口用于分布式训练,参考了 PaddlePaddle 中的 `get_cluster_from_args` 函数
""" """
self.node_ip = "127.0.0.1" self.node_ip = "127.0.0.1"


@@ -157,7 +175,7 @@ class FleetLauncher:


def get_process_info(self): def get_process_info(self):
""" """
Reference to `get_cluster`
获取各个训练进程的设备、rank 和端口信息,参考 PaddlePaddle 的 `get_cluster` 函数。
""" """
sub_trainers = [] sub_trainers = []
assert len(self.endpoints) >= len( assert len(self.endpoints) >= len(


+ 7
- 10
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -17,14 +17,16 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
model: paddle.nn.Layer, **kwargs) -> PaddleDriver: model: paddle.nn.Layer, **kwargs) -> PaddleDriver:
r""" r"""
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去;
注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错;
1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将
设备自动设置为用户指定的设备(由于我们在引入 fastNLP 进行了特殊的设置,因此可以通过 `CUDA_VISIBLE_DEVICES` 获取)
2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver
3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver


:param driver: 该参数的值应为以下之一:["paddle", "fleet"]; :param driver: 该参数的值应为以下之一:["paddle", "fleet"];
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; :param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致;
:param model: 训练或者评测的具体的模型; :param model: 训练或者评测的具体的模型;


:return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中
先后 driver 的次序的正确问题);
:return: 返回构造的 `Driver` 实例。
""" """
if is_in_paddle_launch_dist(): if is_in_paddle_launch_dist():
if device is not None: if device is not None:
@@ -47,9 +49,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
if device >= _could_use_device_num: if device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.") raise ValueError("The gpu device that parameter `device` specifies is not existed.")
if device != -1:
device = f"gpu:{device}"
else:
if device == -1:
device = list(range(_could_use_device_num)) device = list(range(_could_use_device_num))
elif isinstance(device, Sequence) and not isinstance(device, str): elif isinstance(device, Sequence) and not isinstance(device, str):
device = list(set(device)) device = list(set(device))
@@ -61,9 +61,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
elif each >= _could_use_device_num: elif each >= _could_use_device_num:
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than"
" the available gpu number.") " the available gpu number.")
if len(device) == 1:
# 传入了 [1] 这样的,视为单卡。
device = device[0]
elif device is not None and not isinstance(device, str): elif device is not None and not isinstance(device, str):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")


@@ -82,6 +79,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will"
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should "
"choose `paddle` driver.") "choose `paddle` driver.")
return PaddleFleetDriver(model, device, **kwargs)
return PaddleFleetDriver(model, [device], **kwargs)
else: else:
return PaddleFleetDriver(model, device, **kwargs) return PaddleFleetDriver(model, device, **kwargs)

+ 12
- 4
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -19,7 +19,12 @@ from fastNLP.envs import (
rank_zero_call, rank_zero_call,
) )
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
ReproducibleSampler,
RandomBatchSampler,
RandomSampler,
)


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
import paddle import paddle
@@ -29,7 +34,7 @@ if _NEED_IMPORT_PADDLE:
Dataset, Dataset,
Sampler, Sampler,
BatchSampler, BatchSampler,
RandomSampler,
RandomSampler as PaddleRandomSampler,
) )
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer


@@ -333,6 +338,9 @@ class PaddleDriver(Driver):
sampler = dataloader_args.batch_sampler sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler): elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler sampler = dataloader_args.sampler
elif isinstance(dataloader_args.sampler, PaddleRandomSampler):
sampler = RandomSampler(dataloader_args.sampler.data_source)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
elif self.is_distributed(): elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.") "`ReproducibleSampler`.")
@@ -464,7 +472,7 @@ class PaddleDriver(Driver):
res.sampler = dataloader.batch_sampler.sampler res.sampler = dataloader.batch_sampler.sampler
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler):
elif isinstance(dataloader.batch_sampler.sampler, PaddleRandomSampler):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False
@@ -474,7 +482,7 @@ class PaddleDriver(Driver):
res.sampler = batch_sampler.sampler res.sampler = batch_sampler.sampler
if hasattr(batch_sampler.sampler, "shuffle"): if hasattr(batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(batch_sampler.sampler, RandomSampler):
elif isinstance(batch_sampler.sampler, PaddleRandomSampler):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False


+ 56
- 0
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -31,6 +31,9 @@ __all__ = [
] ]


class PaddleSingleDriver(PaddleDriver): class PaddleSingleDriver(PaddleDriver):
"""
支持 paddle cpu 或单卡 gpu 训练的 driver
"""
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs):
if isinstance(model, DataParallel): if isinstance(model, DataParallel):
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`")
@@ -59,18 +62,53 @@ class PaddleSingleDriver(PaddleDriver):
self.world_size = 1 self.world_size = 1


def setup(self): def setup(self):
r"""
该函数用来初始化训练环境,用于设置当前训练的设备,并将模型迁移到对应设备上。
"""
device = self.model_device device = self.model_device
device = get_device_from_visible(device, output_type=str) device = get_device_from_visible(device, output_type=str)
paddle.device.set_device(device) paddle.device.set_device(device)
self.model.to(device) self.model.to(device)


def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
"""
通过调用 `fn` 来实现训练时的前向传播过程;
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的
函数;

:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
:param fn: 调用该函数进行一次计算。
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
"""
if isinstance(batch, Dict) and not self.wo_auto_param_call: if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn) return auto_param_call(fn, batch, signature_fn=signature_fn)
else: else:
return fn(batch) return fn(batch)


def get_model_call_fn(self, fn: str) -> Tuple: def get_model_call_fn(self, fn: str) -> Tuple:
"""
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数;
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用;

之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上;
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;

这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
函数,然后给出 warning;
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
可能需要额外标记最初传入 driver 的模型是哪种形式的;

:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法;
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入;
"""
if hasattr(self.model, fn): if hasattr(self.model, fn):
fn = getattr(self.model, fn) fn = getattr(self.model, fn)
if not callable(fn): if not callable(fn):
@@ -95,6 +133,24 @@ class PaddleSingleDriver(PaddleDriver):


def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False): reproducible: bool = False):
r"""
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。

:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;

:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
可以可以加载。
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
"""


# 暂时不支持iterableDataset # 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \


+ 6
- 18
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -69,7 +69,6 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) ->
os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}"
return seed return seed



def reset_seed() -> None: def reset_seed() -> None:
""" """
fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新
@@ -80,16 +79,10 @@ def reset_seed() -> None:
if seed is not None: if seed is not None:
paddle_seed_everything(int(seed), workers=bool(int(workers))) paddle_seed_everything(int(seed), workers=bool(int(workers)))


class ForwardState(IntEnum):
TRAIN = 0
VALIDATE = 1
TEST = 2
PREDICT = 3

class _FleetWrappingModel(Layer): class _FleetWrappingModel(Layer):
""" """
参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和
pytorch相似的处理方式
参考 _DDPWrappingModel , paddle 的分布式训练也需要用 paddle.nn.DataParallel 进行包装,采用和
pytorch 相似的处理方式
""" """
def __init__(self, model: 'nn.Layer'): def __init__(self, model: 'nn.Layer'):
super(_FleetWrappingModel, self).__init__() super(_FleetWrappingModel, self).__init__()
@@ -109,7 +102,6 @@ class _FleetWrappingModel(Layer):
class DummyGradScaler: class DummyGradScaler:
""" """
用于仿造的GradScaler对象,防止重复写大量的if判断 用于仿造的GradScaler对象,防止重复写大量的if判断

""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -152,6 +144,9 @@ def _build_fp16_env(dummy=False):
return auto_cast, GradScaler return auto_cast, GradScaler


def find_free_ports(num): def find_free_ports(num):
"""
在空闲的端口中找到 num 个端口
"""
def __free_port(): def __free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
@@ -178,18 +173,11 @@ def find_free_ports(num):


return None return None


def get_host_name_ip():
try:
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
return host_name, host_ip
except:
return None

def get_device_from_visible(device: Union[str, int], output_type=int): def get_device_from_visible(device: Union[str, int], output_type=int):
""" """
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。

:param device: 未转化的设备名 :param device: 未转化的设备名
:param output_type: 返回值的类型 :param output_type: 返回值的类型
:return: 转化后的设备id :return: 转化后的设备id


+ 1
- 1
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -76,7 +76,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will "
"still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " "still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should "
"choose `torch` driver.") "choose `torch` driver.")
return TorchDDPDriver(model, device, **kwargs)
return TorchDDPDriver(model, [device], **kwargs)
else: else:
return TorchDDPDriver(model, device, **kwargs) return TorchDDPDriver(model, device, **kwargs)
elif driver == "fairscale": elif driver == "fairscale":


+ 13
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -218,6 +218,8 @@ class TorchDriver(Driver):
# 2. 保存模型的状态; # 2. 保存模型的状态;
if should_save_model: if should_save_model:
model = self.unwrap_model() model = self.unwrap_model()
if not os.path.exists(folder):
os.mkdir(folder)
if only_state_dict: if only_state_dict:
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失;
@@ -401,7 +403,17 @@ class TorchDriver(Driver):
res.sampler = dataloader.batch_sampler.sampler res.sampler = dataloader.batch_sampler.sampler
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler):
elif isinstance(dataloader.batch_sampler.sampler, TorchRandomSampler):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler
if hasattr(batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(batch_sampler.sampler, TorchRandomSampler):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False


+ 13
- 0
fastNLP/core/log/logger.py View File

@@ -173,6 +173,19 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
kwargs["extra"] = extra kwargs["extra"] = extra
return kwargs return kwargs


def setLevel(self, level) -> None:
"""
设置当前 logger 以及其 handler 的 log 级别

:param level:
:return:
"""
if isinstance(level, str):
level = level.upper()
super().setLevel(level)
for handler in self.handlers:
handler.setLevel(level)



def _get_level(level): def _get_level(level):
if not isinstance(level, int): if not isinstance(level, int):


+ 3
- 3
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -416,7 +416,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
@property @property
def batch_idx_in_epoch(self): def batch_idx_in_epoch(self):
if self.drop_last: if self.drop_last:
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else: else:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

+ 14
- 1
fastNLP/core/utils/paddle_utils.py View File

@@ -22,6 +22,13 @@ from .utils import apply_to_collection




def paddle_to(data, device: Union[str, int]): def paddle_to(data, device: Union[str, int]):
"""
将 `data` 迁移到指定的 `device` 上

:param data: 要迁移的张量
:param device: 目标设备,可以是 `str` 或 `int`
:return: 迁移后的张量
"""


if device == "cpu": if device == "cpu":
return data.cpu() return data.cpu()
@@ -31,6 +38,9 @@ def paddle_to(data, device: Union[str, int]):
def get_paddle_gpu_str(device: Union[str, int]): def get_paddle_gpu_str(device: Union[str, int]):
""" """
获得 `gpu:x` 类型的设备名 获得 `gpu:x` 类型的设备名

:param device: 设备编号或设备名
:return: 返回对应的 `gpu:x` 格式的设备名
""" """
if isinstance(device, str): if isinstance(device, str):
return device.replace("cuda", "gpu") return device.replace("cuda", "gpu")
@@ -38,7 +48,10 @@ def get_paddle_gpu_str(device: Union[str, int]):


def get_paddle_device_id(device: Union[str, int]): def get_paddle_device_id(device: Union[str, int]):
""" """
获得 gpu 的设备id,注意不要传入 `cpu` 。
获得 gpu 的设备id

:param: device: 设备编号或设备名
:return: 设备对应的编号
""" """
if isinstance(device, int): if isinstance(device, int):
return device return device


+ 22
- 1
fastNLP/core/utils/rich_progress.py View File

@@ -14,6 +14,7 @@ __all__ = [
] ]


from fastNLP.envs import get_global_rank from fastNLP.envs import get_global_rank
from .utils import is_notebook




class Singleton(type): class Singleton(type):
@@ -34,6 +35,14 @@ class DummyFRichProgress:
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 # 防止用户通过 DummyFRichProgress.console.print() 这种调用
return None return None


@property
def dummy_rich(self)->bool:
"""
当前对象是否是 dummy 的 rich 对象。

:return:
"""
return True


class FRichProgress(Progress, metaclass=Singleton): class FRichProgress(Progress, metaclass=Singleton):
""" """
@@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton):
super().stop_task(task_id) super().stop_task(task_id)
super().remove_task(task_id) super().remove_task(task_id)
self.refresh() # 使得bar不残留 self.refresh() # 使得bar不残留
if len(self._tasks) == 0:
super().stop()


def start(self) -> None: def start(self) -> None:
super().start() super().start()
@@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton):
if refresh: if refresh:
self.refresh() self.refresh()


@property
def dummy_rich(self) -> bool:
"""
当前对象是否是 dummy 的 rich 对象。

:return:
"""
return False



class SpeedColumn(ProgressColumn): class SpeedColumn(ProgressColumn):
""" """
@@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn):
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed')




if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \
get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess( f_rich_progress = FRichProgress().new_progess(
"[progress.description]{task.description}", "[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%", "[progress.percentage]{task.percentage:>3.0f}%",


+ 20
- 1
fastNLP/core/utils/utils.py View File

@@ -696,4 +696,23 @@ def get_class_that_defined_method(method):
None) None)
if isinstance(cls, type): if isinstance(cls, type):
return cls return cls
return getattr(method, '__objclass__', None) # handle special descriptor objects
return getattr(method, '__objclass__', None) # handle special descriptor objects


def is_notebook():
"""
检查当前运行环境是否为 jupyter

:return:
"""
try:
from IPython import get_ipython

if "IPKernelApp" not in get_ipython().config: # pragma: no cover
raise ImportError("console")
if "VSCODE_PID" in os.environ: # pragma: no cover
raise ImportError("vscode")
except:
return False
else: # pragma: no cover
return True

+ 2
- 2
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -16,7 +16,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDatset
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy from torchmetrics import Accuracy
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -53,7 +53,7 @@ def model_and_optimizers(request):
feature_dimension=ArgMaxDatasetConfig.feature_dimension feature_dimension=ArgMaxDatasetConfig.feature_dimension
) )
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)
dataset = TorchArgMaxDatset(
dataset = TorchArgMaxDataset(
feature_dimension=ArgMaxDatasetConfig.feature_dimension, feature_dimension=ArgMaxDatasetConfig.feature_dimension,
data_num=ArgMaxDatasetConfig.data_num, data_num=ArgMaxDatasetConfig.data_num,
seed=ArgMaxDatasetConfig.seed seed=ArgMaxDatasetConfig.seed


+ 2
- 2
tests/core/callbacks/test_load_best_model_callback_torch.py View File

@@ -19,7 +19,7 @@ from fastNLP.core import Evaluator
from fastNLP.core.utils.utils import safe_rm from fastNLP.core.utils.utils import safe_rm
from fastNLP.core.drivers.torch_driver import TorchSingleDriver from fastNLP.core.drivers.torch_driver import TorchSingleDriver
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDatset
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context




@@ -55,7 +55,7 @@ def model_and_optimizers(request):
feature_dimension=ArgMaxDatasetConfig.feature_dimension feature_dimension=ArgMaxDatasetConfig.feature_dimension
) )
trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01) trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01)
dataset = TorchArgMaxDatset(
dataset = TorchArgMaxDataset(
feature_dimension=ArgMaxDatasetConfig.feature_dimension, feature_dimension=ArgMaxDatasetConfig.feature_dimension,
data_num=ArgMaxDatasetConfig.data_num, data_num=ArgMaxDatasetConfig.data_num,
seed=ArgMaxDatasetConfig.seed seed=ArgMaxDatasetConfig.seed


+ 2
- 2
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -24,7 +24,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDatset
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger from fastNLP.core.log import logger
@@ -64,7 +64,7 @@ def model_and_optimizers(request):
feature_dimension=ArgMaxDatasetConfig.feature_dimension feature_dimension=ArgMaxDatasetConfig.feature_dimension
) )
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)
dataset = TorchArgMaxDatset(
dataset = TorchArgMaxDataset(
feature_dimension=ArgMaxDatasetConfig.feature_dimension, feature_dimension=ArgMaxDatasetConfig.feature_dimension,
data_num=ArgMaxDatasetConfig.data_num, data_num=ArgMaxDatasetConfig.data_num,
seed=ArgMaxDatasetConfig.seed seed=ArgMaxDatasetConfig.seed


+ 0
- 0
tests/core/collators/__init__.py View File


+ 0
- 0
tests/core/collators/padders/__init__.py View File


+ 139
- 0
tests/core/collators/padders/test_get_padder.py View File

@@ -0,0 +1,139 @@
import pytest
import numpy as np

from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \
_get_element_shape_dtype
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR


def test_get_element_shape_dtype():
catalog = _get_element_shape_dtype([[1], [2, 3], [3], 2])
catalog = _get_element_shape_dtype([['1'], [2, 3]])
catalog = _get_element_shape_dtype([['1'], [2, 3]])
catalog = _get_element_shape_dtype([['1'], ['2', '3']])
catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))])


@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch")
if not _NEED_IMPORT_PADDLE and backend == 'paddle':
pytest.skip("No paddle")
if not _NEED_IMPORT_PADDLE and backend == 'jittor':
pytest.skip("No jittor")
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')

if backend is not None:
# 不能 pad
batch_field = [[1], [2, 3], [3], 2]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test')

# 不能 pad
batch_field = [['2'], ['2'], ['2', '2']]
with pytest.raises(DtypeError) as exec_info:
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test')

batch_field = [np.zeros(3), np.zeros((3, 1))]
with pytest.raises(InconsistencyError) as exec_info:
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') # no pad

batch_field = [np.zeros((3, 1)), np.zeros((4, 1))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')


def test_raw_padder():
backend = 'raw'
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert pad_batch == batch_field

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3)

batch_field = [[[1]], [[2, 2], [2]], [[3], [3], [3]]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3, 2)


def test_numpy_padder():
backend = 'numpy'
target_type = np.ndarray
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert (pad_batch == np.array(batch_field)).sum()==len(batch_field)

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==3

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,3))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==9

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,))]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')


def test_torch_padder():
if not _NEED_IMPORT_TORCH:
pytest.skip("No torch.")
import torch
backend = 'torch'
target_type = torch.Tensor
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert (pad_batch == torch.LongTensor(batch_field)).sum()==len(batch_field)

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==3

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,3))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==9

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,))]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')


+ 81
- 0
tests/core/collators/padders/test_numpy_padder.py View File

@@ -0,0 +1,81 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.numpy_padder import NumpyTensorPadder, NumpySequencePadder, NumpyNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH


class TestNumpyNumberPadder:
def test_run(self):
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
assert isinstance(a, np.ndarray)
assert (padder(a) == np.array(a)).sum() == 3


class TestNumpySequencePadder:
def test_run(self):
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (2, 3)
b = np.array([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
if _NEED_IMPORT_TORCH:
import torch
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)


class TestNumpyTensorPadder:
def test_run(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1)
a = [np.zeros(3), np.zeros(2), np.zeros(0)]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3)
b = np.array([[0, 0, 0], [0, 0, -1], [-1, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 1))]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3, 2)
b = np.array([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3, 2)
b = np.array([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
if _NEED_IMPORT_TORCH:
import torch
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)




+ 29
- 0
tests/core/collators/padders/test_raw_padder.py View File

@@ -0,0 +1,29 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.raw_padder import RawNumberPadder, RawSequencePadder
from fastNLP.core.collators.padders.exceptions import DtypeError


class TestRawNumberPadder:
def test_run(self):
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
assert padder(a) == a


class TestRawSequencePadder:
def test_run(self):
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
assert shape == (2, 3)
b = np.array([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)

+ 105
- 0
tests/core/collators/padders/test_torch_padder.py View File

@@ -0,0 +1,105 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.torch_padder import TorchTensorPadder, TorchSequencePadder, TorchNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch


class TestTorchNumberPadder:
def test_run(self):
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
t_a = padder(a)
assert isinstance(t_a, torch.Tensor)
assert (t_a == torch.LongTensor(a)).sum() == 3


class TestTorchSequencePadder:
def test_run(self):
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (2, 3)
b = torch.LongTensor([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]])
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)



class TestTorchTensorPadder:
def test_run(self):
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3)
b = torch.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 2))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, 0], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 1))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.FloatTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)




+ 90
- 0
tests/core/collators/padders/test_utils.py View File

@@ -0,0 +1,90 @@
import pytest
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.collators.padders.utils import get_shape, get_padded_numpy_array, \
get_padded_nest_list, is_number_or_numpy_number, is_numpy_number_dtype, is_number


def test_get_shape():
a = [[1, 2, 3], [3]]
assert get_shape(a) == [2, 3]

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
assert get_shape(a) == [2, 3, 3]

a = [[[1], [2], [3, 4]], [[]]]
assert get_shape(a) == [2, 3, 2]


def test_get_padded_numpy_array():
a = [[1, 2, 3], [3]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3)

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3, 3)

a = [[[1], [2], [3, 4]], [[]]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3, 2)


def test_get_padded_nest_list():
a = [[1, 2, 3], [3]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3)

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3, 3)

a = [[[1], [2], [3, 4]], [[]]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3, 2)


def test_is_number_or_numpy_number():
assert is_number_or_numpy_number(type(3)) is True
assert is_number_or_numpy_number(type(3.1)) is True
assert is_number_or_numpy_number(type(True)) is True
assert is_number_or_numpy_number(type('3')) is False
assert is_number_or_numpy_number(np.zeros(3).dtype) is True
assert is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) is True
assert is_number_or_numpy_number(np.zeros(3, dtype=object).dtype) is False

if _NEED_IMPORT_TORCH:
import torch
dtype = torch.ones(3).dtype
assert is_number_or_numpy_number(dtype) is False


def test_is_number():
assert is_number(type(3)) is True
assert is_number(type(3.1)) is True
assert is_number(type(True)) is True
assert is_number(type('3')) is False
assert is_number(np.zeros(3).dtype) is False
assert is_number(np.zeros(3, dtype=int).dtype) is False
assert is_number(np.zeros(3, dtype=object).dtype) is False

if _NEED_IMPORT_TORCH:
import torch
dtype = torch.ones(3).dtype
assert is_number(dtype) is False


def test_is_numpy_number():
assert is_numpy_number_dtype(type(3)) is False
assert is_numpy_number_dtype(type(3.1)) is False
assert is_numpy_number_dtype(type(True)) is False
assert is_numpy_number_dtype(type('3')) is False
assert is_numpy_number_dtype(np.zeros(3).dtype) is True
assert is_numpy_number_dtype(np.zeros(3, dtype=int).dtype) is True
assert is_numpy_number_dtype(np.zeros(3, dtype=object).dtype) is False

if _NEED_IMPORT_TORCH:
import torch
dtype = torch.ones(3).dtype
assert is_numpy_number_dtype(dtype) is False

+ 225
- 0
tests/core/collators/test_new_collator.py View File

@@ -0,0 +1,225 @@

import numpy as np
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.new_collator import Collator


def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()


def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))


def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)


class TestCollator:
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))

if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))

def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))


# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))








+ 37
- 0
tests/core/collators/test_utils.py View File

@@ -0,0 +1,37 @@

from fastNLP.core.collators.utils import *


def test_unpack_batch_mapping():
batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}]
assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]}


def test_unpack_batch_nested_mapping():
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}},
{'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}


def test_pack_batch_nested_mapping():
batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}
new_batch = pack_batch_nested_mapping(batch)
assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2],
'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}}


def test_unpack_batch_sequence():
batch = [[1, 2, 3], [2, 4, 6]]
new_batch = unpack_batch_sequence(batch)
assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]}




+ 2
- 2
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -11,7 +11,7 @@ from torchmetrics import Accuracy


from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDatset
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context


@@ -80,7 +80,7 @@ def model_and_optimizers(request):
feature_dimension=ArgMaxDatasetConfig.feature_dimension feature_dimension=ArgMaxDatasetConfig.feature_dimension
) )
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)
dataset = TorchArgMaxDatset(
dataset = TorchArgMaxDataset(
feature_dimension=ArgMaxDatasetConfig.feature_dimension, feature_dimension=ArgMaxDatasetConfig.feature_dimension,
data_num=ArgMaxDatasetConfig.data_num, data_num=ArgMaxDatasetConfig.data_num,
seed=ArgMaxDatasetConfig.seed seed=ArgMaxDatasetConfig.seed


+ 2
- 2
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -527,7 +527,7 @@ class TestSaveLoad:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
# 不在这里 setup 的话会报错 # 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)
cls.driver = generate_driver(10, 10, device=[0,1])


def setup_method(self): def setup_method(self):
self.dataset = PaddleRandomMaxDataset(20, 10) self.dataset = PaddleRandomMaxDataset(20, 10)
@@ -633,7 +633,7 @@ class TestSaveLoad:
batch_sampler=BucketedBatchSampler( batch_sampler=BucketedBatchSampler(
self.dataset, self.dataset,
length=[10 for i in range(len(self.dataset))], length=[10 for i in range(len(self.dataset))],
batch_size=4,
batch_size=2,
) )
) )
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(


+ 4
- 4
tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py View File

@@ -19,7 +19,7 @@ def test_incorrect_driver():


@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
["cpu", "gpu:0", 0, [1]]
["cpu", "gpu:0", 0]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
@@ -27,7 +27,7 @@ def test_incorrect_driver():
) )
def test_get_single_device(driver, device): def test_get_single_device(driver, device):
""" """
测试正常情况下初始化PaddleSingleDriver的情况
测试正常情况下初始化 PaddleSingleDriver 的情况
""" """


model = PaddleNormalModel_Classification_1(2, 100) model = PaddleNormalModel_Classification_1(2, 100)
@@ -36,7 +36,7 @@ def test_get_single_device(driver, device):


@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
[0, 1]
[0, 1, [1]]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
@@ -45,7 +45,7 @@ def test_get_single_device(driver, device):
@magic_argv_env_context @magic_argv_env_context
def test_get_fleet_2(driver, device): def test_get_fleet_2(driver, device):
""" """
测试 fleet 多卡的初始化情况
测试 fleet 多卡的初始化情况,但传入了单个 gpu
""" """


model = PaddleNormalModel_Classification_1(64, 10) model = PaddleNormalModel_Classification_1(64, 10)


+ 21
- 25
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -34,7 +34,7 @@ class TestPaddleDriverFunctions:


def test_check_single_optimizer_legality(self): def test_check_single_optimizer_legality(self):
""" """
测试传入单个optimizer时的表现
测试传入单个 optimizer 时的表现
""" """
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
parameters=self.driver.model.parameters(), parameters=self.driver.model.parameters(),
@@ -50,7 +50,7 @@ class TestPaddleDriverFunctions:


def test_check_optimizers_legality(self): def test_check_optimizers_legality(self):
""" """
测试传入optimizer list的表现
测试传入 optimizer list 的表现
""" """
optimizers = [ optimizers = [
paddle.optimizer.Adam( paddle.optimizer.Adam(
@@ -70,13 +70,13 @@ class TestPaddleDriverFunctions:


def test_check_dataloader_legality_in_train(self): def test_check_dataloader_legality_in_train(self):
""" """
测试is_train参数为True时,_check_dataloader_legality函数的表现
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
""" """
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
dataloader = DataLoader(PaddleNormalDataset())
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)


# batch_size 和 batch_sampler 均为 None 的情形 # batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)


@@ -90,29 +90,29 @@ class TestPaddleDriverFunctions:


def test_check_dataloader_legality_in_test(self): def test_check_dataloader_legality_in_test(self):
""" """
测试is_train参数为False时,_check_dataloader_legality函数的表现
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现
""" """
# 此时传入的应该是dict # 此时传入的应该是dict
dataloader = { dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset())
"train": DataLoader(PaddleNormalDataset()),
"test":DataLoader(PaddleNormalDataset())
} }
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)


# batch_size 和 batch_sampler 均为 None 的情形 # batch_size 和 batch_sampler 均为 None 的情形
dataloader = { dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
"train": DataLoader(PaddleNormalDataset()),
"test":DataLoader(PaddleNormalDataset(), batch_size=None)
} }
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)


# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
# 传入的不是 dict ,应该报错
dataloader = DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)


# 创建torch的dataloader
# 创建 torch dataloader
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
TorchNormalDataset(), TorchNormalDataset(),
batch_size=32, shuffle=True batch_size=32, shuffle=True
@@ -127,7 +127,7 @@ class TestPaddleDriverFunctions:


def test_tensor_to_numeric(self): def test_tensor_to_numeric(self):
""" """
测试tensor_to_numeric函数
测试 tensor_to_numeric 函数
""" """
# 单个张量 # 单个张量
tensor = paddle.to_tensor(3) tensor = paddle.to_tensor(3)
@@ -180,7 +180,7 @@ class TestPaddleDriverFunctions:


def test_set_model_mode(self): def test_set_model_mode(self):
""" """
测试set_model_mode函数
测试 set_model_mode 函数
""" """
self.driver.set_model_mode("train") self.driver.set_model_mode("train")
assert self.driver.model.training assert self.driver.model.training
@@ -192,14 +192,14 @@ class TestPaddleDriverFunctions:


def test_move_model_to_device_cpu(self): def test_move_model_to_device_cpu(self):
""" """
测试move_model_to_device函数
测试 move_model_to_device 函数
""" """
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu")
assert self.driver.model.linear1.weight.place.is_cpu_place() assert self.driver.model.linear1.weight.place.is_cpu_place()


def test_move_model_to_device_gpu(self): def test_move_model_to_device_gpu(self):
""" """
测试move_model_to_device函数
测试 move_model_to_device 函数
""" """
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu")
assert self.driver.model.linear1.weight.place.is_gpu_place() assert self.driver.model.linear1.weight.place.is_gpu_place()
@@ -207,7 +207,7 @@ class TestPaddleDriverFunctions:


def test_worker_init_function(self): def test_worker_init_function(self):
""" """
测试worker_init_function
测试 worker_init_function
""" """
# 先确保不影响运行 # 先确保不影响运行
# TODO:正确性 # TODO:正确性
@@ -215,7 +215,7 @@ class TestPaddleDriverFunctions:


def test_set_deterministic_dataloader(self): def test_set_deterministic_dataloader(self):
""" """
测试set_deterministic_dataloader
测试 set_deterministic_dataloader
""" """
# 先确保不影响运行 # 先确保不影响运行
# TODO:正确性 # TODO:正确性
@@ -224,7 +224,7 @@ class TestPaddleDriverFunctions:


def test_set_sampler_epoch(self): def test_set_sampler_epoch(self):
""" """
测试set_sampler_epoch
测试 set_sampler_epoch
""" """
# 先确保不影响运行 # 先确保不影响运行
# TODO:正确性 # TODO:正确性
@@ -336,7 +336,7 @@ class TestSingleDeviceFunction:


def test_move_data_to_device(self): def test_move_data_to_device(self):
""" """
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中
这个函数仅调用了 paddle_move_data_to_device ,测试例在 tests/core/utils/test_paddle_utils.py
就不重复测试了 就不重复测试了
""" """
self.driver.move_data_to_device(paddle.rand((32, 64))) self.driver.move_data_to_device(paddle.rand((32, 64)))
@@ -490,9 +490,6 @@ class TestSetDistReproDataloader:
else: else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() sampler_states = replaced_loader.batch_sampler.sampler.state_dict()


# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set() left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
@@ -510,7 +507,6 @@ class TestSetDistReproDataloader:
new_loader.batch_sampler.load_state_dict(sampler_states) new_loader.batch_sampler.load_state_dict(sampler_states)
else: else:
batch_size = replaced_loader.batch_sampler.batch_size batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_samples = num_consumed_batches * batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader # 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)


+ 788
- 0
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -0,0 +1,788 @@
import pytest
import os
from pathlib import Path

os.environ["FASTNLP_BACKEND"] = "torch"
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
from fastNLP.core.samplers import (
RandomSampler,
UnrepeatedSampler,
BucketedBatchSampler,
UnrepeatedRandomSampler,
UnrepeatedSequentialSampler,
)
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler

def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
driver = TorchDDPDriver(
model=torch_model,
parallel_device=device,
fp16=fp16,
output_from_new_proc=output_from_new_proc
)
driver.set_optimizers(torch_opt)
driver.setup()

return driver

def dataloader_with_bucketedbatchsampler(dataset, length, batch_size, shuffle, drop_last):
"""
建立一个 batch_sampler 为 BucketedBatchSampler 的 dataloader
"""
dataloader = DataLoader(
dataset=dataset,
batch_sampler=BucketedBatchSampler(
dataset,
length,
batch_size,
shuffle=shuffle,
drop_last=drop_last,
),
)

return dataloader

def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0, unrepeated=False):
"""
建立一个 sampler 为 RandomSampler 的 dataloader
"""
if unrepeated:
sampler = UnrepeatedRandomSampler(dataset, shuffle, seed)
else:
sampler = RandomSampler(dataset, shuffle, seed=seed)
dataloader = DataLoader(
dataset,
sampler=sampler,
drop_last=drop_last,
batch_size=batch_size
)
return dataloader

############################################################################
#
# 测试 TorchDDPDriver 的一些函数
#
############################################################################

class TestDDPDriverFunction:
"""
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
"""

@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context
def test_multi_drivers(self):
"""
测试使用了多个 TorchDDPDriver 的情况。
"""
driver2 = generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()

@magic_argv_env_context
def test_move_data_to_device(self):
"""
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中
就不重复测试了
"""
self.driver.move_data_to_device(torch.rand((32, 64)))

dist.barrier()

@magic_argv_env_context
def test_is_distributed(self):
"""
测试 is_distributed 函数
"""
assert self.driver.is_distributed() == True
dist.barrier()

@magic_argv_env_context
def test_get_no_sync_context(self):
"""
测试 get_no_sync_context 函数
"""
res = self.driver.get_model_no_sync_context()
dist.barrier()

@magic_argv_env_context
def test_is_global_zero(self):
"""
测试 is_global_zero 函数
"""
self.driver.is_global_zero()
dist.barrier()

@magic_argv_env_context
def test_unwrap_model(self):
"""
测试 unwrap_model 函数
"""
self.driver.unwrap_model()
dist.barrier()

@magic_argv_env_context
def test_get_local_rank(self):
"""
测试 get_local_rank 函数
"""
self.driver.get_local_rank()
dist.barrier()

@magic_argv_env_context
def test_all_gather(self):
"""
测试 all_gather 函数
详细的测试在 test_dist_utils.py 中完成
"""
obj = {
"rank": self.driver.global_rank
}
obj_list = self.driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list):
assert res["rank"] == i

@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_broadcast_object(self, src_rank):
"""
测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成
"""
if self.driver.global_rank == src_rank:
obj = {
"rank": self.driver.global_rank
}
else:
obj = None
res = self.driver.broadcast_object(obj, src=src_rank)
assert res["rank"] == src_rank

############################################################################
#
# 测试 set_dist_repro_dataloader 函数
#
############################################################################

class TestSetDistReproDataloader:

@classmethod
def setup_class(cls):
cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)

def setup_method(self):
self.dataset = TorchNormalDataset(40)

"""
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
此时对应 driver.load 中的情况
"""

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

dist.barrier()
"""
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler`
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定
是否重新实例化 dataloader
"""

@magic_argv_env_context
def test_with_dist_none_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError):
# 应当抛出 RuntimeError
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)

dist.barrier()

@magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同
"""
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同
"""
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert replaced_loader is dataloader
dist.barrier()

"""
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader
"""

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
)
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性
"""
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

"""
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader
"""

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性
"""
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
"""
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

def check_distributed_sampler(self, sampler):
"""
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确
"""
assert sampler.num_replicas == dist.get_world_size()
assert sampler.rank == dist.get_rank()
if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_replicas = len(self.device)
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
dist.barrier()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新改造 dataloader
new_loader = dataloader_with_bucketedbatchsampler(
replaced_loader.dataset,
length=replaced_loader.dataset._data,
batch_size=batch_size,
shuffle=shuffle,
drop_last=False,
)
new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
# 重新构造 dataloader
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False)
new_loader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas


############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
class TestSaveLoad:
"""
测试多卡情况下 save 和 load 相关函数的表现
"""

@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)

def setup_method(self):
self.dataset = TorchArgMaxDataset(10, 20)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(self, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"

dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)

self.driver1.save_model(path, only_state_dict)

# 同步
dist.barrier()
self.driver2.load_model(path, only_state_dict)

for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)

assert torch.equal(res1["preds"], res2["preds"])
finally:
rank_zero_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"
num_replicas = len(device)

self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=4,
shuffle=True,
drop_last=False
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载
# 更改 batch_size
dataloader = dataloader_with_bucketedbatchsampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=2,
shuffle=True,
drop_last=False
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert torch.equal(res1["preds"], res2["preds"])

assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
finally:
rank_zero_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

num_replicas = len(device)

self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)

dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert torch.equal(res1["preds"], res2["preds"])

assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas

finally:
rank_zero_rm(path)

+ 103
- 0
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -0,0 +1,103 @@
import os
import pytest

os.environ["FASTNLP_BACKEND"] = "torch"

from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
from fastNLP.envs import get_gpu_count
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.utils import magic_argv_env_context

import torch

def test_incorrect_driver():

model = TorchNormalModel_Classification_1(2, 100)
with pytest.raises(ValueError):
driver = initialize_torch_driver("paddle", 0, model)

@pytest.mark.parametrize(
"device",
["cpu", "cuda:0", 0, torch.device("cuda:0")]
)
@pytest.mark.parametrize(
"driver",
["torch"]
)
def test_get_single_device(driver, device):
"""
测试正常情况下初始化TorchSingleDriver的情况
"""

model = TorchNormalModel_Classification_1(2, 100)
driver = initialize_torch_driver(driver, device, model)
assert isinstance(driver, TorchSingleDriver)

@pytest.mark.parametrize(
"device",
[0, 1]
)
@pytest.mark.parametrize(
"driver",
["torch_ddp"]
)
@magic_argv_env_context
def test_get_ddp_2(driver, device):
"""
测试 ddp 多卡的初始化情况,但传入了单个 gpu
"""

model = TorchNormalModel_Classification_1(64, 10)
driver = initialize_torch_driver(driver, device, model)

assert isinstance(driver, TorchDDPDriver)

@pytest.mark.parametrize(
"device",
[[0, 2, 3], -1]
)
@pytest.mark.parametrize(
"driver",
["torch", "torch_ddp"]
)
@magic_argv_env_context
def test_get_ddp(driver, device):
"""
测试 ddp 多卡的初始化情况
"""

model = TorchNormalModel_Classification_1(64, 10)
driver = initialize_torch_driver(driver, device, model)

assert isinstance(driver, TorchDDPDriver)

@pytest.mark.parametrize(
("driver", "device"),
[("torch_ddp", "cpu")]
)
@magic_argv_env_context
def test_get_ddp_cpu(driver, device):
"""
测试试图在 cpu 上初始化分布式训练的情况
"""
model = TorchNormalModel_Classification_1(64, 10)
with pytest.raises(ValueError):
driver = initialize_torch_driver(driver, device, model)

@pytest.mark.parametrize(
"device",
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1]
)
@pytest.mark.parametrize(
"driver",
["torch", "torch_ddp"]
)
@magic_argv_env_context
def test_device_out_of_range(driver, device):
"""
测试传入的device超过范围的情况
"""
model = TorchNormalModel_Classification_1(2, 100)
with pytest.raises(ValueError):
driver = initialize_torch_driver(driver, device, model)

+ 697
- 0
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -0,0 +1,697 @@
import os
os.environ["FASTNLP_BACKEND"] = "torch"
import pytest
from pathlib import Path

from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from fastNLP.core import rank_zero_rm

import torch
from torch.utils.data import DataLoader, BatchSampler
import paddle

def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
"""
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader
"""
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last
),
batch_size=batch_size,
drop_last=drop_last,
),
)

return dataloader

def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0):
"""
建立一个 sampler 为 RandomSampler 的 dataloader
"""
dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, shuffle, seed=seed),
drop_last=drop_last,
batch_size=batch_size
)
return dataloader

############################################################################
#
# 测试基类 TorchDrvier 中的一些简单函数
#
############################################################################

class TestTorchDriverFunctions:
"""
使用 TorchSingleDriver 测试基类的函数
"""

@classmethod
def setup_class(self):
model = TorchNormalModel_Classification_1(10, 32)
self.driver = TorchSingleDriver(model, device="cpu")

def test_check_single_optimizer_legality(self):
"""
测试传入单个 optimizer 时的表现
"""
optimizer = torch.optim.Adam(
params=self.driver.model.parameters(),
lr=0.01
)

self.driver.set_optimizers(optimizer)

optimizer = paddle.optimizer.Adam(
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
learning_rate=0.01,
)
# 传入 torch 的 optimize r时,应该报错 ValueError
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizer)

def test_check_optimizers_legality(self):
"""
测试传入 optimizer list 的表现
"""
optimizers = [
torch.optim.Adam(
params=self.driver.model.parameters(),
lr=0.01
) for i in range(10)
]

self.driver.set_optimizers(optimizers)

optimizers += [
paddle.optimizer.Adam(
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
learning_rate=0.01,
)
]

with pytest.raises(ValueError):
self.driver.set_optimizers(optimizers)

def test_check_dataloader_legality_in_train(self):
"""
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
"""
dataloader = DataLoader(TorchNormalDataset())
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

# 创建 paddle 的 dataloader
dataloader = paddle.io.DataLoader(
PaddleNormalDataset(),
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

def test_check_dataloader_legality_in_test(self):
"""
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现
"""
# 此时传入的应该是dict
dataloader = {
"train": DataLoader(TorchNormalDataset()),
"test": DataLoader(TorchNormalDataset())
}
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是 dict,应该报错
dataloader = DataLoader(TorchNormalDataset())
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# 创建 paddle 的 dataloader
train_loader = paddle.io.DataLoader(
PaddleNormalDataset(),
batch_size=32, shuffle=True
)
test_loader = paddle.io.DataLoader(
PaddleNormalDataset(),
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

def test_tensor_to_numeric(self):
"""
测试 tensor_to_numeric 函数
"""
# 单个张量
tensor = torch.tensor(3)
res = TorchSingleDriver.tensor_to_numeric(tensor)
assert res == 3

tensor = torch.rand((3, 4))
res = TorchSingleDriver.tensor_to_numeric(tensor)
assert res == tensor.tolist()

# 张量list
tensor_list = [torch.rand((6, 4, 2)) for i in range(10)]
res = TorchSingleDriver.tensor_to_numeric(tensor_list)
assert isinstance(res, list)
tensor_list = [t.tolist() for t in tensor_list]
assert res == tensor_list

# 张量tuple
tensor_tuple = tuple([torch.rand((6, 4, 2)) for i in range(10)])
res = TorchSingleDriver.tensor_to_numeric(tensor_tuple)
assert isinstance(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
assert res == tensor_tuple

# 张量dict
tensor_dict = {
"tensor": torch.rand((3, 4)),
"list": [torch.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [torch.rand((6, 4, 2)) for i in range(10)],
"tensor": torch.rand((3, 4))
},
"int": 2,
"string": "test string"
}

res = TorchSingleDriver.tensor_to_numeric(tensor_dict)
assert isinstance(res, dict)
assert res["tensor"] == tensor_dict["tensor"].tolist()
assert isinstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
assert r == d.tolist()
assert isinstance(res["int"], int)
assert isinstance(res["string"], str)
assert isinstance(res["dict"], dict)
assert isinstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()

def test_set_model_mode(self):
"""
测试set_model_mode函数
"""
self.driver.set_model_mode("train")
assert self.driver.model.training
self.driver.set_model_mode("eval")
assert not self.driver.model.training
# 应该报错
with pytest.raises(AssertionError):
self.driver.set_model_mode("test")

def test_move_model_to_device_cpu(self):
"""
测试move_model_to_device函数
"""
TorchSingleDriver.move_model_to_device(self.driver.model, "cpu")
assert self.driver.model.linear1.weight.device.type == "cpu"

def test_move_model_to_device_gpu(self):
"""
测试move_model_to_device函数
"""
TorchSingleDriver.move_model_to_device(self.driver.model, "cuda")
assert self.driver.model.linear1.weight.device.type == "cuda"
assert self.driver.model.linear1.weight.device.index == 0

def test_worker_init_function(self):
"""
测试worker_init_function
"""
# 先确保不影响运行
# TODO:正确性
TorchSingleDriver.worker_init_function(0)

def test_set_deterministic_dataloader(self):
"""
测试set_deterministic_dataloader
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(TorchNormalDataset())
self.driver.set_deterministic_dataloader(dataloader)

def test_set_sampler_epoch(self):
"""
测试set_sampler_epoch
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(TorchNormalDataset())
self.driver.set_sampler_epoch(dataloader, 0)

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args(self, batch_size, shuffle, drop_last):
"""
测试正常情况下 get_dataloader_args 的表现
"""
dataloader = DataLoader(
TorchNormalDataset(),
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
res = TorchSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
if shuffle:
assert isinstance(res.sampler, torch.utils.data.RandomSampler)
else:
assert isinstance(res.sampler, torch.utils.data.SequentialSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 batch_sampler 后 get_dataloader_args 的表现
"""
dataset = TorchNormalDataset()
dataloader = dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last)
res = TorchSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
if shuffle:
assert isinstance(res.sampler, torch.utils.data.RandomSampler)
else:
assert isinstance(res.sampler, torch.utils.data.SequentialSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 sampler 后 get_dataloader_args 的表现
"""
dataset = TorchNormalDataset()
dataloader = dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last)
res = TorchSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
assert isinstance(res.sampler, RandomSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last


############################################################################
#
# 测试 TorchSingleDrvier 中的一些简单函数
#
############################################################################

class TestSingleDeviceFunction:
"""
测试其它函数的测试例
"""

@classmethod
def setup_class(cls):
model = TorchNormalModel_Classification_1(10, 784)
cls.driver = TorchSingleDriver(model, device="cpu")

def test_unwrap_model(self):
"""
测试能否运行
"""
res = self.driver.unwrap_model()
assert res is self.driver.model

def test_is_distributed(self):
assert self.driver.is_distributed() == False

def test_move_data_to_device(self):
"""
这个函数仅调用了 torch_move_data_to_device ,测试例在 tests/core/utils/test_torch_utils.py 中
就不重复测试了
"""
self.driver.move_data_to_device(torch.rand((32, 64)))


############################################################################
#
# 测试 set_dist_repro_dataloader 函数
#
############################################################################

class TestSetDistReproDataloader:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
def setup_method(self):
self.dataset = TorchNormalDataset(20)
model = TorchNormalModel_Classification_1(10, 32)
self.driver = TorchSingleDriver(model, device="cpu")
def test_with_reproducible_false(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
当dist为字符串时,此时应该返回原来的 dataloader
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert replaced_loader is dataloader

@pytest.mark.parametrize("shuffle", [True, False])
def test_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)

assert not (replaced_loader is dataloader)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist

self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size

self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
dataloader = dataloader_with_randombatchsampler(self.dataset, 4, shuffle, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

@pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
dataloader = dataloader_with_randomsampler(self.dataset, 2, shuffle, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle

self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)

############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################

def generate_random_driver(features, labels, fp16=False, device="cpu"):
"""
生成driver
"""
model = TorchNormalModel_Classification_1(labels, features)
opt = torch.optim.Adam(params=model.parameters(), lr=0.01)
driver = TorchSingleDriver(model, device=device, fp16=fp16)
driver.set_optimizers(opt)
driver.setup()

return driver

@pytest.fixture
def prepare_test_save_load():
dataset = TorchArgMaxDataset(10, 40)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

assert torch.equal(res1["preds"], res2["preds"])
finally:
rank_zero_rm(path)

@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"
dataset = TorchArgMaxDataset(10, 40)
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False)
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda")

num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载
# 更改 batch_size

dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

batch = driver2.move_data_to_device(batch)
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert torch.equal(res1["preds"], res2["preds"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
rank_zero_rm(path)

@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda")
dataset = TorchArgMaxDataset(10, 40)
dataloader = dataloader_with_randomsampler(dataset, 4, True, False)
num_consumed_batches = 2

already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载
# 更改 batch_size
dataloader = dataloader_with_randomsampler(dataset, 2, True, False)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

batch = driver2.move_data_to_device(batch)
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert torch.equal(res1["preds"], res2["preds"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
rank_zero_rm(path)

+ 36
- 35
tests/core/drivers/torch_driver/test_utils.py View File

@@ -1,35 +1,36 @@
from torch.utils.data.sampler import SequentialSampler, RandomSampler

from fastNLP.core.samplers.sampler import ReproduceSampler
from tests.helpers.datasets.normal_data import NormalIterator


class TestReproduceSampler:

def test_sequentialsampler(self):
normal_iterator = NormalIterator(num_of_data=20)
sequential_sampler = SequentialSampler(normal_iterator)

reproduce_sampler = ReproduceSampler(sequential_sampler)
# iter_seq_sampler = iter(sequential_sampler)
# for each in iter_seq_sampler:
# print(each)
iter_reproduce_sampler = iter(reproduce_sampler)
forward_step = 3
for _ in range(forward_step):
next(iter_reproduce_sampler)
state = reproduce_sampler.save_state()
assert state["current_batch_idx"] == forward_step

new_repro_sampler = ReproduceSampler(sequential_sampler)
assert new_repro_sampler.save_state()["current_batch_idx"] == 0

new_repro_sampler.load_state(state)
iter_new_repro_sampler = iter(new_repro_sampler)
new_index_list = []
for each in iter_new_repro_sampler:
new_index_list.append(each)
assert new_index_list == list(range(3, 20))



import os
import pytest
os.environ["FASTNLP_BACKEND"] = "torch"

from fastNLP.core.drivers.torch_driver.utils import (
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from torch.utils.data import DataLoader, BatchSampler

from tests.helpers.datasets.torch_data import TorchNormalDataset

def test_replace_batch_sampler():
dataset = TorchNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)

replaced_loader = replace_batch_sampler(dataloader, batch_sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.dataset, TorchNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16

def test_replace_sampler():
dataset = TorchNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
sampler = RandomSampler(dataset)

replaced_loader = replace_sampler(dataloader, sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)

+ 5
- 5
tests/helpers/callbacks/helper_callbacks.py View File

@@ -38,7 +38,7 @@ class RecordMetricCallback(Callback):
self.metric_threshold = metric_threshold self.metric_threshold = metric_threshold
self.metric_begin_value = None self.metric_begin_value = None


def on_validate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
self.metric = results[self.monitor] self.metric = results[self.monitor]
if self.metric_begin_value is None: if self.metric_begin_value is None:
self.metric_begin_value = self.metric self.metric_begin_value = self.metric
@@ -113,11 +113,11 @@ class RecordTrainerEventTriggerCallback(Callback):
def on_after_zero_grad(self, trainer, optimizers): def on_after_zero_grad(self, trainer, optimizers):
print("on_after_zero_grad") print("on_after_zero_grad")


def on_validate_begin(self, trainer):
print("on_validate_begin")
def on_evaluate_begin(self, trainer):
print("on_evaluate_begin")


def on_validate_end(self, trainer, results):
print("on_validate_end")
def on_evaluate_end(self, trainer, results):
print("on_evaluate_end")








+ 1
- 1
tests/helpers/datasets/torch_data.py View File

@@ -38,7 +38,7 @@ class TorchNormalDataset_Classification(Dataset):
return {"x": self.x[item], "y": self.y[item]} return {"x": self.x[item], "y": self.y[item]}




class TorchArgMaxDatset(Dataset):
class TorchArgMaxDataset(Dataset):
def __init__(self, feature_dimension=10, data_num=1000, seed=0): def __init__(self, feature_dimension=10, data_num=1000, seed=0):
self.num_labels = feature_dimension self.num_labels = feature_dimension
self.feature_dimension = feature_dimension self.feature_dimension = feature_dimension


+ 1009
- 0
tutorials/fastnlp_tutorial_0.ipynb
File diff suppressed because it is too large
View File


Loading…
Cancel
Save