@@ -12,6 +12,34 @@ from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
class Callback: | class Callback: | ||||
r""" | r""" | ||||
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | ||||
callback 调用时机顺序大概如下 | |||||
Trainer.__init__(): | |||||
on_after_trainer_initialized() | |||||
Trainer.run(): | |||||
if num_eval_sanity_batch>0: | |||||
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end() | |||||
try: | |||||
on_train_begin() | |||||
while cur_epoch_idx < n_epochs: | |||||
on_train_epoch_begin() | |||||
while batch_idx_in_epoch<=num_batches_per_epoch: | |||||
on_fetch_data_begin() | |||||
on_fetch_data_end() | |||||
on_train_batch_begin() | |||||
on_before_backward() | |||||
on_after_backward() | |||||
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end() | |||||
on_train_epoch_end() | |||||
except BaseException: | |||||
self.on_exception() | |||||
finally: | |||||
on_train_end() | |||||
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 | |||||
""" | """ | ||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
@@ -221,9 +249,9 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_validate_begin(self, trainer): | |||||
def on_evaluate_begin(self, trainer): | |||||
""" | """ | ||||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||||
在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 | |||||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | ||||
:param trainer: | :param trainer: | ||||
@@ -231,9 +259,9 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
""" | """ | ||||
结束 validate 时调用,并把 validate 的结果传入。 | |||||
结束 evaluate 时调用,并把 evaluate 的结果传入。 | |||||
:param trainer: | :param trainer: | ||||
:param results: Evaluate 的结果,一般是个 dict 。 | :param results: Evaluate 的结果,一般是个 dict 。 | ||||
@@ -96,8 +96,8 @@ class Events(EventEnum): | |||||
on_after_optimizers_step = "on_after_optimizers_step" | on_after_optimizers_step = "on_after_optimizers_step" | ||||
on_before_zero_grad = "on_before_zero_grad" | on_before_zero_grad = "on_before_zero_grad" | ||||
on_after_zero_grad = "on_after_zero_grad" | on_after_zero_grad = "on_after_zero_grad" | ||||
on_validate_begin = "on_validate_begin" | |||||
on_validate_end = "on_validate_end" | |||||
on_evaluate_begin = "on_evaluate_begin" | |||||
on_evaluate_end = "on_evaluate_end" | |||||
class EventsList: | class EventsList: | ||||
@@ -8,7 +8,6 @@ __all__ = [ | |||||
from .callback_events import Events | from .callback_events import Events | ||||
from .callback import Callback | from .callback import Callback | ||||
from .progress_callback import ProgressCallback, choose_progress_callback | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -35,7 +34,7 @@ class CallbackManager: | |||||
class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; | class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; | ||||
callback_fns: dict | callback_fns: dict | ||||
def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'): | |||||
def __init__(self, callbacks: Optional[List[Callback]]): | |||||
r""" | r""" | ||||
注意 callback 的调用顺序: | 注意 callback 的调用顺序: | ||||
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | ||||
@@ -46,7 +45,6 @@ class CallbackManager: | |||||
""" | """ | ||||
self._need_reproducible_sampler = False | self._need_reproducible_sampler = False | ||||
_has_progress_callback = False | |||||
_callbacks = [] | _callbacks = [] | ||||
if callbacks is not None: | if callbacks is not None: | ||||
if isinstance(callbacks, Callback): | if isinstance(callbacks, Callback): | ||||
@@ -57,16 +55,7 @@ class CallbackManager: | |||||
for _callback in callbacks: | for _callback in callbacks: | ||||
if not isinstance(_callback, Callback): | if not isinstance(_callback, Callback): | ||||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | ||||
if isinstance(_callback, ProgressCallback): | |||||
_has_progress_callback = True | |||||
_callbacks += callbacks | _callbacks += callbacks | ||||
if not _has_progress_callback: | |||||
# 添加 progress callback | |||||
progress_callback = choose_progress_callback(progress_bar=progress_bar) | |||||
if progress_callback is None: | |||||
logger.info("There is no progress bar, Trainer will not output training progress.") | |||||
else: | |||||
_callbacks.append(progress_callback) | |||||
self.callback_fns = defaultdict(list) | self.callback_fns = defaultdict(list) | ||||
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | ||||
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | ||||
@@ -292,9 +281,9 @@ class CallbackManager: | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_validate_begin(self, trainer): | |||||
def on_evaluate_begin(self, trainer): | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
pass | pass |
@@ -114,7 +114,7 @@ class CheckpointCallback(Callback): | |||||
if self.topk_saver.topk_queue and trainer.evaluator is None: | if self.topk_saver.topk_queue and trainer.evaluator is None: | ||||
logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") | logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
# 如果发生了保存,则返回的 folder 不为 None | # 如果发生了保存,则返回的 folder 不为 None | ||||
folder = self.topk_saver.save_topk(trainer, results) | folder = self.topk_saver.save_topk(trainer, results) | ||||
@@ -16,13 +16,13 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | 果(字典类型),返回一个 float 值作为 monitor 的结果。 | ||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param patience: 多少次 validate 不没有提升就停止。 | |||||
:param patience: 多少次 evaluate 不没有提升就停止。 | |||||
""" | """ | ||||
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | ||||
self.wait = 0 | self.wait = 0 | ||||
self.patience = patience | self.patience = patience | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
monitor_value = self.get_monitor_value(results) | monitor_value = self.get_monitor_value(results) | ||||
if monitor_value is None: | if monitor_value is None: | ||||
return | return | ||||
@@ -32,13 +32,13 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
self.wait += 1 | self.wait += 1 | ||||
def on_fetch_data_begin(self, trainer): | def on_fetch_data_begin(self, trainer): | ||||
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
# 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
if self.wait >= self.patience: | if self.wait >= self.patience: | ||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | raise EarlyStopException(f"After {self.wait} validations, no improvement for " | ||||
f"metric `{self._real_monitor}`") | f"metric `{self._real_monitor}`") | ||||
def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
# 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
if self.wait >= self.patience: | if self.wait >= self.patience: | ||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | raise EarlyStopException(f"After {self.wait} validations, no improvement for " | ||||
f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | ||||
@@ -216,6 +216,6 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||||
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') | ||||
self.execute_fn = execute_fn | self.execute_fn = execute_fn | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
if self.is_better_results(results): | if self.is_better_results(results): | ||||
self.execute_fn() | self.execute_fn() |
@@ -76,7 +76,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
@@ -95,25 +95,14 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
self.buffer.seek(0) | self.buffer.seek(0) | ||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
if self.delete_after_after: | |||||
if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
# 只需要 rank 0 执行删除。 | |||||
logger.info(f"Deleting {self.real_save_folder}...") | |||||
shutil.rmtree(self.real_save_folder) | |||||
try: | |||||
# 如果是 emtpy 的,就会被删除掉 | |||||
os.rmdir(self.save_folder) | |||||
except: | |||||
pass | |||||
elif hasattr(self, 'buffer'): | |||||
self.buffer.close() | |||||
del self.buffer | |||||
self._delete_after_after(trainer) | |||||
def on_exception(self, trainer, exception): | |||||
def _delete_after_after(self, trainer): | |||||
trainer.driver.barrier() | |||||
if self.delete_after_after: | if self.delete_after_after: | ||||
if self.real_save_folder: # 这里,谁处异常,谁删除 | |||||
if self.real_save_folder: | |||||
logger.info(f"Deleting {self.real_save_folder}...") | logger.info(f"Deleting {self.real_save_folder}...") | ||||
shutil.rmtree(self.real_save_folder) | |||||
shutil.rmtree(self.real_save_folder, ignore_errors=True) | |||||
try: | try: | ||||
# 如果是 emtpy 的,就会被删除掉 | # 如果是 emtpy 的,就会被删除掉 | ||||
os.rmdir(self.save_folder) | os.rmdir(self.save_folder) | ||||
@@ -31,8 +31,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
:param dataloaders: 需要评估的数据 | :param dataloaders: 需要评估的数据 | ||||
:param metrics: 使用的 metrics 。 | :param metrics: 使用的 metrics 。 | ||||
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch | |||||
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||||
:param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch evaluate 一次;(2) 为正整数则表示每隔几个 batch | |||||
evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 | |||||
一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | ||||
:param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | ||||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | ||||
@@ -108,7 +108,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
'metrics': self.metrics, | 'metrics': self.metrics, | ||||
'driver': self.kwargs.get('driver', trainer.driver), | 'driver': self.kwargs.get('driver', trainer.driver), | ||||
'device': self.kwargs.get('device', trainer.device), | 'device': self.kwargs.get('device', trainer.device), | ||||
'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn), | |||||
'evaluate_batch_step_fn': self.kwargs.get('evaluate_batch_step_fn', trainer.evaluate_batch_step_fn), | |||||
'evaluate_fn': self.evaluate_fn, | 'evaluate_fn': self.evaluate_fn, | ||||
'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), | 'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), | ||||
'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), | 'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), | ||||
@@ -128,7 +128,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) | results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) | ||||
self.topk_saver.get_monitor_value(results) | self.topk_saver.get_monitor_value(results) | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
results = self.evaluator.run() | results = self.evaluator.run() | ||||
self.topk_saver.save_topk(trainer, results) | self.topk_saver.save_topk(trainer, results) | ||||
@@ -137,8 +137,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
if self.watch_monitor is not None: | if self.watch_monitor is not None: | ||||
return | return | ||||
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | ||||
validate_every = -self.evaluate_every | |||||
if trainer.cur_epoch_idx % validate_every == 0: | |||||
evaluate_every = -self.evaluate_every | |||||
if trainer.cur_epoch_idx % evaluate_every == 0: | |||||
results = self.evaluator.run() | results = self.evaluator.run() | ||||
self.topk_saver.save_topk(trainer, results) | self.topk_saver.save_topk(trainer, results) | ||||
@@ -1,6 +1,6 @@ | |||||
import json | import json | ||||
import sys | import sys | ||||
from typing import Union | |||||
__all__ = [ | __all__ = [ | ||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
@@ -11,11 +11,22 @@ __all__ = [ | |||||
from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils.utils import is_notebook | |||||
def choose_progress_callback(progress_bar:str): | |||||
class ProgressCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | |||||
f_rich_progress.stop() | |||||
@property | |||||
def name(self): # progress bar的名称 | |||||
return 'auto' | |||||
def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: | |||||
if progress_bar == 'auto': | if progress_bar == 'auto': | ||||
if (sys.stdin and sys.stdin.isatty()): | |||||
if not f_rich_progress.dummy_rich: | |||||
progress_bar = 'rich' | progress_bar = 'rich' | ||||
else: | else: | ||||
progress_bar = 'raw' | progress_bar = 'raw' | ||||
@@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str): | |||||
return RichCallback() | return RichCallback() | ||||
elif progress_bar == 'raw': | elif progress_bar == 'raw': | ||||
return RawTextCallback() | return RawTextCallback() | ||||
elif isinstance(progress_bar, ProgressCallback): | |||||
return progress_bar | |||||
else: | else: | ||||
return None | return None | ||||
class ProgressCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | |||||
f_rich_progress.stop() | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
format_json=True): | format_json=True): | ||||
@@ -92,7 +100,7 @@ class RichCallback(ProgressCallback): | |||||
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | ||||
advance=self.epoch_bar_update_advance, refresh=True) | advance=self.epoch_bar_update_advance, refresh=True) | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
if len(results)==0: | if len(results)==0: | ||||
return | return | ||||
rule_style = '' | rule_style = '' | ||||
@@ -114,9 +122,6 @@ class RichCallback(ProgressCallback): | |||||
else: | else: | ||||
self.progress_bar.print(results) | self.progress_bar.print(results) | ||||
def on_exception(self, trainer, exception): | |||||
self.clear_tasks() | |||||
def clear_tasks(self): | def clear_tasks(self): | ||||
for key, taskid in self.task2id.items(): | for key, taskid in self.task2id.items(): | ||||
self.progress_bar.destroy_task(taskid) | self.progress_bar.destroy_task(taskid) | ||||
@@ -124,6 +129,10 @@ class RichCallback(ProgressCallback): | |||||
self.task2id = {} | self.task2id = {} | ||||
self.loss = 0 | self.loss = 0 | ||||
@property | |||||
def name(self): # progress bar的名称 | |||||
return 'rich' | |||||
class RawTextCallback(ProgressCallback): | class RawTextCallback(ProgressCallback): | ||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
@@ -166,7 +175,7 @@ class RawTextCallback(ProgressCallback): | |||||
f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' | f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' | ||||
logger.info(text) | logger.info(text) | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
if len(results)==0: | if len(results)==0: | ||||
return | return | ||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
@@ -184,3 +193,7 @@ class RawTextCallback(ProgressCallback): | |||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | ||||
else: | else: | ||||
logger.info(results) | logger.info(results) | ||||
@property | |||||
def name(self): # progress bar的名称 | |||||
return 'raw' |
@@ -0,0 +1,181 @@ | |||||
from typing import List, Union, Dict, Callable, Sequence, Mapping | |||||
from fastNLP.core.log import logger | |||||
from .padders.get_padder import get_padder | |||||
import re | |||||
from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||||
pack_batch_sequence, NESTED_DICT_SEPARATOR | |||||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] | |||||
class Collator: | |||||
def __init__(self, backend='torch'): | |||||
""" | |||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||||
若为 None ,则不进行 padding 。 | |||||
""" | |||||
self.unpack_batch_func = None | |||||
self.pack_batch_func = None | |||||
self.ignore_fields = set() | |||||
self.padders = {} | |||||
self.input_fields = {} | |||||
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 | |||||
self.set_backend(backend) | |||||
def __call__(self, batch)->Union[List, Dict]: | |||||
""" | |||||
batch可能存在三种可能性 | |||||
List[Dict], List[List], List[Sample] | |||||
第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 | |||||
第二步:使用每个 field 各自的 padder 进行 pad 。 | |||||
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 | |||||
第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 | |||||
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample | |||||
的类别。 | |||||
第一次调用会根据当前 field 决定对应的 Padder 。 | |||||
""" | |||||
if self.unpack_batch_func is None: | |||||
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 | |||||
if self.batch_data_type is None: | |||||
if isinstance(batch[0], Mapping): | |||||
self.batch_data_type = 'd' | |||||
elif isinstance(batch[0], Sequence): # 这里存在误判的风险 | |||||
self.batch_data_type = 'l' | |||||
else: | |||||
self.batch_data_type = 's' | |||||
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||||
f"is {self.batch_data_type}") | |||||
if self.batch_data_type == 's': | |||||
self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整 | |||||
self.pack_batch_func = lambda x:x['_single'] | |||||
elif self.batch_data_type == 'l': | |||||
self.unpack_batch_func = unpack_batch_sequence | |||||
self.pack_batch_func = pack_batch_sequence | |||||
elif self.batch_data_type == 'd': | |||||
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value} | |||||
self.unpack_batch_func = unpack_batch_nested_mapping | |||||
self.pack_batch_func = pack_batch_nested_mapping | |||||
else: | |||||
self.unpack_batch_func = unpack_batch_mapping | |||||
self.pack_batch_func = lambda x:x | |||||
unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。 | |||||
pad_batch = {} | |||||
if len(self.padders)==0: # 第一次运行,准备 padder | |||||
for key in unpack_batch.keys(): | |||||
if key not in self.input_fields and key not in self.ignore_fields: | |||||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | |||||
for field_name, setting in self.input_fields.items(): | |||||
pad_fn = setting.get('pad_fn', None) | |||||
if callable(pad_fn): | |||||
padder = pad_fn | |||||
else: | |||||
batch_field = unpack_batch.get(field_name) | |||||
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | |||||
dtype=setting['dtype'], backend=setting['backend'], | |||||
field_name=field_name) | |||||
self.padders[field_name] = padder | |||||
if self.batch_data_type == 'l': | |||||
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | |||||
for key, padder in self.padders.items(): | |||||
batch = unpack_batch.get(key) | |||||
pad_batch[key] = padder(batch) | |||||
return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||||
def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> "Collator": | |||||
""" | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | |||||
self.padders.clear() # 重新生成 | |||||
if self.batch_data_type is not None: | |||||
if self.batch_data_type == 's': | |||||
logger.debug("Set as single field mode.") | |||||
self.input_fields.clear() | |||||
elif self.batch_data_type == 'd': | |||||
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ | |||||
f"index, but other field is set as dict mode." | |||||
elif self.batch_data_type == 'l': | |||||
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||||
f"field name is {field_name}" | |||||
if field_name == '_single': | |||||
self.batch_data_type = 's' | |||||
elif sequence_idx_str.match(field_name): | |||||
self.batch_data_type = 'l' | |||||
else: | |||||
self.batch_data_type = 'd' | |||||
if field_name in self.ignore_fields: | |||||
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") | |||||
if backend is None: | |||||
backend = self.backend | |||||
else: | |||||
assert backend in SUPPORTED_BACKENDS | |||||
self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} | |||||
return self | |||||
def set_backend(self, backend:str): | |||||
""" | |||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||||
若为 None ,则不进行 padding 。 | |||||
:return: | |||||
""" | |||||
assert backend in SUPPORTED_BACKENDS | |||||
self.padders.clear() | |||||
self.backend = backend | |||||
def set_ignore(self, *field_names) -> "Collator": | |||||
""" | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
""" | |||||
for field_name in field_names: | |||||
if field_name in self.input_fields: | |||||
self.input_fields.pop(field_name) | |||||
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | |||||
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 | |||||
self.ignore_fields.add(field_name) | |||||
return self | |||||
@@ -0,0 +1,44 @@ | |||||
__all__ = [ | |||||
'InconsistencyError', | |||||
'EleDtypeUnsupportedError', | |||||
'EleDtypeDtypeConversionError', | |||||
'DtypeUnsupportedError', | |||||
"DtypeError" | |||||
] | |||||
class InconsistencyError(BaseException): | |||||
""" | |||||
当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。 | |||||
""" | |||||
def __init__(self, msg, *args): | |||||
super(InconsistencyError, self).__init__(msg, *args) | |||||
class DtypeError(BaseException): | |||||
def __init__(self, msg, *args): | |||||
super(DtypeError, self).__init__(msg, *args) | |||||
self.msg = msg | |||||
class EleDtypeUnsupportedError(DtypeError): | |||||
""" | |||||
当 batch 中的 element 的类别本身无法 pad 的时候报错。 | |||||
例如要求 str 类型的数据进行 padding 。 | |||||
""" | |||||
class EleDtypeDtypeConversionError(DtypeError): | |||||
""" | |||||
当 batch 中的 element 的类别无法转换为 dtype 类型时报错。 | |||||
""" | |||||
class DtypeUnsupportedError(DtypeError): | |||||
""" | |||||
当当前 backend 不支持这种类型的 dtype 时报错。 | |||||
""" |
@@ -0,0 +1,193 @@ | |||||
from typing import Dict | |||||
from typing import Sequence, Any, Union, Dict | |||||
from abc import ABC | |||||
from fastNLP.core.log import logger | |||||
from .padder import Padder, NullPadder | |||||
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder | |||||
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | |||||
from .raw_padder import RawNumberPadder, RawSequencePadder | |||||
from .exceptions import * | |||||
def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: | |||||
""" | |||||
根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。 | |||||
:param batch_field: 将某 field 的内容组合成一个 batch 传入。 | |||||
:param pad_val: | |||||
:param backend: | |||||
:param dtype: | |||||
:param field_name: 方便报错的。 | |||||
:return: | |||||
""" | |||||
logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field)) | |||||
if pad_val is None: | |||||
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | |||||
return NullPadder() | |||||
if backend is None: | |||||
logger.debug(f"The backend for field:{field_name} is None, not padding this field.") | |||||
return NullPadder() | |||||
# 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。 | |||||
must_pad = False | |||||
if pad_val != 0 or dtype is not None: | |||||
must_pad = True | |||||
catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。 | |||||
# 根据 catalog 来判定当前是否可以进行 pad 。 | |||||
# 首先检查是否所有的 key 是一样长的,表明深度是一致的 | |||||
depths = set(map(len, catalog.keys())) | |||||
num_depth = len(depths) | |||||
if num_depth != 1: | |||||
msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \ | |||||
f"information please set logger's level to DEBUG." | |||||
if must_pad: | |||||
raise InconsistencyError(msg) | |||||
logger.debug(msg) | |||||
return NullPadder() | |||||
# 再检查所有的元素 shape 是否一致? | |||||
shape_lens = set([len(v[0]) for v in catalog.values()]) | |||||
num_shape = len(shape_lens) | |||||
if num_shape != 1: | |||||
msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \ | |||||
f"information please set logger's level to DEBUG." | |||||
if must_pad: | |||||
raise InconsistencyError(msg) | |||||
logger.debug(msg) | |||||
return NullPadder() | |||||
# 再检查所有的元素 type 是否一致 | |||||
ele_dtypes = set([v[1] for v in catalog.values()]) | |||||
num_eletypes = len(ele_dtypes) | |||||
if num_eletypes != 1: | |||||
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ | |||||
f"information please set logger's level to DEBUG." | |||||
if must_pad: | |||||
raise InconsistencyError(msg) | |||||
logger.debug(msg) | |||||
return NullPadder() | |||||
depth = depths.pop() | |||||
shape_len = shape_lens.pop() | |||||
ele_dtype = ele_dtypes.pop() | |||||
# 需要由 padder 自己决定是否能够 pad 。 | |||||
try: | |||||
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] | |||||
if backend == 'raw': | |||||
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
elif backend == 'numpy': | |||||
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
elif backend == 'torch': | |||||
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | |||||
if backend == 'raw': | |||||
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
elif backend == 'numpy': | |||||
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
elif backend == 'torch': | |||||
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
if depth == 1 and shape_len != 0: | |||||
if backend == 'numpy': | |||||
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
elif backend == 'torch': | |||||
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
if shape_len != 0 and depth>1: | |||||
msg = "Does not support pad tensor under nested list. If you need this, please report." | |||||
if must_pad: | |||||
raise RuntimeError(msg) | |||||
logger.debug(msg) | |||||
return NullPadder() | |||||
except DtypeError as e: | |||||
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | |||||
"information please set logger's level to DEBUG." | |||||
if must_pad: | |||||
raise type(e)(msg=msg) | |||||
logger.debug(msg) | |||||
return NullPadder() | |||||
except BaseException as e: | |||||
raise e | |||||
return NullPadder() | |||||
class HasShapeDtype(ABC): | |||||
""" | |||||
检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is HasShapeDtype: | |||||
if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'): | |||||
return True | |||||
return False | |||||
return NotImplemented | |||||
def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: | |||||
""" | |||||
获取对象的中 element 的基本信息,用于判断是否可以 padding。 | |||||
:param content: | |||||
:param tuple parent: | |||||
:param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。 | |||||
例如: [1, 2, 3] -> {(0,): ((), <class 'int'>), (1,): ((), <class 'int'>), (2,): ((), <class 'int'>)} | |||||
例如: [1, [2, 3], 4] -> {(0,): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (1, 1): ((), <class 'int'>), (2,): ((), <class 'int'>)} | |||||
例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), <class 'int'>), (0, 1): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (2, 0): ((), <class 'int'>), (2, 1): ((), <class 'int'>)} | |||||
例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)] | |||||
-> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)} | |||||
:return: | |||||
""" | |||||
if catalog is None: | |||||
catalog = {} | |||||
if parent is None: | |||||
parent = () | |||||
if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray | |||||
shape = content.shape | |||||
dtype = content.dtype | |||||
catalog[parent] = (shape, dtype) | |||||
elif isinstance(content, (tuple, list)): | |||||
for i, c in enumerate(content): | |||||
_get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog) | |||||
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 | |||||
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 | |||||
return catalog | |||||
""" | |||||
from numbers import Number | |||||
issubclass(type(3), Number) # True | |||||
issubclass(type(3.1), Number) # True | |||||
issubclass(type('3'), Number) # False | |||||
issubclass(type(True), Number) # True | |||||
issubclass(type(np.zeros(3)[0]), Number) # True | |||||
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True | |||||
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True | |||||
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 | |||||
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) | |||||
""" | |||||
@@ -0,0 +1,72 @@ | |||||
__all__ = [ | |||||
'NumpyNumberPadder', | |||||
'NumpySequencePadder', | |||||
] | |||||
from numbers import Number | |||||
from abc import ABC | |||||
from typing import Any, Union | |||||
import numpy as np | |||||
from .padder import Padder | |||||
from .utils import get_padded_numpy_array, is_number_or_numpy_number | |||||
from .exceptions import * | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
if not is_number_or_numpy_number(ele_dtype): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"or numpy numbers but get `{ele_dtype}`.") | |||||
if dtype is None: | |||||
dtype = ele_dtype | |||||
else: | |||||
if not is_number_or_numpy_number(dtype): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or numpy numbers but get `{dtype}`.") | |||||
dtype = dtype | |||||
return dtype | |||||
class NumpyNumberPadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
return np.array(batch_field, dtype=dtype) | |||||
class NumpySequencePadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) | |||||
class NumpyTensorPadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
""" | |||||
pad 类似于 [np.array([3, 4], np.array([1])] 的 field | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
shapes = [field.shape for field in batch_field] | |||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||||
array = np.full(max_shape, fill_value=pad_val, dtype=dtype) | |||||
for i, field in enumerate(batch_field): | |||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||||
array[slices] = field | |||||
return array | |||||
@@ -0,0 +1,21 @@ | |||||
class Padder: | |||||
def __init__(self, pad_val, dtype): | |||||
self.pad_val = pad_val | |||||
self.dtype = dtype | |||||
def __call__(self, batch_field): | |||||
return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
raise NotImplementedError() | |||||
class NullPadder(Padder): | |||||
def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
def __call__(self, batch_field): | |||||
# 直接返回,不调用 pad() 方法加快速度。 | |||||
return batch_field |
@@ -0,0 +1,48 @@ | |||||
from .padder import Padder | |||||
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array | |||||
from .exceptions import * | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
if is_number(ele_dtype): | |||||
if dtype is None: | |||||
dtype = ele_dtype | |||||
elif not is_number(dtype): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " | |||||
f"get `{dtype}`.") | |||||
else: | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"but get `{ele_dtype}`.") | |||||
return dtype | |||||
class RawNumberPadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
def __call__(self, batch_field): | |||||
return batch_field | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
raise NotImplementedError() | |||||
class RawSequencePadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
""" | |||||
:param batch_field: | |||||
:param pad_val: | |||||
:param dtype: 该参数无意义。 | |||||
:return: | |||||
""" | |||||
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() |
@@ -0,0 +1,157 @@ | |||||
from inspect import isclass | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
numpy_to_torch_dtype_dict = { | |||||
np.bool_: torch.bool, | |||||
np.uint8: torch.uint8, | |||||
np.int8: torch.int8, | |||||
np.int16: torch.int16, | |||||
np.int32: torch.int32, | |||||
np.int64: torch.int64, | |||||
np.float16: torch.float16, | |||||
np.float32: torch.float32, | |||||
np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||||
np.complex64: torch.complex64, | |||||
np.complex128: torch.complex128 | |||||
} | |||||
number_to_torch_dtype_dict = { | |||||
float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64 | |||||
int: torch.int64, | |||||
bool: torch.bool | |||||
} | |||||
from .padder import Padder | |||||
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||||
from .exceptions import * | |||||
def is_torch_tensor(dtype): | |||||
if not isclass(dtype) and isinstance(dtype, torch.dtype): | |||||
return True | |||||
return False | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | |||||
if dtype is not None: | |||||
if not (is_torch_tensor(dtype) or is_number(dtype)): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or torch.dtype but get `{dtype}`.") | |||||
dtype = number_to_torch_dtype_dict.get(dtype, dtype) | |||||
else: | |||||
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||||
return dtype | |||||
class TorchNumberPadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
return torch.tensor(batch_field, dtype=dtype) | |||||
class TorchSequencePadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||||
return tensor | |||||
class TorchTensorPadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
""" | |||||
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
shapes = [field.shape for field in batch_field] | |||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||||
if isinstance(dtype, np.dtype): | |||||
print(dtype) | |||||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | |||||
for i, field in enumerate(batch_field): | |||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||||
if isinstance(field, np.ndarray): | |||||
field = torch.from_numpy(field) | |||||
tensor[slices] = field | |||||
return tensor | |||||
def fill_tensor(batch_field, padded_batch, dtype): | |||||
""" | |||||
将 batch_field 中的值填入到 tensor 中。 | |||||
:param batch_field: 需要填充进入 array 中的内容 | |||||
:param padded_batch: 待填充的 tensor | |||||
:param dtype: 数据的类别 | |||||
:return: | |||||
""" | |||||
if padded_batch.ndim == 2: | |||||
for i, content_i in enumerate(batch_field): | |||||
padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype) | |||||
elif padded_batch.ndim == 3: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) | |||||
elif padded_batch.ndim == 4: | |||||
try: # 应该是图像,所以直接应该就 ok 了。 | |||||
padded_batch = np.array(batch_field) | |||||
except: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
for k, content_iii in enumerate(content_ii): | |||||
padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype) | |||||
elif padded_batch.ndim == 1: | |||||
padded_batch[:] = torch.tensor(batch_field, dtype=dtype) | |||||
else: | |||||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
"report.") | |||||
return padded_batch | |||||
def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0): | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]]) | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
shapes = get_shape(batch_field) | |||||
tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val) | |||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||||
return tensor |
@@ -0,0 +1,20 @@ | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
def is_torch_tensor_dtype(dtype) -> bool: | |||||
""" | |||||
返回当前 dtype 是否是 torch 的 dtype 类型 | |||||
:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 | |||||
:return: | |||||
""" | |||||
try: | |||||
return isinstance(dtype, torch.dtype) | |||||
except: | |||||
return False |
@@ -0,0 +1,173 @@ | |||||
from typing import Sequence, List | |||||
from numbers import Number | |||||
import re | |||||
from inspect import isclass | |||||
import numpy as np | |||||
np_str_obj_array_pattern = re.compile(r'[SaUO]') | |||||
def get_shape(batch_field:List, shape=None): | |||||
""" | |||||
给定 field 返回这个 field pad 完成之后的 shape 。 | |||||
例如: [[1, 2, 3], [3]] -> [2, 3] | |||||
[[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3] | |||||
:param batch_field: list,第 0 维一般为 batch 维度。 | |||||
:param shape: 无需传入。 | |||||
:return: | |||||
""" | |||||
if shape is None: | |||||
shape = [] | |||||
if isinstance(batch_field, Sequence): | |||||
num_ele = len(batch_field) | |||||
_shape = shape + [num_ele] | |||||
try: | |||||
shapes = [] | |||||
if isinstance(batch_field[0], Sequence): | |||||
for _field in batch_field: | |||||
shapes.append(get_shape(_field, _shape)) | |||||
max_shape = [max(_) for _ in zip(*shapes)] | |||||
return max_shape | |||||
except IndexError: # 空的shape | |||||
pass | |||||
return _shape # 说明是一个空的 sequence | |||||
else: | |||||
return shape | |||||
def fill_array(batch_field:List, padded_batch:np.ndarray): | |||||
""" | |||||
将 batch_field 中的值填入到 array 中。 | |||||
:param batch_field: 需要填充进入 array 中的内容 | |||||
:param padded_batch: 待填充的 np.ndarray | |||||
:return: | |||||
""" | |||||
if padded_batch.ndim == 2: | |||||
for i, content_i in enumerate(batch_field): | |||||
padded_batch[i, :len(content_i)] = content_i | |||||
elif padded_batch.ndim == 3: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
padded_batch[i, j, :len(content_ii)] = content_ii | |||||
elif padded_batch.ndim == 4: | |||||
try: # 应该是图像,所以直接应该就 ok 了。 | |||||
padded_batch = np.array(batch_field) | |||||
except: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
for k, content_iii in enumerate(content_ii): | |||||
padded_batch[i, j, k, :len(content_iii)] = content_iii | |||||
elif padded_batch.ndim == 1: | |||||
padded_batch[:] = batch_field | |||||
else: | |||||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
"report.") | |||||
return padded_batch | |||||
def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> np.array([[1, 2], [3, 0]]) | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
shapes = get_shape(batch_field) | |||||
array = np.full(shapes, dtype=dtype, fill_value=pad_val) | |||||
array = fill_array(batch_field, array) | |||||
return array | |||||
def get_padded_nest_list(batch_field: List, pad_val=0) -> List: | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> [[1, 2], [3, 0]] | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist() | |||||
return array | |||||
def is_number_or_numpy_number(dtype): | |||||
""" | |||||
判断 dtype 是否是数字类型,或者 numpy 的数字类型。 | |||||
is_number_or_numpy_number(type(3)) # True | |||||
is_number_or_numpy_number(type(3.1)) # True | |||||
is_number_or_numpy_number(type('3')) # False | |||||
is_number_or_numpy_number(type(True)) # True | |||||
is_number_or_numpy_number(type(np.zeros(3)[0])) # True | |||||
is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True | |||||
is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True | |||||
is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False | |||||
is_number_or_numpy_number(np.array([1, [2]]).dtype) # False | |||||
:param dtype: | |||||
:return: | |||||
""" | |||||
if is_number(dtype): | |||||
return True | |||||
else: | |||||
if isclass(dtype): | |||||
return is_numpy_generic_class(dtype) | |||||
elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: | |||||
return True | |||||
return False | |||||
def is_numpy_number_dtype(dtype): | |||||
if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: | |||||
return True | |||||
return False | |||||
def is_numpy_generic_class(dtype): | |||||
""" | |||||
形如 np.int64,或者 np.zeros(1).dtype.type 的值 | |||||
:param dtype: | |||||
:return: | |||||
""" | |||||
if isclass(dtype) and issubclass(dtype, np.generic): | |||||
return True | |||||
return False | |||||
def is_number(dtype): | |||||
try: | |||||
if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \ | |||||
and not is_numpy_number_dtype(dtype): | |||||
return True | |||||
except: | |||||
return False | |||||
if __name__ == '__main__': | |||||
# a = [[[1]], [1, 2, 3], [3]] | |||||
# a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||||
# b = get_padded_nest_list(a) | |||||
# print(type(b[0])) | |||||
# print(b) | |||||
# import torch | |||||
print(is_number_or_numpy_number(type(3))) # True | |||||
print(is_number_or_numpy_number(type(3.1))) # True | |||||
print(is_number_or_numpy_number(type('3'))) # False | |||||
print(is_number_or_numpy_number(type(True))) # True | |||||
print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True | |||||
print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True | |||||
print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True | |||||
print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False | |||||
print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False | |||||
@@ -0,0 +1,103 @@ | |||||
from collections import defaultdict | |||||
from functools import reduce | |||||
from typing import Sequence, Mapping, Dict | |||||
NESTED_DICT_SEPARATOR = '@@' | |||||
def unpack_batch_mapping(batch:Sequence[Mapping])->Dict: | |||||
""" | |||||
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} | |||||
:param batch: | |||||
:return: | |||||
""" | |||||
dict_batch = defaultdict(list) | |||||
for sample in batch: | |||||
for key, value in sample.items(): | |||||
dict_batch[key].append(value) | |||||
return dict_batch | |||||
def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict: | |||||
""" | |||||
将 nested 的 dict 中的内容展开到一个 flat dict 中 | |||||
:param batch: | |||||
:param _parent: 内部使用 | |||||
:return: | |||||
""" | |||||
dict_batch = defaultdict(list) | |||||
if _parent != '': | |||||
_parent += NESTED_DICT_SEPARATOR | |||||
for sample in batch: | |||||
for key, value in sample.items(): | |||||
if isinstance(value, Mapping): | |||||
_dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key) | |||||
for key, value in _dict_batch.items(): | |||||
dict_batch[key].append(value) | |||||
else: | |||||
dict_batch[_parent + key].append(value) | |||||
return dict_batch | |||||
def _unpack_batch_nested_mapping(value, _parent)->Dict: | |||||
_dict = {} | |||||
_parent += NESTED_DICT_SEPARATOR | |||||
for k, v in value.items(): | |||||
if isinstance(v, Mapping): | |||||
__dict = _unpack_batch_nested_mapping(v, _parent=_parent + k) | |||||
_dict.update(__dict) | |||||
else: | |||||
_dict[_parent + k] = v | |||||
return _dict | |||||
def pack_batch_nested_mapping(batch:Mapping) -> Dict: | |||||
""" | |||||
需要恢复出 nested 的 dict 原来的样式 | |||||
:param batch: | |||||
:return: | |||||
""" | |||||
dicts = [] | |||||
for key, value in batch.items(): | |||||
keys = key.split(NESTED_DICT_SEPARATOR) | |||||
d = {keys[-1]: value} | |||||
for key in keys[:-1:][::-1]: | |||||
d = {key: d} | |||||
dicts.append(d) | |||||
return reduce(_merge_dict, dicts) | |||||
def _merge_dict(a, b, path=None): | |||||
"merges b into a" | |||||
if path is None: path = [] | |||||
for key in b: | |||||
if key in a: | |||||
if isinstance(a[key], dict) and isinstance(b[key], dict): | |||||
_merge_dict(a[key], b[key], path + [str(key)]) | |||||
else: | |||||
raise Exception('Conflict at %s' % '.'.join(path + [str(key)])) | |||||
else: | |||||
a[key] = b[key] | |||||
return a | |||||
def unpack_batch_sequence(batch:Sequence[Sequence])->Dict: | |||||
""" | |||||
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||||
:param batch: | |||||
:return: | |||||
""" | |||||
dict_batch = defaultdict(list) | |||||
for sample in batch: | |||||
for i, content in enumerate(sample): | |||||
dict_batch[f'_{i}'].append(content) | |||||
return dict_batch | |||||
def pack_batch_sequence(batch:Mapping)->Sequence: | |||||
return list(batch.values()) |
@@ -20,47 +20,31 @@ from fastNLP.core.log import logger | |||||
class Evaluator: | class Evaluator: | ||||
""" | |||||
1. 我们目前不直接提供每一个 metric 对应一个或者特殊的多个 dataloader 的功能,默认就是所有 metric 处理所有 dataloader,如果用户有这种 | |||||
需求,请使用多个 Tester 进行操作; | |||||
2. Trainer 的 validate dataloader 只允许传进去一个,而 Tester 则可以多个;因为 Trainer 涉及到保存 topk 模型的逻辑,而 Tester | |||||
则只需要给出评测的结果即可; | |||||
""" | |||||
driver: Driver | driver: Driver | ||||
_evaluate_batch_loop: Loop | _evaluate_batch_loop: Loop | ||||
def __init__( | |||||
self, | |||||
model, | |||||
dataloaders, | |||||
metrics: Optional[Union[Dict, Metric]] = None, | |||||
driver: Union[str, Driver] = 'torch', | |||||
device: Optional[Union[int, List[int], str]] = None, | |||||
batch_step_fn: Optional[callable] = None, | |||||
evaluate_fn: Optional[str] = None, | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||||
model_wo_auto_param_call: bool = False, | |||||
fp16: bool = False, | |||||
verbose: int = 1, | |||||
**kwargs | |||||
): | |||||
def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, | |||||
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, | |||||
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||||
output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, | |||||
fp16: bool = False, verbose: int = 1, **kwargs): | |||||
""" | """ | ||||
用于对数据进行评测。 | |||||
:param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 | :param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 | ||||
:param dataloaders: 待评测的数据集。 | |||||
:param dataloaders: 待评测的数据集。如果为多个,请使用 dict 传入。 | |||||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | ||||
metric ,torchmetrics,allennlpmetrics等。 | |||||
metric ,torchmetrics,allennlpmetrics 等。 | |||||
:param driver: 使用 driver 。 | :param driver: 使用 driver 。 | ||||
:param device: 使用的设备。 | :param device: 使用的设备。 | ||||
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | |||||
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | |||||
batch_step_fn 函数。 | |||||
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, | |||||
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 | |||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | ||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | ||||
找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 | ||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | |||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中。如果针对 | |||||
model 和 metric 需要不同的 mapping,请考虑使用 evaluate_batch_step_fn 参数定制。 | |||||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | ||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | ||||
@@ -69,7 +53,8 @@ class Evaluator: | |||||
:param verbose: 是否打印 evaluate 的结果。 | :param verbose: 是否打印 evaluate 的结果。 | ||||
:param kwargs: | :param kwargs: | ||||
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | ||||
与 batch normalization 将会关闭。默认为True。 | |||||
与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 | |||||
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 | |||||
TODO 还没完成。 | TODO 还没完成。 | ||||
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | ||||
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | ||||
@@ -96,9 +81,9 @@ class Evaluator: | |||||
self.device = device | self.device = device | ||||
self.verbose = verbose | self.verbose = verbose | ||||
if batch_step_fn is not None: | |||||
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') | |||||
self.batch_step_fn = batch_step_fn | |||||
if evaluate_batch_step_fn is not None: | |||||
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn') | |||||
self.evaluate_batch_step_fn = evaluate_batch_step_fn | |||||
self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
@@ -106,14 +91,14 @@ class Evaluator: | |||||
if not isinstance(dataloaders, dict): | if not isinstance(dataloaders, dict): | ||||
dataloaders = {None: dataloaders} | dataloaders = {None: dataloaders} | ||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | |||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) | |||||
self.driver.setup() | self.driver.setup() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
self.separator = kwargs.get('separator', '#') | self.separator = kwargs.get('separator', '#') | ||||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | ||||
use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed()) | |||||
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) | |||||
if use_dist_sampler: | if use_dist_sampler: | ||||
self._dist_sampler = "unrepeatdist" | self._dist_sampler = "unrepeatdist" | ||||
else: | else: | ||||
@@ -134,7 +119,7 @@ class Evaluator: | |||||
self.progress_bar = kwargs.get('progress_bar', 'auto') | self.progress_bar = kwargs.get('progress_bar', 'auto') | ||||
if self.progress_bar == 'auto': | if self.progress_bar == 'auto': | ||||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||||
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' | |||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -235,8 +220,8 @@ class Evaluator: | |||||
@evaluate_batch_loop.setter | @evaluate_batch_loop.setter | ||||
def evaluate_batch_loop(self, loop: Loop): | def evaluate_batch_loop(self, loop: Loop): | ||||
if self.batch_step_fn is not None: | |||||
logger.warning("`batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||||
if self.evaluate_batch_step_fn is not None: | |||||
logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||||
"when the `evaluate_batch_loop` is also customized.") | "when the `evaluate_batch_loop` is also customized.") | ||||
self._evaluate_batch_loop = loop | self._evaluate_batch_loop = loop | ||||
@@ -249,15 +234,15 @@ class Evaluator: | |||||
""" | """ | ||||
self.metrics_wrapper.reset() | self.metrics_wrapper.reset() | ||||
def update(self, *args, **kwargs): | |||||
def update(self, batch, outputs): | |||||
""" | """ | ||||
调用所有metric的 update 方法,对当前 batch 的结果进行累积,会根据相应 metric 的参数列表进行匹配传参。 | |||||
自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。 | |||||
:param args: | |||||
:param kwargs: | |||||
:param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 | |||||
:param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 | |||||
:return: | :return: | ||||
""" | """ | ||||
self.metrics_wrapper.update(*args, **kwargs) | |||||
self.metrics_wrapper.update(batch, outputs) | |||||
def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: | def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: | ||||
""" | """ | ||||
@@ -271,7 +256,7 @@ class Evaluator: | |||||
@property | @property | ||||
def metrics_wrapper(self): | def metrics_wrapper(self): | ||||
""" | """ | ||||
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 batch_step_fn )中使用,同时也为了支持 | |||||
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持 | |||||
不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper | 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper | ||||
进行操作。 | 进行操作。 | ||||
@@ -283,11 +268,11 @@ class Evaluator: | |||||
def evaluate_step(self, batch): | def evaluate_step(self, batch): | ||||
""" | """ | ||||
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 | |||||
将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再 | |||||
返回。 | 返回。 | ||||
:param batch: | |||||
:return: | |||||
:param batch: {evaluate_fn} 函数支持的输入类型 | |||||
:return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。 | |||||
""" | """ | ||||
outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) | outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) | ||||
outputs = match_and_substitute_params(self.output_mapping, outputs) | outputs = match_and_substitute_params(self.output_mapping, outputs) | ||||
@@ -43,7 +43,7 @@ class TrainBatchLoop(Loop): | |||||
trainer.check_batch_step_fn() | trainer.check_batch_step_fn() | ||||
trainer.on_train_batch_end() | trainer.on_train_batch_end() | ||||
trainer.step_validate() | |||||
trainer.step_evaluate() | |||||
trainer.batch_idx_in_epoch = 0 | trainer.batch_idx_in_epoch = 0 | ||||
@staticmethod | @staticmethod | ||||
@@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated | |||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | ||||
from fastNLP.core.callbacks.callback import _CallbackWrapper | from fastNLP.core.callbacks.callback import _CallbackWrapper | ||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.callbacks.progress_callback import choose_progress_callback | |||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
@@ -82,10 +83,10 @@ class Trainer(TrainerEventTrigger): | |||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; | :param n_epochs: 训练总共的 epoch 的数量,默认为 20; | ||||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | ||||
为 None; | 为 None; | ||||
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 | |||||
`batch`;默认为 None; | |||||
:param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | |||||
两个参数必须为 `evaluator` 和 `batch`;默认为 None; | |||||
:param batch_step_fn: 定制每次 train batch 执行的函数。该函数应接受两个参数为 `trainer` 和`batch`,不需要要返回值;可以 | |||||
参考 fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop中的batch_step_fn函数。 | |||||
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, | |||||
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 | |||||
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; | :param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; | ||||
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, | 默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, | ||||
则使用模型默认的前向传播函数。 | 则使用模型默认的前向传播函数。 | ||||
@@ -102,10 +103,12 @@ class Trainer(TrainerEventTrigger): | |||||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | ||||
类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; | 类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; | ||||
注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时); | 注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时); | ||||
如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。 | |||||
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | ||||
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | ||||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | ||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | ||||
如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。 | |||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | ||||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; | 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; | ||||
@@ -125,14 +128,17 @@ class Trainer(TrainerEventTrigger): | |||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | ||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | ||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | ||||
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | ||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过 | |||||
callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。 | |||||
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 | |||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | |||||
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 | |||||
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 | |||||
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 | |||||
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 | |||||
""" | """ | ||||
self.model = model | self.model = model | ||||
self.marker = marker | self.marker = marker | ||||
@@ -147,8 +153,18 @@ class Trainer(TrainerEventTrigger): | |||||
self.evaluate_dataloaders = evaluate_dataloaders | self.evaluate_dataloaders = evaluate_dataloaders | ||||
self.optimizers = optimizers | self.optimizers = optimizers | ||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self.input_mapping = input_mapping | |||||
self.output_mapping = output_mapping | |||||
train_input_mapping = kwargs.get('train_input_mapping', None) | |||||
train_output_mapping = kwargs.get('train_output_mapping', None) | |||||
evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None) | |||||
evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None) | |||||
train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \ | |||||
_get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, | |||||
evaluate_input_mapping, evaluate_output_mapping) | |||||
self.input_mapping = train_input_mapping | |||||
self.output_mapping = train_output_mapping | |||||
self.evaluate_fn = evaluate_fn | self.evaluate_fn = evaluate_fn | ||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
@@ -185,8 +201,8 @@ class Trainer(TrainerEventTrigger): | |||||
callbacks=callbacks, | callbacks=callbacks, | ||||
metrics=metrics, | metrics=metrics, | ||||
evaluate_every=evaluate_every, | evaluate_every=evaluate_every, | ||||
input_mapping=input_mapping, | |||||
output_mapping=output_mapping, | |||||
input_mapping=evaluate_input_mapping, | |||||
output_mapping=evaluate_output_mapping, | |||||
model_wo_auto_param_call=model_wo_auto_param_call, | model_wo_auto_param_call=model_wo_auto_param_call, | ||||
accumulation_steps=accumulation_steps, | accumulation_steps=accumulation_steps, | ||||
fp16=fp16, | fp16=fp16, | ||||
@@ -195,8 +211,20 @@ class Trainer(TrainerEventTrigger): | |||||
) | ) | ||||
self.driver.set_optimizers(optimizers=optimizers) | self.driver.set_optimizers(optimizers=optimizers) | ||||
# 根据 progress_bar 参数选择 ProgressBarCallback | |||||
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto')) | |||||
if progress_bar_callback is not None: | |||||
if callbacks is None: | |||||
callbacks = [] | |||||
elif not isinstance(callbacks, Sequence): | |||||
callbacks = [callbacks] | |||||
callbacks = list(callbacks) + [progress_bar_callback] | |||||
else: | |||||
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||||
"during training.") | |||||
# 初始化 callback manager; | # 初始化 callback manager; | ||||
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) | |||||
self.callback_manager = CallbackManager(callbacks) | |||||
# 添加所有的函数式 callbacks; | # 添加所有的函数式 callbacks; | ||||
self._fetch_matched_fn_callbacks() | self._fetch_matched_fn_callbacks() | ||||
# 添加所有的类 callbacks; | # 添加所有的类 callbacks; | ||||
@@ -237,21 +265,15 @@ class Trainer(TrainerEventTrigger): | |||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and evaluate_dataloaders is not None: | if metrics is not None and evaluate_dataloaders is not None: | ||||
check_evaluate_every(evaluate_every) | check_evaluate_every(evaluate_every) | ||||
self.evaluator = Evaluator( | |||||
model=model, | |||||
dataloaders=evaluate_dataloaders, | |||||
metrics=metrics, | |||||
driver=self.driver, | |||||
device=device, | |||||
batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, | |||||
input_mapping=input_mapping, | |||||
output_mapping=output_mapping, | |||||
fp16=fp16, | |||||
verbose=0, | |||||
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), | |||||
progress_bar=kwargs.get('progress_bar', 'auto') | |||||
) | |||||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||||
progress_bar = progress_bar.name | |||||
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | |||||
driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, | |||||
evaluate_fn=evaluate_fn, input_mapping=input_mapping, | |||||
output_mapping=output_mapping, fp16=fp16, verbose=0, | |||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), | |||||
progress_bar=progress_bar) | |||||
if train_fn is not None and not isinstance(train_fn, str): | if train_fn is not None and not isinstance(train_fn, str): | ||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | ||||
@@ -317,11 +339,11 @@ class Trainer(TrainerEventTrigger): | |||||
self.num_batches_per_epoch = len(self.dataloader) | self.num_batches_per_epoch = len(self.dataloader) | ||||
self.total_batches = self.num_batches_per_epoch * self.n_epochs | self.total_batches = self.num_batches_per_epoch * self.n_epochs | ||||
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | ||||
self.on_train_begin() | |||||
self.driver.barrier() | |||||
self.driver.zero_grad(self.set_grad_to_none) | |||||
try: | try: | ||||
self.on_train_begin() | |||||
self.driver.barrier() | |||||
self.driver.zero_grad(self.set_grad_to_none) | |||||
while self.cur_epoch_idx < self.n_epochs: | while self.cur_epoch_idx < self.n_epochs: | ||||
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | ||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | ||||
@@ -334,10 +356,8 @@ class Trainer(TrainerEventTrigger): | |||||
self.cur_epoch_idx += 1 | self.cur_epoch_idx += 1 | ||||
self.on_train_epoch_end() | self.on_train_epoch_end() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
self.epoch_validate() | |||||
self.epoch_evaluate() | |||||
self.driver.barrier() | self.driver.barrier() | ||||
self.on_train_end() | |||||
self.driver.barrier() | |||||
except EarlyStopException as e: | except EarlyStopException as e: | ||||
logger.info(f"Catch early stop exception: {e.msg}.") | logger.info(f"Catch early stop exception: {e.msg}.") | ||||
@@ -351,17 +371,20 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.on_exception() | self.driver.on_exception() | ||||
self.on_exception(e) | self.on_exception(e) | ||||
raise e | raise e | ||||
finally: | |||||
self.on_train_end() | |||||
self.driver.barrier() | |||||
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | ||||
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: | |||||
trainer.on_validate_begin() | |||||
_validate_res: dict = validate_fn() | |||||
trainer.on_validate_end(_validate_res) | |||||
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: | |||||
trainer.on_evaluate_begin() | |||||
_evaluate_res: dict = evaluate_fn() | |||||
trainer.on_evaluate_end(_evaluate_res) | |||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
def step_validate(self): | |||||
def step_evaluate(self): | |||||
""" | """ | ||||
在每个 batch 结束后调用,根据设置执行 evaluate 。 | 在每个 batch 结束后调用,根据设置执行 evaluate 。 | ||||
@@ -374,7 +397,7 @@ class Trainer(TrainerEventTrigger): | |||||
elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: | elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: | ||||
self.run_evaluate() | self.run_evaluate() | ||||
def epoch_validate(self): | |||||
def epoch_evaluate(self): | |||||
""" | """ | ||||
在每个 epoch 结束后调用,根据设置执行 evaluate 。 | 在每个 epoch 结束后调用,根据设置执行 evaluate 。 | ||||
@@ -382,8 +405,8 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | ||||
validate_every = -self.evaluate_every | |||||
if self.cur_epoch_idx % validate_every == 0: | |||||
evaluate_every = -self.evaluate_every | |||||
if self.cur_epoch_idx % evaluate_every == 0: | |||||
self.run_evaluate() | self.run_evaluate() | ||||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | ||||
@@ -576,7 +599,7 @@ class Trainer(TrainerEventTrigger): | |||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
if not callable(model_load_fn): | if not callable(model_load_fn): | ||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | ||||
rank_zero_call(model_load_fn)(folder) | |||||
model_load_fn(folder) | |||||
else: | else: | ||||
if isinstance(folder, str): | if isinstance(folder, str): | ||||
folder = Path(folder) | folder = Path(folder) | ||||
@@ -653,7 +676,7 @@ class Trainer(TrainerEventTrigger): | |||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
if not callable(model_load_fn): | if not callable(model_load_fn): | ||||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | raise ValueError("Parameter `model_save_fn` should be `Callable`.") | ||||
rank_zero_call(model_load_fn)(folder) | |||||
model_load_fn(folder) | |||||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | ||||
else: | else: | ||||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | ||||
@@ -839,6 +862,32 @@ class Trainer(TrainerEventTrigger): | |||||
self._evaluate_dataloaders = evaluate_dataloaders | self._evaluate_dataloaders = evaluate_dataloaders | ||||
def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, | |||||
evaluate_input_mapping, evaluate_output_mapping): | |||||
if train_input_mapping is not None and input_mapping is not None: | |||||
raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.") | |||||
if evaluate_input_mapping is not None and input_mapping is not None: | |||||
raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.") | |||||
if train_output_mapping is not None and output_mapping is not None: | |||||
raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.") | |||||
if evaluate_output_mapping is not None and output_mapping is not None: | |||||
raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.") | |||||
if train_input_mapping is None: | |||||
train_input_mapping = input_mapping | |||||
if evaluate_input_mapping is None: | |||||
evaluate_input_mapping = input_mapping | |||||
if train_output_mapping is None: | |||||
train_output_mapping = output_mapping | |||||
if evaluate_output_mapping is None: | |||||
evaluate_output_mapping = output_mapping | |||||
return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping | |||||
@@ -81,12 +81,12 @@ class TrainerEventTrigger: | |||||
def on_after_zero_grad(self, optimizers): | def on_after_zero_grad(self, optimizers): | ||||
self.callback_manager.on_after_zero_grad(self, optimizers) | self.callback_manager.on_after_zero_grad(self, optimizers) | ||||
def on_validate_begin(self): | |||||
self.callback_manager.on_validate_begin(self) | |||||
def on_evaluate_begin(self): | |||||
self.callback_manager.on_evaluate_begin(self) | |||||
def on_validate_end(self, results): | |||||
def on_evaluate_end(self, results): | |||||
self.trainer_state.save_on_this_step = True | self.trainer_state.save_on_this_step = True | ||||
self.callback_manager.on_validate_end(self, results) | |||||
self.callback_manager.on_evaluate_end(self, results) | |||||
class _TruncatedDataLoader: | class _TruncatedDataLoader: | ||||
@@ -126,8 +126,8 @@ class _TruncatedDataLoader: | |||||
return getattr(self.dataloader, item) | return getattr(self.dataloader, item) | ||||
def check_evaluate_every(validate_every): | |||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||||
def check_evaluate_every(evaluate_every): | |||||
if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): | |||||
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
if callable(validate_every): | |||||
_check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||||
if callable(evaluate_every): | |||||
_check_valid_parameters_number(evaluate_every, expected_params=['trainer']) |
@@ -63,7 +63,7 @@ class JittorDriver(Driver): | |||||
def check_evaluator_mode(self, mode: str): | def check_evaluator_mode(self, mode: str): | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if mode == "validate": | |||||
if mode == "evaluate": | |||||
if not hasattr(model, "evaluate_step"): | if not hasattr(model, "evaluate_step"): | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning_once( | logger.warning_once( | ||||
@@ -19,6 +19,7 @@ from fastNLP.core.utils import ( | |||||
check_user_specific_params, | check_user_specific_params, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | is_in_paddle_dist, | ||||
rank_zero_rm | |||||
) | ) | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomBatchSampler, | RandomBatchSampler, | ||||
@@ -55,20 +56,134 @@ class PaddleFleetDriver(PaddleDriver): | |||||
fp16: bool = False, | fp16: bool = False, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
""" | |||||
采用fleet接口进行并行paddle训练的driver | |||||
PaddleFleetDriver 目前考虑支持的三种启动方式: | |||||
1. 用户自己不进行 fleet 的任何操作,直接使用我们的 Trainer,并且只运行一个 main 脚本,这时是由我们自己使用 open_subprocesses 拉起 | |||||
多个进程,然后由 Driver 自己进行初始化 | |||||
2. 其它情况同 1,但是用户自己使用 python -m paddle.distributed.launch 拉起; | |||||
3. 用户自己在外面初始化 Fleet,并且通过 python -m paddle.distributed.launch 拉起; | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动; | |||||
如果用户自己在外面初始化了 fleet,那么 | |||||
parallel_device 为 None; | |||||
data_device 为 表示单卡的一个参数; | |||||
dist.is_initialized 为 true; | |||||
r""" | |||||
通过使用 PaddlePaddle 的 Fleet 框架启动多卡进程的 Driver。 | |||||
需要注意的一点是,由于 PaddlePaddle 框架的特性,如果直接使用在 rank0 拉起其它进程的方法的话,如果不加以任何限制,PaddlePaddle会出现 | |||||
第一次前向传播后卡住或占用所有显卡的现象;为了解决这一问题,我们在引入 FastNLP 时,会使用 `CUDA_VISIBLE_DEVICES` 将设备限制在卡0上, | |||||
而用户如果使用了这一环境变量,我们会将其储存在 `USER_CUDA_VISIBLE_DEVICES` 中,并且通过一定的手段实现了转换(详细的设置请参见: | |||||
`fastNLP/envs/set_backend.py`)。在拉起其它进程的时候,我们会如法炮制,将环境限制在对应的设备上。 | |||||
`PaddleFleetDriver` 目前支持的三种启动方式: | |||||
1. 用户自己不进行分布式的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `FleetLauncher` 拉起多个进程, | |||||
然后 `PaddleFleetDriver` 自己通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 A) | |||||
2. 用户同样不在 Trainer 之外初始化分布式训练,但是用户自己使用 python -m paddle.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 B) | |||||
3. 用户自己在外面初始化分布式,并且通过 python -m paddle.distributed.launch 拉起,这时无论是多个进程的拉起和通信组的建立 | |||||
都由用户自己操作,我们只会在 driver.setup 的时候对 `PaddleFleetDriver` 设置一些必要的属性值;(情况 C) | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动;因此我们不会在 `PaddleFleetDriver` 中保存 | |||||
任何当前有多少台机器的信息; | |||||
Part 1:三种启动方式的具体分析: | |||||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||||
`PaddleFleetDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DataParallel` 包裹的model), | |||||
因为 `Parallel` 的使用一定要求 fleet.init 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||||
用户需要使用 2 张以上的显卡,那么其必然需要使用 paddle.distributed.launch 来启动,意味着就不是情况 A 了; | |||||
这时我们首先会调用 `FleetLauncher.launch` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||||
接着我们会调用 `fleet.init` 来初始化各个进程之间的通信组; | |||||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||||
才会去真正地运行 `FleetLauncher.launch`;进程 0 运行到 `fleet.init`,paddle 会阻塞进程 0 继续 | |||||
向前运行,直到其它进程也运行到这里; | |||||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DataParallel` 将模型包裹; | |||||
至此,paddle 分布式的环境配置过程全部完成; | |||||
-> 情况 B:注意这种情况我们直接限定了用户是通过 paddle.distributed.launch 拉起,并且没有自己建立分布式的通信组。这时在 | |||||
`PaddleFleetDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||||
这时每个进程所使用的 gpu 是我们直接通过 `CUDA_VISIBLE_DEVICE` 来配置的;因此,如果用户想要实现使用特定 gpu | |||||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现,我们会通过一定的手段将其保存起来); | |||||
剩下的操作和情况 A 类似; | |||||
-> 情况 C:注意这种情况我们限定了用户是通过 paddle.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DataParallel` 包裹等。 | |||||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `PaddleFleetDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||||
启动方式来实现这一点的: | |||||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||||
使用 '情况 A' 来启动 `PaddleFleetDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `PaddleFleetDriver` 的初始化和 setup 过程中, | |||||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||||
Part 2:对应的代码细节: | |||||
1. 如何判断当前的各进程之间的通信组已经被建立(fleet 已经被初始化); | |||||
parallel_helper._is_parallel_ctx_initialized(); | |||||
2. 如何判断不同的进程是否是由 `python -m paddle.distributed.launch` 拉起还是由我们的 `FleetLauncher.launch()` | |||||
函数拉起; | |||||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'PADDLE_RANK_IN_NODE'、'PADDLE_TRAINER_ID' | |||||
以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m paddle.distributed.launch` | |||||
来拉起多个进程; | |||||
3. 整体的处理判断流程: | |||||
___________________________________ | |||||
|进入 PaddleFleetDriver 的 __init__ 函数| | |||||
——————————————————————————————————— | |||||
↓ | |||||
___________________________________________________ | |||||
| 判断不同的进程是否是由 paddle.distributed.launch 拉起 | | |||||
|(或者我们自己的 FleetLauncher 函数拉起) | --------------> | |||||
——————————————————————————————————————————————————— | | |||||
↓ 是由 paddle.distributed.launch 拉起 | 我们自己的 FleetLauncher 函数拉起多个进程 | |||||
_____________________________ | | |||||
←←←←← | 检测用户是否自己初始化了 fleet | | | |||||
↓ ————————————————————————————— ↓ | |||||
↓ ↓ 是 ________ | |||||
↓ ______ | 情况 A | | |||||
↓ 否 |情况 C| ————————— | |||||
↓ ——————— | |||||
↓ | |||||
↓ ______ | |||||
↓ -----------> |情况 B| | |||||
——————— | |||||
4. 为了完成全部的建立分布式所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
配置 fleet 所 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch | |||||
需要的环境变量 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
开启多个进程 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
调用 fleet.init函数 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | 用户自己调用 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
设置 PaddleFleetDriver | | | | |||||
的 world_size 和 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | PaddleFleetDriver.setup | |||||
global_rank 属性 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
Part 3:其它的处理细节: | |||||
1. 环境变量; | |||||
fastNLP 的 `PaddleFleetDriver` 运行时所需要的环境变量分为两种,一种是 paddle fleet 运行所需要的环境变量;另一种是 fastNLP 自己 | |||||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||||
2. parallel_device, model_device 和 data_device 的关系; | |||||
parallel_device 为 `PaddleFleetDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||||
model_device 永远都为单独的一个 torch.device; | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
parallel_device | 由用户传入trainer的参数 | | | |||||
| device 决定,必须是一个list, | 为 CUDA_VISIBLE_DEVICES | 为 CUDA_VISIBLE_DEVICES | |||||
| 其中每一个对象都是 int | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
model_device | parallel_device[local_rank] | parallel_device | None | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||||
| | | data_device 决定 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
3. _DDPWrappingModel 的作用; | |||||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DataParallel` 的forward 函数来帮助 | |||||
我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DataParallel` 的 forward 方法, | |||||
然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 forward 函数,还是 | |||||
`train_step`、`evaluate_step`、`test_step` 方法。 | |||||
4. 当某一个进程出现 exception 后,`PaddleFleetDriver` 的处理; | |||||
不管是什么情况,`PaddleFleetDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||||
""" | """ | ||||
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
@@ -78,6 +193,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | "when your value of parameter `device` is `None` in your `Trainer` instance.") | ||||
# 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 | # 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 | ||||
# 这个参数会在 initialize_paddle_drvier 中设置。 | |||||
self.is_pull_by_paddle_run = is_pull_by_paddle_run | self.is_pull_by_paddle_run = is_pull_by_paddle_run | ||||
self.parallel_device = parallel_device | self.parallel_device = parallel_device | ||||
# 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu | # 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu | ||||
@@ -98,7 +214,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.outside_fleet = True | self.outside_fleet = True | ||||
# 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中 | # 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中 | ||||
# 我们就直接将 model_device 置为 None; | |||||
# 我们就直接将 model_device 置为 None; | |||||
self._model_device = None | self._model_device = None | ||||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | ||||
@@ -119,9 +235,12 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.world_size = None | self.world_size = None | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
self.gloo_rendezvous_dir = None | |||||
# 分布式环境的其它参数设置 | |||||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | ||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | ||||
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | |||||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | ||||
self.is_collective = self._fleet_kwargs.get("is_collective", True) | self.is_collective = self._fleet_kwargs.get("is_collective", True) | ||||
if not self.is_collective: | if not self.is_collective: | ||||
@@ -145,7 +264,10 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def setup(self): | def setup(self): | ||||
""" | """ | ||||
在主进程拉起其它子进程,将主进程作为rank 0 | |||||
根据不同的情况进行不同的设置。 | |||||
1、如果是通过 paddle.distributed.launch 方法启动时,则根据已经设置好的环境获取 | |||||
分布式的属性。 | |||||
2、否则,调用 FleetLauncher 类启动子进程 | |||||
""" | """ | ||||
if self._has_setup: | if self._has_setup: | ||||
return | return | ||||
@@ -174,7 +296,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | # 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | ||||
# parallel_device 是 list, | # parallel_device 是 list, | ||||
if not parallel_helper._is_parallel_ctx_initialized(): | if not parallel_helper._is_parallel_ctx_initialized(): | ||||
# 没有初始化分布式环境,且是主进程 | |||||
# 拉起子进程并设置相应的属性 | |||||
self.init_fleet_and_set() | self.init_fleet_and_set() | ||||
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; | # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; | ||||
else: | else: | ||||
@@ -216,12 +338,13 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 是 rank0 的话,则拉起其它子进程 | # 是 rank0 的话,则拉起其它子进程 | ||||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | ||||
launcher.launch() | launcher.launch() | ||||
self.gloo_rendezvous_dir = launcher.gloo_rendezvous_dir | |||||
# 设置参数和初始化分布式环境 | # 设置参数和初始化分布式环境 | ||||
fleet.init(self.role_maker, self.is_collective, self.strategy) | fleet.init(self.role_maker, self.is_collective, self.strategy) | ||||
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | ||||
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | ||||
# 正常情况下不会Assert出问题,但还是保险一下 | |||||
# 正常情况下不会 Assert 出问题,但还是保险一下 | |||||
assert self.global_rank is not None | assert self.global_rank is not None | ||||
assert self.world_size is not None | assert self.world_size is not None | ||||
assert self.world_size == len(self.parallel_device) | assert self.world_size == len(self.parallel_device) | ||||
@@ -235,10 +358,19 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.global_rank = paddledist.get_rank() | self.global_rank = paddledist.get_rank() | ||||
def barrier(self): | def barrier(self): | ||||
r""" | |||||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | |||||
仅在多分布式训练场景中有使用。 | |||||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | ||||
paddledist.barrier() | paddledist.barrier() | ||||
def configure_fleet(self): | def configure_fleet(self): | ||||
""" | |||||
将模型用 DataParallel 和自定义的类型包裹起来 | |||||
""" | |||||
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): | if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): | ||||
self.model = DataParallel( | self.model = DataParallel( | ||||
_FleetWrappingModel(self.model), | _FleetWrappingModel(self.model), | ||||
@@ -247,8 +379,14 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self._has_fleetwrapped = True | self._has_fleetwrapped = True | ||||
def on_exception(self): | def on_exception(self): | ||||
if os.path.exists(self.gloo_rendezvous_dir): | |||||
shutil.rmtree(self.gloo_rendezvous_dir) | |||||
""" | |||||
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | |||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | |||||
因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 | |||||
pid 的信息; | |||||
""" | |||||
rank_zero_rm(self.gloo_rendezvous_dir) | |||||
super().on_exception() | super().on_exception() | ||||
@property | @property | ||||
@@ -282,6 +420,17 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return self.model_device | return self.model_device | ||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
""" | |||||
通过调用 `fn` 来实现训练时的前向传播过程; | |||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||||
函数; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:param fn: 调用该函数进行一次计算。 | |||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
if self._has_fleetwrapped: | if self._has_fleetwrapped: | ||||
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | ||||
wo_auto_param_call=self.wo_auto_param_call) | wo_auto_param_call=self.wo_auto_param_call) | ||||
@@ -292,6 +441,27 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return fn(batch) | return fn(batch) | ||||
def get_model_call_fn(self, fn: str) -> Tuple: | def get_model_call_fn(self, fn: str) -> Tuple: | ||||
""" | |||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
函数,然后给出 warning; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||||
可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||||
""" | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if self._has_fleetwrapped: | if self._has_fleetwrapped: | ||||
if hasattr(model, fn): | if hasattr(model, fn): | ||||
@@ -316,7 +486,25 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return self.model, model.forward | return self.model, model.forward | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | ||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
reproducible: bool = False): | |||||
r""" | |||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||||
可以可以加载。 | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||||
""" | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
@@ -429,10 +617,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
""" | |||||
paddle存在设置分布式optimizers的函数,返回值为fleet.meta_optimizers.HybridParallelOptimizer | |||||
重写是为了防止单卡下也传入了分布式的优化器 | |||||
""" | |||||
# paddle 存在设置分布式 optimizers 的函数,返回值为 fleet.meta_optimizers.HybridParallelOptimizer | |||||
DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer | DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer | ||||
for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): | if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): | ||||
@@ -20,7 +20,7 @@ from .utils import ( | |||||
# 记录各个进程信息 | # 记录各个进程信息 | ||||
class SubTrainer(object): | class SubTrainer(object): | ||||
""" | """ | ||||
和fastnlp的Triainer没有关系,仅用于统计节点内不同训练的一些信息 | |||||
用于统计节点内不同训练进程的信息,和 fastnlp 的 Triainer 没有关系 | |||||
""" | """ | ||||
def __init__(self, endpoint=None, rank=None): | def __init__(self, endpoint=None, rank=None): | ||||
self.devices = [] | self.devices = [] | ||||
@@ -30,8 +30,8 @@ class SubTrainer(object): | |||||
class FleetLauncher: | class FleetLauncher: | ||||
""" | """ | ||||
复原了 paddle 的 launch_collective 函数,将其集成到一个类里 | |||||
仅支持单机多卡的启动 | |||||
复原了 paddle 的 launch_collective 函数,将其简化后集成到一个类里 | |||||
仅支持每个机器单卡的情况。 | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
@@ -45,17 +45,26 @@ class FleetLauncher: | |||||
self.setup() | self.setup() | ||||
def setup(self): | def setup(self): | ||||
""" | |||||
进行初始化设置的函数,根据传入的设备找到分布式训练使用的端口号 | |||||
""" | |||||
self.set_endpoints() | self.set_endpoints() | ||||
self.sub_trainers = self.get_process_info() | self.sub_trainers = self.get_process_info() | ||||
def launch(self) -> int: | |||||
def launch(self): | |||||
""" | |||||
用于启动分布式进程。 | |||||
首先设置 PaddlePaddle 分布式训练需要设置的环境变量,然后建立新的子进程 | |||||
""" | |||||
# 设置环境变量 | # 设置环境变量 | ||||
self.global_envs = self.get_global_env() | self.global_envs = self.get_global_env() | ||||
self.open_subprocess() | self.open_subprocess() | ||||
reset_seed() | reset_seed() | ||||
def open_subprocess(self): | def open_subprocess(self): | ||||
""" | |||||
从 sub_trainers 中获取各个 rank 的信息,并且使用 subprocess.Popen 建立新的子进程。 | |||||
""" | |||||
if __main__.__spec__ is None: | if __main__.__spec__ is None: | ||||
# Script called as `python a/b/c.py` | # Script called as `python a/b/c.py` | ||||
@@ -77,6 +86,7 @@ class FleetLauncher: | |||||
current_env = copy.copy(self.global_envs) | current_env = copy.copy(self.global_envs) | ||||
for idx, t in enumerate(self.sub_trainers): | for idx, t in enumerate(self.sub_trainers): | ||||
# 根据不同的 rank 设置环境变量 | |||||
proc_env = { | proc_env = { | ||||
# global_rank | # global_rank | ||||
"PADDLE_TRAINER_ID": f"{t.rank}", | "PADDLE_TRAINER_ID": f"{t.rank}", | ||||
@@ -108,6 +118,14 @@ class FleetLauncher: | |||||
os.environ.update(current_env) | os.environ.update(current_env) | ||||
def get_global_env(self): | def get_global_env(self): | ||||
""" | |||||
设置分布式训练需要的全局变量,包括: | |||||
1、GLOO 相关的设置 | |||||
2、`PADDLE_TRAINERS_NUM` :所有的进程数目 | |||||
3、`PADDLE_TRAINER_ENDPOINTS` :使用的所有地址及其端口 | |||||
4、`PADDLE_WORLD_DEVICE_IDS` :使用的所有设备 | |||||
5、FASTNLP_DISTRIBUTED_CHECK:通过 fastNLP 建立子进程的标志,保存分布式训练使用的设备 | |||||
""" | |||||
global_envs = copy.copy(os.environ.copy()) | global_envs = copy.copy(os.environ.copy()) | ||||
self.gloo_rendezvous_dir = tempfile.mkdtemp() | self.gloo_rendezvous_dir = tempfile.mkdtemp() | ||||
@@ -137,7 +155,7 @@ class FleetLauncher: | |||||
def set_endpoints(self): | def set_endpoints(self): | ||||
""" | """ | ||||
Reference to `get_cluster_from_args` | |||||
寻找用户设置的端口或是空闲端口用于分布式训练,参考了 PaddlePaddle 中的 `get_cluster_from_args` 函数 | |||||
""" | """ | ||||
self.node_ip = "127.0.0.1" | self.node_ip = "127.0.0.1" | ||||
@@ -157,7 +175,7 @@ class FleetLauncher: | |||||
def get_process_info(self): | def get_process_info(self): | ||||
""" | """ | ||||
Reference to `get_cluster` | |||||
获取各个训练进程的设备、rank 和端口信息,参考 PaddlePaddle 的 `get_cluster` 函数。 | |||||
""" | """ | ||||
sub_trainers = [] | sub_trainers = [] | ||||
assert len(self.endpoints) >= len( | assert len(self.endpoints) >= len( | ||||
@@ -17,14 +17,16 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
model: paddle.nn.Layer, **kwargs) -> PaddleDriver: | model: paddle.nn.Layer, **kwargs) -> PaddleDriver: | ||||
r""" | r""" | ||||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | ||||
注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | |||||
1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 | |||||
设备自动设置为用户指定的设备(由于我们在引入 fastNLP 进行了特殊的设置,因此可以通过 `CUDA_VISIBLE_DEVICES` 获取) | |||||
2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver | |||||
3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver | |||||
:param driver: 该参数的值应为以下之一:["paddle", "fleet"]; | :param driver: 该参数的值应为以下之一:["paddle", "fleet"]; | ||||
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | :param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | ||||
:param model: 训练或者评测的具体的模型; | :param model: 训练或者评测的具体的模型; | ||||
:return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||||
先后 driver 的次序的正确问题); | |||||
:return: 返回构造的 `Driver` 实例。 | |||||
""" | """ | ||||
if is_in_paddle_launch_dist(): | if is_in_paddle_launch_dist(): | ||||
if device is not None: | if device is not None: | ||||
@@ -47,9 +49,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
if device >= _could_use_device_num: | if device >= _could_use_device_num: | ||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
if device != -1: | |||||
device = f"gpu:{device}" | |||||
else: | |||||
if device == -1: | |||||
device = list(range(_could_use_device_num)) | device = list(range(_could_use_device_num)) | ||||
elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
device = list(set(device)) | device = list(set(device)) | ||||
@@ -61,9 +61,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
elif each >= _could_use_device_num: | elif each >= _could_use_device_num: | ||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | ||||
" the available gpu number.") | " the available gpu number.") | ||||
if len(device) == 1: | |||||
# 传入了 [1] 这样的,视为单卡。 | |||||
device = device[0] | |||||
elif device is not None and not isinstance(device, str): | elif device is not None and not isinstance(device, str): | ||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
@@ -82,6 +79,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | ||||
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | ||||
"choose `paddle` driver.") | "choose `paddle` driver.") | ||||
return PaddleFleetDriver(model, device, **kwargs) | |||||
return PaddleFleetDriver(model, [device], **kwargs) | |||||
else: | else: | ||||
return PaddleFleetDriver(model, device, **kwargs) | return PaddleFleetDriver(model, device, **kwargs) |
@@ -19,7 +19,12 @@ from fastNLP.envs import ( | |||||
rank_zero_call, | rank_zero_call, | ||||
) | ) | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
from fastNLP.core.samplers import ( | |||||
ReproducibleBatchSampler, | |||||
ReproducibleSampler, | |||||
RandomBatchSampler, | |||||
RandomSampler, | |||||
) | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -29,7 +34,7 @@ if _NEED_IMPORT_PADDLE: | |||||
Dataset, | Dataset, | ||||
Sampler, | Sampler, | ||||
BatchSampler, | BatchSampler, | ||||
RandomSampler, | |||||
RandomSampler as PaddleRandomSampler, | |||||
) | ) | ||||
from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
@@ -333,6 +338,9 @@ class PaddleDriver(Driver): | |||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | elif isinstance(dataloader_args.sampler, ReproducibleSampler): | ||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
elif isinstance(dataloader_args.sampler, PaddleRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
@@ -464,7 +472,7 @@ class PaddleDriver(Driver): | |||||
res.sampler = dataloader.batch_sampler.sampler | res.sampler = dataloader.batch_sampler.sampler | ||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | ||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
elif isinstance(dataloader.batch_sampler.sampler, PaddleRandomSampler): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
@@ -474,7 +482,7 @@ class PaddleDriver(Driver): | |||||
res.sampler = batch_sampler.sampler | res.sampler = batch_sampler.sampler | ||||
if hasattr(batch_sampler.sampler, "shuffle"): | if hasattr(batch_sampler.sampler, "shuffle"): | ||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
elif isinstance(batch_sampler.sampler, RandomSampler): | |||||
elif isinstance(batch_sampler.sampler, PaddleRandomSampler): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
@@ -31,6 +31,9 @@ __all__ = [ | |||||
] | ] | ||||
class PaddleSingleDriver(PaddleDriver): | class PaddleSingleDriver(PaddleDriver): | ||||
""" | |||||
支持 paddle cpu 或单卡 gpu 训练的 driver | |||||
""" | |||||
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | ||||
if isinstance(model, DataParallel): | if isinstance(model, DataParallel): | ||||
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | ||||
@@ -59,18 +62,53 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self.world_size = 1 | self.world_size = 1 | ||||
def setup(self): | def setup(self): | ||||
r""" | |||||
该函数用来初始化训练环境,用于设置当前训练的设备,并将模型迁移到对应设备上。 | |||||
""" | |||||
device = self.model_device | device = self.model_device | ||||
device = get_device_from_visible(device, output_type=str) | device = get_device_from_visible(device, output_type=str) | ||||
paddle.device.set_device(device) | paddle.device.set_device(device) | ||||
self.model.to(device) | self.model.to(device) | ||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
""" | |||||
通过调用 `fn` 来实现训练时的前向传播过程; | |||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||||
函数; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:param fn: 调用该函数进行一次计算。 | |||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | return auto_param_call(fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return fn(batch) | return fn(batch) | ||||
def get_model_call_fn(self, fn: str) -> Tuple: | def get_model_call_fn(self, fn: str) -> Tuple: | ||||
""" | |||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
函数,然后给出 warning; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||||
可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||||
""" | |||||
if hasattr(self.model, fn): | if hasattr(self.model, fn): | ||||
fn = getattr(self.model, fn) | fn = getattr(self.model, fn) | ||||
if not callable(fn): | if not callable(fn): | ||||
@@ -95,6 +133,24 @@ class PaddleSingleDriver(PaddleDriver): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | ||||
reproducible: bool = False): | reproducible: bool = False): | ||||
r""" | |||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||||
可以可以加载。 | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||||
""" | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
@@ -69,7 +69,6 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||||
os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" | os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" | ||||
return seed | return seed | ||||
def reset_seed() -> None: | def reset_seed() -> None: | ||||
""" | """ | ||||
fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | ||||
@@ -80,16 +79,10 @@ def reset_seed() -> None: | |||||
if seed is not None: | if seed is not None: | ||||
paddle_seed_everything(int(seed), workers=bool(int(workers))) | paddle_seed_everything(int(seed), workers=bool(int(workers))) | ||||
class ForwardState(IntEnum): | |||||
TRAIN = 0 | |||||
VALIDATE = 1 | |||||
TEST = 2 | |||||
PREDICT = 3 | |||||
class _FleetWrappingModel(Layer): | class _FleetWrappingModel(Layer): | ||||
""" | """ | ||||
参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 | |||||
pytorch相似的处理方式 | |||||
参考 _DDPWrappingModel , paddle 的分布式训练也需要用 paddle.nn.DataParallel 进行包装,采用和 | |||||
pytorch 相似的处理方式 | |||||
""" | """ | ||||
def __init__(self, model: 'nn.Layer'): | def __init__(self, model: 'nn.Layer'): | ||||
super(_FleetWrappingModel, self).__init__() | super(_FleetWrappingModel, self).__init__() | ||||
@@ -109,7 +102,6 @@ class _FleetWrappingModel(Layer): | |||||
class DummyGradScaler: | class DummyGradScaler: | ||||
""" | """ | ||||
用于仿造的GradScaler对象,防止重复写大量的if判断 | 用于仿造的GradScaler对象,防止重复写大量的if判断 | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
pass | pass | ||||
@@ -152,6 +144,9 @@ def _build_fp16_env(dummy=False): | |||||
return auto_cast, GradScaler | return auto_cast, GradScaler | ||||
def find_free_ports(num): | def find_free_ports(num): | ||||
""" | |||||
在空闲的端口中找到 num 个端口 | |||||
""" | |||||
def __free_port(): | def __free_port(): | ||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | ||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, | s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, | ||||
@@ -178,18 +173,11 @@ def find_free_ports(num): | |||||
return None | return None | ||||
def get_host_name_ip(): | |||||
try: | |||||
host_name = socket.gethostname() | |||||
host_ip = socket.gethostbyname(host_name) | |||||
return host_name, host_ip | |||||
except: | |||||
return None | |||||
def get_device_from_visible(device: Union[str, int], output_type=int): | def get_device_from_visible(device: Union[str, int], output_type=int): | ||||
""" | """ | ||||
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | ||||
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | ||||
:param device: 未转化的设备名 | :param device: 未转化的设备名 | ||||
:param output_type: 返回值的类型 | :param output_type: 返回值的类型 | ||||
:return: 转化后的设备id | :return: 转化后的设备id | ||||
@@ -76,7 +76,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||||
logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " | logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " | ||||
"still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " | "still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " | ||||
"choose `torch` driver.") | "choose `torch` driver.") | ||||
return TorchDDPDriver(model, device, **kwargs) | |||||
return TorchDDPDriver(model, [device], **kwargs) | |||||
else: | else: | ||||
return TorchDDPDriver(model, device, **kwargs) | return TorchDDPDriver(model, device, **kwargs) | ||||
elif driver == "fairscale": | elif driver == "fairscale": | ||||
@@ -218,6 +218,8 @@ class TorchDriver(Driver): | |||||
# 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
if should_save_model: | if should_save_model: | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if not os.path.exists(folder): | |||||
os.mkdir(folder) | |||||
if only_state_dict: | if only_state_dict: | ||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | ||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | ||||
@@ -401,7 +403,17 @@ class TorchDriver(Driver): | |||||
res.sampler = dataloader.batch_sampler.sampler | res.sampler = dataloader.batch_sampler.sampler | ||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | ||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
elif isinstance(dataloader.batch_sampler.sampler, TorchRandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
# RandomBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | |||||
batch_sampler = dataloader.batch_sampler.batch_sampler | |||||
res.sampler = batch_sampler.sampler | |||||
if hasattr(batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(batch_sampler.sampler, TorchRandomSampler): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
@@ -173,6 +173,19 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
kwargs["extra"] = extra | kwargs["extra"] = extra | ||||
return kwargs | return kwargs | ||||
def setLevel(self, level) -> None: | |||||
""" | |||||
设置当前 logger 以及其 handler 的 log 级别 | |||||
:param level: | |||||
:return: | |||||
""" | |||||
if isinstance(level, str): | |||||
level = level.upper() | |||||
super().setLevel(level) | |||||
for handler in self.handlers: | |||||
handler.setLevel(level) | |||||
def _get_level(level): | def _get_level(level): | ||||
if not isinstance(level, int): | if not isinstance(level, int): | ||||
@@ -416,7 +416,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size | |||||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | else: | ||||
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \ | |||||
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size |
@@ -22,6 +22,13 @@ from .utils import apply_to_collection | |||||
def paddle_to(data, device: Union[str, int]): | def paddle_to(data, device: Union[str, int]): | ||||
""" | |||||
将 `data` 迁移到指定的 `device` 上 | |||||
:param data: 要迁移的张量 | |||||
:param device: 目标设备,可以是 `str` 或 `int` | |||||
:return: 迁移后的张量 | |||||
""" | |||||
if device == "cpu": | if device == "cpu": | ||||
return data.cpu() | return data.cpu() | ||||
@@ -31,6 +38,9 @@ def paddle_to(data, device: Union[str, int]): | |||||
def get_paddle_gpu_str(device: Union[str, int]): | def get_paddle_gpu_str(device: Union[str, int]): | ||||
""" | """ | ||||
获得 `gpu:x` 类型的设备名 | 获得 `gpu:x` 类型的设备名 | ||||
:param device: 设备编号或设备名 | |||||
:return: 返回对应的 `gpu:x` 格式的设备名 | |||||
""" | """ | ||||
if isinstance(device, str): | if isinstance(device, str): | ||||
return device.replace("cuda", "gpu") | return device.replace("cuda", "gpu") | ||||
@@ -38,7 +48,10 @@ def get_paddle_gpu_str(device: Union[str, int]): | |||||
def get_paddle_device_id(device: Union[str, int]): | def get_paddle_device_id(device: Union[str, int]): | ||||
""" | """ | ||||
获得 gpu 的设备id,注意不要传入 `cpu` 。 | |||||
获得 gpu 的设备id | |||||
:param: device: 设备编号或设备名 | |||||
:return: 设备对应的编号 | |||||
""" | """ | ||||
if isinstance(device, int): | if isinstance(device, int): | ||||
return device | return device | ||||
@@ -14,6 +14,7 @@ __all__ = [ | |||||
] | ] | ||||
from fastNLP.envs import get_global_rank | from fastNLP.envs import get_global_rank | ||||
from .utils import is_notebook | |||||
class Singleton(type): | class Singleton(type): | ||||
@@ -34,6 +35,14 @@ class DummyFRichProgress: | |||||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | # 防止用户通过 DummyFRichProgress.console.print() 这种调用 | ||||
return None | return None | ||||
@property | |||||
def dummy_rich(self)->bool: | |||||
""" | |||||
当前对象是否是 dummy 的 rich 对象。 | |||||
:return: | |||||
""" | |||||
return True | |||||
class FRichProgress(Progress, metaclass=Singleton): | class FRichProgress(Progress, metaclass=Singleton): | ||||
""" | """ | ||||
@@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
super().stop_task(task_id) | super().stop_task(task_id) | ||||
super().remove_task(task_id) | super().remove_task(task_id) | ||||
self.refresh() # 使得bar不残留 | self.refresh() # 使得bar不残留 | ||||
if len(self._tasks) == 0: | |||||
super().stop() | |||||
def start(self) -> None: | def start(self) -> None: | ||||
super().start() | super().start() | ||||
@@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
if refresh: | if refresh: | ||||
self.refresh() | self.refresh() | ||||
@property | |||||
def dummy_rich(self) -> bool: | |||||
""" | |||||
当前对象是否是 dummy 的 rich 对象。 | |||||
:return: | |||||
""" | |||||
return False | |||||
class SpeedColumn(ProgressColumn): | class SpeedColumn(ProgressColumn): | ||||
""" | """ | ||||
@@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn): | |||||
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | ||||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||||
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \ | |||||
get_global_rank() == 0: | |||||
f_rich_progress = FRichProgress().new_progess( | f_rich_progress = FRichProgress().new_progess( | ||||
"[progress.description]{task.description}", | "[progress.description]{task.description}", | ||||
"[progress.percentage]{task.percentage:>3.0f}%", | "[progress.percentage]{task.percentage:>3.0f}%", | ||||
@@ -696,4 +696,23 @@ def get_class_that_defined_method(method): | |||||
None) | None) | ||||
if isinstance(cls, type): | if isinstance(cls, type): | ||||
return cls | return cls | ||||
return getattr(method, '__objclass__', None) # handle special descriptor objects | |||||
return getattr(method, '__objclass__', None) # handle special descriptor objects | |||||
def is_notebook(): | |||||
""" | |||||
检查当前运行环境是否为 jupyter | |||||
:return: | |||||
""" | |||||
try: | |||||
from IPython import get_ipython | |||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover | |||||
raise ImportError("console") | |||||
if "VSCODE_PID" in os.environ: # pragma: no cover | |||||
raise ImportError("vscode") | |||||
except: | |||||
return False | |||||
else: # pragma: no cover | |||||
return True |
@@ -16,7 +16,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -53,7 +53,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -19,7 +19,7 @@ from fastNLP.core import Evaluator | |||||
from fastNLP.core.utils.utils import safe_rm | from fastNLP.core.utils.utils import safe_rm | ||||
from fastNLP.core.drivers.torch_driver import TorchSingleDriver | from fastNLP.core.drivers.torch_driver import TorchSingleDriver | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@@ -55,7 +55,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01) | trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -24,7 +24,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -64,7 +64,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -0,0 +1,139 @@ | |||||
import pytest | |||||
import numpy as np | |||||
from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \ | |||||
_get_element_shape_dtype | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||||
def test_get_element_shape_dtype(): | |||||
catalog = _get_element_shape_dtype([[1], [2, 3], [3], 2]) | |||||
catalog = _get_element_shape_dtype([['1'], [2, 3]]) | |||||
catalog = _get_element_shape_dtype([['1'], [2, 3]]) | |||||
catalog = _get_element_shape_dtype([['1'], ['2', '3']]) | |||||
catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))]) | |||||
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) | |||||
def test_get_padder_run(backend): | |||||
if not _NEED_IMPORT_TORCH and backend == 'torch': | |||||
pytest.skip("No torch") | |||||
if not _NEED_IMPORT_PADDLE and backend == 'paddle': | |||||
pytest.skip("No paddle") | |||||
if not _NEED_IMPORT_PADDLE and backend == 'jittor': | |||||
pytest.skip("No jittor") | |||||
batch_field = [1, 2, 3] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
if backend is not None: | |||||
# 不能 pad | |||||
batch_field = [[1], [2, 3], [3], 2] | |||||
with pytest.raises(InconsistencyError): | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') | |||||
# 不能 pad | |||||
batch_field = [['2'], ['2'], ['2', '2']] | |||||
with pytest.raises(DtypeError) as exec_info: | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') | |||||
batch_field = [np.zeros(3), np.zeros((3, 1))] | |||||
with pytest.raises(InconsistencyError) as exec_info: | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') # no pad | |||||
batch_field = [np.zeros((3, 1)), np.zeros((4, 1))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
def test_raw_padder(): | |||||
backend = 'raw' | |||||
batch_field = [1, 2, 3] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert pad_batch == batch_field | |||||
batch_field = [[1], [2, 2], [3, 3, 3]] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert np.shape(pad_batch) == (3, 3) | |||||
batch_field = [[[1]], [[2, 2], [2]], [[3], [3], [3]]] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert np.shape(pad_batch) == (3, 3, 2) | |||||
def test_numpy_padder(): | |||||
backend = 'numpy' | |||||
target_type = np.ndarray | |||||
batch_field = [1, 2, 3] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert (pad_batch == np.array(batch_field)).sum()==len(batch_field) | |||||
batch_field = [[1], [2, 2], [3, 3, 3]] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert np.shape(pad_batch) == (3, 3) | |||||
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==3 | |||||
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,3))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert np.shape(pad_batch) == (3, 3, 3) | |||||
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==9 | |||||
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert np.shape(pad_batch) == (3, 3, 3) | |||||
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 | |||||
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,))] | |||||
with pytest.raises(InconsistencyError): | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
def test_torch_padder(): | |||||
if not _NEED_IMPORT_TORCH: | |||||
pytest.skip("No torch.") | |||||
import torch | |||||
backend = 'torch' | |||||
target_type = torch.Tensor | |||||
batch_field = [1, 2, 3] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert (pad_batch == torch.LongTensor(batch_field)).sum()==len(batch_field) | |||||
batch_field = [[1], [2, 2], [3, 3, 3]] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3) | |||||
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==3 | |||||
batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,3))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3, 3) | |||||
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==9 | |||||
batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3, 3) | |||||
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12 | |||||
batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,))] | |||||
with pytest.raises(InconsistencyError): | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
@@ -0,0 +1,81 @@ | |||||
import numpy as np | |||||
import pytest | |||||
from fastNLP.core.collators.padders.numpy_padder import NumpyTensorPadder, NumpySequencePadder, NumpyNumberPadder | |||||
from fastNLP.core.collators.padders.exceptions import DtypeError | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
class TestNumpyNumberPadder: | |||||
def test_run(self): | |||||
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [1, 2, 3] | |||||
assert isinstance(a, np.ndarray) | |||||
assert (padder(a) == np.array(a)).sum() == 3 | |||||
class TestNumpySequencePadder: | |||||
def test_run(self): | |||||
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [[1, 2, 3], [3]] | |||||
a = padder(a) | |||||
shape = np.shape(a) | |||||
assert isinstance(a, np.ndarray) | |||||
assert shape == (2, 3) | |||||
b = np.array([[1, 2, 3], [3, -1, -1]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
def test_dtype_check(self): | |||||
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
with pytest.raises(DtypeError): | |||||
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
class TestNumpyTensorPadder: | |||||
def test_run(self): | |||||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
a = [np.zeros(3), np.zeros(2), np.zeros(0)] | |||||
a = padder(a) | |||||
shape = np.shape(a) | |||||
assert isinstance(a, np.ndarray) | |||||
assert shape == (3, 3) | |||||
b = np.array([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 1))] | |||||
a = padder(a) | |||||
shape = np.shape(a) | |||||
assert isinstance(a, np.ndarray) | |||||
assert shape == (3, 3, 2) | |||||
b = np.array([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | |||||
a = padder(a) | |||||
shape = np.shape(a) | |||||
assert isinstance(a, np.ndarray) | |||||
assert shape == (3, 3, 2) | |||||
b = np.array([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
def test_dtype_check(self): | |||||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
with pytest.raises(DtypeError): | |||||
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||||
@@ -0,0 +1,29 @@ | |||||
import numpy as np | |||||
import pytest | |||||
from fastNLP.core.collators.padders.raw_padder import RawNumberPadder, RawSequencePadder | |||||
from fastNLP.core.collators.padders.exceptions import DtypeError | |||||
class TestRawNumberPadder: | |||||
def test_run(self): | |||||
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [1, 2, 3] | |||||
assert padder(a) == a | |||||
class TestRawSequencePadder: | |||||
def test_run(self): | |||||
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [[1, 2, 3], [3]] | |||||
a = padder(a) | |||||
shape = np.shape(a) | |||||
assert shape == (2, 3) | |||||
b = np.array([[1, 2, 3], [3, -1, -1]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
def test_dtype_check(self): | |||||
with pytest.raises(DtypeError): | |||||
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) |
@@ -0,0 +1,105 @@ | |||||
import numpy as np | |||||
import pytest | |||||
from fastNLP.core.collators.padders.torch_padder import TorchTensorPadder, TorchSequencePadder, TorchNumberPadder | |||||
from fastNLP.core.collators.padders.exceptions import DtypeError | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
class TestTorchNumberPadder: | |||||
def test_run(self): | |||||
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [1, 2, 3] | |||||
t_a = padder(a) | |||||
assert isinstance(t_a, torch.Tensor) | |||||
assert (t_a == torch.LongTensor(a)).sum() == 3 | |||||
class TestTorchSequencePadder: | |||||
def test_run(self): | |||||
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [[1, 2, 3], [3]] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, torch.Tensor) | |||||
assert tuple(shape) == (2, 3) | |||||
b = torch.LongTensor([[1, 2, 3], [3, -1, -1]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
def test_dtype_check(self): | |||||
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) | |||||
a = padder([[1], [2, 322]]) | |||||
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | |||||
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
class TestTorchTensorPadder: | |||||
def test_run(self): | |||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, torch.Tensor) | |||||
assert tuple(shape) == (3, 3) | |||||
b = torch.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 2))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, torch.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[0, 0], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 1))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, torch.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, torch.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) | |||||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, torch.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = torch.FloatTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
def test_dtype_check(self): | |||||
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||||
@@ -0,0 +1,90 @@ | |||||
import pytest | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.core.collators.padders.utils import get_shape, get_padded_numpy_array, \ | |||||
get_padded_nest_list, is_number_or_numpy_number, is_numpy_number_dtype, is_number | |||||
def test_get_shape(): | |||||
a = [[1, 2, 3], [3]] | |||||
assert get_shape(a) == [2, 3] | |||||
a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||||
assert get_shape(a) == [2, 3, 3] | |||||
a = [[[1], [2], [3, 4]], [[]]] | |||||
assert get_shape(a) == [2, 3, 2] | |||||
def test_get_padded_numpy_array(): | |||||
a = [[1, 2, 3], [3]] | |||||
a = get_padded_numpy_array(a, dtype=int, pad_val=-1) | |||||
assert a.shape == (2, 3) | |||||
a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||||
a = get_padded_numpy_array(a, dtype=int, pad_val=-1) | |||||
assert a.shape == (2, 3, 3) | |||||
a = [[[1], [2], [3, 4]], [[]]] | |||||
a = get_padded_numpy_array(a, dtype=int, pad_val=-1) | |||||
assert a.shape == (2, 3, 2) | |||||
def test_get_padded_nest_list(): | |||||
a = [[1, 2, 3], [3]] | |||||
a = get_padded_nest_list(a, pad_val=-1) | |||||
assert np.shape(a) == (2, 3) | |||||
a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||||
a = get_padded_nest_list(a, pad_val=-1) | |||||
assert np.shape(a) == (2, 3, 3) | |||||
a = [[[1], [2], [3, 4]], [[]]] | |||||
a = get_padded_nest_list(a, pad_val=-1) | |||||
assert np.shape(a) == (2, 3, 2) | |||||
def test_is_number_or_numpy_number(): | |||||
assert is_number_or_numpy_number(type(3)) is True | |||||
assert is_number_or_numpy_number(type(3.1)) is True | |||||
assert is_number_or_numpy_number(type(True)) is True | |||||
assert is_number_or_numpy_number(type('3')) is False | |||||
assert is_number_or_numpy_number(np.zeros(3).dtype) is True | |||||
assert is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) is True | |||||
assert is_number_or_numpy_number(np.zeros(3, dtype=object).dtype) is False | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
dtype = torch.ones(3).dtype | |||||
assert is_number_or_numpy_number(dtype) is False | |||||
def test_is_number(): | |||||
assert is_number(type(3)) is True | |||||
assert is_number(type(3.1)) is True | |||||
assert is_number(type(True)) is True | |||||
assert is_number(type('3')) is False | |||||
assert is_number(np.zeros(3).dtype) is False | |||||
assert is_number(np.zeros(3, dtype=int).dtype) is False | |||||
assert is_number(np.zeros(3, dtype=object).dtype) is False | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
dtype = torch.ones(3).dtype | |||||
assert is_number(dtype) is False | |||||
def test_is_numpy_number(): | |||||
assert is_numpy_number_dtype(type(3)) is False | |||||
assert is_numpy_number_dtype(type(3.1)) is False | |||||
assert is_numpy_number_dtype(type(True)) is False | |||||
assert is_numpy_number_dtype(type('3')) is False | |||||
assert is_numpy_number_dtype(np.zeros(3).dtype) is True | |||||
assert is_numpy_number_dtype(np.zeros(3, dtype=int).dtype) is True | |||||
assert is_numpy_number_dtype(np.zeros(3, dtype=object).dtype) is False | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
dtype = torch.ones(3).dtype | |||||
assert is_numpy_number_dtype(dtype) is False |
@@ -0,0 +1,225 @@ | |||||
import numpy as np | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.collators.new_collator import Collator | |||||
def _assert_equal(d1, d2): | |||||
try: | |||||
if 'torch' in str(type(d1)): | |||||
if 'float64' in str(d2.dtype): | |||||
print(d2.dtype) | |||||
assert (d1 == d2).all().item() | |||||
else: | |||||
assert all(d1 == d2) | |||||
except TypeError: | |||||
assert d1 == d2 | |||||
except ValueError: | |||||
assert (d1 == d2).all() | |||||
def findDictDiff(d1, d2, path=""): | |||||
for k in d1: | |||||
if k in d2: | |||||
if isinstance(d1[k], dict): | |||||
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) | |||||
else: | |||||
_assert_equal(d1[k], d2[k]) | |||||
else: | |||||
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) | |||||
def findListDiff(d1, d2): | |||||
assert len(d1)==len(d2) | |||||
for _d1, _d2 in zip(d1, d2): | |||||
if isinstance(_d1, list): | |||||
findListDiff(_d1, _d2) | |||||
else: | |||||
_assert_equal(_d1, _d2) | |||||
class TestCollator: | |||||
def test_run(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'a': 1, 'b':[1, 2]} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'a': 2, 'b': [1, 2]} | |||||
} | |||||
] | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||||
collator = Collator(backend='raw') | |||||
assert raw_pad_batch == collator(dict_batch) | |||||
collator = Collator(backend='raw') | |||||
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
collator = Collator(backend='numpy') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), | |||||
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), | |||||
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), | |||||
'b': np.array([[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='numpy') | |||||
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), | |||||
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), | |||||
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(numpy_pad_lst, collator(list_batch)) | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
collator = Collator(backend='torch') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), | |||||
'lst_int': torch.LongTensor([[1, 0], [1, 2]]), | |||||
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
'float': torch.FloatTensor([1.1, 2.1]), | |||||
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), | |||||
'numpy': torch.FloatTensor([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), | |||||
'b': torch.LongTensor( | |||||
[[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='torch') | |||||
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), | |||||
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), | |||||
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(torch_pad_lst, collator(list_batch)) | |||||
def test_pad(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'a': 1, 'b':[1, 2]} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'a': 2, 'b': [1, 2]} | |||||
} | |||||
] | |||||
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||||
# 测试 ignore | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# 测试 set_pad | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('str', pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
# 测试设置 pad 值 | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('nest_lst_int', pad_val=100) | |||||
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# 设置 backend 和 type | |||||
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||||
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
# [{'1'}, {'2'}]] | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('_0', '_3', '_1') | |||||
collator.set_pad('_4', pad_val=None) | |||||
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('_0', pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('_0', '_3', '_1') | |||||
collator.set_pad('_2', backend='numpy') | |||||
collator.set_pad('_4', backend='numpy', pad_val=100) | |||||
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
# _single | |||||
collator = Collator() | |||||
collator.set_pad('_single') | |||||
findListDiff(list_batch, collator(list_batch)) | |||||
@@ -0,0 +1,37 @@ | |||||
from fastNLP.core.collators.utils import * | |||||
def test_unpack_batch_mapping(): | |||||
batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] | |||||
assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||||
def test_unpack_batch_nested_mapping(): | |||||
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] | |||||
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]} | |||||
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] | |||||
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]} | |||||
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, | |||||
{'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] | |||||
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||||
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||||
def test_pack_batch_nested_mapping(): | |||||
batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||||
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||||
new_batch = pack_batch_nested_mapping(batch) | |||||
assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], | |||||
'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} | |||||
def test_unpack_batch_sequence(): | |||||
batch = [[1, 2, 3], [2, 4, 6]] | |||||
new_batch = unpack_batch_sequence(batch) | |||||
assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]} | |||||
@@ -11,7 +11,7 @@ from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@@ -80,7 +80,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -527,7 +527,7 @@ class TestSaveLoad: | |||||
@classmethod | @classmethod | ||||
def setup_class(cls): | def setup_class(cls): | ||||
# 不在这里 setup 的话会报错 | # 不在这里 setup 的话会报错 | ||||
cls.driver = generate_driver(10, 10) | |||||
cls.driver = generate_driver(10, 10, device=[0,1]) | |||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = PaddleRandomMaxDataset(20, 10) | self.dataset = PaddleRandomMaxDataset(20, 10) | ||||
@@ -633,7 +633,7 @@ class TestSaveLoad: | |||||
batch_sampler=BucketedBatchSampler( | batch_sampler=BucketedBatchSampler( | ||||
self.dataset, | self.dataset, | ||||
length=[10 for i in range(len(self.dataset))], | length=[10 for i in range(len(self.dataset))], | ||||
batch_size=4, | |||||
batch_size=2, | |||||
) | ) | ||||
) | ) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
@@ -19,7 +19,7 @@ def test_incorrect_driver(): | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "gpu:0", 0, [1]] | |||||
["cpu", "gpu:0", 0] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
@@ -27,7 +27,7 @@ def test_incorrect_driver(): | |||||
) | ) | ||||
def test_get_single_device(driver, device): | def test_get_single_device(driver, device): | ||||
""" | """ | ||||
测试正常情况下初始化PaddleSingleDriver的情况 | |||||
测试正常情况下初始化 PaddleSingleDriver 的情况 | |||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(2, 100) | model = PaddleNormalModel_Classification_1(2, 100) | ||||
@@ -36,7 +36,7 @@ def test_get_single_device(driver, device): | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[0, 1] | |||||
[0, 1, [1]] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
@@ -45,7 +45,7 @@ def test_get_single_device(driver, device): | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_get_fleet_2(driver, device): | def test_get_fleet_2(driver, device): | ||||
""" | """ | ||||
测试 fleet 多卡的初始化情况 | |||||
测试 fleet 多卡的初始化情况,但传入了单个 gpu | |||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(64, 10) | model = PaddleNormalModel_Classification_1(64, 10) | ||||
@@ -34,7 +34,7 @@ class TestPaddleDriverFunctions: | |||||
def test_check_single_optimizer_legality(self): | def test_check_single_optimizer_legality(self): | ||||
""" | """ | ||||
测试传入单个optimizer时的表现 | |||||
测试传入单个 optimizer 时的表现 | |||||
""" | """ | ||||
optimizer = paddle.optimizer.Adam( | optimizer = paddle.optimizer.Adam( | ||||
parameters=self.driver.model.parameters(), | parameters=self.driver.model.parameters(), | ||||
@@ -50,7 +50,7 @@ class TestPaddleDriverFunctions: | |||||
def test_check_optimizers_legality(self): | def test_check_optimizers_legality(self): | ||||
""" | """ | ||||
测试传入optimizer list的表现 | |||||
测试传入 optimizer list 的表现 | |||||
""" | """ | ||||
optimizers = [ | optimizers = [ | ||||
paddle.optimizer.Adam( | paddle.optimizer.Adam( | ||||
@@ -70,13 +70,13 @@ class TestPaddleDriverFunctions: | |||||
def test_check_dataloader_legality_in_train(self): | def test_check_dataloader_legality_in_train(self): | ||||
""" | """ | ||||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
""" | """ | ||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
@@ -90,29 +90,29 @@ class TestPaddleDriverFunctions: | |||||
def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
""" | """ | ||||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
""" | """ | ||||
# 此时传入的应该是dict | # 此时传入的应该是dict | ||||
dataloader = { | dataloader = { | ||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset()) | |||||
"train": DataLoader(PaddleNormalDataset()), | |||||
"test":DataLoader(PaddleNormalDataset()) | |||||
} | } | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = { | dataloader = { | ||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
"train": DataLoader(PaddleNormalDataset()), | |||||
"test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
} | } | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
# 传入的不是dict,应该报错 | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
# 传入的不是 dict ,应该报错 | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
# 创建torch的dataloader | |||||
# 创建 torch 的 dataloader | |||||
train_loader = torch.utils.data.DataLoader( | train_loader = torch.utils.data.DataLoader( | ||||
TorchNormalDataset(), | TorchNormalDataset(), | ||||
batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
@@ -127,7 +127,7 @@ class TestPaddleDriverFunctions: | |||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
""" | """ | ||||
测试tensor_to_numeric函数 | |||||
测试 tensor_to_numeric 函数 | |||||
""" | """ | ||||
# 单个张量 | # 单个张量 | ||||
tensor = paddle.to_tensor(3) | tensor = paddle.to_tensor(3) | ||||
@@ -180,7 +180,7 @@ class TestPaddleDriverFunctions: | |||||
def test_set_model_mode(self): | def test_set_model_mode(self): | ||||
""" | """ | ||||
测试set_model_mode函数 | |||||
测试 set_model_mode 函数 | |||||
""" | """ | ||||
self.driver.set_model_mode("train") | self.driver.set_model_mode("train") | ||||
assert self.driver.model.training | assert self.driver.model.training | ||||
@@ -192,14 +192,14 @@ class TestPaddleDriverFunctions: | |||||
def test_move_model_to_device_cpu(self): | def test_move_model_to_device_cpu(self): | ||||
""" | """ | ||||
测试move_model_to_device函数 | |||||
测试 move_model_to_device 函数 | |||||
""" | """ | ||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | ||||
assert self.driver.model.linear1.weight.place.is_cpu_place() | assert self.driver.model.linear1.weight.place.is_cpu_place() | ||||
def test_move_model_to_device_gpu(self): | def test_move_model_to_device_gpu(self): | ||||
""" | """ | ||||
测试move_model_to_device函数 | |||||
测试 move_model_to_device 函数 | |||||
""" | """ | ||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | ||||
assert self.driver.model.linear1.weight.place.is_gpu_place() | assert self.driver.model.linear1.weight.place.is_gpu_place() | ||||
@@ -207,7 +207,7 @@ class TestPaddleDriverFunctions: | |||||
def test_worker_init_function(self): | def test_worker_init_function(self): | ||||
""" | """ | ||||
测试worker_init_function | |||||
测试 worker_init_function | |||||
""" | """ | ||||
# 先确保不影响运行 | # 先确保不影响运行 | ||||
# TODO:正确性 | # TODO:正确性 | ||||
@@ -215,7 +215,7 @@ class TestPaddleDriverFunctions: | |||||
def test_set_deterministic_dataloader(self): | def test_set_deterministic_dataloader(self): | ||||
""" | """ | ||||
测试set_deterministic_dataloader | |||||
测试 set_deterministic_dataloader | |||||
""" | """ | ||||
# 先确保不影响运行 | # 先确保不影响运行 | ||||
# TODO:正确性 | # TODO:正确性 | ||||
@@ -224,7 +224,7 @@ class TestPaddleDriverFunctions: | |||||
def test_set_sampler_epoch(self): | def test_set_sampler_epoch(self): | ||||
""" | """ | ||||
测试set_sampler_epoch | |||||
测试 set_sampler_epoch | |||||
""" | """ | ||||
# 先确保不影响运行 | # 先确保不影响运行 | ||||
# TODO:正确性 | # TODO:正确性 | ||||
@@ -336,7 +336,7 @@ class TestSingleDeviceFunction: | |||||
def test_move_data_to_device(self): | def test_move_data_to_device(self): | ||||
""" | """ | ||||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||||
这个函数仅调用了 paddle_move_data_to_device ,测试例在 tests/core/utils/test_paddle_utils.py 中 | |||||
就不重复测试了 | 就不重复测试了 | ||||
""" | """ | ||||
self.driver.move_data_to_device(paddle.rand((32, 64))) | self.driver.move_data_to_device(paddle.rand((32, 64))) | ||||
@@ -490,9 +490,6 @@ class TestSetDistReproDataloader: | |||||
else: | else: | ||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | ||||
left_idxes = set() | left_idxes = set() | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | ||||
@@ -510,7 +507,6 @@ class TestSetDistReproDataloader: | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
else: | else: | ||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
num_consumed_samples = num_consumed_batches * batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
# 重新构造 dataloader | # 重新构造 dataloader | ||||
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) | batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) | ||||
@@ -0,0 +1,788 @@ | |||||
import pytest | |||||
import os | |||||
from pathlib import Path | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from fastNLP.core.samplers import ( | |||||
RandomSampler, | |||||
UnrepeatedSampler, | |||||
BucketedBatchSampler, | |||||
UnrepeatedRandomSampler, | |||||
UnrepeatedSequentialSampler, | |||||
) | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import rank_zero_rm | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
device = [torch.device(i) for i in device] | |||||
driver = TorchDDPDriver( | |||||
model=torch_model, | |||||
parallel_device=device, | |||||
fp16=fp16, | |||||
output_from_new_proc=output_from_new_proc | |||||
) | |||||
driver.set_optimizers(torch_opt) | |||||
driver.setup() | |||||
return driver | |||||
def dataloader_with_bucketedbatchsampler(dataset, length, batch_size, shuffle, drop_last): | |||||
""" | |||||
建立一个 batch_sampler 为 BucketedBatchSampler 的 dataloader | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=BucketedBatchSampler( | |||||
dataset, | |||||
length, | |||||
batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
), | |||||
) | |||||
return dataloader | |||||
def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0, unrepeated=False): | |||||
""" | |||||
建立一个 sampler 为 RandomSampler 的 dataloader | |||||
""" | |||||
if unrepeated: | |||||
sampler = UnrepeatedRandomSampler(dataset, shuffle, seed) | |||||
else: | |||||
sampler = RandomSampler(dataset, shuffle, seed=seed) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=sampler, | |||||
drop_last=drop_last, | |||||
batch_size=batch_size | |||||
) | |||||
return dataloader | |||||
############################################################################ | |||||
# | |||||
# 测试 TorchDDPDriver 的一些函数 | |||||
# | |||||
############################################################################ | |||||
class TestDDPDriverFunction: | |||||
""" | |||||
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.driver = generate_driver(10, 10) | |||||
@magic_argv_env_context | |||||
def test_multi_drivers(self): | |||||
""" | |||||
测试使用了多个 TorchDDPDriver 的情况。 | |||||
""" | |||||
driver2 = generate_driver(20, 10) | |||||
with pytest.raises(RuntimeError): | |||||
# 设备设置不同,应该报错 | |||||
driver3 = generate_driver(20, 3, device=[0,1,2]) | |||||
assert False | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_move_data_to_device(self): | |||||
""" | |||||
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 | |||||
就不重复测试了 | |||||
""" | |||||
self.driver.move_data_to_device(torch.rand((32, 64))) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_distributed(self): | |||||
""" | |||||
测试 is_distributed 函数 | |||||
""" | |||||
assert self.driver.is_distributed() == True | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_no_sync_context(self): | |||||
""" | |||||
测试 get_no_sync_context 函数 | |||||
""" | |||||
res = self.driver.get_model_no_sync_context() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_global_zero(self): | |||||
""" | |||||
测试 is_global_zero 函数 | |||||
""" | |||||
self.driver.is_global_zero() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_unwrap_model(self): | |||||
""" | |||||
测试 unwrap_model 函数 | |||||
""" | |||||
self.driver.unwrap_model() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_local_rank(self): | |||||
""" | |||||
测试 get_local_rank 函数 | |||||
""" | |||||
self.driver.get_local_rank() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_all_gather(self): | |||||
""" | |||||
测试 all_gather 函数 | |||||
详细的测试在 test_dist_utils.py 中完成 | |||||
""" | |||||
obj = { | |||||
"rank": self.driver.global_rank | |||||
} | |||||
obj_list = self.driver.all_gather(obj, group=None) | |||||
for i, res in enumerate(obj_list): | |||||
assert res["rank"] == i | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("src_rank", ([0, 1])) | |||||
def test_broadcast_object(self, src_rank): | |||||
""" | |||||
测试 broadcast_object 函数 | |||||
详细的函数在 test_dist_utils.py 中完成 | |||||
""" | |||||
if self.driver.global_rank == src_rank: | |||||
obj = { | |||||
"rank": self.driver.global_rank | |||||
} | |||||
else: | |||||
obj = None | |||||
res = self.driver.broadcast_object(obj, src=src_rank) | |||||
assert res["rank"] == src_rank | |||||
############################################################################ | |||||
# | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | |||||
############################################################################ | |||||
class TestSetDistReproDataloader: | |||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.device = [0, 1] | |||||
cls.driver = generate_driver(10, 10, device=cls.device) | |||||
def setup_method(self): | |||||
self.dataset = TorchNormalDataset(40) | |||||
""" | |||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||||
此时对应 driver.load 中的情况 | |||||
""" | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | |||||
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | |||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler is batch_sampler | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | |||||
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | |||||
sampler = RandomSampler(self.dataset, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is sampler | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | |||||
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 | |||||
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 | |||||
是否重新实例化 dataloader | |||||
""" | |||||
@magic_argv_env_context | |||||
def test_with_dist_none_reproducible_true(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | |||||
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
with pytest.raises(RuntimeError): | |||||
# 应当抛出 RuntimeError | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
# @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler | |||||
时的表现 | |||||
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler | |||||
和原 dataloader 相同 | |||||
""" | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
pad=True | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
self.check_distributed_sampler(dataloader.batch_sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 | |||||
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 | |||||
batch_sampler.sampler 和原 dataloader 相同 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.drop_last == False | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | |||||
此时直接返回原来的 dataloader,不做任何处理。 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert replaced_loader is dataloader | |||||
dist.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=self.dataset, | |||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | |||||
) | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 | |||||
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 | |||||
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
def check_distributed_sampler(self, sampler): | |||||
""" | |||||
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 | |||||
""" | |||||
assert sampler.num_replicas == dist.get_world_size() | |||||
assert sampler.rank == dist.get_rank() | |||||
if not isinstance(sampler, UnrepeatedSampler): | |||||
assert sampler.pad == True | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): | |||||
""" | |||||
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
num_replicas = len(self.device) | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_idx.update(batch) | |||||
dist.barrier() | |||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||||
# 重新改造 dataloader | |||||
new_loader = dataloader_with_bucketedbatchsampler( | |||||
replaced_loader.dataset, | |||||
length=replaced_loader.dataset._data, | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=False, | |||||
) | |||||
new_loader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
pad=True | |||||
) | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||||
else: | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||||
# 重新构造 dataloader | |||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) | |||||
new_loader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
) | |||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
for idx, batch in enumerate(new_loader): | |||||
left_idxes.update(batch) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | |||||
############################################################################ | |||||
# | |||||
# 测试 save 和 load 相关的功能 | |||||
# | |||||
############################################################################ | |||||
class TestSaveLoad: | |||||
""" | |||||
测试多卡情况下 save 和 load 相关函数的表现 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
# 不在这里 setup 的话会报错 | |||||
cls.driver = generate_driver(10, 10) | |||||
def setup_method(self): | |||||
self.dataset = TorchArgMaxDataset(10, 20) | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_model(self, only_state_dict): | |||||
""" | |||||
测试 save_model 和 load_model 函数 | |||||
""" | |||||
try: | |||||
path = "model" | |||||
dataloader = DataLoader(self.dataset, batch_size=2) | |||||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
self.driver1.save_model(path, only_state_dict) | |||||
# 同步 | |||||
dist.barrier() | |||||
self.driver2.load_model(path, only_state_dict) | |||||
for idx, batch in enumerate(dataloader): | |||||
batch = self.driver1.move_data_to_device(batch) | |||||
res1 = self.driver1.model( | |||||
batch, | |||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
res2 = self.driver2.model( | |||||
batch, | |||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
@pytest.mark.parametrize("device", ([[0,1]])) | |||||
def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
num_replicas = len(device) | |||||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
generate_driver(10, 10, device=device, fp16=False) | |||||
dataloader = dataloader_with_bucketedbatchsampler( | |||||
self.dataset, | |||||
length=[10 for i in range(len(self.dataset))], | |||||
batch_size=4, | |||||
shuffle=True, | |||||
drop_last=False | |||||
) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
pad=True | |||||
) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
# 同步 | |||||
dist.barrier() | |||||
# 保存状态 | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_bucketedbatchsampler( | |||||
self.dataset, | |||||
length=[10 for i in range(len(self.dataset))], | |||||
batch_size=2, | |||||
shuffle=True, | |||||
drop_last=False | |||||
) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
pad=True | |||||
) | |||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = self.driver1.model( | |||||
batch, | |||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
res2 = self.driver2.model( | |||||
batch, | |||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
@pytest.mark.parametrize("device", ([[0,1]])) | |||||
def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
num_replicas = len(device) | |||||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
pad=True | |||||
) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
# 同步 | |||||
dist.barrier() | |||||
# 保存状态 | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
if only_state_dict: | |||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | |||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
pad=True | |||||
) | |||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = self.driver1.model( | |||||
batch, | |||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
res2 = self.driver2.model( | |||||
batch, | |||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | |||||
finally: | |||||
rank_zero_rm(path) |
@@ -0,0 +1,103 @@ | |||||
import os | |||||
import pytest | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | |||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||||
from fastNLP.envs import get_gpu_count | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
import torch | |||||
def test_incorrect_driver(): | |||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver("paddle", 0, model) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
["cpu", "cuda:0", 0, torch.device("cuda:0")] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch"] | |||||
) | |||||
def test_get_single_device(driver, device): | |||||
""" | |||||
测试正常情况下初始化TorchSingleDriver的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchSingleDriver) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[0, 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp_2(driver, device): | |||||
""" | |||||
测试 ddp 多卡的初始化情况,但传入了单个 gpu | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchDDPDriver) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[[0, 2, 3], -1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch", "torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp(driver, device): | |||||
""" | |||||
测试 ddp 多卡的初始化情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchDDPDriver) | |||||
@pytest.mark.parametrize( | |||||
("driver", "device"), | |||||
[("torch_ddp", "cpu")] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp_cpu(driver, device): | |||||
""" | |||||
测试试图在 cpu 上初始化分布式训练的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch", "torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_device_out_of_range(driver, device): | |||||
""" | |||||
测试传入的device超过范围的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver(driver, device, model) |
@@ -0,0 +1,697 @@ | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
import pytest | |||||
from pathlib import Path | |||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from fastNLP.core import rank_zero_rm | |||||
import torch | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
import paddle | |||||
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | |||||
""" | |||||
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader | |||||
""" | |||||
if shuffle: | |||||
sampler = torch.utils.data.RandomSampler(dataset) | |||||
else: | |||||
sampler = torch.utils.data.SequentialSampler(dataset) | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=RandomBatchSampler( | |||||
BatchSampler( | |||||
sampler, batch_size=batch_size, drop_last=drop_last | |||||
), | |||||
batch_size=batch_size, | |||||
drop_last=drop_last, | |||||
), | |||||
) | |||||
return dataloader | |||||
def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0): | |||||
""" | |||||
建立一个 sampler 为 RandomSampler 的 dataloader | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, shuffle, seed=seed), | |||||
drop_last=drop_last, | |||||
batch_size=batch_size | |||||
) | |||||
return dataloader | |||||
############################################################################ | |||||
# | |||||
# 测试基类 TorchDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
class TestTorchDriverFunctions: | |||||
""" | |||||
使用 TorchSingleDriver 测试基类的函数 | |||||
""" | |||||
@classmethod | |||||
def setup_class(self): | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
self.driver = TorchSingleDriver(model, device="cpu") | |||||
def test_check_single_optimizer_legality(self): | |||||
""" | |||||
测试传入单个 optimizer 时的表现 | |||||
""" | |||||
optimizer = torch.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
learning_rate=0.01, | |||||
) | |||||
# 传入 torch 的 optimize r时,应该报错 ValueError | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizer) | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试传入 optimizer list 的表现 | |||||
""" | |||||
optimizers = [ | |||||
torch.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
optimizers += [ | |||||
paddle.optimizer.Adam( | |||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
learning_rate=0.01, | |||||
) | |||||
] | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizers) | |||||
def test_check_dataloader_legality_in_train(self): | |||||
""" | |||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
""" | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
# 创建 paddle 的 dataloader | |||||
dataloader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
def test_check_dataloader_legality_in_test(self): | |||||
""" | |||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
""" | |||||
# 此时传入的应该是dict | |||||
dataloader = { | |||||
"train": DataLoader(TorchNormalDataset()), | |||||
"test": DataLoader(TorchNormalDataset()) | |||||
} | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 传入的不是 dict,应该报错 | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 创建 paddle 的 dataloader | |||||
train_loader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试 tensor_to_numeric 函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = torch.tensor(3) | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == 3 | |||||
tensor = torch.rand((3, 4)) | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == tensor.tolist() | |||||
# 张量list | |||||
tensor_list = [torch.rand((6, 4, 2)) for i in range(10)] | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor_list) | |||||
assert isinstance(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
assert res == tensor_list | |||||
# 张量tuple | |||||
tensor_tuple = tuple([torch.rand((6, 4, 2)) for i in range(10)]) | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor_tuple) | |||||
assert isinstance(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
assert res == tensor_tuple | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": torch.rand((3, 4)), | |||||
"list": [torch.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [torch.rand((6, 4, 2)) for i in range(10)], | |||||
"tensor": torch.rand((3, 4)) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor_dict) | |||||
assert isinstance(res, dict) | |||||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||||
assert isinstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
assert r == d.tolist() | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
assert r == d.tolist() | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试set_model_mode函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
assert self.driver.model.training | |||||
self.driver.set_model_mode("eval") | |||||
assert not self.driver.model.training | |||||
# 应该报错 | |||||
with pytest.raises(AssertionError): | |||||
self.driver.set_model_mode("test") | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
TorchSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
assert self.driver.model.linear1.weight.device.type == "cpu" | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
TorchSingleDriver.move_model_to_device(self.driver.model, "cuda") | |||||
assert self.driver.model.linear1.weight.device.type == "cuda" | |||||
assert self.driver.model.linear1.weight.device.index == 0 | |||||
def test_worker_init_function(self): | |||||
""" | |||||
测试worker_init_function | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
TorchSingleDriver.worker_init_function(0) | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试set_deterministic_dataloader | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试正常情况下 get_dataloader_args 的表现 | |||||
""" | |||||
dataloader = DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
) | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, TorchNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, torch.utils.data.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, torch.utils.data.SequentialSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 batch_sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = TorchNormalDataset() | |||||
dataloader = dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last) | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, TorchNormalDataset) | |||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, torch.utils.data.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, torch.utils.data.SequentialSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = TorchNormalDataset() | |||||
dataloader = dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last) | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, TorchNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
assert isinstance(res.sampler, RandomSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
############################################################################ | |||||
# | |||||
# 测试 TorchSingleDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
class TestSingleDeviceFunction: | |||||
""" | |||||
测试其它函数的测试例 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
model = TorchNormalModel_Classification_1(10, 784) | |||||
cls.driver = TorchSingleDriver(model, device="cpu") | |||||
def test_unwrap_model(self): | |||||
""" | |||||
测试能否运行 | |||||
""" | |||||
res = self.driver.unwrap_model() | |||||
assert res is self.driver.model | |||||
def test_is_distributed(self): | |||||
assert self.driver.is_distributed() == False | |||||
def test_move_data_to_device(self): | |||||
""" | |||||
这个函数仅调用了 torch_move_data_to_device ,测试例在 tests/core/utils/test_torch_utils.py 中 | |||||
就不重复测试了 | |||||
""" | |||||
self.driver.move_data_to_device(torch.rand((32, 64))) | |||||
############################################################################ | |||||
# | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | |||||
############################################################################ | |||||
class TestSetDistReproDataloader: | |||||
""" | |||||
专门测试 set_dist_repro_dataloader 函数的类 | |||||
""" | |||||
def setup_method(self): | |||||
self.dataset = TorchNormalDataset(20) | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
self.driver = TorchSingleDriver(model, device="cpu") | |||||
def test_with_reproducible_false(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||||
当dist为字符串时,此时应该返回原来的 dataloader | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert replaced_loader is dataloader | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
def test_with_reproducible_true(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | |||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||||
assert not (replaced_loader is dataloader) | |||||
if shuffle: | |||||
# 此时会替换 sampler | |||||
assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
else: | |||||
# 此时会替换 batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) | |||||
dist = RandomSampler(self.dataset, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is dist | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = dataloader_with_randombatchsampler(self.dataset, 4, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): | |||||
""" | |||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_idx.update(batch) | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
# 重新改造 dataloader | |||||
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||||
else: | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
# 重新构造 dataloader | |||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
for idx, batch in enumerate(new_loader): | |||||
left_idxes.update(batch) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||||
############################################################################ | |||||
# | |||||
# 测试 save 和 load 相关的功能 | |||||
# | |||||
############################################################################ | |||||
def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
""" | |||||
生成driver | |||||
""" | |||||
model = TorchNormalModel_Classification_1(labels, features) | |||||
opt = torch.optim.Adam(params=model.parameters(), lr=0.01) | |||||
driver = TorchSingleDriver(model, device=device, fp16=fp16) | |||||
driver.set_optimizers(opt) | |||||
driver.setup() | |||||
return driver | |||||
@pytest.fixture | |||||
def prepare_test_save_load(): | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = DataLoader(dataset, batch_size=4) | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
return driver1, driver2, dataloader | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
""" | |||||
测试 save_model 和 load_model 函数 | |||||
""" | |||||
try: | |||||
path = "model" | |||||
driver1, driver2, dataloader = prepare_test_save_load | |||||
driver1.save_model(path, only_state_dict) | |||||
driver2.load_model(path, only_state_dict) | |||||
for batch in dataloader: | |||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | |||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | |||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
batch = driver2.move_data_to_device(batch) | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) | |||||
assert len(left_x_batches | already_seen_x_set) == len(dataset) | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
batch = driver2.move_data_to_device(batch) | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) | |||||
assert len(left_x_batches | already_seen_x_set) == len(dataset) | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||||
finally: | |||||
rank_zero_rm(path) |
@@ -1,35 +1,36 @@ | |||||
from torch.utils.data.sampler import SequentialSampler, RandomSampler | |||||
from fastNLP.core.samplers.sampler import ReproduceSampler | |||||
from tests.helpers.datasets.normal_data import NormalIterator | |||||
class TestReproduceSampler: | |||||
def test_sequentialsampler(self): | |||||
normal_iterator = NormalIterator(num_of_data=20) | |||||
sequential_sampler = SequentialSampler(normal_iterator) | |||||
reproduce_sampler = ReproduceSampler(sequential_sampler) | |||||
# iter_seq_sampler = iter(sequential_sampler) | |||||
# for each in iter_seq_sampler: | |||||
# print(each) | |||||
iter_reproduce_sampler = iter(reproduce_sampler) | |||||
forward_step = 3 | |||||
for _ in range(forward_step): | |||||
next(iter_reproduce_sampler) | |||||
state = reproduce_sampler.save_state() | |||||
assert state["current_batch_idx"] == forward_step | |||||
new_repro_sampler = ReproduceSampler(sequential_sampler) | |||||
assert new_repro_sampler.save_state()["current_batch_idx"] == 0 | |||||
new_repro_sampler.load_state(state) | |||||
iter_new_repro_sampler = iter(new_repro_sampler) | |||||
new_index_list = [] | |||||
for each in iter_new_repro_sampler: | |||||
new_index_list.append(each) | |||||
assert new_index_list == list(range(3, 20)) | |||||
import os | |||||
import pytest | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers.torch_driver.utils import ( | |||||
replace_batch_sampler, | |||||
replace_sampler, | |||||
) | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
def test_replace_batch_sampler(): | |||||
dataset = TorchNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, TorchNormalDataset) | |||||
assert len(replaced_loader.dataset) == len(dataset) | |||||
assert replaced_loader.batch_sampler.batch_size == 16 | |||||
def test_replace_sampler(): | |||||
dataset = TorchNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
sampler = RandomSampler(dataset) | |||||
replaced_loader = replace_sampler(dataloader, sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
@@ -38,7 +38,7 @@ class RecordMetricCallback(Callback): | |||||
self.metric_threshold = metric_threshold | self.metric_threshold = metric_threshold | ||||
self.metric_begin_value = None | self.metric_begin_value = None | ||||
def on_validate_end(self, trainer, results): | |||||
def on_evaluate_end(self, trainer, results): | |||||
self.metric = results[self.monitor] | self.metric = results[self.monitor] | ||||
if self.metric_begin_value is None: | if self.metric_begin_value is None: | ||||
self.metric_begin_value = self.metric | self.metric_begin_value = self.metric | ||||
@@ -113,11 +113,11 @@ class RecordTrainerEventTriggerCallback(Callback): | |||||
def on_after_zero_grad(self, trainer, optimizers): | def on_after_zero_grad(self, trainer, optimizers): | ||||
print("on_after_zero_grad") | print("on_after_zero_grad") | ||||
def on_validate_begin(self, trainer): | |||||
print("on_validate_begin") | |||||
def on_evaluate_begin(self, trainer): | |||||
print("on_evaluate_begin") | |||||
def on_validate_end(self, trainer, results): | |||||
print("on_validate_end") | |||||
def on_evaluate_end(self, trainer, results): | |||||
print("on_evaluate_end") | |||||
@@ -38,7 +38,7 @@ class TorchNormalDataset_Classification(Dataset): | |||||
return {"x": self.x[item], "y": self.y[item]} | return {"x": self.x[item], "y": self.y[item]} | ||||
class TorchArgMaxDatset(Dataset): | |||||
class TorchArgMaxDataset(Dataset): | |||||
def __init__(self, feature_dimension=10, data_num=1000, seed=0): | def __init__(self, feature_dimension=10, data_num=1000, seed=0): | ||||
self.num_labels = feature_dimension | self.num_labels = feature_dimension | ||||
self.feature_dimension = feature_dimension | self.feature_dimension = feature_dimension | ||||