@@ -0,0 +1,23 @@ | |||||
__all__ = [ | |||||
'Callback', | |||||
'Events', | |||||
'EventsList', | |||||
'Filter', | |||||
'CallbackManager', | |||||
'CheckpointCallback', | |||||
'choose_progress_callback', | |||||
'ProgressCallback', | |||||
'RichCallback', | |||||
"LRSchedCallback", | |||||
'LoadBestModelCallback' | |||||
] | |||||
from .callback import Callback | |||||
from .callback_events import EventsList, Events, Filter | |||||
from .callback_manager import CallbackManager | |||||
from .checkpoint_callback import CheckpointCallback | |||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | |||||
from .lr_scheduler_callback import LRSchedCallback | |||||
from .load_best_model_callback import LoadBestModelCallback | |||||
@@ -0,0 +1,153 @@ | |||||
from typing import Union, Callable, Dict, Optional | |||||
__all__ = [ | |||||
'Callback', | |||||
] | |||||
from .callback_events import Events, EventsList, Filter | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
class Callback: | |||||
r""" | |||||
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | |||||
""" | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
r""" | |||||
在 `Trainer` 初始化后会被触发; | |||||
""" | |||||
pass | |||||
def on_sanity_check_begin(self, trainer): | |||||
r""" | |||||
在 '预跑'检测 开始前会被触发; | |||||
""" | |||||
pass | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
r""" | |||||
在 '预跑'检测 开始后会被触发; | |||||
""" | |||||
pass | |||||
def on_train_begin(self, trainer): | |||||
r""" | |||||
在训练开始前会被触发; | |||||
""" | |||||
pass | |||||
def on_train_end(self, trainer): | |||||
r""" | |||||
在训练完成后会被触发; | |||||
""" | |||||
pass | |||||
def on_train_epoch_begin(self, trainer): | |||||
r""" | |||||
在训练过程中的每一个 epoch 开始前会被触发; | |||||
""" | |||||
pass | |||||
def on_train_epoch_end(self, trainer): | |||||
r""" | |||||
在训练过程中的每一个 epoch 完成后会被触发; | |||||
""" | |||||
pass | |||||
def on_fetch_data_begin(self, trainer): | |||||
r""" | |||||
在训练过程中拿到当前的具体的一个 batch 前会被触发; | |||||
""" | |||||
pass | |||||
def on_fetch_data_end(self, trainer): | |||||
r""" | |||||
在训练过程中拿到当前的具体的一个 batch 后会被触发; | |||||
""" | |||||
pass | |||||
def on_train_batch_begin(self, trainer, batch, indices=None): | |||||
r""" | |||||
在训练过程中开始具体的一个 batch 前会被触发; | |||||
:param trainer: `fastNLP.Trainer` | |||||
:param batch: 当前正在运行的一个 batch; | |||||
:param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; | |||||
""" | |||||
pass | |||||
def on_train_batch_end(self, trainer): | |||||
pass | |||||
def on_exception(self, trainer, exception): | |||||
pass | |||||
def on_save_model(self, trainer): | |||||
pass | |||||
def on_load_model(self, trainer): | |||||
pass | |||||
def on_save_checkpoint(self, trainer) -> Dict: | |||||
""" | |||||
当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 | |||||
工作的状态;而不应该让该函数去判断两个 callback 是否一样; | |||||
""" | |||||
pass | |||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||||
r""" | |||||
如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; | |||||
""" | |||||
pass | |||||
def on_before_backward(self, trainer, outputs): | |||||
pass | |||||
def on_after_backward(self, trainer): | |||||
pass | |||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
pass | |||||
def on_before_zero_grad(self, trainer, optimizers): | |||||
pass | |||||
def on_validate_begin(self, trainer): | |||||
pass | |||||
def on_validate_end(self, trainer, results): | |||||
pass | |||||
@property | |||||
def callback_name(self): | |||||
return self.__class__.__name__ | |||||
class _CallbackWrapper(Callback): | |||||
""" | |||||
对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 | |||||
这一个 callback 函数; | |||||
""" | |||||
def __init__(self, event: Union[Events, EventsList], fn: Callable): | |||||
r""" | |||||
:param event: 具体的 callback 时机,例如 'on_train_begin' 等;可以多个时机,此时 `event` 的 type 应当为 'EventsList'; | |||||
:param fn: 用户定制的 callback 函数; | |||||
""" | |||||
self.fn = fn | |||||
if isinstance(event, EventsList): | |||||
for each_event in event: | |||||
_filter = Filter(each_event.every, each_event.once, each_event.filter_fn) | |||||
setattr(self, each_event.value, _filter(fn)) | |||||
elif isinstance(event, _SingleEventState): | |||||
_filter = Filter(event.every, event.once, event.filter_fn) | |||||
setattr(self, event.value, _filter(fn)) | |||||
@property | |||||
def callback_name(self): | |||||
return self.fn.__name__ | |||||
@@ -0,0 +1,218 @@ | |||||
from enum import Enum, unique | |||||
from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict | |||||
from types import DynamicClassAttribute | |||||
from functools import wraps | |||||
import fastNLP | |||||
__all__ = [ | |||||
'Events', | |||||
'EventsList', | |||||
'Filter' | |||||
] | |||||
class _SingleEventState: | |||||
every: Optional[int] | |||||
once: Optional[int] | |||||
def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, | |||||
filter_fn: Optional[Callable] = None, name: Optional[str] = None): | |||||
# 具体的检测参数对错的逻辑放在具体的 Filter 里; | |||||
if every is None and once is None and filter_fn is None: | |||||
self.every = 1 | |||||
self.once = None | |||||
self.filter_fn = None | |||||
else: | |||||
self.every = every | |||||
self.once = once | |||||
self.filter_fn = filter_fn | |||||
if not hasattr(self, "_value_"): | |||||
self._value_ = value | |||||
if not hasattr(self, "_name_") and name is not None: | |||||
self._name_ = name | |||||
# copied to be compatible to enum | |||||
@DynamicClassAttribute | |||||
def name(self) -> str: | |||||
"""The name of the Enum member.""" | |||||
return self._name_ | |||||
@DynamicClassAttribute | |||||
def value(self) -> str: | |||||
"""The value of the Enum member.""" | |||||
return self._value_ | |||||
def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): | |||||
return _SingleEventState(self.value, every, once, filter_fn, self.name) | |||||
def __str__(self): | |||||
return "<event={0}, every={1}, once={2}, filter fn is None:{3}>".format(self.name, self.every, self.once, | |||||
self.filter_fn) | |||||
def __eq__(self, other) -> bool: | |||||
if isinstance(other, _SingleEventState): | |||||
return self.name == other.name | |||||
elif isinstance(other, str): | |||||
return self.name == other | |||||
else: | |||||
raise NotImplemented | |||||
def __hash__(self): | |||||
return hash(self._name_) | |||||
def __or__(self, other) -> "EventsList": | |||||
return EventsList() | self | other | |||||
class EventEnum(_SingleEventState, Enum): | |||||
pass | |||||
@unique | |||||
class Events(EventEnum): | |||||
ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized" | |||||
ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin" | |||||
ON_SANITY_CHECK_END = "on_sanity_check_end" | |||||
ON_TRAIN_BEGIN = "on_train_begin" | |||||
ON_TRAIN_END = "on_train_end" | |||||
ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin" | |||||
ON_TRAIN_EPOCH_END = "on_train_epoch_end" | |||||
ON_FETCH_DATA_BEGIN = "on_fetch_data_begin" | |||||
ON_FETCH_DATA_END = "on_fetch_data_end" | |||||
ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin" | |||||
ON_TRAIN_BATCH_END = "on_train_batch_end" | |||||
ON_EXCEPTION = "on_exception" | |||||
ON_SAVE_MODEL = "on_save_model" | |||||
ON_LOAD_MODEL = "on_load_model" | |||||
ON_SAVE_CHECKPOINT = "on_save_checkpoint" | |||||
ON_LOAD_CHECKPOINT = "on_load_checkpoint" | |||||
ON_BEFORE_BACKWARD = "on_before_backward" | |||||
ON_AFTER_BACKWARD = "on_after_backward" | |||||
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" | |||||
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | |||||
ON_VALIDATE_BEGIN = "on_validate_begin" | |||||
ON_VALIDATE_END = "on_validate_end" | |||||
class EventsList: | |||||
"""Collection of events stacked by operator `__or__`. | |||||
""" | |||||
def __init__(self) -> None: | |||||
self._events = [] # type: List[Union[Events, _SingleEventState]] | |||||
def _append(self, event: Union[Events, _SingleEventState]) -> None: | |||||
if not isinstance(event, (Events, _SingleEventState)): | |||||
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}") | |||||
self._events.append(event) | |||||
def __getitem__(self, item: int) -> Union[Events, _SingleEventState]: | |||||
return self._events[item] | |||||
def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]: | |||||
return iter(self._events) | |||||
def __len__(self) -> int: | |||||
return len(self._events) | |||||
def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList": | |||||
self._append(event=other) | |||||
return self | |||||
class Filter: | |||||
def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): | |||||
r""" | |||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; | |||||
:param every: 表示一个函数隔多少次运行一次; | |||||
:param once: 表示一个函数只在第多少次时运行一次; | |||||
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 | |||||
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; | |||||
""" | |||||
if (every is None) and (once is None) and (filter_fn is None): | |||||
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.") | |||||
if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): | |||||
raise ValueError("These three values should be only set one.") | |||||
if (filter_fn is not None) and not callable(filter_fn): | |||||
raise TypeError("Argument event_filter should be a callable") | |||||
if (every is not None) and not (isinstance(every, int) and every > 0): | |||||
raise ValueError("Argument every should be integer and greater than zero") | |||||
if (once is not None) and not (isinstance(once, int) and once > 0): | |||||
raise ValueError("Argument once should be integer and positive") | |||||
# 设置变量,包括全局变量; | |||||
self.num_called = 0 | |||||
self.num_executed = 0 | |||||
if every is not None: | |||||
self._every = every | |||||
self._filter = self.every_filter | |||||
elif once is not None: | |||||
self._once = once | |||||
self._filter = self.once_filter | |||||
else: | |||||
self._filter = filter_fn | |||||
def __call__(self, fn: Callable): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs) -> Callable: | |||||
self.num_called += 1 | |||||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | |||||
# 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑; | |||||
# 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中; | |||||
# 提取参数的逻辑; | |||||
trainer = kwargs.get("trainer", None) | |||||
if trainer is None and len(args) > 0: | |||||
trainer = args[0] | |||||
if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer | |||||
# 也会调用这个 module,但是 Controller 不会; | |||||
param = (self, trainer) | |||||
else: | |||||
param = (self, ) | |||||
if self._filter(*param): | |||||
self.num_executed += 1 | |||||
return fn(*args, **kwargs) | |||||
wrapper.__fastNLP_filter__ = self | |||||
return wrapper | |||||
def every_filter(self, *args): | |||||
return self.num_called % self._every == 0 | |||||
def once_filter(self, *args): | |||||
return self.num_called == self._once | |||||
def state_dict(self) -> Dict: | |||||
r""" | |||||
通过该函数来保存该 `Filter` 的状态; | |||||
""" | |||||
return {"num_called": self.num_called, "num_executed": self.num_executed} | |||||
def load_state_dict(self, state: Dict): | |||||
r""" | |||||
通过该函数来加载 `Filter` 的状态; | |||||
:param state: 通过 `Filter.state_dict` 函数保存的状态元组; | |||||
""" | |||||
self.num_called = state["num_called"] | |||||
self.num_executed = state["num_executed"] | |||||
@@ -0,0 +1,294 @@ | |||||
import inspect | |||||
from typing import List, Optional, Dict, Sequence | |||||
from collections import defaultdict | |||||
__all__ = [ | |||||
'CallbackManager' | |||||
] | |||||
from .callback_events import Events | |||||
from .callback import Callback | |||||
from .checkpoint_callback import CheckpointCallback | |||||
from .progress_callback import ProgressCallback, choose_progress_callback | |||||
from fastNLP.core.log import logger | |||||
def _transfer(func): | |||||
r""" | |||||
装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||||
需要注意这里的 wrapper 内的函数不会运行 `func` 本身,因此如果有什么需要直接在 callback 函数内运行的代码,请放在 TrainerCallback 内; | |||||
""" | |||||
def wrapper(manager, *arg, **kwargs): | |||||
manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; | |||||
returns = [] | |||||
for callback_fn in manager.callback_fns[func.__name__]: | |||||
returns.append(callback_fn(*arg, **kwargs)) | |||||
return returns | |||||
return wrapper | |||||
class CallbackManager: | |||||
r""" | |||||
用来管理训练过程中的所有的 callback 实例; | |||||
""" | |||||
all_callbacks: List[Callback] | |||||
class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; | |||||
callback_fns: dict | |||||
def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'): | |||||
r""" | |||||
注意 callback 的调用顺序: | |||||
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | |||||
2. 通过 `Trainer` 的参数 `callbacks` 添加的 callback 类; | |||||
3. 通过 `Trainer.add_callback_fn` 添加的 callback 函数; | |||||
:param callbacks: 初始化时可以传入的一系列 callback 类,通常为用户在初始化 'Trainer' 时直接传入的 callback 类; | |||||
""" | |||||
self._has_trainer_checkpoint = False | |||||
_has_progress_callback = False | |||||
_callbacks = [] | |||||
if callbacks is not None: | |||||
if isinstance(callbacks, Callback): | |||||
callbacks = [callbacks] | |||||
if not isinstance(callbacks, Sequence): | |||||
raise ValueError("Parameter `callbacks` should be type 'List' or 'Tuple'.") | |||||
callbacks = list(callbacks) | |||||
for _callback in callbacks: | |||||
if not isinstance(_callback, Callback): | |||||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | |||||
if isinstance(_callback, ProgressCallback): | |||||
_has_progress_callback = True | |||||
_callbacks += callbacks | |||||
if not _has_progress_callback: | |||||
# 添加 progress callback | |||||
progress_callback = choose_progress_callback(progress_bar=progress_bar) | |||||
if progress_callback is None: | |||||
logger.info("There is no progress bar, Trainer will not output training progress.") | |||||
else: | |||||
_callbacks.append(progress_callback) | |||||
self.callback_fns = defaultdict(list) | |||||
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | |||||
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | |||||
self.callback_counter = defaultdict(lambda: 0) | |||||
if len(_callbacks): | |||||
# 这一对象是为了保存原始的类 callback 对象来帮助用户进行 debug,理论上在正常的使用中你并不会需要它; | |||||
self.class_callbacks = _callbacks | |||||
else: | |||||
self.class_callbacks: Optional[List[Callback]] = [] | |||||
# 预跑需要拿到每一个被 `Filter` 修饰的函数的 `Filter` 实例,从而在预跑结束后重置它们的内部状态; | |||||
self._callback_filters = [] # [(callback_name, fn_name, filter 实例), ] | |||||
# 保留所有 callback 的引用,用于断点重训;包括全部的三种callback:函数修饰器 callback;类 callback;纯函数 callback; | |||||
# 因为所有的 callback 都是通过函数 `self.add_one_callback` 添加,因此我们选择在其下进行添加; | |||||
# 一个比较重要的概念在于在训练过程运行的时候,两个 callback 的 callback_name 可以是一样的,并且理论上不会造成任何影响;但是当 | |||||
# `on_load_checkpoint` 时,我们需要处理两个 callback_name 一样这种情况了; | |||||
# 因此这里的 `all_callbacks` 为了避免正常训练过程的运行,只能是一个 List,而不能是一个 dict,`_callback_filters` 也是一样; | |||||
self.all_callbacks = [] | |||||
def initialize_class_callbacks(self): | |||||
r""" | |||||
在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | |||||
一个个 callback 时机,也就是 `Events` 的类别; | |||||
如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | |||||
:param callbacks: | |||||
:return: | |||||
""" | |||||
for each_callback in self.class_callbacks: | |||||
if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint: | |||||
self._has_trainer_checkpoint = True | |||||
self.dissect_one_callback(each_callback) | |||||
def dissect_one_callback(self, callback: Callback): | |||||
r""" | |||||
将具体的一个 callback 实例的所有 callback 函数拆分后按时机插入到字典中; | |||||
:param callback: 一个具体的 callback 实例; | |||||
""" | |||||
self.all_callbacks.append(callback) | |||||
for name, member in Events.__members__.items(): | |||||
_fn = getattr(callback, member.value) | |||||
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): | |||||
self.callback_fns[member.value].append(_fn) | |||||
self.extract_callback_filter_state(callback.callback_name, _fn) | |||||
def extract_callback_filter_state(self, callback_name, callback_fn): | |||||
r""" | |||||
将一个具体的 callback 函数的 filter 的状态抽取出来; | |||||
""" | |||||
if hasattr(callback_fn, "__fastNLP_filter__"): | |||||
# 注意我们的 `Filter` 使用了 `@wraps` 来保证被修饰的函数的 `__name__` 属性仍旧是其真实的名字; | |||||
self._callback_filters.append((callback_name, callback_fn.__name__, callback_fn.__fastNLP_filter__)) | |||||
def on_save_checkpoint(self, trainer) -> Dict: | |||||
r""" | |||||
用于断点重训的 callback 的保存函数; | |||||
该函数主要涉及两个方面: | |||||
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 | |||||
断点重训应当保存的状态; | |||||
2. 每一个具体的 callback 函数的 filter 的状态; | |||||
:return: 一个包含上述内容的字典; | |||||
{ | |||||
"callback_name_1": { | |||||
"states": {...}, | |||||
"filter_states": {"on_train_begin": filter1.state_dict(), ...} | |||||
} | |||||
} | |||||
""" | |||||
states = {} | |||||
# 1. 每一个 callback 的状态; | |||||
# 如果有两个 callback 的 name 相同,那么我们只会保存第一个; | |||||
_duplicated_callbacks = [] | |||||
for each_callback in self.all_callbacks: | |||||
if each_callback.callback_name in states: | |||||
_duplicated_callbacks.append(each_callback.callback_name) | |||||
# 对于 callback_name 有重复的 callback,我们仍旧会调用其 `on_save_checkpoint` 函数,就如同调用其它 callback 函数 | |||||
# 一样,但是其结果并不会存储在 states 中返回; | |||||
each_callback.on_save_checkpoint(trainer) | |||||
else: | |||||
states[each_callback.callback_name] = {} | |||||
states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer) | |||||
if len(_duplicated_callbacks) > 0: | |||||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callbacks}, " | |||||
f"and we will only save the first callback's state we meet.") | |||||
# 2. 每一个具体的 callback 函数的 filter 的状态; | |||||
_record_duplicated_callback_names = set() | |||||
for each_callback_filters in self._callback_filters: | |||||
if each_callback_filters[0] not in _record_duplicated_callback_names: | |||||
_record_duplicated_callback_names.add(each_callback_filters[0]) | |||||
states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() | |||||
# 3. 保存 callback_counter; | |||||
# callback_counter 不应当被保存,因为其在断点重训时会由新的 callback_manager 重新初始化; | |||||
# 对于断点重训,我们不会保存 Trainer 的所有参数,例如 batch_step_fn;如果在断点重训时重新初始化 Trainer 发现 batch_step_fn | |||||
# 不为 None,那么 Trainer 就会调用实际的 check_batch_step_fn 函数,从而需要 callback_counter 为全新的状态; | |||||
return states | |||||
def on_load_checkpoint(self, trainer, states: Dict): | |||||
r""" | |||||
用于断点重训的加载函数; | |||||
对应于断点重训的保存函数; | |||||
:param trainer: `Trainer` | |||||
:param states: 见 `on_save_checkpoint` 函数的返回值; | |||||
""" | |||||
# 1. 先恢复每一个具体的 callback 函数的 filter 的状态; | |||||
# self._callback_filters 是当前的 Trainer 的 callback 的 filter 状态,是我们要去维护的对象; | |||||
_already_loaded_callback_names = set() | |||||
_duplicated_callback_names = set() | |||||
for each_callback_filters in self._callback_filters: | |||||
if each_callback_filters[0] in states: | |||||
if each_callback_filters[0] not in _already_loaded_callback_names: | |||||
_already_loaded_callback_names.add(each_callback_filters[0]) | |||||
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]]) | |||||
else: | |||||
_duplicated_callback_names.add(each_callback_filters[0]) | |||||
if len(_duplicated_callback_names) > 0: | |||||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callback_names}, " | |||||
f"and we will only load the first callback's state we meet.") | |||||
# 2. 再恢复每一个 callback 的单独的状态; | |||||
# 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样, | |||||
# 那么它们就应该是同一个对象; | |||||
_already_loaded_callback_names = set() | |||||
for each_callback in self.all_callbacks: | |||||
if each_callback.callback_name in states and each_callback.callback_name not in _already_loaded_callback_names: | |||||
_already_loaded_callback_names.add(each_callback.callback_name) | |||||
# 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态; | |||||
each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"]) | |||||
else: | |||||
each_callback.on_load_checkpoint(trainer, None) | |||||
@property | |||||
def has_trainer_chechpoint(self) -> bool: | |||||
return self._has_trainer_checkpoint | |||||
@_transfer | |||||
def on_after_trainer_initialized(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_sanity_check_begin(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_sanity_check_end(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_train_begin(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_train_end(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_train_epoch_begin(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_train_epoch_end(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_fetch_data_begin(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_fetch_data_end(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_train_batch_begin(self, trainer, batch, indices=None): | |||||
pass | |||||
@_transfer | |||||
def on_train_batch_end(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_exception(self, trainer, exception): | |||||
pass | |||||
@_transfer | |||||
def on_save_model(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_load_model(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_before_backward(self, trainer, outputs): | |||||
pass | |||||
@_transfer | |||||
def on_after_backward(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | |||||
def on_before_zero_grad(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | |||||
def on_validate_begin(self, trainer): | |||||
pass | |||||
@_transfer | |||||
def on_validate_end(self, trainer, results): | |||||
pass |
@@ -0,0 +1,267 @@ | |||||
import os | |||||
from typing import Union, Optional, Callable, Dict, Sequence | |||||
from pathlib import Path | |||||
from functools import partial | |||||
from time import sleep | |||||
__all__ = [ | |||||
'CheckpointCallback' | |||||
] | |||||
import fastNLP | |||||
from .callback import Callback, Filter | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||||
class CheckpointCallback(Callback): | |||||
""" | |||||
1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情; | |||||
2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar; | |||||
3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1", | |||||
"acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例; | |||||
4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样; | |||||
""" | |||||
def __init__( | |||||
self, | |||||
monitor, | |||||
is_trainer_checkpoint: Optional[bool] = False, | |||||
save_folder: Optional[Union[str, Path]] = None, | |||||
save_every_n_epochs: Optional[int] = None, | |||||
save_every_n_global_batches: Optional[int] = None, | |||||
save_last: bool = True, | |||||
save_topk: Optional[int] = None, | |||||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, | |||||
larger_better: bool = True, | |||||
only_state_dict: bool = True, | |||||
model_save_fn: Optional[Callable] = None, | |||||
**kwargs, | |||||
): | |||||
if monitor is None and save_topk is not None: | |||||
raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") | |||||
if monitor is not None and not isinstance(monitor, str): | |||||
raise ValueError("Parameter `monitor` should be of 'str' type.") | |||||
if not isinstance(is_trainer_checkpoint, bool): | |||||
raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.") | |||||
if save_folder is None: | |||||
logger.warning( | |||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | |||||
save_folder = Path.cwd() | |||||
if not save_folder.exists(): | |||||
raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") | |||||
elif save_folder.is_file(): | |||||
raise ValueError("Parameter `save_folder` should be a directory instead of a file.") | |||||
if save_every_n_epochs is not None: | |||||
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: | |||||
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") | |||||
# 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的 | |||||
# 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样; | |||||
self._filter_every_n_epochs = Filter(every=save_every_n_epochs) | |||||
if save_every_n_global_batches is not None: | |||||
if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1: | |||||
raise ValueError( | |||||
"parameter save_every_n_global_batches should be an int and greater than or equal to 1.") | |||||
self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches) | |||||
if save_topk is not None: | |||||
if not isinstance(save_topk, int) or save_topk < 1: | |||||
raise ValueError("parameter save_topk should be an int and greater than or equal to 1.") | |||||
if save_on_exception is not None: | |||||
if not isinstance(save_on_exception, Sequence): | |||||
save_on_exception = [save_on_exception] | |||||
for exception in save_on_exception: | |||||
if not issubclass(exception, BaseException): | |||||
raise TypeError("Each exception in parameter `save_on_exception` can only be " | |||||
"`BaseException` type.") | |||||
self.monitor = monitor | |||||
self.is_trainer_checkpoint = is_trainer_checkpoint | |||||
self.save_folder = Path(save_folder) | |||||
self.save_every_n_epochs = save_every_n_epochs | |||||
self.save_every_n_global_batches = save_every_n_global_batches | |||||
self.save_last = save_last | |||||
self.save_topk = save_topk | |||||
self.larger_better = larger_better | |||||
self.only_state_dict = only_state_dict | |||||
self.model_save_fn = model_save_fn | |||||
self.save_on_exception = save_on_exception | |||||
self.kwargs = kwargs | |||||
# 这些参数是专门留给 topk 模式专门使用的; | |||||
self._topk_model = {} | |||||
self._topn = 0 # 表示目前已经保存了几个最好的模型; | |||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个 | |||||
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 | |||||
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; | |||||
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; | |||||
self._real_monitor = self.monitor | |||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | |||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | |||||
self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | |||||
synchronize_mkdir(self.log_filepath) | |||||
def on_validate_end(self, trainer, validate_res): | |||||
self._save_topk(trainer, validate_res) | |||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | |||||
self._save_every_n_epochs(trainer) | |||||
self._save_last(trainer) | |||||
def on_train_batch_end(self, trainer): | |||||
self._save_every_n_global_batches(trainer) | |||||
def on_exception(self, trainer, exception: BaseException): | |||||
if self.save_on_exception is not None and exception.__class__ in self.save_on_exception: | |||||
folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None) | |||||
folder = folder + f"_{exception.__class__.__name__}" | |||||
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder) | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
self._get_validate_metric(sanity_check_res) | |||||
def on_save_checkpoint(self, trainer) -> Dict: | |||||
""" | |||||
我们需要保存 CheckpointCallback 内部的几个 filter 的状态; | |||||
""" | |||||
states = {} | |||||
if self.save_every_n_epochs is not None: | |||||
states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict() | |||||
if self.save_every_n_global_batches is not None: | |||||
states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict() | |||||
states["real_monitor"] = self._real_monitor | |||||
return states | |||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||||
if self.save_every_n_epochs is not None: | |||||
self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"]) | |||||
if self.save_every_n_global_batches is not None: | |||||
self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"]) | |||||
self._real_monitor = states["real_monitor"] | |||||
def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"): | |||||
if self.save_every_n_epochs is not None: | |||||
if self.is_trainer_checkpoint: | |||||
_fn_every_n_epochs = trainer.save | |||||
else: | |||||
_fn_every_n_epochs = trainer.save_model | |||||
_fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None) | |||||
_fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs) | |||||
_fn_every_n_epochs() | |||||
def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"): | |||||
if self.save_every_n_global_batches is not None: | |||||
if self.is_trainer_checkpoint: | |||||
_fn_every_n_global_batches = trainer.save | |||||
else: | |||||
_fn_every_n_global_batches = trainer.save_model | |||||
_fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None) | |||||
_fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches) | |||||
_fn_every_n_global_batches() | |||||
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): | |||||
if self.save_topk is not None: | |||||
_metric_value = self._get_validate_metric(validate_res) | |||||
_saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value) | |||||
_should_save = False | |||||
if self._topn < self.save_topk: | |||||
self._topk_model[_saved_name] = _metric_value | |||||
self._topn += 1 | |||||
_should_save = True | |||||
else: | |||||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||||
key=lambda x: self._topk_model[x]) | |||||
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ | |||||
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): | |||||
self._topk_model[_saved_name] = _metric_value | |||||
_should_save = True | |||||
self._topk_model.pop(_least_valuable_model) | |||||
synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model)) | |||||
assert len(self._topk_model) == self.save_topk == self._topn | |||||
if _should_save: | |||||
self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name) | |||||
def _save_last(self, trainer: "fastNLP.Trainer"): | |||||
if self.save_last: | |||||
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last") | |||||
def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None, | |||||
substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None): | |||||
# 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者 | |||||
# epoch-batch-monitor-monitor_value; | |||||
if substitute_folder is None: | |||||
folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric)) | |||||
else: | |||||
folder = self.log_filepath.joinpath(substitute_folder) | |||||
synchronize_mkdir(folder) | |||||
# 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数; | |||||
if substitute_fn is not None: | |||||
_fn = substitute_fn | |||||
else: | |||||
if self.is_trainer_checkpoint: | |||||
_fn = trainer.save | |||||
else: | |||||
_fn = trainer.save_model | |||||
_fn( | |||||
folder=folder, | |||||
only_state_dict=self.only_state_dict, | |||||
model_save_fn=self.model_save_fn, | |||||
**self.kwargs | |||||
) | |||||
def _get_validate_metric(self, res: Dict): | |||||
""" | |||||
该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); | |||||
如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; | |||||
:param res: | |||||
:return: | |||||
""" | |||||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | |||||
self._real_monitor = use_monitor | |||||
return value | |||||
def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False, | |||||
metric: Optional[Union[int, float]] = None) -> str: | |||||
""" | |||||
获取当前保存模型的真正地名字; | |||||
metric 参数仅当 mode 为 'topk' 时起作用; | |||||
""" | |||||
cur_epoch_idx = trainer.cur_epoch_idx | |||||
global_forward_batches = trainer.global_forward_batches | |||||
_other = "" | |||||
if topk: | |||||
_other = f"_{metric}" | |||||
return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}" | |||||
@property | |||||
def callback_name(self): | |||||
""" | |||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; | |||||
:return: | |||||
""" | |||||
return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}" | |||||
@@ -0,0 +1,129 @@ | |||||
__all__ = [ | |||||
'LoadBestModelCallback' | |||||
] | |||||
import os | |||||
from typing import Optional, Callable | |||||
from .callback import Callback | |||||
from .utils import _get_monitor_value | |||||
from io import BytesIO | |||||
import shutil | |||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import all_rank_call | |||||
class LoadBestModelCallback(Callback): | |||||
def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True, | |||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | |||||
model_load_fn:Optional[Callable] = None, | |||||
delete_after_train:bool = True): | |||||
""" | |||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | |||||
:param str monitor: 监控的 metric 值。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | |||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | |||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | |||||
:param only_state_dict: 是否只保存模型的参数。当 model_save_fn 不为空时,该值无效。 | |||||
:param model_save_fn: 保存 model 的函数,与 model_load_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的保存。 | |||||
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, | |||||
请在函数内完成对模型的加载。 | |||||
:param delete_after_train: 在加载了最佳模型之后是否删掉模型。 | |||||
""" | |||||
if model_load_fn is not None: | |||||
assert callable(model_load_fn), "`model_load_fn` must be a callable object." | |||||
assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." | |||||
if model_save_fn is not None: | |||||
assert callable(model_save_fn), "`model_save_fn` must be a callable object." | |||||
assert model_load_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." | |||||
if model_save_fn is not None: | |||||
assert save_folder is not None, "When passing `model_save_fn`, `save_folder` must be provided." | |||||
if save_folder is not None: | |||||
if os.path.exists(save_folder): | |||||
assert os.path.isdir(save_folder), f"`save_folder` must be a directory." | |||||
else: | |||||
os.makedirs(save_folder, exist_ok=True) | |||||
save_folder = os.path.join(save_folder, os.environ.get(FASTNLP_LAUNCH_TIME)) | |||||
self.real_save_folder = os.path.join(save_folder, 'best_so_far') | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
os.makedirs(self.real_save_folder) | |||||
else: # 创建出一个 stringio | |||||
self.real_save_folder = None | |||||
self.buffer = BytesIO() | |||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
self.save_folder = save_folder | |||||
self.only_state_dict = only_state_dict | |||||
self.model_save_fn = model_save_fn | |||||
self.model_load_fn = model_load_fn | |||||
self.delete_after_after = delete_after_train | |||||
self._real_monitor = None | |||||
self.monitor_value = float('-inf') if larger_better else float('inf') | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | |||||
# 如果需要保存,但是又是不是 fastNLP 拉起的, 需要同步一下 folder | |||||
try: | |||||
self.real_save_folder = driver.broadcast_object(self.real_save_folder, src=0, group=None) | |||||
logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | |||||
except NotImplementedError: | |||||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | |||||
f"save best model when launch using script.") | |||||
def on_validate_end(self, trainer, results): | |||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (monitor_value < self.monitor_value and self.larger_better is False) or \ | |||||
(monitor_value > self.monitor_value and self.larger_better): | |||||
self.monitor_value = monitor_value | |||||
if self.real_save_folder: | |||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_save_fn=self.model_save_fn) | |||||
else: | |||||
self.buffer.seek(0) | |||||
with all_rank_call(): | |||||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
def on_train_end(self, trainer): | |||||
logger.info(f"Loading best model with {self._real_monitor}: {self.monitor_value}...") | |||||
if self.real_save_folder: | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_load_fn=self.model_load_fn) | |||||
else: | |||||
self.buffer.seek(0) | |||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
if self.delete_after_after: | |||||
if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
# 只需要 rank 0 执行删除。 | |||||
logger.info(f"Deleting {self.real_save_folder}...") | |||||
shutil.rmtree(self.real_save_folder) | |||||
try: | |||||
# 如果是 emtpy 的,就会被删除掉 | |||||
os.rmdir(self.save_folder) | |||||
except: | |||||
pass | |||||
elif hasattr(self, 'buffer'): | |||||
self.buffer.close() | |||||
del self.buffer | |||||
def on_exception(self, trainer, exception): | |||||
if self.delete_after_after: | |||||
if self.real_save_folder: # 这里,谁处异常,谁删除 | |||||
logger.info(f"Deleting {self.real_save_folder}...") | |||||
shutil.rmtree(self.real_save_folder) | |||||
try: | |||||
# 如果是 emtpy 的,就会被删除掉 | |||||
os.rmdir(self.save_folder) | |||||
except: | |||||
pass | |||||
elif hasattr(self, 'buffer'): | |||||
self.buffer.close() | |||||
del self.buffer |
@@ -0,0 +1,27 @@ | |||||
from .callback import Callback | |||||
__all__ = [ | |||||
'LRSchedCallback' | |||||
] | |||||
class LRSchedCallback(Callback): | |||||
def __init__(self, scheduler, step_on:str='batch'): | |||||
""" | |||||
根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。 | |||||
:param scheduler: 实现了 step() 函数的对象 | |||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数 | |||||
""" | |||||
assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \ | |||||
"step function." | |||||
self.scheduler = scheduler | |||||
self.step_on = 0 if step_on == 'batch' else 1 | |||||
def on_train_batch_end(self, trainer): | |||||
if self.step_on == 0: | |||||
self.scheduler.step() | |||||
def on_train_epoch_end(self, trainer): | |||||
if self.step_on == 1: | |||||
self.scheduler.step() |
@@ -0,0 +1,207 @@ | |||||
import json | |||||
import sys | |||||
__all__ = [ | |||||
'choose_progress_callback', | |||||
'ProgressCallback', | |||||
'RichCallback' | |||||
] | |||||
from .callback import Callback | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from fastNLP.core.utils import f_rich_progress | |||||
from fastNLP.core.log import logger | |||||
def choose_progress_callback(progress_bar:str): | |||||
if progress_bar == 'auto': | |||||
if (sys.stdin and sys.stdin.isatty()): | |||||
progress_bar = 'rich' | |||||
else: | |||||
progress_bar = 'raw' | |||||
if progress_bar == 'rich': | |||||
return RichCallback() | |||||
elif progress_bar == 'raw': | |||||
return RawTextCallback() | |||||
else: | |||||
return None | |||||
class ProgressCallback(Callback): | |||||
def on_train_end(self, trainer): | |||||
f_rich_progress.stop() | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | |||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=sanity_check_res) | |||||
class RichCallback(ProgressCallback): | |||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||||
format_json=True): | |||||
""" | |||||
:param print_every: 多少个 batch 更新一次显示。 | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | |||||
:param format_json: 是否format json再打印 | |||||
""" | |||||
super().__init__() | |||||
self.print_every = print_every | |||||
self.progress_bar = f_rich_progress | |||||
self.task2id = {} | |||||
self.loss = 0 | |||||
self.loss_round_ndigit = loss_round_ndigit | |||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = monitor | |||||
self.format_json = format_json | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
if not self.progress_bar.disable: | |||||
self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) | |||||
def on_train_begin(self, trainer): | |||||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | |||||
completed=trainer.global_forward_batches/(trainer.total_batches+1e-6)) | |||||
def on_train_epoch_begin(self, trainer): | |||||
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | |||||
if 'batch' in self.task2id: | |||||
self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) | |||||
else: | |||||
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', total=trainer.num_batches_per_epoch) | |||||
def on_train_epoch_end(self, trainer): | |||||
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | |||||
advance=None, completed=trainer.cur_epoch_idx, refresh=True) | |||||
def on_train_end(self, trainer): | |||||
self.clear_tasks() | |||||
def on_before_backward(self, trainer, outputs): | |||||
loss = trainer.extract_loss_from_outputs(outputs) | |||||
loss = trainer.driver.tensor_to_numeric(loss, reduce='sum') | |||||
self.loss += loss | |||||
def on_train_batch_end(self, trainer): | |||||
if trainer.global_forward_batches % self.print_every == 0: | |||||
loss = self.loss/self.print_every | |||||
self.loss = 0 | |||||
self.progress_bar.update(self.task2id['batch'], description=f'Batch:{trainer.batch_idx_in_epoch}', | |||||
advance=self.print_every, | |||||
post_desc=f'Loss:{round(loss, self.loss_round_ndigit)}', refresh=True) | |||||
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | |||||
advance=self.epoch_bar_update_advance, refresh=True) | |||||
def on_validate_end(self, trainer, results): | |||||
if len(results)==0: | |||||
return | |||||
rule_style = '' | |||||
text_style = '' | |||||
characters = '-' | |||||
if self.monitor is not None: | |||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
if abs(self.monitor_value) != float('inf'): | |||||
rule_style = 'spring_green3' | |||||
text_style = '[bold]' | |||||
characters = '+' | |||||
self.monitor_value = monitor_value | |||||
self.progress_bar.print() | |||||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | |||||
f"Batch:{trainer.batch_idx_in_epoch}", | |||||
style=rule_style, characters=characters) | |||||
if self.format_json: | |||||
self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
else: | |||||
self.progress_bar.print(results) | |||||
def on_exception(self, trainer, exception): | |||||
self.clear_tasks() | |||||
def clear_tasks(self): | |||||
for key, taskid in self.task2id.items(): | |||||
self.progress_bar.destroy_task(taskid) | |||||
self.progress_bar.stop() | |||||
self.task2id = {} | |||||
self.loss = 0 | |||||
class RawTextCallback(ProgressCallback): | |||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||||
format_json=True): | |||||
""" | |||||
通过向命令行打印进度的方式显示 | |||||
:param print_every: 多少个 batch 更新一次显示。 | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | |||||
:param format_json: 是否format json再打印 | |||||
""" | |||||
super().__init__() | |||||
self.print_every = print_every | |||||
self.task2id = {} | |||||
self.loss = 0 | |||||
self.loss_round_ndigit = loss_round_ndigit | |||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = monitor | |||||
self.format_json = format_json | |||||
self.num_signs = 10 | |||||
def on_train_epoch_begin(self, trainer): | |||||
logger.info('\n' + "*"*self.num_signs + f'Epoch:{trainer.cur_epoch_idx} starts' + '*'*self.num_signs) | |||||
def on_before_backward(self, trainer, outputs): | |||||
loss = trainer.extract_loss_from_outputs(outputs) | |||||
loss = trainer.driver.tensor_to_numeric(loss, reduce='sum') | |||||
self.loss += loss | |||||
def on_train_batch_end(self, trainer): | |||||
if trainer.global_forward_batches % self.print_every == 0: | |||||
loss = self.loss/self.print_every | |||||
self.loss = 0 | |||||
text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \ | |||||
f'loss:{round(loss, self.loss_round_ndigit)}, ' \ | |||||
f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' | |||||
logger.info(text) | |||||
def on_validate_end(self, trainer, results): | |||||
if len(results)==0: | |||||
return | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||||
text = '' | |||||
if self.monitor is not None: | |||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
if abs(self.monitor_value) != float('inf'): | |||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||||
self.monitor_value = monitor_value | |||||
if len(text) == 0: | |||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||||
logger.info(text) | |||||
if self.format_json: | |||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
else: | |||||
logger.info(results) |
@@ -0,0 +1,41 @@ | |||||
from typing import Optional | |||||
from fastNLP.core.log.logger import logger | |||||
from difflib import SequenceMatcher | |||||
def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float): | |||||
""" | |||||
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | |||||
匹配。 | |||||
:param monitor: | |||||
:param real_monitor: | |||||
:param res: | |||||
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value | |||||
""" | |||||
if len(res)==0: | |||||
return monitor, 0 | |||||
if monitor in res: | |||||
return monitor, res[monitor] | |||||
pairs = [] | |||||
for idx, (key, value) in enumerate(res.items()): | |||||
match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor)) | |||||
pairs.append((key, value, match.size, idx)) | |||||
pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) | |||||
key, value, match_size = pairs[0][:3] | |||||
if real_monitor is not None and real_monitor in res and real_monitor != key: | |||||
# 如果 real_monitor 比新找的更长就继续用之前的。 | |||||
match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor)) | |||||
if match.size > match_size: | |||||
return real_monitor, res[real_monitor] | |||||
logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||||
f"we use the `{key}` as the monitor.") | |||||
real_monitor = key | |||||
return real_monitor, value | |||||
@@ -0,0 +1,15 @@ | |||||
__all__ = [ | |||||
'Loop', | |||||
'EvaluateBatchLoop', | |||||
'TrainBatchLoop', | |||||
'State', | |||||
'TrainerState', | |||||
'Evaluator', | |||||
'Trainer', | |||||
] | |||||
from .loops import Loop, EvaluateBatchLoop, TrainBatchLoop | |||||
from .utils import State, TrainerState | |||||
from .evaluator import Evaluator | |||||
from .trainer import Trainer | |||||
@@ -0,0 +1,434 @@ | |||||
from typing import Union, List, Optional, Dict, Callable | |||||
from functools import partial | |||||
from dataclasses import is_dataclass | |||||
import sys | |||||
__all__ = [ | |||||
'Evaluator' | |||||
] | |||||
from fastNLP.core.drivers import Driver | |||||
from fastNLP.core.drivers.utils import choose_driver | |||||
from .loops import Loop, EvaluateBatchLoop | |||||
from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \ | |||||
match_and_substitute_params, f_rich_progress | |||||
from fastNLP.core.metrics import Metric | |||||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | |||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | |||||
from fastNLP.core.log import logger | |||||
class Evaluator: | |||||
""" | |||||
1. 我们目前不直接提供每一个 metric 对应一个或者特殊的多个 dataloader 的功能,默认就是所有 metric 处理所有 dataloader,如果用户有这种 | |||||
需求,请使用多个 Tester 进行操作; | |||||
2. Trainer 的 validate dataloader 只允许传进去一个,而 Tester 则可以多个;因为 Trainer 涉及到保存 topk 模型的逻辑,而 Tester | |||||
则只需要给出评测的结果即可; | |||||
""" | |||||
driver: Driver | |||||
_evaluate_batch_loop: Loop | |||||
def __init__( | |||||
self, | |||||
model, | |||||
dataloaders, | |||||
metrics: Optional[Union[Dict, Metric]] = None, | |||||
driver: Union[str, Driver] = 'single', | |||||
device: Optional[Union[int, List[int], str]] = None, | |||||
batch_step_fn: Optional[callable] = None, | |||||
mode: str = "validate", | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||||
fp16: Optional[bool] = False, | |||||
verbose: int = 1, | |||||
**kwargs | |||||
): | |||||
""" | |||||
:param dataloaders: | |||||
:param model: | |||||
:param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 | |||||
metric ,torchmetrics,allennlpmetrics等。 | |||||
:param driver: 使用 driver 。 | |||||
:param device: 使用的设备。 | |||||
:param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 | |||||
DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 | |||||
batch_step_fn 函数。 | |||||
:param mode: 可选 ["validate", "test"], 当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 | |||||
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, | |||||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | |||||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | |||||
:param fp16: 是否使用 fp16 。 | |||||
:param verbose: 是否打印 evaluate 的结果。 | |||||
:param kwargs: | |||||
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | |||||
与 batch normalization 将会关闭。默认为True。 | |||||
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | |||||
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | |||||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | |||||
不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定, | |||||
metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。 | |||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。如果为True,将使得每个进程上 | |||||
的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。请确保使用的 metrics 支持自动分布式累积。 | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
progress_bar: evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 | |||||
到当前terminal为交互型则使用 rich,否则使用 raw。 | |||||
""" | |||||
self.model = model | |||||
self.metrics = metrics | |||||
self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs) | |||||
self.device = device | |||||
self.verbose = verbose | |||||
assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \ | |||||
"two parameters." | |||||
self.batch_step_fn = batch_step_fn | |||||
self.mode = mode | |||||
assert mode in {'validate', 'test'}, "Parameter `mode` should only be 'validate' or 'test'." | |||||
self.input_mapping = input_mapping | |||||
self.output_mapping = output_mapping | |||||
if not isinstance(dataloaders, dict): | |||||
dataloaders = {None: dataloaders} | |||||
if mode == "validate": | |||||
self._evaluate_step = self.driver.validate_step | |||||
self.driver.set_dataloader(validate_dataloaders=dataloaders) | |||||
else: | |||||
self._evaluate_step = self.driver.test_step | |||||
self.driver.set_dataloader(test_dataloaders=dataloaders) | |||||
self.mode = mode | |||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) | |||||
self.separator = kwargs.get('separator', '#') | |||||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) | |||||
use_dist_sampler = kwargs.get("use_dist_sampler", False) # 如果是 Evaluator 自身的默认值的话,应当为 False; | |||||
if use_dist_sampler: | |||||
self._dist_sampler = "unrepeatdist" | |||||
else: | |||||
self._dist_sampler = None | |||||
self._metric_wrapper = None | |||||
_ = self.metrics_wrapper # 触发检查 | |||||
assert self.driver.has_validate_dataloaders() or self.driver.has_test_dataloaders() | |||||
self.driver.setup() | |||||
self.driver.barrier() | |||||
self.dataloaders = {} | |||||
for name, dl in dataloaders.items(): # 替换为正确的 sampler | |||||
dl = self.driver.replace_sampler( | |||||
dataloader=dl, | |||||
dist_sampler=self._dist_sampler, | |||||
reproducible=False | |||||
) | |||||
self.dataloaders[name] = dl | |||||
self.progress_bar = kwargs.get('progress_bar', 'auto') | |||||
if self.progress_bar == 'auto': | |||||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||||
self.driver.barrier() | |||||
def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | |||||
""" | |||||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name | |||||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||||
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||||
:return: | |||||
""" | |||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | |||||
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." | |||||
self.driver.check_evaluator_mode(self.mode) | |||||
if self.mode == 'validate': | |||||
assert self.driver.has_validate_dataloaders() | |||||
else: | |||||
assert self.driver.has_test_dataloaders() | |||||
metric_results = {} | |||||
self.reset() | |||||
evaluate_context = self.driver.get_evaluate_context() | |||||
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') | |||||
with evaluate_context(): | |||||
try: | |||||
for dataloader_name, dataloader in self.dataloaders.items(): | |||||
self.driver.barrier() | |||||
if num_eval_batch_per_dl != -1: | |||||
dataloader = _TruncatedDataLoader(dataloader, num_eval_batch_per_dl) | |||||
self.driver.set_sampler_epoch(dataloader, -1) | |||||
self.start_progress_bar(total=len(dataloader), dataloader_name=dataloader_name) | |||||
self.cur_dataloader_name = dataloader_name | |||||
results = self.evaluate_batch_loop.run(self, dataloader) | |||||
self.remove_progress_bar(dataloader_name) | |||||
metric_results.update(results) | |||||
self.reset() | |||||
self.driver.barrier() | |||||
except BaseException as e: | |||||
raise e | |||||
finally: | |||||
self.finally_progress_bar() | |||||
self.driver.set_model_mode(mode='train') | |||||
if self.verbose: | |||||
if self.progress_bar == 'rich': | |||||
f_rich_progress.print(metric_results) | |||||
else: | |||||
logger.info(metric_results) | |||||
return metric_results | |||||
def start_progress_bar(self, total:int, dataloader_name): | |||||
if self.progress_bar == 'rich': | |||||
if dataloader_name is None: | |||||
desc = f'Eval. Batch:0' | |||||
else: | |||||
desc = f'Eval. on {dataloader_name} Batch:0' | |||||
self._rich_task_id = f_rich_progress.add_task(description=desc, total=total) | |||||
elif self.progress_bar == 'raw': | |||||
desc = 'Evaluation starts' | |||||
if dataloader_name is not None: | |||||
desc += f' on {dataloader_name}' | |||||
logger.info('\n' + "*" * 10 + desc + '*' * 10) | |||||
def update_progress_bar(self, batch_idx, dataloader_name, **kwargs): | |||||
if dataloader_name is None: | |||||
desc = f'Eval. Batch:{batch_idx}' | |||||
else: | |||||
desc = f'Eval. on {dataloader_name} Batch:{batch_idx}' | |||||
if self.progress_bar == 'rich': | |||||
assert hasattr(self, '_rich_task_id'), "You must first call `start_progress_bar()` before calling " \ | |||||
"update_progress_bar()" | |||||
f_rich_progress.update(self._rich_task_id, description=desc, post_desc=kwargs.get('post_desc', ''), | |||||
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), | |||||
visible=kwargs.get('visible', True)) | |||||
elif self.progress_bar == 'raw': | |||||
if self.verbose>1: | |||||
logger.info(desc) | |||||
def remove_progress_bar(self, dataloader_name): | |||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||||
f_rich_progress.destroy_task(self._rich_task_id) | |||||
delattr(self, '_rich_task_id') | |||||
elif self.progress_bar == 'raw': | |||||
desc = 'Evaluation ends' | |||||
if dataloader_name is not None: | |||||
desc += f' on {dataloader_name}' | |||||
logger.info("*" * 10 + desc + '*' * 10 + '\n') | |||||
def finally_progress_bar(self): | |||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||||
f_rich_progress.destroy_task(self._rich_task_id) | |||||
delattr(self, '_rich_task_id') | |||||
@property | |||||
def eval_dataloaders(self): | |||||
if self.mode == "validate": | |||||
return self.driver.validate_dataloaders | |||||
else: | |||||
return self.driver.test_dataloaders | |||||
@property | |||||
def evaluate_batch_loop(self): | |||||
return self._evaluate_batch_loop | |||||
@evaluate_batch_loop.setter | |||||
def evaluate_batch_loop(self, loop: Loop): | |||||
if self.batch_step_fn is not None: | |||||
logger.warning("`batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||||
"when the `evaluate_batch_loop` is also customized.") | |||||
self._evaluate_batch_loop = loop | |||||
def reset(self): | |||||
""" | |||||
调用所有 metric 的 reset() 方法,清除累积的状态。 | |||||
Returns: | |||||
""" | |||||
self.metrics_wrapper.reset() | |||||
def update(self, *args, **kwargs): | |||||
""" | |||||
调用所有metric的 update 方法,对当前 batch 的结果进行累积,会根据相应 metric 的参数列表进行匹配传参。 | |||||
:param args: | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
self.metrics_wrapper.update(*args, **kwargs) | |||||
def get_dataloader_metric(self, dataloader_name:Optional[str]='') -> Dict: | |||||
""" | |||||
获取当前dataloader的metric结果 | |||||
:param str dataloader_name: 当前dataloader的名字 | |||||
:return: | |||||
""" | |||||
return self.metrics_wrapper.get_metric(dataloader_name=dataloader_name, separator=self.separator) | |||||
@property | |||||
def metrics_wrapper(self): | |||||
""" | |||||
由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 batch_step_fn )中使用,同时也为了支持 | |||||
不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper | |||||
进行操作。 | |||||
Returns: | |||||
""" | |||||
if self._metric_wrapper is None: | |||||
self._metric_wrapper = _MetricsWrapper(self.metrics, evaluator=self) | |||||
return self._metric_wrapper | |||||
def evaluate_step(self, batch): | |||||
""" | |||||
将 batch 传递到model中进行处理,根据当前 mode 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 | |||||
返回。 | |||||
:param batch: | |||||
:return: | |||||
""" | |||||
outputs = self._evaluate_step(batch) | |||||
outputs = match_and_substitute_params(self.output_mapping, outputs) | |||||
return outputs | |||||
@property | |||||
def metrics(self): | |||||
""" | |||||
返回用户传入的 metrics 对象。 | |||||
:return: | |||||
""" | |||||
return self._metrics | |||||
@metrics.setter | |||||
def metrics(self, metrics): | |||||
self._metrics = metrics | |||||
def move_data_to_device(self, batch): | |||||
return self.driver.move_data_to_device(batch) | |||||
class _MetricsWrapper: | |||||
""" | |||||
注意 metrics 的输入只支持:Dict[str, Metric]; | |||||
并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。 | |||||
""" | |||||
def __init__(self, metrics, evaluator): | |||||
self.evaluator = evaluator | |||||
self._metrics = [] | |||||
self._metric_names = [] | |||||
if metrics is not None: | |||||
if not isinstance(metrics, Dict): | |||||
raise TypeError("Parameter `metrics` can only be `Dict` type.") | |||||
for metric_name, metric in metrics.items(): | |||||
# 因为 torchmetrics 是一个 nn.Module,因此我们需要先将其移到对应的机器上; | |||||
if _is_torchmetrics_metric(metric): | |||||
# torchmetrics 是默认自动开启了多卡的 | |||||
evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | |||||
elif isinstance(metric, Metric): | |||||
if evaluator._dist_sampler is not None and evaluator.driver.is_distributed() \ | |||||
and metric.aggregate_when_get_metric is False: | |||||
logger.warning("You have replace the sampler as distributed sampler when evaluation, but your " | |||||
f"metric:{metric_name}' `aggregate_when_get_metric` is False.") | |||||
if evaluator._dist_sampler is None and evaluator.driver.is_distributed() \ | |||||
and metric.aggregate_when_get_metric is True: | |||||
pass # 这种情况无所谓,因为 | |||||
metric.to(evaluator.driver.data_device) | |||||
self._metric_names.append(metric_name) | |||||
self._metrics.append(metric) | |||||
def update(self, batch, outputs): | |||||
if is_dataclass(outputs): | |||||
outputs = dataclass_to_dict(outputs) | |||||
for metric in self._metrics: | |||||
if not isinstance(batch, dict): | |||||
raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" | |||||
f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") | |||||
if not isinstance(outputs, dict): | |||||
raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" | |||||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||||
if isinstance(metric, Metric): | |||||
auto_param_call(metric.update, batch, outputs) | |||||
elif _is_torchmetrics_metric(metric): | |||||
auto_param_call(metric.update, batch, outputs) | |||||
elif _is_allennlp_metric(metric): | |||||
auto_param_call(metric.__call__, batch, outputs) | |||||
elif _is_paddle_metric(metric): | |||||
res = auto_param_call(metric.compute, batch, outputs) | |||||
metric.update(res) | |||||
def reset(self): | |||||
for metric in self._metrics: | |||||
if _is_allennlp_metric(metric): | |||||
metric.get_metric(reset=True) | |||||
elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): | |||||
metric.reset() | |||||
def get_metric(self, dataloader_name:str, separator:str) -> Dict: | |||||
""" | |||||
将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 | |||||
indicator_name{separator}metric_name{separator}dataloader_name | |||||
例如: f1#F1PreRec#dev | |||||
:param dataloader_name: 当前metric对应的dataloader的名字。若为空,则不显示在最终的key上面。 | |||||
:param separator: 用于间隔不同称呼。 | |||||
:return: 返回一个一级结构的字典,其中 key 为区别一个 metric 的名字,value 为该 metric 的值; | |||||
""" | |||||
results = {} | |||||
for metric_name, metric in zip(self._metric_names, self._metrics): | |||||
if isinstance(metric, Metric): | |||||
_results = metric.get_metric() | |||||
elif _is_allennlp_metric(metric): | |||||
_results = metric.get_metric(reset=False) | |||||
elif _is_torchmetrics_metric(metric): | |||||
_results = metric.compute() | |||||
# 我们规定了 evaluator 中的 metrics 的输入只能是一个 dict,这样如果 metric 是一个 torchmetrics 时,如果 evaluator | |||||
# 没有传入 func_post_proc,那么我们就自动使用该 metric 的 metric name 当做其的 indicator name 将其自动转换成一个字典; | |||||
elif _is_paddle_metric(metric): | |||||
_results = metric.accumulate() | |||||
if not isinstance(_results, Dict): | |||||
name = _get_metric_res_name(dataloader_name, metric_name, '', separator) | |||||
results[name] = _results | |||||
else: | |||||
for indicator_name, value in _results.items(): | |||||
name = _get_metric_res_name(dataloader_name, metric_name, indicator_name, separator) | |||||
results[name] = value | |||||
return results | |||||
def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indicator_name: str, separator='#') -> str: | |||||
""" | |||||
:param dataloader_name: dataloder的名字 | |||||
:param metric_name: metric的名字 | |||||
:param indicator_name: metric中的各项metric名称,例如f, precision, recall | |||||
:param separator: 用以间隔不同对象的间隔符 | |||||
:return: | |||||
""" | |||||
names = [] | |||||
if indicator_name: | |||||
names.append(indicator_name) | |||||
if metric_name: | |||||
names.append(metric_name) | |||||
if dataloader_name: | |||||
names.append(dataloader_name) | |||||
if len(names) == 0: | |||||
raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") | |||||
return separator.join(names) |
@@ -0,0 +1,9 @@ | |||||
__all__ = [ | |||||
'EvaluateBatchLoop', | |||||
'Loop', | |||||
'TrainBatchLoop' | |||||
] | |||||
from .loop import Loop | |||||
from .evaluate_batch_loop import EvaluateBatchLoop | |||||
from .train_batch_loop import TrainBatchLoop |
@@ -0,0 +1,50 @@ | |||||
from typing import Optional, Callable, Dict | |||||
__all__ = [ | |||||
'EvaluateBatchLoop' | |||||
] | |||||
from .loop import Loop | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.utils import match_and_substitute_params | |||||
class EvaluateBatchLoop(Loop): | |||||
def __init__(self, batch_step_fn:Optional[Callable]=None): | |||||
if batch_step_fn is not None: | |||||
self.batch_step_fn = batch_step_fn | |||||
def run(self, evaluator, dataloader) -> Dict: | |||||
""" | |||||
需要返回在传入的 dataloader 中的 evaluation 结果 | |||||
:param evaluator: Evaluator 对象 | |||||
:param dataloader: 当前需要进行 evaluate 的dataloader | |||||
:return: | |||||
""" | |||||
iterator = iter(dataloader) | |||||
batch_idx = 0 | |||||
while True: | |||||
try: | |||||
batch = next(iterator) | |||||
batch = match_and_substitute_params(evaluator.input_mapping, batch) | |||||
batch = evaluator.move_data_to_device(batch) | |||||
except StopIteration: | |||||
break | |||||
except BaseException as e: | |||||
if callable(getattr(dataloader, 'get_batch_indices', None)): | |||||
indices = dataloader.get_batch_indices() | |||||
logger.debug(f"The following exception happens when running on samples: {indices}") | |||||
raise e | |||||
self.batch_step_fn(evaluator, batch) | |||||
batch_idx += 1 | |||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | |||||
# 获取metric结果。返回的dict内容示例为{'f1#F1Metric#dl1': 0.93, 'pre#F1Metric#dl1': 0.95, ...} | |||||
results = evaluator.get_dataloader_metric(dataloader_name=evaluator.cur_dataloader_name) | |||||
return results | |||||
@staticmethod | |||||
def batch_step_fn(evaluator, batch): | |||||
outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | |||||
evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 |
@@ -0,0 +1,17 @@ | |||||
__all__ = [ | |||||
'Loop' | |||||
] | |||||
class Loop: | |||||
def run(self, *args, **kwargs): | |||||
""" | |||||
该循环的主要运行过程; | |||||
""" | |||||
def step(self, *args, **kwargs): | |||||
""" | |||||
该循环运行过程中的一步; | |||||
""" |
@@ -0,0 +1,56 @@ | |||||
__all__ = [ | |||||
'TrainBatchLoop' | |||||
] | |||||
from typing import Optional, Callable | |||||
from .loop import Loop | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.utils import match_and_substitute_params | |||||
class TrainBatchLoop(Loop): | |||||
def __init__(self, batch_step_fn: Optional[Callable] = None): | |||||
if batch_step_fn is not None: | |||||
self.batch_step_fn = batch_step_fn | |||||
def run(self, trainer, dataloader): | |||||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | |||||
else lambda *args, **kwargs: None | |||||
dataloader = iter(dataloader) | |||||
indices = None | |||||
while True: | |||||
try: | |||||
trainer.on_fetch_data_begin() | |||||
batch = next(dataloader) | |||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||||
indices = get_batch_indices() | |||||
batch = trainer.move_data_to_device(batch) | |||||
trainer.on_fetch_data_end() | |||||
except StopIteration: | |||||
break | |||||
except BaseException as e: # TODO 把这里的信息写入进去 | |||||
if indices: | |||||
logger.debug(f"The following exception happens when running on samples: {indices}") | |||||
raise e | |||||
trainer.on_train_batch_begin(batch, indices) | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | |||||
trainer.batch_idx_in_epoch += 1 | |||||
trainer.check_batch_step_fn() | |||||
trainer.on_train_batch_end() | |||||
trainer.step_validate() | |||||
trainer.batch_idx_in_epoch = 0 | |||||
@staticmethod | |||||
def batch_step_fn(trainer, batch): | |||||
outputs = trainer.train_step(batch) | |||||
trainer.backward(outputs) | |||||
trainer.step() | |||||
trainer.zero_grad() | |||||
@@ -0,0 +1,806 @@ | |||||
from typing import Union, Optional, List, Callable, Dict, Sequence, BinaryIO, IO | |||||
from functools import partial | |||||
from collections import defaultdict | |||||
import copy | |||||
from contextlib import contextmanager | |||||
from dataclasses import is_dataclass | |||||
import os | |||||
from pathlib import Path | |||||
import io | |||||
__all__ = [ | |||||
'Trainer', | |||||
] | |||||
from .loops import Loop, TrainBatchLoop | |||||
from .utils import State, TrainerState | |||||
from .evaluator import Evaluator | |||||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | |||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter | |||||
from fastNLP.core.callbacks.callback import _CallbackWrapper | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
from fastNLP.core.drivers import Driver | |||||
from fastNLP.core.drivers.utils import choose_driver | |||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||||
from fastNLP.envs import rank_zero_call | |||||
from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||||
class Trainer(TrainerEventTrigger): | |||||
_custom_callbacks: dict = defaultdict(list) | |||||
def __init__( | |||||
self, | |||||
model, | |||||
driver, | |||||
train_dataloader, | |||||
optimizers, | |||||
device: Optional[Union[int, List[int], str]] = "cpu", | |||||
n_epochs: int = 20, | |||||
validate_dataloaders=None, | |||||
batch_step_fn: Optional[Callable] = None, | |||||
validate_batch_step_fn: Optional[Callable] = None, | |||||
validate_mode: str = "validate", | |||||
callbacks: Union[List[Callback], Callback, None] = None, | |||||
metrics: Optional[dict] = None, | |||||
validate_every: Optional[Union[int, callable]] = -1, | |||||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||||
accumulation_steps: int = 1, | |||||
fp16: bool = False, | |||||
marker: Optional[str] = None, | |||||
**kwargs | |||||
): | |||||
r""" | |||||
`Trainer` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 | |||||
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需 | |||||
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; | |||||
:param model: 训练所需要的模型,目前支持 pytorch; | |||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle | |||||
等国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 | |||||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | |||||
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | |||||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你 | |||||
可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 | |||||
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 | |||||
自己构造 DDP 的多进程场景); | |||||
device 的可选输入如下所示: | |||||
1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中; | |||||
2. torch.device:将模型装载到torch.device上; | |||||
3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`; | |||||
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; | |||||
5. None: 为None则不对模型进行任何处理; | |||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; | |||||
:param validate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | |||||
为 None; | |||||
:param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 | |||||
`batch`;默认为 None; | |||||
:param validate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 | |||||
两个参数必须为 `evaluator` 和 `batch`;默认为 None; | |||||
:param validate_mode: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,其值应当为以下之一:["validate", "test"]; | |||||
默认为 "validate";当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 | |||||
寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, | |||||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||||
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; | |||||
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; | |||||
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; | |||||
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象 | |||||
中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed | |||||
表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。 | |||||
例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch | |||||
之后,每个 epoch 结束进行一次 validate 。 | |||||
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 | |||||
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 | |||||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 | |||||
类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; | |||||
注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时); | |||||
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | |||||
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | |||||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | |||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | |||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | |||||
:param fp16: 是否开启混合精度训练;默认为 False; | |||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | |||||
:param kwargs: 一些其它的可能需要的参数; | |||||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; | |||||
注意如果 model_device 为 None,那么 data_device 不会起作用; | |||||
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数; | |||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
use_dist_sampler: 表示在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过 | |||||
callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。 | |||||
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 | |||||
""" | |||||
# TODO 是不是可以加一个参数让用户现在关掉参数匹配。 | |||||
self.marker = marker | |||||
self.model = model | |||||
self.driver_name = driver | |||||
self.device = device | |||||
self.fp16 = fp16 | |||||
self.input_mapping = input_mapping | |||||
self.output_mapping = output_mapping | |||||
assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \ | |||||
"two parameters." | |||||
self.batch_step_fn = batch_step_fn | |||||
if batch_step_fn is not None: | |||||
self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) | |||||
else: | |||||
self.check_batch_step_fn = lambda *args, **kwargs: ... | |||||
# 该变量表示是否检测过 `train_batch_loop`,主要用于当用户通过属性替换的方式使用自己定制的 `train_batch_loop` 时,我们需要检测 | |||||
# 用户是否正确地调用了 callback 函数以及是否正确地更新了 `trainer_state` 的状态; | |||||
# 我们将其默认值置为 True,这表示默认的 `train_batch_loop` 已经检测过,不需要再进行检测; | |||||
# 我们只会在第一个 epoch 运行完后进行检测,之后的 epoch 不会再进行检测; | |||||
self.has_checked_train_batch_loop = True | |||||
self._train_batch_loop = TrainBatchLoop(batch_step_fn=batch_step_fn) | |||||
if not isinstance(accumulation_steps, int): | |||||
raise ValueError("Parameter `accumulation_steps` can only be `int` type.") | |||||
elif accumulation_steps < 0: | |||||
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.") | |||||
self.accumulation_steps = accumulation_steps | |||||
self.driver = choose_driver( | |||||
model=model, | |||||
driver=driver, | |||||
train_dataloader=train_dataloader, | |||||
optimizers=optimizers, | |||||
device=device, | |||||
n_epochs=n_epochs, | |||||
validate_dataloaders=validate_dataloaders, | |||||
batch_step_fn=batch_step_fn, | |||||
validate_batch_step_fn=validate_batch_step_fn, | |||||
validate_mode=validate_mode, | |||||
callbacks=callbacks, | |||||
metrics=metrics, | |||||
validate_every=validate_every, | |||||
input_mapping=input_mapping, | |||||
output_mapping=output_mapping, | |||||
accumulation_steps=accumulation_steps, | |||||
fp16=fp16, | |||||
marker=marker, | |||||
**kwargs | |||||
) | |||||
self.driver.set_optimizers(optimizers=optimizers) | |||||
if train_dataloader is not None: | |||||
self.driver.set_dataloader(train_dataloader=train_dataloader) | |||||
# 初始化 callback manager; | |||||
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) | |||||
# 添加所有的函数式 callbacks; | |||||
self._fetch_matched_fn_callbacks() | |||||
# 添加所有的类 callbacks; | |||||
self.callback_manager.initialize_class_callbacks() | |||||
# 初始化 state,包括提供给用户的接口和我们自己使用的接口; | |||||
self.state = State() | |||||
self.trainer_state = TrainerState( | |||||
n_epochs=n_epochs, | |||||
cur_epoch_idx=0, | |||||
global_forward_batches=0, | |||||
batch_idx_in_epoch=0, | |||||
num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化; | |||||
total_batches=None | |||||
) | |||||
use_dist_sampler = kwargs.get("use_dist_sampler", True) | |||||
if use_dist_sampler: | |||||
_dist_sampler = "dist" | |||||
else: | |||||
_dist_sampler = None | |||||
""" 设置内部的 Evaluator """ | |||||
if metrics is None and validate_dataloaders is not None: | |||||
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.") | |||||
if metrics is not None and validate_dataloaders is None: | |||||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") | |||||
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | |||||
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | |||||
self.evaluator = None | |||||
self.epoch_validate = lambda *args, **kwargs: ... | |||||
self.step_validate = lambda *args, **kwargs: ... | |||||
if metrics is not None and validate_dataloaders is not None: | |||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | |||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | |||||
self.evaluator = Evaluator( | |||||
model=model, | |||||
dataloaders=validate_dataloaders, | |||||
metrics=metrics, | |||||
driver=self.driver, | |||||
device=device, | |||||
batch_step_fn=validate_batch_step_fn, | |||||
mode=validate_mode, | |||||
input_mapping=input_mapping, | |||||
output_mapping=output_mapping, | |||||
fp16=fp16, | |||||
verbose=0, | |||||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", use_dist_sampler), | |||||
progress_bar=kwargs.get('progress_bar', 'auto') | |||||
) | |||||
if callable(validate_every): | |||||
self._step_validate_filter = Filter(filter_fn=validate_every) | |||||
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||||
"and in this way, the kind of controlling frequency is depending on the 'step'.") | |||||
elif validate_every < 0: | |||||
self._epoch_validate_filter = Filter(every=-validate_every) | |||||
else: | |||||
# validate_every > 0 | |||||
self._step_validate_filter = Filter(every=validate_every) | |||||
self.metrics = metrics | |||||
self.validate_every = validate_every | |||||
assert self.driver.has_train_dataloader() | |||||
self.driver.setup() | |||||
self.driver.barrier() | |||||
self.dataloader = self.train_dataloader | |||||
self.driver.set_deterministic_dataloader(self.dataloader) | |||||
self.dataloader = self.driver.replace_sampler( | |||||
dataloader=self.train_dataloader, | |||||
dist_sampler=_dist_sampler, | |||||
reproducible=self.callback_manager.has_trainer_chechpoint | |||||
) | |||||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | |||||
self.on_after_trainer_initialized(self.driver) | |||||
self.driver.barrier() | |||||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | |||||
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | |||||
catch_KeyboardInterrupt=True): | |||||
""" | |||||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | |||||
去保存断点重训的文件; | |||||
:param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。 | |||||
:param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止,-1 为根据 dataloader 有多少个 batch 决定。 | |||||
:param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误。为 0 表示不检测。 | |||||
:param resume_from: 从哪个路径下恢复 trainer 的状态 | |||||
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | |||||
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | |||||
行。 | |||||
:return: | |||||
""" | |||||
if self.driver.is_distributed(): | |||||
if catch_KeyboardInterrupt: | |||||
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||||
"driver. And we are gonna to set it to False.") | |||||
catch_KeyboardInterrupt = False | |||||
self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) | |||||
if resume_from is not None: | |||||
if os.path.exists(resume_from): | |||||
self.load(resume_from, resume_training=resume_training) | |||||
else: | |||||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | |||||
if self.evaluator is not None and num_eval_sanity_batch > 0: | |||||
self.on_sanity_check_begin() | |||||
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | |||||
self.on_sanity_check_end(sanity_check_res) | |||||
if num_train_batch_per_epoch != -1: | |||||
self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch) | |||||
self.num_batches_per_epoch = len(self.dataloader) | |||||
self.total_batches = self.num_batches_per_epoch * self.n_epochs | |||||
self.on_train_begin() | |||||
self.driver.barrier() | |||||
self.driver.zero_grad(self.set_grad_to_none) | |||||
try: | |||||
while self.cur_epoch_idx < self.n_epochs: | |||||
self.driver.set_model_mode("train") | |||||
self.on_train_epoch_begin() | |||||
self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | |||||
self.train_batch_loop.run(self, self.dataloader) | |||||
if not self.has_checked_train_batch_loop: | |||||
self._check_train_batch_loop_legality() | |||||
self.cur_epoch_idx += 1 | |||||
self.on_train_epoch_end() | |||||
self.driver.barrier() | |||||
self.epoch_validate() | |||||
self.driver.barrier() | |||||
self.on_train_end() | |||||
self.driver.barrier() | |||||
except KeyboardInterrupt as e: | |||||
self.driver.on_exception() | |||||
self.on_exception(e) | |||||
if not catch_KeyboardInterrupt: | |||||
raise e | |||||
except BaseException as e: | |||||
self.driver.on_exception() | |||||
self.on_exception(e) | |||||
raise e | |||||
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | |||||
def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None: | |||||
trainer.on_validate_begin() | |||||
_validate_res: dict = validate_fn() | |||||
trainer.on_validate_end(_validate_res) | |||||
if self.evaluator is not None: | |||||
if callable(self.validate_every): | |||||
self.step_validate = self._step_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
elif self.validate_every < 0: | |||||
self.epoch_validate = self._epoch_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
else: | |||||
# validate_every > 0 | |||||
self.step_validate = self._step_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | |||||
r""" | |||||
在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; | |||||
这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; | |||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | |||||
:param fn: 具体的 callback 函数; | |||||
""" | |||||
if not isinstance(event, (_SingleEventState, EventsList)): | |||||
raise ValueError("parameter event should only be `Events` or `EventsList` type.") | |||||
_custom_callback = _CallbackWrapper(event, fn) | |||||
self.callback_manager.dissect_one_callback(_custom_callback) | |||||
@classmethod | |||||
def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None): | |||||
r""" | |||||
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; | |||||
注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; | |||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | |||||
:param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, | |||||
表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer | |||||
实例使用; | |||||
:return: 返回原函数; | |||||
""" | |||||
def wrapper(fn: Callable) -> Callable: | |||||
cls._custom_callbacks[marker].append((event, fn)) | |||||
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \ | |||||
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \ | |||||
"`Callback` with the same callback time." | |||||
return fn | |||||
return wrapper | |||||
def _fetch_matched_fn_callbacks(self): | |||||
""" | |||||
因为对于使用装饰器加入的函数 callback,我们是加在类属性中,因此在初始化一个具体的 trainer 实例后,我们需要从 Trainer 的 | |||||
callback 类属性中将属于其的 callback 函数拿到,然后加入到 callback_manager 中; | |||||
""" | |||||
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) | |||||
_own_callbacks.extend(self._custom_callbacks[None]) | |||||
self._custom_callbacks[None] = [] | |||||
if self.marker is not None: | |||||
if len(self._custom_callbacks[self.marker]) == 0: | |||||
print(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched " | |||||
f"`{self.marker}` that is added through function `Trainer.on`") | |||||
_own_callbacks += self._custom_callbacks[self.marker] | |||||
for each_callback in _own_callbacks: | |||||
self.add_callback_fn(*each_callback) | |||||
def _check_callback_called_legality(self, check_mode: bool = True): | |||||
""" | |||||
1. 函数的调用时机: | |||||
当检测 'batch_step_fn' 时,这个函数应当在 'train_batch_loop.run' 的 while 循环的最后进行调用; | |||||
当检测 'TrainBatchLoop' 时,这个函数应当在每一个 epoch 的最后进行调用; | |||||
2. 函数作用 | |||||
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | |||||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / | |||||
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | |||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | |||||
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | |||||
注意,这一函数只会在 batch_step_fn 不为 None 时或者 TrainBatchLoop 没有被替换时才会被调用; | |||||
:param check_mode: 用来判断该函数是用来检测 'batch_step_fn' 还是用来检测 'TrainBatchLoop' 的参数,为 True 时表示检测 | |||||
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | |||||
""" | |||||
if check_mode: | |||||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
else: | |||||
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | |||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
_not_called_callback_fns = [] | |||||
for each_callback_fn in callbacks: | |||||
if each_callback_fn in self.callback_manager.callback_fns: | |||||
if self.callback_manager.callback_counter[each_callback_fn] == 0: | |||||
_not_called_callback_fns.append(each_callback_fn) | |||||
if check_mode: | |||||
logger.warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||||
f"callback_fns: {_not_called_callback_fns}, but it seems that" | |||||
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||||
# 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass | |||||
# 函数; | |||||
self.check_batch_step_fn = lambda *args, **kwargs: ... | |||||
else: | |||||
logger.warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: " | |||||
f"{_not_called_callback_fns}, but it seems that" | |||||
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||||
def _check_train_batch_loop_legality(self): | |||||
r""" | |||||
该函数用于检测用户定制的 `train_batch_loop` 是否正确地调用了 callback 函数以及是否正确地更新了 `trainer_state` 的状态; | |||||
该函数仅当用户通过属性更换用自己的定制的 `train_batch_loop` 替换了默认的 `TrainBatchLoop` 对象后才会被调用; | |||||
当被调用时,该函数仅当第一次被调用时被调用; | |||||
""" | |||||
# 1. 检测用户定制的 `train_batch_loop` 是否正确地调用了 callback 函数; | |||||
self._check_callback_called_legality(check_mode=False) | |||||
# 2. 检测用户定制的 `train_batch_loop` 是否正确地更新了 `trainer_state` 的状态; | |||||
# 因为该检测函数只会在第一个 epoch 运行完后调用,因此我们只需要检测这些 `trainer_state` 的值是否正确即可; | |||||
if self.batch_idx_in_epoch == 0: | |||||
logger.warning("You have customized your `train_batch_loop`, but it seemed that you forget to update the " | |||||
"`trainer_state.batch_idx_in_epoch` in your process of training. Look the origin class " | |||||
"`TrainBatchLoop`.") | |||||
if self.global_forward_batches == 0: | |||||
logger.warning("You have customized your `train_batch_loop`, but it seemed that you forget to update the " | |||||
"`trainer_state.global_forward_batches` in your process of training. Look the origin class " | |||||
"`TrainBatchLoop`.") | |||||
self.has_checked_train_batch_loop = True | |||||
""" Trainer 需要的一些 property """ | |||||
@property | |||||
def train_dataloader(self): | |||||
return self.driver.train_dataloader | |||||
@property | |||||
def driver(self): | |||||
return self._driver | |||||
@driver.setter | |||||
def driver(self, driver: Driver): | |||||
driver.trainer = self | |||||
driver.model = self.model | |||||
self._driver = driver | |||||
@property | |||||
def train_batch_loop(self): | |||||
return self._train_batch_loop | |||||
@train_batch_loop.setter | |||||
def train_batch_loop(self, loop: Loop): | |||||
self.has_checked_train_batch_loop = False | |||||
if self.batch_step_fn is not None: | |||||
logger.warning("`batch_step_fn` was customized in the Trainer initialization, it will be ignored " | |||||
"when the `train_batch_loop` is also customized.") | |||||
# 如果用户定制了 TrainBatchLoop,那么我们不需要再专门去检测 batch_step_fn,因为该函数一定会被忽略; | |||||
self.check_batch_step_fn = lambda *args, **kwargs: ... | |||||
self._train_batch_loop = loop | |||||
def save_model(self, folder: Union[str, os.PathLike, BinaryIO, io.BytesIO], only_state_dict: bool = False, | |||||
model_save_fn: Optional[Callable] = None, **kwargs): | |||||
r""" | |||||
用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现; | |||||
:param folder: 保存模型的地址; | |||||
:param only_state_dict: 是否只保存模型的 `state_dict`; | |||||
:param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; | |||||
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它; | |||||
""" | |||||
self.on_save_model() | |||||
self.driver.barrier() | |||||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | |||||
if model_save_fn is not None: | |||||
if not callable(model_save_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
rank_zero_call(model_save_fn)(folder) | |||||
else: | |||||
if isinstance(folder, str): | |||||
folder = Path(folder) | |||||
self.driver.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
else: | |||||
if model_save_fn is not None: | |||||
raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | |||||
"`io.BytesIO` type.") | |||||
self.driver.save_model(folder, only_state_dict, **kwargs) | |||||
self.driver.barrier() | |||||
def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False, | |||||
model_load_fn: Optional[Callable] = None, **kwargs): | |||||
self.on_load_model() | |||||
self.driver.barrier() | |||||
if not isinstance(folder, (io.BytesIO, BinaryIO)): | |||||
if model_load_fn is not None: | |||||
if not callable(model_load_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
rank_zero_call(model_load_fn)(folder) | |||||
else: | |||||
if isinstance(folder, str): | |||||
folder = Path(folder) | |||||
self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||||
else: | |||||
if model_load_fn is not None: | |||||
raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being " | |||||
"`io.BytesIO` type.") | |||||
self.driver.load_model(folder, only_state_dict, **kwargs) | |||||
self.driver.barrier() | |||||
def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): | |||||
r""" | |||||
用于断点重训的保存函数; | |||||
""" | |||||
self.driver.barrier() | |||||
# 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | |||||
# 2. trainer_state; | |||||
states = {"callback_states": self.on_save_checkpoint(), | |||||
"trainer_state": self.trainer_state.state_dict()} | |||||
# 3. validate filter state; | |||||
if self.evaluator is not None: | |||||
val_filter_state = {} | |||||
if hasattr(self.step_validate, "__fastNLP_filter__"): | |||||
val_filter_state["step_validate"] = self.step_validate.__fastNLP_filter__.state_dict() | |||||
if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||||
val_filter_state["epoch_validate"] = self.epoch_validate.__fastNLP_filter__.state_dict() | |||||
states["val_filter_state"] = val_filter_state | |||||
else: | |||||
states["val_filter_state"] = None | |||||
# 4. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
dataloader_args = self.driver.get_dataloader_args(self.dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
states['sampler_states'] = sampler.state_dict() | |||||
else: | |||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
if isinstance(folder, str): | |||||
folder = Path(folder) | |||||
if model_save_fn is not None: | |||||
if not callable(model_save_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
rank_zero_call(model_save_fn)(folder) | |||||
self.driver.save(folder=folder, states=states, should_save_model=False, **kwargs) | |||||
else: | |||||
self.driver.save(folder=folder, states=states, | |||||
only_state_dict=only_state_dict, should_save_model=True, **kwargs) | |||||
self.driver.barrier() | |||||
def load(self, folder: str, resume_training: bool = True, only_state_dict: bool = True, | |||||
model_load_fn: Optional[Callable] = None, **kwargs): | |||||
r""" | |||||
用于断点重训的加载函数; | |||||
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | |||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | |||||
注意我们目前不支持单卡到多卡的断点重训; | |||||
TODO:注意我们目前不支持 RandomSampler、BucketedSampler 或者 SortedSampler 之间的断点重训; | |||||
因此如果用户自己需要使用 BucketedSampler,那么其需要自己在 Trainer 之前初始化 BucketedSampler,然后替换原始 Dataloader 中的 | |||||
sampler,不管其是第一次断点重训,还是之后的加载的重新训练; | |||||
:param folder: 保存断点重训 states 的文件地址; | |||||
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | |||||
只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; | |||||
""" | |||||
self.driver.barrier() | |||||
if isinstance(folder, str): | |||||
folder = Path(folder) | |||||
if model_load_fn is not None: | |||||
if not callable(model_load_fn): | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
rank_zero_call(model_load_fn)(folder) | |||||
states = self.driver.load(folder=folder, should_load_model=False, **kwargs) | |||||
else: | |||||
states = self.driver.load(folder=folder, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||||
if not resume_training: | |||||
return | |||||
# 1. 恢复 sampler 的状态; | |||||
dataloader_args = self.driver.get_dataloader_args(self.dataloader) | |||||
sampler = dataloader_args.sampler | |||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||||
if self.driver.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use single device checkpoint retraining before but ddp now.") | |||||
sampler = ReproducibleBatchSampler( | |||||
batch_sampler=sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states['sampler_states']) | |||||
self.driver.replace_sampler(self.dataloader, sampler) | |||||
# 2. validate filter state; | |||||
if self.evaluator is not None: | |||||
val_filter_state = states["val_filter_state"] | |||||
if hasattr(self.step_validate, "__fastNLP_filter__"): | |||||
self.step_validate.__fastNLP_filter__.load_state_dict(val_filter_state["step_validate"]) | |||||
if hasattr(self.epoch_validate, "__fastNLP_filter__"): | |||||
self.epoch_validate.__fastNLP_filter__.load_state_dict(val_filter_state["epoch_validate"]) | |||||
# 3. 恢复 trainer_state 的状态; | |||||
self.trainer_state.load_state_dict(states["trainer_state"]) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
self.trainer_state.batch_idx_in_epoch = len(sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
self.trainer_state.batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
self.trainer_state.batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
# 5. 恢复所有 callback 的状态; | |||||
self.on_load_checkpoint(states["callback_states"]) | |||||
self.driver.barrier() | |||||
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 step 函数) 的 """ | |||||
def train_step(self, batch): | |||||
with self.driver.auto_cast(): | |||||
outputs = self.driver.train_step(batch) | |||||
outputs = match_and_substitute_params(self.output_mapping, outputs) | |||||
return outputs | |||||
def backward(self, outputs): | |||||
self.on_before_backward(outputs) | |||||
loss = self.extract_loss_from_outputs(outputs) | |||||
loss = loss / self.accumulation_steps | |||||
with self.get_no_sync_context(): | |||||
self.driver.backward(loss) | |||||
self.on_after_backward() | |||||
def zero_grad(self): | |||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | |||||
self.on_before_zero_grad(self.driver.optimizers) | |||||
self.driver.zero_grad(self.set_grad_to_none) | |||||
def step(self): | |||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | |||||
self.on_before_optimizer_step(self.driver.optimizers) | |||||
self.driver.step() | |||||
def move_data_to_device(self, batch): | |||||
return self.driver.move_data_to_device(batch) | |||||
@staticmethod | |||||
def extract_loss_from_outputs(outputs): | |||||
r""" | |||||
用来从用户模型的输出对象中抽取 `loss` 对象; | |||||
目前支持 `outputs` 对象为 'Dict' 或者 'dataclass'; | |||||
:return: 返回被抽取出来的 `loss` 对象,如果当前运行的是 'pytorch' 的 `Driver`,那么返回的就是一个 tensor; | |||||
""" | |||||
if isinstance(outputs, Dict): | |||||
try: | |||||
loss = outputs["loss"] | |||||
except: | |||||
raise KeyError(f"We cannot find `loss` from your model output(with keys:{outputs.keys()}). Please either " | |||||
f"directly return it from your model or use `output_mapping` to prepare it.") | |||||
elif is_dataclass(outputs): | |||||
try: | |||||
loss = outputs.loss | |||||
except: | |||||
raise AttributeError("We cannot find `loss` from your model output. Please either directly return it from" | |||||
" your model or use `output_mapping` to prepare it.") | |||||
else: | |||||
raise ValueError("The `outputs` from your model could only be of `dataclass` or `Dict` type. Or you can use " | |||||
"the parameter `output_mapping` to prepare loss.") | |||||
return loss | |||||
@contextmanager | |||||
def get_no_sync_context(self): | |||||
r""" | |||||
用于在梯度累积并且使用 DDP 时,由于在前 `accumulation_steps` - 1 的时间内不需要进行梯度的同步,因此通过使用该 context 上下文 | |||||
环境来避免梯度的同步; | |||||
:return: 一个 no_sync 的 context; | |||||
""" | |||||
if (self.global_forward_batches + 1) % self.accumulation_steps != 0: | |||||
_no_sync_context = self.driver.get_no_sync_context() | |||||
else: | |||||
_no_sync_context = nullcontext | |||||
with _no_sync_context(): | |||||
yield | |||||
""" trainer state property """ | |||||
@property | |||||
def n_epochs(self) -> int: | |||||
return self.trainer_state.n_epochs | |||||
@n_epochs.setter | |||||
def n_epochs(self, n_epochs: int): | |||||
self.trainer_state.n_epochs = n_epochs | |||||
@property | |||||
def cur_epoch_idx(self) -> int: | |||||
return self.trainer_state.cur_epoch_idx | |||||
@cur_epoch_idx.setter | |||||
def cur_epoch_idx(self, cur_epoch_idx: int): | |||||
self.trainer_state.cur_epoch_idx = cur_epoch_idx | |||||
@property | |||||
def global_forward_batches(self) -> int: | |||||
return self.trainer_state.global_forward_batches | |||||
@global_forward_batches.setter | |||||
def global_forward_batches(self, global_forward_batches: int): | |||||
self.trainer_state.global_forward_batches = global_forward_batches | |||||
@property | |||||
def batch_idx_in_epoch(self) -> int: | |||||
return self.trainer_state.batch_idx_in_epoch | |||||
@batch_idx_in_epoch.setter | |||||
def batch_idx_in_epoch(self, batch_idx_in_epoch: int): | |||||
self.trainer_state.batch_idx_in_epoch = batch_idx_in_epoch | |||||
@property | |||||
def num_batches_per_epoch(self) -> int: | |||||
return self.trainer_state.num_batches_per_epoch | |||||
@num_batches_per_epoch.setter | |||||
def num_batches_per_epoch(self, num_batches_per_epoch: int): | |||||
self.trainer_state.num_batches_per_epoch = num_batches_per_epoch | |||||
@property | |||||
def total_batches(self) -> int: | |||||
return self.trainer_state.total_batches | |||||
@total_batches.setter | |||||
def total_batches(self, total_batches: int): | |||||
self.trainer_state.total_batches = total_batches | |||||
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
'State', | |||||
'TrainerState' | |||||
] | |||||
from .state import State, TrainerState |
@@ -0,0 +1,93 @@ | |||||
""" | |||||
该 Module 用来实现一个用于记载用户 callback 实时数据的 state,该 state 实际上是一个 字典,我们通过复用 __getattr__ 方法来实现类似 | |||||
类属性的字典调用方式; | |||||
提供该类的主要目的在于与 Filter 中的特殊的 filter_fn 合作,方便用户能够使用到自己想要的一切特殊的定制方式; | |||||
这一特殊的 Filter 实现需要用户记录一些特殊的状态值,例如 accuracy 等,但是我们不希望用户将这些状态值直接挂在 trainer 实例上,因为这样会 | |||||
污染 trainer 自己的类属性,从而可能导致一些莫名其妙的 bug; | |||||
我们开放 state 用于用户这一特殊的定制选择; | |||||
""" | |||||
from dataclasses import dataclass | |||||
from typing import Optional, Dict | |||||
__all__ = [ | |||||
'State', | |||||
'TrainerState' | |||||
] | |||||
class State(dict): | |||||
r""" | |||||
提供给用户使用的 state; | |||||
为了实现断点重训,用户应当保证其保存的信息都是可序列化的; | |||||
推荐的使用方式: | |||||
>>> state = State() | |||||
>>> state["best_accuracy"] = 0.9 | |||||
>>> print(state["best_accuracy"]) | |||||
or | |||||
>>> print(state.best_accuracy) | |||||
""" | |||||
__slots__ = () # 用户不应当使用 state.name = "name" 来使用此类,因此我们限制用户不可自己对该类设置属性,但是可以通过属性访问字典; | |||||
def __init__(self, *args, **kwargs): | |||||
super(State, self).__init__(*args, **kwargs) | |||||
def __getattr__(self, item): | |||||
if item in self: | |||||
_value = self[item] | |||||
if isinstance(_value, dict): | |||||
return State(_value) | |||||
else: | |||||
return _value | |||||
else: | |||||
raise ValueError(f"key '{item}' is not existed!") | |||||
@dataclass | |||||
class TrainerState: | |||||
r""" | |||||
该类用于我们 fastNLP 自己内部为了训练流程所记录的一些状态,当然是要暴露给用户给用户使用的; | |||||
我们保存的state大部分上是 trainer 断点重训 需要重新加载的; | |||||
专属于 `Trainer` 的状态记载的类; | |||||
n_epochs: 训练过程中总共的 epoch 的数量; | |||||
cur_epoch_idx: 当前正在运行第几个 epoch; | |||||
global_forward_batches: 当前模型总共 forward 了多少个 step; | |||||
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||||
total_batches: 每一个 epoch 会 forward 多少个 step; | |||||
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||||
""" | |||||
n_epochs: Optional[int] = None # 无论如何重新算 | |||||
cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0; | |||||
global_forward_batches: Optional[int] = None # 断点重训 | |||||
batch_idx_in_epoch: Optional[int] = None # 断点重训 | |||||
num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | |||||
total_batches: Optional[int] = None # 无论如何重新算 | |||||
def state_dict(self) -> Dict: | |||||
r""" | |||||
:return: 返回用于断点重训来保存的状态字典; | |||||
""" | |||||
return {"cur_epoch_idx": self.cur_epoch_idx, "global_forward_batches": self.global_forward_batches, | |||||
"batch_idx_in_epoch": self.batch_idx_in_epoch} | |||||
def load_state_dict(self, state_dict: Dict): | |||||
r""" | |||||
用于断点重训来重新加载保存的状态字典; | |||||
:param state_dict: 用于加载的状态字典; | |||||
""" | |||||
for key in state_dict: | |||||
assert key in {"cur_epoch_idx", "global_forward_batches", "batch_idx_in_epoch"}, "Wrong state_dict for `TrainerState`." | |||||
setattr(self, key, state_dict[key]) | |||||
@@ -0,0 +1,123 @@ | |||||
from collections.abc import Iterator | |||||
from typing import Dict | |||||
from fastNLP.core.callbacks import CallbackManager | |||||
from .state import TrainerState | |||||
class TrainerEventTrigger: | |||||
""" | |||||
为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer' | |||||
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 | |||||
模型。 | |||||
""" | |||||
callback_manager: CallbackManager | |||||
trainer_state: TrainerState | |||||
def on_after_trainer_initialized(self, driver): | |||||
self.callback_manager.on_after_trainer_initialized(self, driver) | |||||
def on_sanity_check_begin(self): | |||||
self.callback_manager.on_sanity_check_begin(self) | |||||
def on_sanity_check_end(self, sanity_check_res): | |||||
self.callback_manager.on_sanity_check_end(self, sanity_check_res) | |||||
def on_train_begin(self): | |||||
self.callback_manager.on_train_begin(self) | |||||
def on_train_end(self): | |||||
self.callback_manager.on_train_end(self) | |||||
def on_train_epoch_begin(self): | |||||
self.callback_manager.on_train_epoch_begin(self) | |||||
def on_train_epoch_end(self): | |||||
self.callback_manager.on_train_epoch_end(self) | |||||
def on_fetch_data_begin(self): | |||||
self.callback_manager.on_fetch_data_begin(self) | |||||
def on_fetch_data_end(self): | |||||
self.callback_manager.on_fetch_data_end(self) | |||||
def on_train_batch_begin(self, batch, indices=None): | |||||
self.callback_manager.on_train_batch_begin(self, batch, indices) | |||||
def on_train_batch_end(self): | |||||
self.callback_manager.on_train_batch_end(self) | |||||
def on_exception(self, exception): | |||||
self.callback_manager.on_exception(self, exception) | |||||
def on_save_model(self): | |||||
self.callback_manager.on_save_model(self) | |||||
def on_load_model(self): | |||||
self.callback_manager.on_load_model(self) | |||||
def on_save_checkpoint(self) -> Dict: | |||||
return self.callback_manager.on_save_checkpoint(self) | |||||
def on_load_checkpoint(self, states): | |||||
self.callback_manager.on_load_checkpoint(self, states) | |||||
def on_before_backward(self, outputs): | |||||
self.callback_manager.on_before_backward(self, outputs) | |||||
def on_after_backward(self): | |||||
self.callback_manager.on_after_backward(self) | |||||
def on_before_optimizer_step(self, optimizers): | |||||
self.callback_manager.on_before_optimizer_step(self, optimizers) | |||||
def on_before_zero_grad(self, optimizers): | |||||
self.callback_manager.on_before_zero_grad(self, optimizers) | |||||
def on_validate_begin(self): | |||||
self.callback_manager.on_validate_begin(self) | |||||
def on_validate_end(self, results): | |||||
self.trainer_state.save_on_this_step = True | |||||
self.callback_manager.on_validate_end(self, results) | |||||
class _TruncatedDataLoader: | |||||
def __init__(self, dataloader, num_batches: int): | |||||
""" | |||||
限制 | |||||
:param dataloader: 可迭代的 dataloader 。 | |||||
:param num_batches: 迭代多少个 batch 就停止。 | |||||
""" | |||||
self.dataloader = dataloader | |||||
self._num_batches = min(num_batches, len(dataloader)) | |||||
self._count = 0 | |||||
def __len__(self): | |||||
r""" | |||||
为了在外部调用 `len` 方法时正确地返回当前会迭代的长度; | |||||
""" | |||||
return self._num_batches | |||||
def __iter__(self): | |||||
# 将初试的 `dataloader` 转换成一个 `Iterator` 的逻辑应该放在这里,即只有当外界真正的调用 iter(dataloader) 的时候才需要返回一个 Iterator; | |||||
# TODO 测试一下 | |||||
self._iterator = iter(self.dataloader) | |||||
self._count = 0 | |||||
return self | |||||
def __next__(self): | |||||
if self._count >= self._num_batches: | |||||
raise StopIteration | |||||
self._count += 1 | |||||
# 注意 dataloader 数据不足时会自己本身触发 `StopIteration`; | |||||
return next(self._iterator) | |||||
def __getattr__(self, item): | |||||
return getattr(self.dataloader, item) | |||||
@@ -0,0 +1,26 @@ | |||||
__all__ = [ | |||||
'Driver', | |||||
'TorchDriver', | |||||
"TorchSingleDriver", | |||||
"TorchDDPDriver", | |||||
"PaddleDriver", | |||||
"PaddleSingleDriver", | |||||
"PaddleFleetDriver", | |||||
"JittorDriver", | |||||
"JittorSingleDriver", | |||||
"JittorMPIDriver", | |||||
"TorchPaddleDriver", | |||||
'torch_seed_everything', | |||||
'paddle_seed_everything', | |||||
'optimizer_state_to_device' | |||||
] | |||||
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device | |||||
from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver | |||||
from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything | |||||
from .torch_paddle_driver import TorchPaddleDriver | |||||
from .driver import Driver | |||||
@@ -0,0 +1,476 @@ | |||||
import os | |||||
import signal | |||||
import sys | |||||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union | |||||
from abc import ABC | |||||
from datetime import datetime | |||||
from pathlib import Path | |||||
from io import BytesIO | |||||
__all__ = [ | |||||
'Driver' | |||||
] | |||||
from fastNLP.core.utils import nullcontext | |||||
# todo 航总 check 一下哪一些方法需要 @abstractmethod; | |||||
class Driver(ABC): | |||||
r""" | |||||
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; | |||||
fastNLP 提供的 driver 实例都会同时被 Trainer 和 Evaluator 调用; | |||||
""" | |||||
def __init__(self, model): | |||||
r""" | |||||
:param model: 训练或者评测的模型,需要注意该模型可能为用户已经使用类似 `torch.nn.DataParallel` 或者 | |||||
`torch.nn.parallel.DistributedDataParallel` 包裹过的模型; | |||||
""" | |||||
self.model = model | |||||
# 这些属性用于 open_subprocess 和 on_exception 函数协同配合; | |||||
# self._consensus_file: Optional[Union[str, Path]] = None | |||||
self._pids: Optional[List[int]] = None | |||||
def setup(self): | |||||
r""" | |||||
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; | |||||
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | |||||
""" | |||||
def replace_sampler(self, dataloader, dist_sampler: Optional[str], reproducible: bool = False): | |||||
r""" | |||||
因为一些特殊的情况需要替换 dataloader 的 sampler,而每一个 driver 中的该函数会提供该功能;例如在多卡训练的中,我们 | |||||
需要将 sampler 替换为 distributed sampler;以及如果用户在 Trainer 中加入了断点重训的 callback,那么我们就需要将 sampler 替换 | |||||
为 reproducible sampler; | |||||
:param dataloader: 由 trainer 中传入的原始的 dataloader; | |||||
:param dist_sampler: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];用于指定使用怎样的 sampler; | |||||
目前该参数被定制为分布式训练服务,其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist",否则为 None; | |||||
evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
:param reproducible: 用于在 `Trainer` 中指定是否替换为断点重训的 sampler(多卡) 或者 batch_sampler(单卡);如果是单卡的 Driver, | |||||
并且该参数为 True,表示当前正在断点重训,那么我们就会使用我们的 `ReproducibleBatchSampler` 来替换 dataloader 原本的 batch_sampler; | |||||
如果是多卡的 Driver,那么我们就会用 `RandomSampler` 替换 dataloader 原本的 sampler; | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `replace_sampler` function.") | |||||
def set_deterministic_dataloader(self, dataloader): | |||||
r""" | |||||
为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 torch 的 dataloader,其 | |||||
需要将 worker_init_fn 替换; | |||||
""" | |||||
def set_sampler_epoch(self, dataloader, cur_epoch_idx): | |||||
r""" | |||||
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||||
:param cur_epoch_idx: 当前是第几个 epoch; | |||||
""" | |||||
def train_step(self, batch): | |||||
""" | |||||
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; | |||||
如果检测到用户模型实现了 train_step | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:return: 返回由模型的 `train_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `train_step` function.") | |||||
def validate_step(self, batch): | |||||
""" | |||||
通过调用模型自带的 `validate_step` 或者 `forward` 方法来实现模型评测的前向过程; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:return: 返回由模型的 `validate_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `validate_step` function.") | |||||
def test_step(self, batch): | |||||
""" | |||||
通过调用模型自带的 `test_step` 或者 `forward` 方法来实现模型评测的前向过程; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:return: 返回由模型的 `test_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `test_step` function.") | |||||
def check_evaluator_mode(self, mode: str): | |||||
r""" | |||||
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | |||||
我们应当提醒用户这一行为; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `predict_step` function.") | |||||
@property | |||||
def model(self): | |||||
return self._model | |||||
@model.setter | |||||
def model(self, model): | |||||
self._model = model | |||||
@property | |||||
def train_dataloader(self): | |||||
return self._train_dataloader | |||||
@train_dataloader.setter | |||||
def train_dataloader(self, train_dataloader: Any): | |||||
self._train_dataloader = train_dataloader | |||||
@property | |||||
def validate_dataloaders(self): | |||||
return self._validate_dataloaders | |||||
@validate_dataloaders.setter | |||||
def validate_dataloaders(self, validate_dataloaders: Any): | |||||
self._validate_dataloaders = validate_dataloaders | |||||
@property | |||||
def test_dataloaders(self): | |||||
return self._test_dataloaders | |||||
@test_dataloaders.setter | |||||
def test_dataloaders(self, test_dataloaders: Any): | |||||
self._test_dataloaders = test_dataloaders | |||||
@property | |||||
def predict_dataloaders(self): | |||||
return self._predict_dataloaders | |||||
@predict_dataloaders.setter | |||||
def predict_dataloaders(self, predict_dataloaders: Any): | |||||
self._predict_dataloaders = predict_dataloaders | |||||
def set_dataloader(self, **kwargs): | |||||
r""" | |||||
设置训练或者检验过程中的数据;用于在 trainer 和 evaluator 中将数据 dataloader 挂载到每一个具体的 driver 上; | |||||
:param kwargs: 输入的数据,应当使用 'keyword-only' 的参数进行设置; | |||||
""" | |||||
if "train_dataloader" in kwargs: | |||||
self.train_dataloader = kwargs["train_dataloader"] | |||||
self._check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) | |||||
if "validate_dataloaders" in kwargs: | |||||
self.validate_dataloaders = kwargs["validate_dataloaders"] | |||||
self._check_dataloader_legality(self.validate_dataloaders, "validate_dataloaders", is_train=False) | |||||
if "test_dataloaders" in kwargs: | |||||
self.test_dataloaders = kwargs["test_dataloaders"] | |||||
self._check_dataloader_legality(self.test_dataloaders, "test_dataloaders", is_train=False) | |||||
if "predict_dataloaders" in kwargs: | |||||
self.predict_dataloaders = kwargs["predict_dataloaders"] | |||||
self._check_dataloader_legality(self.predict_dataloaders, "predict_dataloaders", is_train=False) | |||||
@staticmethod | |||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
r""" | |||||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 | |||||
行为是不相同的; | |||||
:param dataloader: 需要检测的输入的 `dataloader`; | |||||
:param dataloader_name: | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `_check_dataloader_legality` function.") | |||||
def has_train_dataloader(self): | |||||
return "_train_dataloader" in self.__dict__ | |||||
def has_validate_dataloaders(self): | |||||
return "_validate_dataloaders" in self.__dict__ | |||||
def has_test_dataloaders(self): | |||||
return "_test_dataloaders" in self.__dict__ | |||||
def has_predict_dataloaders(self): | |||||
return "_predict_dataloaders" in self.__dict__ | |||||
@property | |||||
def optimizers(self) -> List: | |||||
r""" | |||||
如下所示,driver 返回的 optimizers 一定是一个 List,如果用户直接向 Trainer 传入一个单独的 optimzer,我们会使用一个 List 将其 | |||||
包裹; | |||||
:return: List[optimizer0, optimizer1, optimizer2, ...] | |||||
""" | |||||
return self._optimizers | |||||
@optimizers.setter | |||||
def optimizers(self, optimizers): | |||||
if not isinstance(optimizers, Sequence): | |||||
self._optimizers = [optimizers] | |||||
else: | |||||
self._optimizers = optimizers | |||||
self._check_optimizer_legality(self._optimizers) | |||||
@property | |||||
def model_device(self): | |||||
return self._model_device | |||||
@model_device.setter | |||||
def model_device(self, model_device): | |||||
self._model_device = model_device | |||||
@property | |||||
def data_device(self): | |||||
return self.model_device | |||||
@staticmethod | |||||
def _check_optimizer_legality(optimizers): | |||||
""" | |||||
对于用户传入 trainer 的每一个 optimizer,检测其是否合理,因为不同的深度学习框架所使用的的 optimizer 是不相同的; | |||||
:param optimizers: 需要检测的 `optimizers`; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `_check_optimizer_legality` function.") | |||||
def set_optimizers(self, optimizers=None): | |||||
""" | |||||
trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上; | |||||
:param optimizers: | |||||
:return: | |||||
""" | |||||
self.optimizers = optimizers | |||||
def backward(self, loss): | |||||
""" | |||||
实现深度学习中的反向传播过程; | |||||
:param loss: 用来实现反向传播的损失函数值; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `backward` function.") | |||||
def step(self): | |||||
r""" | |||||
实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `step` function.") | |||||
def zero_grad(self, set_to_none: bool = False): | |||||
r""" | |||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | |||||
注意梯度累积不需要在这里实现,trainer 已经在内部实现了梯度累积; | |||||
:param set_to_none: 用来判断是否需要将梯度直接置为 None; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.") | |||||
def get_no_sync_context(self): | |||||
r""" | |||||
返回一个用于关闭多进程之间互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,单卡的 driver 不需要; | |||||
:return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象; | |||||
""" | |||||
return nullcontext | |||||
def get_evaluate_context(self): | |||||
r""" | |||||
返回一个不计算梯度的环境用来对模型进行评测; | |||||
:return: 一个类似 `torch.no_grad` 的 context 上下文对象; | |||||
""" | |||||
return nullcontext | |||||
@property | |||||
def auto_cast(self): | |||||
""" | |||||
fp16 的上下文环境; | |||||
:return: 返回一个用于 fp16 计算的上下文环境; | |||||
""" | |||||
return self._auto_cast | |||||
@auto_cast.setter | |||||
def auto_cast(self, auto_cast): | |||||
self._auto_cast = auto_cast | |||||
def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs): | |||||
r""" | |||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | |||||
:param filepath: 保存文件的文件位置(需要包括文件名)或一个 BytesIO 对象; | |||||
:param only_state_dict: 是否只保存模型的 `state_dict`; | |||||
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `save_model` function.") | |||||
def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs): | |||||
r""" | |||||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | |||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名)或一个 BytesIO 对象; | |||||
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | |||||
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | |||||
:return: 返回加载指定文件后的结果; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | |||||
def save(self, folder, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
r""" | |||||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | |||||
:param folder: 保存断点重训的状态的文件夹;save 函数应该在下面新增两(一)个文件 的 FASTNLP_CHECKPOINT_FILENAME 文件与 | |||||
FASTNLP_MODEL_FILENAME (如果 should_save_model 为 True )。把 model 相关的内容放入到 FASTNLP_MODEL_FILENAME 文件 | |||||
中,将传入的 states 以及自身产生其它状态一并保存在 FASTNLP_CHECKPOINT_FILENAME 里面。 | |||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | |||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的 | |||||
传入的值保持一致。 | |||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `save` function.") | |||||
def load(self, folder: Union[str, Path], only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
r""" | |||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | |||||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | |||||
该函数应该在所有 rank 上执行。 | |||||
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME | |||||
(如果 should_load_model 为True)。 | |||||
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 | |||||
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 | |||||
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | |||||
找到对应的模型状态,则报错。 | |||||
:return: 需要返回 save 函数输入的 states 内容; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `load` function.") | |||||
@staticmethod | |||||
def tensor_to_numeric(tensor, reduce: Optional[str]=None): | |||||
r""" | |||||
将一个 `tensor` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 `numeric` 对象;如果 tensor 只包含一个 | |||||
元素则返回 float 或 int 。 | |||||
:param tensor: 需要被转换的 `tensor` 对象 | |||||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||||
float 或 int 对象。 | |||||
:return: 转换后返回的结果 | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") | |||||
def set_model_mode(self, mode: str): | |||||
r""" | |||||
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式; | |||||
:param mode: 应为二者之一:["train", "eval"]; | |||||
""" | |||||
def unwrap_model(self): | |||||
""" | |||||
保证用户拿到的模型一定是最原始的模型; | |||||
注意因为我们把保存模型的主要逻辑和代码移到了 `Driver` 中,因此在 `save_model` 函数中,一定要先调用此函数来保证我们保存的模型一定是 | |||||
最为原始的模型; | |||||
需要注意用户本身传入的模型就是经过类似 `torch.nn.DataParallel` 或者 `torch.nn.parallel.DistributedDataParallel` 包裹的模型, | |||||
因此在该函数内需要先判断模型的类别; | |||||
:return: 返回最原始的模型,例如没有被 `DistributedDataParallel` 包裹的模型; | |||||
""" | |||||
@staticmethod | |||||
def move_model_to_device(model, device): | |||||
r""" | |||||
用来将模型转移到指定的 device 上; | |||||
之所以写成 `staticmethod`,是因为一方面在 `Driver` 中我们要使用 `unwrap_model` 来拿到最原始的模型,另一方面,在 `save_model` | |||||
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数; | |||||
""" | |||||
def move_data_to_device(self, batch): | |||||
r""" | |||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||||
:return: 将移动到指定机器上的 batch 对象返回; | |||||
""" | |||||
def get_local_rank(self) -> int: | |||||
r""" | |||||
返回当前的local_rank,本函数的返回值只在运行分布式训练的时候有实际含义。 | |||||
:return: 一个整数值,表示当前进程在当前这台机器上的序号; | |||||
""" | |||||
return 0 | |||||
def barrier(self): | |||||
r""" | |||||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | |||||
仅在多分布式训练场景中有使用。 | |||||
""" | |||||
@staticmethod | |||||
def get_dataloader_args(dataloader): | |||||
""" | |||||
用于从 dataloader 中抽取一些属性的值,返回的dataclass中必须包含以下的key: | |||||
sampler, batch_sampler, batch_size, drop_last; | |||||
:param dataloader: | |||||
:return: 返回一个 dataclass,其实例属性应当包括以上的各个属性,并且其名字也应当与这些属性相同,从而方便 trainer 或者其它对象调用; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `get_dataloader_args` function.") | |||||
def is_distributed(self) -> bool: | |||||
""" | |||||
当前的 driver 实例是否是分布式的; | |||||
:return: 返回一个 bool 值,如果当前的 driver 实例是用于分布式的,那么返回 True; | |||||
""" | |||||
return False | |||||
def on_exception(self): | |||||
""" | |||||
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | |||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | |||||
因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 | |||||
pid 的信息; | |||||
""" | |||||
# 单卡 driver 不需要这个函数; | |||||
if self._pids is not None: | |||||
exc_type, exc_value, exc_traceback_obj = sys.exc_info() | |||||
_write_exc_info = { | |||||
'exc_type': exc_type, | |||||
'exc_value': exc_value, | |||||
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||||
'global_rank': getattr(self, "global_rank", None), | |||||
'rank': self.get_local_rank(), | |||||
} | |||||
sys.stderr.write(str(_write_exc_info)+"\n") | |||||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | |||||
for pid in self._pids: | |||||
if pid != os.getpid(): | |||||
os.kill(pid, signal.SIGKILL) | |||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||||
""" | |||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )broadcast 到其它所有进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||||
:param int src: source 的 global rank 。 | |||||
:param group: 所属的 group | |||||
:param kwargs: | |||||
:return: 输入的 obj 。 | |||||
""" | |||||
if not self.is_distributed(): | |||||
return obj | |||||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `broadcast_object` method right " | |||||
f"now.") | |||||
def all_gather(self, obj, group)->List: | |||||
""" | |||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||||
pickle 进行序列化,接收到之后再反序列化。 | |||||
:param obj: 可以是 float/int/bool/np.ndarray/{}/[]/Tensor等。 | |||||
:param group: | |||||
:return: 返回值应该是 [obj0, obj1, ...], 其中obj1是rank0上的对象,obj1是rank1上的对象... | |||||
""" | |||||
if not self.is_distributed(): | |||||
return [obj] | |||||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `all_gather` method right " | |||||
f"now.") | |||||
@@ -0,0 +1,19 @@ | |||||
__all__ = [ | |||||
'TorchDDPDriver', | |||||
'TorchSingleDriver', | |||||
'TorchDriver', | |||||
'torch_seed_everything', | |||||
'optimizer_state_to_device' | |||||
] | |||||
from .ddp import TorchDDPDriver | |||||
# todo 实现 fairscale 后再将 fairscale 导入到这里; | |||||
from .single_device import TorchSingleDriver | |||||
from .torch_driver import TorchDriver | |||||
from .utils import torch_seed_everything, optimizer_state_to_device | |||||
@@ -0,0 +1,477 @@ | |||||
import os | |||||
import sys | |||||
import __main__ | |||||
import socket | |||||
import numpy as np | |||||
from time import sleep | |||||
from typing import List, Optional, Union, Dict | |||||
from functools import partial | |||||
# todo 这个等大家的 __all__ 都弄完后改为 from fastNLP.env import; | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.nn.parallel import DistributedDataParallel | |||||
__all__ = [ | |||||
'TorchDDPDriver' | |||||
] | |||||
from .torch_driver import TorchDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import ( | |||||
_DDPWrappingModel, | |||||
ForwardState, | |||||
_MODE_PARAMETER, | |||||
reset_seed, | |||||
replace_sampler | |||||
) | |||||
from fastNLP.core.drivers.utils import distributed_open_proc | |||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
class TorchDDPDriver(TorchDriver): | |||||
def __init__( | |||||
self, | |||||
model, | |||||
parallel_device: Optional[Union[List["torch.device"], "torch.device"]], | |||||
is_pull_by_torch_run: bool = False, | |||||
fp16: bool = False, | |||||
**kwargs | |||||
): | |||||
""" | |||||
DDP 目前考虑支持的三种启动方式: | |||||
1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,并且只运行一个 main 脚本,这时是由我们自己使用 open_subprocesses 拉起 | |||||
多个进程,然后 TorchDDPDriver 自己 init_process_group; | |||||
2. 其它情况同 1,但是用户自己使用 python -m torch.distributed.launch 拉起; | |||||
3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起; | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动; | |||||
如果用户自己在外面初始化了 ddp,那么 | |||||
parallel_device 为 None; | |||||
data_device 为 表示单卡的一个参数; | |||||
dist.is_initialized 为 true; | |||||
""" | |||||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
if isinstance(model, torch.nn.DataParallel): | |||||
raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be " | |||||
f"`torch.nn.Module` or `torch.nn.parallel.DistributedDataParallel` type.") | |||||
# 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的; | |||||
self.is_pull_by_torch_run = is_pull_by_torch_run | |||||
self.parallel_device = parallel_device | |||||
if not is_pull_by_torch_run and parallel_device is None: | |||||
raise ValueError("Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||||
# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; | |||||
if is_pull_by_torch_run: | |||||
self.model_device = parallel_device | |||||
else: | |||||
# 我们的 model_device 一定是 torch.device,而不是一个 list; | |||||
self.model_device = parallel_device[self.local_rank] | |||||
# 如果用户自己在外面初始化了 DDP; | |||||
self.outside_ddp = False | |||||
if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and "fastnlp_special" not in os.environ: | |||||
# 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | |||||
if not isinstance(model, DistributedDataParallel): | |||||
raise RuntimeError( | |||||
"It is not allowed to input a normal model instead of `DistributedDataParallel` when" | |||||
"you initialize the ddp process out of our control.") | |||||
self.outside_ddp = True | |||||
# 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中 | |||||
# 我们就直接将 model_device 置为 None; | |||||
self.model_device = None | |||||
def _running_fn_(batch, step_fn, signature_fn): | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||||
else: | |||||
return self._validate_step(batch) | |||||
model = model.module | |||||
if hasattr(model, "train_step"): | |||||
logger.warning( | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | |||||
"model also implements the `train_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `train_step` and you should note that.") | |||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._train_signature_fn = model.forward | |||||
if hasattr(model, "validate_step"): | |||||
logger.warning( | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | |||||
"model also implements the `validate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._validate_signature_fn = model.forward | |||||
if hasattr(model, "test_step"): | |||||
logger.warning( | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | |||||
"model also implements the `test_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `test_step` and you should note that.") | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._test_signature_fn = model.forward | |||||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | |||||
self._data_device = kwargs.get("data_device", None) | |||||
# if self.outside_ddp and self._data_device is None: | |||||
# raise RuntimeError("When you initialize your ddp out of our control, the parameter " | |||||
# "`data_device` can not be None.") | |||||
if isinstance(self._data_device, int): | |||||
if self._data_device < 0: | |||||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||||
_could_use_device_num = torch.cuda.device_count() | |||||
if self._data_device >= _could_use_device_num: | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
self._data_device = torch.device(f"cuda:{self._data_device}") | |||||
elif isinstance(self._data_device, str): | |||||
self._data_device = torch.device(self._data_device) | |||||
elif self._data_device is not None and not isinstance(self._data_device, torch.device): | |||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||||
self._master_port = None | |||||
# world_size 表示的就是全局的显卡的数量; | |||||
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | |||||
self.global_rank = 0 | |||||
self._configured = False # 防止重复调用 configure_ddp() 函数使用的 | |||||
self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {}) | |||||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) | |||||
if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | |||||
logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set " | |||||
"'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" | |||||
" to 'False' to avoid redundant data translation between different processes.") | |||||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | |||||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | |||||
if self.output_from_new_proc not in {"all", "ignore", "only_error"}: | |||||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||||
# 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_setup = False | |||||
def setup(self): | |||||
if self._has_setup: | |||||
return | |||||
self._has_setup = True | |||||
# 如果用户需要使用多机模式,那么一定进入到这里; | |||||
if self.is_pull_by_torch_run: | |||||
if self.outside_ddp: | |||||
self.world_size = dist.get_world_size() | |||||
self.global_rank = dist.get_rank() | |||||
else: | |||||
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; | |||||
self.world_size = int(os.environ.get("WORLD_SIZE")) | |||||
self.global_rank = int(os.environ.get("RANK")) | |||||
reset_seed() | |||||
logger.info(f"World size:{self.world_size}, Global rank:{self.global_rank}") | |||||
if not dist.is_initialized(): | |||||
dist.init_process_group( | |||||
backend="nccl", rank=self.global_rank, world_size=self.world_size | |||||
) | |||||
os.environ["fastnlp_special"] = "yes" | |||||
# 进入到这里的情况时: | |||||
# dist.is_initialized 一定为 False; | |||||
# 一定是单机; | |||||
# self.parallel_device 一定是 List[torch.device]; | |||||
else: | |||||
if not dist.is_initialized(): | |||||
# 这里主要的问题在于要区分 rank0 和其它 rank 的情况; | |||||
self.world_size = len(self.parallel_device) | |||||
self.open_subprocess() | |||||
self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的; | |||||
reset_seed() | |||||
dist.init_process_group( | |||||
backend="nccl", rank=self.global_rank, world_size=self.world_size | |||||
) | |||||
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 TorchDDPDriver; | |||||
else: | |||||
# 如果 `dist.is_initialized() == True`,那么说明 TorchDDPDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在 | |||||
# 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的; | |||||
pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) | |||||
if pre_num_processes != len(self.parallel_device): | |||||
raise RuntimeError("Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not" | |||||
"allowed that your second `TorchDDPDriver` has a new setting of parameters " | |||||
"`num_nodes` and `num_processes`.") | |||||
self.world_size = dist.get_world_size() | |||||
self.global_rank = dist.get_rank() | |||||
if not self.outside_ddp: | |||||
torch.cuda.set_device(self.model_device) | |||||
self.model.to(self.model_device) | |||||
self.configure_ddp() | |||||
self.barrier() | |||||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | |||||
self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())] | |||||
dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device)) | |||||
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None | |||||
if local_world_size is None: | |||||
local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device) | |||||
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX) | |||||
local_world_size = local_world_size.tolist() + 1 | |||||
node_rank = self.global_rank // local_world_size | |||||
self._pids = self._pids[node_rank*local_world_size: (node_rank+1)*local_world_size] | |||||
self._pids = self.tensor_to_numeric(self._pids) | |||||
def configure_ddp(self): | |||||
if not self._configured and not isinstance(self.model, DistributedDataParallel): | |||||
self.model = DistributedDataParallel( | |||||
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; | |||||
_DDPWrappingModel(self.model), device_ids=[self.model_device.index], | |||||
**self._ddp_kwargs | |||||
) | |||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
self._configured = True | |||||
def open_subprocess(self): | |||||
if self.local_rank == 0: | |||||
# self._consensus_file = Path(tempfile.mkstemp()[1]) | |||||
# self._consensus_file.unlink() | |||||
# Script called as `python a/b/c.py` | |||||
if __main__.__spec__ is None: # pragma: no-cover | |||||
# pull out the commands used to run the script and resolve the abs file path | |||||
command = sys.argv | |||||
command[0] = os.path.abspath(command[0]) | |||||
# use the same python interpreter and actually running | |||||
command = [sys.executable] + command | |||||
# Script called as `python -m a.b.c` | |||||
else: | |||||
command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:] | |||||
os.environ['MASTER_ADDR'] = self.master_address | |||||
os.environ['MASTER_PORT'] = self.master_port | |||||
os.environ["LOCAL_RANK"] = str(self.local_rank) | |||||
os.environ["WORLD_SIZE"] = f"{self.world_size}" | |||||
os.environ[FASTNLP_DISTRIBUTED_CHECK] = f"{len(self.parallel_device)}" | |||||
os.environ[FASTNLP_GLOBAL_RANK] = "0" | |||||
interactive_ddp_procs = [] | |||||
for rank in range(1, len(self.parallel_device)): | |||||
env_copy = os.environ.copy() | |||||
env_copy["LOCAL_RANK"] = f"{rank}" | |||||
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; | |||||
env_copy[FASTNLP_GLOBAL_RANK] = str(rank) | |||||
proc = distributed_open_proc(self.output_from_new_proc, command, env_copy, self.global_rank) | |||||
interactive_ddp_procs.append(proc) | |||||
delay = np.random.uniform(1, 5, 1)[0] | |||||
sleep(delay) | |||||
@property | |||||
def master_address(self) -> str: | |||||
return os.environ.get("MASTER_ADDR", "127.0.0.1") | |||||
@property | |||||
def master_port(self) -> str: | |||||
if self.outside_ddp: | |||||
return os.environ.get("MASTER_PORT") | |||||
if self._master_port is None: | |||||
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port()) | |||||
return self._master_port | |||||
@property | |||||
def world_size(self) -> int: | |||||
return self._world_size | |||||
@world_size.setter | |||||
def world_size(self, size: int): | |||||
self._world_size = size | |||||
@property | |||||
def global_rank(self) -> int: | |||||
return self._global_rank | |||||
@global_rank.setter | |||||
def global_rank(self, rank: int) -> None: | |||||
self._global_rank = rank | |||||
@property | |||||
def local_rank(self) -> int: | |||||
return int(os.environ.get("LOCAL_RANK", 0)) | |||||
@property | |||||
def data_device(self): | |||||
if self.outside_ddp: | |||||
return self._data_device | |||||
return self.model_device | |||||
def train_step(self, batch): | |||||
# 注意这里的 self.model 已经是 'fastNLP.drivers.utils._DDPWrappingModel'; | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
return self._train_step(batch) | |||||
def validate_step(self, batch): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
return self._validate_step(batch) | |||||
def test_step(self, batch): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
return self._test_step(batch) | |||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||||
if isinstance(dist_sampler, ReproducibleIterator): | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
dist_sampler = re_instantiate_sampler(dist_sampler) | |||||
return replace_sampler(dataloader, dist_sampler) | |||||
# trainer, evaluator | |||||
if dist_sampler is None: | |||||
if reproducible: | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||||
"control.") | |||||
else: | |||||
return dataloader | |||||
# trainer | |||||
elif dist_sampler == "dist": | |||||
args = self.get_dataloader_args(dataloader) | |||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||||
if isinstance(args.sampler, ReproducibleIterator): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
sampler = RandomSampler( | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||||
) | |||||
# todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行 | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
# evaluator | |||||
elif dist_sampler == "unrepeatdist": | |||||
args = self.get_dataloader_args(dataloader) | |||||
sampler = UnrepeatedDistributedSampler( | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def is_global_zero(self): | |||||
return self.global_rank == 0 | |||||
def get_no_sync_context(self): | |||||
# 注意此时的 model 是 "DistributedDataParallel" 对象; | |||||
return self.model.no_sync | |||||
def unwrap_model(self): | |||||
_module = self.model.module | |||||
if isinstance(_module, _DDPWrappingModel): | |||||
return _module.model | |||||
else: | |||||
return _module | |||||
def get_local_rank(self) -> int: | |||||
return self.local_rank | |||||
def barrier(self): | |||||
torch.distributed.barrier(async_op=True) | |||||
def is_distributed(self): | |||||
return True | |||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||||
""" | |||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||||
:param int src: source 的 global rank 。 | |||||
:param int dst: target 的 global rank,可以是多个目标 rank | |||||
:param group: 所属的 group | |||||
:param kwargs: | |||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | |||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | |||||
""" | |||||
return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group) | |||||
def all_gather(self, obj, group) -> List: | |||||
""" | |||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||||
pickle 进行序列化,接收到之后再反序列化。 | |||||
example: | |||||
obj = { | |||||
'a': [1, 1], | |||||
'b': [[1, 2], [1, 2]], | |||||
'c': { | |||||
'd': [1, 2] | |||||
} | |||||
} | |||||
-> | |||||
[ | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
] | |||||
:param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 | |||||
:param group: | |||||
:return: | |||||
""" | |||||
return fastnlp_torch_all_gather(obj, device=self.data_device, group=group) | |||||
def find_free_network_port() -> str: | |||||
"""Finds a free port on localhost. | |||||
It is useful in single-node training when we don't want to connect to a real master node but have to set the | |||||
`MASTER_PORT` environment variable. | |||||
""" | |||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||||
s.bind(("", 0)) | |||||
s.listen(1) | |||||
port = s.getsockname()[1] | |||||
s.close() | |||||
return str(port) |
@@ -0,0 +1,461 @@ | |||||
import io | |||||
import pickle | |||||
from typing import Mapping | |||||
_pickler = pickle.Pickler | |||||
_unpickler = pickle.Unpickler | |||||
from abc import ABC | |||||
from typing import Any, Union, List | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch import distributed as dist | |||||
from fastNLP.core.utils import apply_to_collection | |||||
def all_gather_object(object_list, obj, group=None): | |||||
""" | |||||
Gathers picklable objects from the whole group into a list. Similar to | |||||
:func:`all_gather`, but Python objects can be passed in. Note that the object | |||||
must be picklable in order to be gathered. | |||||
Args: | |||||
object_list (list[Any]): Output list. It should be correctly sized as the | |||||
size of the group for this collective and will contain the output. | |||||
object (Any): Pickable Python object to be broadcast from current process. | |||||
group (ProcessGroup, optional): The process group to work on. If None, | |||||
the default process group will be used. Default is ``None``. | |||||
Returns: | |||||
None. If the calling rank is part of this group, the output of the | |||||
collective will be populated into the input ``object_list``. If the | |||||
calling rank is not part of the group, the passed in ``object_list`` will | |||||
be unmodified. | |||||
.. note:: Note that this API differs slightly from the :func:`all_gather` | |||||
collective since it does not provide an ``async_op`` handle and thus | |||||
will be a blocking call. | |||||
.. note:: For NCCL-based processed groups, internal tensor representations | |||||
of objects must be moved to the GPU device before communication takes | |||||
place. In this case, the device used is given by | |||||
``torch.cuda.current_device()`` and it is the user's responsiblity to | |||||
ensure that this is set so that each rank has an individual GPU, via | |||||
``torch.cuda.set_device()``. | |||||
.. warning:: | |||||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||||
known to be insecure. It is possible to construct malicious pickle data | |||||
which will execute arbitrary code during unpickling. Only call this | |||||
function with data you trust. | |||||
Example:: | |||||
>>> # Note: Process group initialization omitted on each rank. | |||||
>>> import torch.distributed as dist | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
""" | |||||
if dist.distributed_c10d._rank_not_in_group(group): | |||||
return | |||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
current_device = torch.device("cpu") | |||||
if dist.is_nccl_available() and isinstance( | |||||
group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL | |||||
): | |||||
# See note about using torch.cuda.current_device() here in docstring. | |||||
# We cannot simply use my_rank since rank == device is not necessarily | |||||
# true. | |||||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist.get_world_size(group=group) | |||||
object_sizes_tensor = torch.zeros( | |||||
group_size, dtype=torch.long, device=current_device | |||||
) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes | |||||
dist.all_gather(object_size_list, local_size, group=group) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor.resize_(max_object_size) | |||||
coalesced_output_tensor = torch.empty( | |||||
max_object_size * group_size, dtype=torch.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
dist.all_gather(output_tensors, input_tensor, group=group) | |||||
# Deserialize outputs back to object. | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(torch.uint8) | |||||
if tensor.device != torch.device("cpu"): | |||||
tensor = tensor.cpu() | |||||
tensor_size = object_size_list[i] | |||||
object_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
if dst == my_rank: | |||||
if not gather_list: | |||||
raise ValueError( | |||||
"Argument ``gather_list`` must be specified on destination rank." | |||||
) | |||||
elif gather_list: | |||||
raise ValueError( | |||||
"Argument ``gather_list`` must NOT be specified " | |||||
"on non-destination ranks." | |||||
) | |||||
def gather_object(obj, object_gather_list=None, dst=0, group=None): | |||||
""" | |||||
Gathers picklable objects from the whole group in a single process. | |||||
Similar to :func:`gather`, but Python objects can be passed in. Note that the | |||||
object must be picklable in order to be gathered. | |||||
Args: | |||||
obj (Any): Input object. Must be picklable. | |||||
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it | |||||
should be correctly sized as the size of the group for this | |||||
collective and will contain the output. Must be ``None`` on non-dst | |||||
ranks. (default is ``None``) | |||||
dst (int, optional): Destination rank. (default is 0) | |||||
group: (ProcessGroup, optional): The process group to work on. If None, | |||||
the default process group will be used. Default is ``None``. | |||||
Returns: | |||||
None. On the ``dst`` rank, ``object_gather_list`` will contain the | |||||
output of the collective. | |||||
.. note:: Note that this API differs slightly from the gather collective | |||||
since it does not provide an async_op handle and thus will be a blocking | |||||
call. | |||||
.. note:: Note that this API is not supported when using the NCCL backend. | |||||
.. warning:: | |||||
:func:`gather_object` uses ``pickle`` module implicitly, which is | |||||
known to be insecure. It is possible to construct malicious pickle data | |||||
which will execute arbitrary code during unpickling. Only call this | |||||
function with data you trust. | |||||
Example:: | |||||
>>> # Note: Process group initialization omitted on each rank. | |||||
>>> import torch.distributed as dist | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> dist.gather_object( | |||||
gather_objects[dist.get_rank()], | |||||
output if dist.get_rank() == 0 else None, | |||||
dst=0 | |||||
) | |||||
>>> # On rank 0 | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
""" | |||||
if dist.distributed_c10d._rank_not_in_group(group): | |||||
return | |||||
# Ensure object_gather_list is specified appopriately. | |||||
my_rank = dist.get_rank() | |||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list) | |||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
group_backend = dist.get_backend(group) | |||||
current_device = torch.device("cpu") | |||||
is_nccl_backend = group_backend == dist.Backend.NCCL | |||||
if is_nccl_backend: | |||||
current_device = torch.device('cuda', torch.cuda.current_device()) | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist.get_world_size(group=group) | |||||
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes. An all-gather is needed here despite this being a | |||||
# gather, since each rank needs to broadcast a tensor of the same (maximal) | |||||
# size. | |||||
dist.all_gather(object_size_list, local_size, group=group) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor.resize_(max_object_size) | |||||
# Avoid populating output tensors if the result won't be gathered on this rank. | |||||
if my_rank == dst: | |||||
coalesced_output_tensor = torch.empty( | |||||
max_object_size * group_size, dtype=torch.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
# All ranks call gather with equal-sized tensors. | |||||
dist.gather( | |||||
input_tensor, | |||||
gather_list=output_tensors if my_rank == dst else None, | |||||
dst=dst, | |||||
group=group, | |||||
) | |||||
if my_rank != dst: | |||||
return | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(torch.uint8) # type: ignore[call-overload] | |||||
tensor_size = object_size_list[i] | |||||
object_gather_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
def _object_to_tensor(obj, device=None): | |||||
f = io.BytesIO() | |||||
_pickler(f).dump(obj) | |||||
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined] | |||||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. | |||||
# Otherwise, it will casue 100X slowdown. | |||||
# See: https://github.com/pytorch/pytorch/issues/65696 | |||||
byte_tensor = torch.ByteTensor(byte_storage) | |||||
local_size = torch.LongTensor([byte_tensor.numel()]) | |||||
if device is not None: | |||||
byte_tensor = byte_tensor.to(device) | |||||
local_size = local_size.to(device) | |||||
return byte_tensor, local_size | |||||
def _tensor_to_object(tensor, tensor_size): | |||||
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] | |||||
return _unpickler(io.BytesIO(buf)).load() | |||||
def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | |||||
# src rank send to all other ranks | |||||
size = torch.LongTensor([0]).to(device) | |||||
if cur_rank == src: | |||||
world_size = dist.get_world_size(group=group) | |||||
tensor, size = _object_to_tensor(obj) | |||||
tensor = tensor.to(device) | |||||
size = size.to(device) | |||||
# 首先同步 obj 的 size 的信息; | |||||
dist.broadcast(size, src, group=group) | |||||
for subrank in range(world_size): | |||||
if subrank != src: | |||||
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag) | |||||
else: | |||||
dist.broadcast(size, src, group=group) | |||||
tensor = torch.ByteTensor([0] * size).to(device) | |||||
dist.recv(tensor=tensor, src=src, group=group, tag=tag) | |||||
return _tensor_to_object(tensor.cpu(), size) | |||||
def _all_gather(obj, **kwargs): | |||||
group = kwargs.get('group', None) | |||||
if isinstance(obj, torch.Tensor): | |||||
gathered_tensor = [torch.zeros_like(obj) for _ in | |||||
range(torch.distributed.get_world_size(group=group))] | |||||
torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||||
return gathered_tensor | |||||
elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor): | |||||
tensor, size = obj | |||||
# 首先需要同步 size 吧? | |||||
group_size = dist.get_world_size(group=group) | |||||
object_sizes_tensor = torch.zeros( | |||||
group_size, dtype=torch.long, device=tensor.device | |||||
) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
dist.all_gather(object_size_list, size, group=group) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
tensor.resize_(max_object_size) | |||||
coalesced_output_tensor = torch.empty( | |||||
max_object_size * group_size, dtype=torch.uint8, device=tensor.device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
dist.all_gather(output_tensors, tensor, group=group) | |||||
object_list = [] | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(torch.uint8) | |||||
tensor_size = object_size_list[i] | |||||
object_list.append(_tensor_to_object(tensor, tensor_size)) | |||||
return object_list | |||||
elif isinstance(obj, tuple) and len(obj) == 2: | |||||
obj, _type = obj | |||||
gathered_tensor = [torch.zeros_like(obj) for _ in | |||||
range(torch.distributed.get_world_size(group=group))] | |||||
torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||||
if _type == np.ndarray: | |||||
gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor] | |||||
else: | |||||
gathered_tensor = [_type(t.item()) for t in gathered_tensor] | |||||
return gathered_tensor | |||||
else: | |||||
raise RuntimeError("Unsupported types to implement all_gather.") | |||||
class CanTransferDataType(ABC): | |||||
""" | |||||
检测可以进行传输的对象。 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is CanTransferDataType: | |||||
if issubclass(subclass, Mapping): | |||||
return False | |||||
if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray): | |||||
return True | |||||
return False | |||||
return NotImplemented | |||||
def _tensorize(obj, device=None): | |||||
if isinstance(obj, torch.Tensor): | |||||
return obj | |||||
if isinstance(obj, bool): | |||||
return torch.tensor(obj, dtype=torch.uint8, device=device), bool | |||||
if isinstance(obj, float): | |||||
return torch.tensor(obj, dtype=torch.float, device=device), float | |||||
if isinstance(obj, int): | |||||
return torch.tensor(obj, dtype=torch.int, device=device), int | |||||
if isinstance(obj, np.ndarray): | |||||
return torch.from_numpy(obj), np.ndarray | |||||
return _object_to_tensor(obj, device) | |||||
def _to_device(tensor, device): | |||||
return tensor.contiguous().to(device) | |||||
def convert_to_tensors(data: Any, device=None) -> Any: | |||||
data = apply_to_collection(data, CanTransferDataType, _tensorize) | |||||
def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]): | |||||
if isinstance(t, tuple): | |||||
if isinstance(t[1], torch.Tensor): # 说明是 object 转的 | |||||
return t[0].to(device).contiguous(), t[1].to(device) | |||||
else: # 说明第二个元素是type,见 to_dtype_tensor 函数 | |||||
return t[0].to(device).contiguous(), t[1] | |||||
return t.to(device).contiguous() | |||||
data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device) | |||||
return data | |||||
def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||||
""" | |||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||||
example: | |||||
obj = { | |||||
'a': [1, 1], | |||||
'b': [[1, 2], [1, 2]], | |||||
'c': { | |||||
'd': [1, 2] | |||||
} | |||||
} | |||||
-> | |||||
[ | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
] | |||||
:param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值 | |||||
:param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。 | |||||
:param group: | |||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||||
""" | |||||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
objs = [None for _ in range(dist.get_world_size(group))] | |||||
dist.all_gather_object(objs, obj) | |||||
return objs | |||||
if device is None: | |||||
device = torch.cuda.current_device() | |||||
group = group if group is not None else torch.distributed.group.WORLD | |||||
data = convert_to_tensors(obj, device=device) | |||||
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) | |||||
objs = [] | |||||
def _get_obj_on_idx(obj, idx): | |||||
return obj[idx] | |||||
for i in range(dist.get_world_size(group)): | |||||
objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i)) | |||||
return objs | |||||
def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||||
""" | |||||
将 src 上的 obj 对象广播到其它 rank 上。 | |||||
:param obj: | |||||
:param src: | |||||
:param device: | |||||
:param group: | |||||
:return: | |||||
""" | |||||
cur_rank = dist.get_rank(group) | |||||
# if cur_rank == src: | |||||
# # 如果有 tensor 全部移动到 cpu 上,方便 pickle | |||||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
if cur_rank!=src: | |||||
get_obj = [None] | |||||
dist.broadcast_object_list(get_obj, src=src, group=group) | |||||
return get_obj[0] | |||||
else: | |||||
dist.broadcast_object_list([obj], src=src, group=group) | |||||
return obj | |||||
if cur_rank == src: | |||||
tensor, size = _object_to_tensor(obj, device=device) | |||||
else: | |||||
size = torch.LongTensor([0]).to(device) | |||||
dist.broadcast(size, src=src, group=group) | |||||
if cur_rank != src: | |||||
tensor = torch.empty( | |||||
size.int().item(), # type: ignore[arg-type] | |||||
dtype=torch.uint8, | |||||
device=device | |||||
) | |||||
dist.broadcast(tensor, src=src, group=group) | |||||
return _tensor_to_object(tensor, tensor_size=size.item()) | |||||
@@ -0,0 +1,63 @@ | |||||
from typing import List | |||||
from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE | |||||
if _NEED_IMPORT_FAIRSCALE: | |||||
import torch | |||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel | |||||
from fairscale.optim import OSS | |||||
__all__ = [ | |||||
'ShardedDriver' | |||||
] | |||||
from .ddp import TorchDDPDriver | |||||
# todo 注意 fairscale 现在几乎所有的功能都没有实现; | |||||
# TODO:预跑前后对模型和 optimizers 的支持; | |||||
# TODO:fairscale 的 fp16 额外的处理; | |||||
class ShardedDriver(TorchDDPDriver): | |||||
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M | |||||
def __init__( | |||||
self, | |||||
model, | |||||
parallel_device: List["torch.device"], | |||||
num_nodes: int = 1, | |||||
fp16: bool = False, | |||||
**kwargs | |||||
): | |||||
super(ShardedDriver, self).__init__( | |||||
model=model, | |||||
parallel_device=parallel_device, | |||||
num_nodes=num_nodes, | |||||
fp16=fp16, | |||||
**kwargs | |||||
) | |||||
def configure_ddp(self): | |||||
if "reduce_buffer_size" not in self._ddp_kwargs: | |||||
# For multi-node training, enabling bucketing will improve performance. | |||||
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 | |||||
self.optimizers = self._wrap_optimizers(self.optimizers) | |||||
self.model = ShardedDataParallel(self.model, sharded_optimizer=self.optimizers, **self._ddp_kwargs) | |||||
def _wrap_optimizers(self, optimizers) -> List["OSS"]: | |||||
# TODO:之后得去研究一下 pytorch lightning 为什么这样写,我们是不是也需要这样写; | |||||
# if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: | |||||
# return optimizers | |||||
return self._reinit_optimizers_with_oss(optimizers) | |||||
def _reinit_optimizers_with_oss(self, optimizers) -> List["OSS"]: | |||||
for x, optimizer in enumerate(optimizers): | |||||
if not isinstance(optimizer, OSS): | |||||
optim_class = type(optimizer) | |||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) | |||||
# TODO:具体细节见 pytorch lightning 的这一函数,主要的点在于加入 fp16 相关的一些东西; | |||||
optimizers[x] = zero_optimizer | |||||
del optimizer | |||||
return optimizers | |||||
@@ -0,0 +1,89 @@ | |||||
import os | |||||
from typing import Optional, Union, List, Sequence | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from .torch_driver import TorchDriver | |||||
from .single_device import TorchSingleDriver | |||||
from .ddp import TorchDDPDriver | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | |||||
def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.device, int, List[int]]], | |||||
model: torch.nn.Module, **kwargs) -> TorchDriver: | |||||
r""" | |||||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | |||||
注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | |||||
:param driver: 该参数的值应为以下之一:["torch", "torch_ddp", "fairscale"]; | |||||
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | |||||
:param model: 训练或者评测的具体的模型; | |||||
:return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||||
先后 driver 的次序的正确问题); | |||||
""" | |||||
# world_size 和 rank | |||||
if FASTNLP_BACKEND_LAUNCH in os.environ: | |||||
if device is not None: | |||||
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
"up your script. And we will directly get the local device via " | |||||
"`os.environ['LOCAL_RANK']`.") | |||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | |||||
if driver not in {"torch", "torch_ddp", "fairscale"}: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale'].") | |||||
_could_use_device_num = torch.cuda.device_count() | |||||
if isinstance(device, str): | |||||
device = torch.device(device) | |||||
elif isinstance(device, int): | |||||
if device < 0 and device != -1: | |||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | |||||
if device >= _could_use_device_num: | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
device = torch.device(f"cuda:{device}") | |||||
elif isinstance(device, Sequence): | |||||
device = list(set(device)) | |||||
for each in device: | |||||
if not isinstance(each, int): | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | |||||
elif each < 0: | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | |||||
elif each >= _could_use_device_num: | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | |||||
" the available gpu number.") | |||||
device = [torch.device(f"cuda:{w}") for w in device] | |||||
elif device is not None and not isinstance(device, torch.device): | |||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||||
if driver == "torch": | |||||
if not isinstance(device, List): | |||||
return TorchSingleDriver(model, device, **kwargs) | |||||
else: | |||||
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||||
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | |||||
"`driver` as `TorchDDPDriver`.") | |||||
return TorchDDPDriver(model, device, **kwargs) | |||||
elif driver == "torch_ddp": | |||||
if device is not None and not isinstance(device, List): | |||||
if device.type == 'cpu': | |||||
raise ValueError("You are using `torch_ddp` driver, but your chosen `device` is 'cpu'.") | |||||
logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " | |||||
"still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " | |||||
"choose `torch` driver.") | |||||
return TorchDDPDriver(model, device, **kwargs) | |||||
else: | |||||
return TorchDDPDriver(model, device, **kwargs) | |||||
elif driver == "fairscale": | |||||
raise NotImplementedError("`fairscale` is not support right now.") | |||||
# if not isinstance(device, List): | |||||
# if device.type == 'cpu': | |||||
# raise ValueError("You are using `fairscale` driver, but your chosen `device` is 'cpu'.") | |||||
# log.info("Notice you are using `fairscale` driver, but your chosen `device` is only one gpu, we will" | |||||
# "still use `fairscale` for you, but if you mean using `TorchSingleDriver`, you should " | |||||
# "choose `torch` driver.") | |||||
# return ShardedDriver(model, [device], **kwargs) | |||||
# else: | |||||
# return ShardedDriver(model, device, **kwargs) |
@@ -0,0 +1,176 @@ | |||||
import os | |||||
from typing import Dict, Union | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.nn import DataParallel | |||||
from torch.nn.parallel import DistributedDataParallel | |||||
__all__ = [ | |||||
'TorchSingleDriver' | |||||
] | |||||
from .torch_driver import TorchDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||||
from fastNLP.core.utils import auto_param_call | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
class TorchSingleDriver(TorchDriver): | |||||
r""" | |||||
用于 cpu 和 单卡 gpu 运算; | |||||
""" | |||||
def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | |||||
if isinstance(model, DistributedDataParallel): | |||||
raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | |||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) | |||||
if cuda_visible_devices == "": | |||||
device = torch.device("cpu") | |||||
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | |||||
"use `cpu` instead of `gpu` device.") | |||||
super(TorchSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
if device is None: | |||||
raise ValueError("Parameter `device` can not be None in `TorchSingleDriver`.") | |||||
self.model_device = device | |||||
self.local_rank = 0 | |||||
self.global_rank = 0 | |||||
self.world_size = 1 | |||||
if isinstance(model, DataParallel): | |||||
model = self.unwrap_model() | |||||
if hasattr(model, "train_step"): | |||||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||||
"model also implements the `train_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `train_step` and you should note that.") | |||||
self._train_step = self.model | |||||
self._train_signature_fn = model.forward | |||||
if hasattr(model, "validate_step"): | |||||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||||
"model also implements the `validate_step` method, which we can not call actually, " | |||||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||||
self._validate_step = self.model | |||||
self._validate_signature_fn = model.forward | |||||
if hasattr(model, "test_step"): | |||||
logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " | |||||
"model also implements the `test_step` method, which we can not call actually, we will" | |||||
" call `forward` function instead of `test_step` and you should note that.") | |||||
self._test_step = self.model | |||||
self._test_signature_fn = model.forward | |||||
else: | |||||
if hasattr(self.model, "train_step"): | |||||
self._train_step = self.model.train_step | |||||
self._train_signature_fn = None | |||||
else: | |||||
self._train_step = self.model | |||||
# 输入的模型是 `DataParallel` 或者 `DistributedDataParallel`,我们需要保证其 signature_fn 是正确的; | |||||
model = self.unwrap_model() | |||||
self._train_signature_fn = model.forward | |||||
if hasattr(self.model, "validate_step"): | |||||
self._validate_step = self.model.validate_step | |||||
self._validate_signature_fn = None | |||||
elif hasattr(self.model, "test_step"): | |||||
self._validate_step = self.model.test_step | |||||
self._validate_signature_fn = self.model.test_step | |||||
else: | |||||
self._validate_step = self.model | |||||
model = self.unwrap_model() | |||||
self._validate_signature_fn = model.forward | |||||
if hasattr(self.model, "test_step"): | |||||
self._test_step = self.model.test_step | |||||
self._test_signature_fn = None | |||||
elif hasattr(self.model, "validate_step"): | |||||
self._test_step = self.model.validate_step | |||||
self._test_signature_fn = self.model.validate_step | |||||
else: | |||||
self._test_step = self.model | |||||
model = self.unwrap_model() | |||||
self._test_signature_fn = model.forward | |||||
def setup(self): | |||||
if self.model_device is not None: | |||||
self.model.to(self.model_device) | |||||
def train_step(self, batch) -> Dict: | |||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
else: | |||||
return self._train_step(batch) | |||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def validate_step(self, batch) -> Dict: | |||||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | |||||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
else: | |||||
return self._validate_step(batch) | |||||
def test_step(self, batch) -> Dict: | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
else: | |||||
return self._test_step(batch) | |||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
reproducible: bool = False): | |||||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, dist_sampler) | |||||
elif isinstance(dist_sampler, ReproducibleIterator): | |||||
return replace_sampler(dataloader, dist_sampler) | |||||
if reproducible: | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleIterator): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | |||||
return dataloader | |||||
def unwrap_model(self): | |||||
if isinstance(self.model, torch.nn.DataParallel) or \ | |||||
isinstance(self.model, torch.nn.parallel.DistributedDataParallel): | |||||
return self.model.module | |||||
else: | |||||
return self.model | |||||
@property | |||||
def data_device(self): | |||||
""" | |||||
单卡模式不支持 data_device; | |||||
""" | |||||
return self.model_device | |||||
def is_distributed(self): | |||||
return False | |||||
@@ -0,0 +1,330 @@ | |||||
import os | |||||
from typing import Union, Dict, Optional, Callable | |||||
from functools import partial | |||||
from pkg_resources import parse_version | |||||
import numpy as np | |||||
import random | |||||
from dataclasses import dataclass | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from pathlib import Path | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.utils.data import DataLoader, IterableDataset, RandomSampler, Sampler, BatchSampler, Dataset | |||||
from torch.optim import Optimizer | |||||
_reduces = { | |||||
'sum': torch.max, | |||||
'min': torch.min, | |||||
'max': torch.max, | |||||
'mean': torch.mean | |||||
} | |||||
__all__ = [ | |||||
'TorchDriver' | |||||
] | |||||
from .utils import optimizer_state_to_device | |||||
from fastNLP.core.drivers.driver import Driver | |||||
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env | |||||
from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||||
from fastNLP.core.log import logger | |||||
class TorchDriver(Driver): | |||||
r""" | |||||
专属于 pytorch 的 driver;因为我们会在同一个 Trainer 框架下提供 jittor、paddle 等训练框架的支持; | |||||
""" | |||||
def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | |||||
super(TorchDriver, self).__init__(model) | |||||
""" 进行 fp16 的设置 """ | |||||
# 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里; | |||||
self.fp16 = fp16 | |||||
if parse_version(torch.__version__) < parse_version('1.6'): | |||||
raise RuntimeError("Pytorch supports float16 after version 1.6, please upgrade your pytorch version.") | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||||
self.grad_scaler = _grad_scaler() | |||||
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | |||||
self.non_blocking = kwargs.get("torch_non_blocking", True) | |||||
def zero_grad(self, set_to_none: bool = False): | |||||
for optimizer in self.optimizers: | |||||
self._clear_grad(optimizer, set_to_none) | |||||
def _clear_grad(self, optimizer, set_to_none): | |||||
param_groups = optimizer.param_groups | |||||
for group in param_groups: | |||||
for p in group['params']: | |||||
if p.grad is not None: | |||||
if set_to_none: | |||||
p.grad = None | |||||
else: | |||||
if p.grad.grad_fn is not None: | |||||
p.grad.detach_() | |||||
else: | |||||
p.grad.requires_grad_(False) | |||||
p.grad.zero_() | |||||
@staticmethod | |||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||||
if is_train: | |||||
if not isinstance(dataloader, DataLoader): | |||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") | |||||
# todo 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | |||||
if isinstance(dataloader.dataset, IterableDataset): | |||||
raise TypeError("`IterableDataset` is not allowed.") | |||||
else: | |||||
if not isinstance(dataloader, Dict): | |||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||||
else: | |||||
for each_dataloader in dataloader.values(): | |||||
if not isinstance(each_dataloader, DataLoader): | |||||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'DataLoader' " | |||||
f"type, not {type(each_dataloader)}.") | |||||
if isinstance(each_dataloader.dataset, IterableDataset): | |||||
raise TypeError("`IterableDataset` is not allowed.") | |||||
@staticmethod | |||||
def _check_optimizer_legality(optimizers): | |||||
for each_optimizer in optimizers: | |||||
if not isinstance(each_optimizer, Optimizer): | |||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||||
f"not {type(each_optimizer)}.") | |||||
def check_evaluator_mode(self, mode: str): | |||||
model = self.unwrap_model() | |||||
if mode == "validate": | |||||
if not hasattr(model, "validate_step"): | |||||
if hasattr(model, "test_step"): | |||||
logger.warning( | |||||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||||
"'validate_step'.") | |||||
else: | |||||
if not hasattr(model, "test_step"): | |||||
if hasattr(model, "validate_step"): | |||||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||||
"'test_step'.") | |||||
@staticmethod | |||||
def tensor_to_numeric(tensor, reduce=None): | |||||
if tensor is None: | |||||
return None | |||||
def _translate(_data): | |||||
if _data.numel() == 1: | |||||
return _data.item() | |||||
if reduce is None: | |||||
return _data.tolist() | |||||
return _reduces[reduce](_data).item() | |||||
return apply_to_collection( | |||||
data=tensor, | |||||
dtype=torch.Tensor, | |||||
function=_translate | |||||
) | |||||
def set_model_mode(self, mode: str): | |||||
assert mode in {"train", "eval"} | |||||
getattr(self.model, mode)() | |||||
@rank_zero_call | |||||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): | |||||
""" | |||||
保存当前 driver 的模型到 folder 下。 | |||||
:param filepath: 保存到哪个文件夹; | |||||
:param only_state_dict: 是否只保存权重; | |||||
:param model_save_fn: | |||||
:return: | |||||
""" | |||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
torch.save(states, filepath) | |||||
else: | |||||
if self.model_device is not None: | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, torch.device("cpu")) | |||||
torch.save(model, filepath) | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, self.model_device) | |||||
else: | |||||
torch.save(model, filepath) | |||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
""" | |||||
从 folder 中加载权重并赋值到当前 driver 的模型上。 | |||||
:param filepath: 加载权重或模型的路径 | |||||
:param load_state_dict: 保存的内容是否只是权重。 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
model = self.unwrap_model() | |||||
res = torch.load(filepath, map_location='cpu') | |||||
if only_state_dict: | |||||
model.load_state_dict(res) | |||||
else: | |||||
model.load_state_dict(res.state_dict()) | |||||
@rank_zero_call | |||||
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 1. 保存模型的状态; | |||||
if should_save_model: | |||||
model = self.unwrap_model() | |||||
if only_state_dict: | |||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | |||||
torch.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
logger.debug("Save model state dict") | |||||
else: | |||||
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
logger.debug("Save model") | |||||
# 2. 保存 optimizers 的状态; | |||||
optimizers_state_dict = {} | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
optimizer_state = optimizer.state_dict() | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
logger.debug("Save optimizer state dict") | |||||
states["optimizers_state_dict"] = optimizers_state_dict | |||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
# 1. 加载 optimizers 的状态; | |||||
optimizers_state_dict = states["optimizers_state_dict"] | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"]) | |||||
logger.debug("Load optimizer state dict.") | |||||
# 2. 加载模型状态; | |||||
if should_load_model: | |||||
model = self.unwrap_model() | |||||
res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu') | |||||
if only_state_dict: | |||||
model.load_state_dict(res) | |||||
logger.debug("Load model state dict.") | |||||
else: | |||||
model.load_state_dict(res.state_dict()) | |||||
logger.debug("Load model.") | |||||
return states | |||||
def get_evaluate_context(self): | |||||
return torch.no_grad | |||||
@staticmethod | |||||
def move_model_to_device(model: "torch.nn.Module", device: "torch.device"): | |||||
if device is not None: | |||||
model.to(device) | |||||
def move_data_to_device(self, batch: "torch.Tensor"): | |||||
return torch_move_data_to_device(batch, self.data_device, self.non_blocking) | |||||
@staticmethod | |||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed | |||||
with ``seed_everything(seed, workers=True)``. | |||||
See also the PyTorch documentation on | |||||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | |||||
""" | |||||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | |||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
process_seed = torch.initial_seed() | |||||
# back out the base seed so we can use all the bits | |||||
base_seed = process_seed - worker_id | |||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||||
# use 128 bits (4 x 32-bit words) | |||||
np.random.seed(ss.generate_state(4)) | |||||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||||
torch_ss, stdlib_ss = ss.spawn(2) | |||||
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) | |||||
# use 128 bits expressed as an integer | |||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||||
random.seed(stdlib_seed) | |||||
def set_deterministic_dataloader(self, dataloader: "DataLoader"): | |||||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||||
dataloader.worker_init_fn = partial(self.worker_init_function, | |||||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx: int): | |||||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||||
if callable(getattr(dataloader.sampler, "set_epoch", None)): | |||||
dataloader.sampler.set_epoch(cur_epoch_idx) | |||||
@staticmethod | |||||
def get_dataloader_args(dataloader: "DataLoader"): | |||||
""" | |||||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||||
""" | |||||
@dataclass | |||||
class Res: | |||||
dataset: Optional[Dataset] = None | |||||
batch_sampler: Optional[BatchSampler] = None | |||||
sampler: Optional[Sampler] = None | |||||
batch_size: Optional[int] = None | |||||
shuffle: Optional[bool] = None | |||||
drop_last: Optional[bool] = None | |||||
res = Res() | |||||
# pytorch 的 DataLoader 一定会有 dataset 属性; | |||||
res.dataset = dataloader.dataset | |||||
# dataloader 使用的是 sampler; | |||||
if dataloader.batch_sampler is None: | |||||
res.sampler = dataloader.sampler | |||||
res.batch_size = 1 | |||||
res.shuffle = True if isinstance(dataloader.sampler, RandomSampler) else False | |||||
res.drop_last = False | |||||
# dataloader 使用的是 batch_sampler; | |||||
else: | |||||
res.batch_sampler = dataloader.batch_sampler | |||||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; | |||||
else: | |||||
dataloader_iter = iter(dataloader) | |||||
pre_sample = next(dataloader_iter) | |||||
res.batch_size = pre_sample.shape[0] | |||||
if hasattr(dataloader.batch_sampler, "sampler"): | |||||
res.sampler = dataloader.batch_sampler.sampler | |||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
else: | |||||
# 如果 dataloader.batch_sampler 没有 sampler 这个属性,那么说明其使用的是自己的 batch_sampler,且没有 "sampler" 属性; | |||||
# 这种情况下 DataLoader 会自己初始化一个 sampler;我们因此将这个默认初始化的 sampler 挂载到 res 上; | |||||
res.sampler = dataloader.sampler | |||||
res.shuffle = False | |||||
if hasattr(dataloader.batch_sampler, "drop_last"): | |||||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; | |||||
else: | |||||
res.drop_last = False | |||||
return res |
@@ -0,0 +1,374 @@ | |||||
import os | |||||
from typing import Any, Dict, Optional | |||||
from enum import IntEnum | |||||
import contextlib | |||||
import random | |||||
import numpy as np | |||||
import inspect | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
# import torch.nn as nn | |||||
from torch.nn import Module | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
from torch.utils.data.sampler import Sampler | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
__all__ = [ | |||||
'torch_seed_everything', | |||||
'optimizer_state_to_device' | |||||
] | |||||
from fastNLP.core.utils import auto_param_call | |||||
from fastNLP.envs import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS | |||||
from fastNLP.core.log import logger | |||||
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: | |||||
return random.randint(min_seed_value, max_seed_value) | |||||
def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: | |||||
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, | |||||
sets the following environment variables: | |||||
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). | |||||
- `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. | |||||
Args: | |||||
seed: the integer value seed for global random state in Lightning. | |||||
If `None`, will read seed from `PL_GLOBAL_SEED` env variable | |||||
or select it randomly. | |||||
workers: if set to ``True``, will properly configure all dataloaders passed to the | |||||
Trainer with a ``worker_init_fn``. If the user already provides such a function | |||||
for their dataloaders, setting this argument will have no influence. See also: | |||||
:func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. | |||||
""" | |||||
max_seed_value = np.iinfo(np.uint32).max | |||||
min_seed_value = np.iinfo(np.uint32).min | |||||
if seed is None: | |||||
env_seed = os.environ.get(FASTNLP_GLOBAL_SEED) | |||||
if env_seed is None: | |||||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||||
# rank_zero_warn(f"No seed found, seed set to {seed}") | |||||
else: | |||||
try: | |||||
seed = int(env_seed) | |||||
except ValueError: | |||||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||||
# rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") | |||||
elif not isinstance(seed, int): | |||||
seed = int(seed) | |||||
if not (min_seed_value <= seed <= max_seed_value): | |||||
logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") | |||||
# rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") | |||||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||||
# using `log.info` instead of `rank_zero_info`, | |||||
# so users can verify the seed is properly set in distributed training. | |||||
# log.info(f"Global seed set to {seed}") | |||||
random.seed(seed) | |||||
np.random.seed(seed) | |||||
torch.manual_seed(seed) | |||||
torch.cuda.manual_seed_all(seed) | |||||
os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" | |||||
return seed | |||||
def reset_seed() -> None: | |||||
""" | |||||
这个函数主要是给 ddp 用的,因为 ddp 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | |||||
进行随机数的设置; | |||||
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. | |||||
""" | |||||
seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) | |||||
workers = os.environ.get(FASTNLP_SEED_WORKERS, "0") | |||||
if seed is not None: | |||||
torch_seed_everything(int(seed), workers=bool(int(workers))) | |||||
class ForwardState(IntEnum): | |||||
TRAIN = 0 | |||||
VALIDATE = 1 | |||||
TEST = 2 | |||||
PREDICT = 3 | |||||
_MODE_PARAMETER = "_forward_state" | |||||
class _DDPWrappingModel(Module): | |||||
""" | |||||
该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; | |||||
之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; | |||||
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'validate_step' 等; | |||||
然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 | |||||
`model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; | |||||
因此出于以上考虑,我们实现了这一函数; | |||||
对于更详细的解释,可以参考 'pytorch_lightning' 的 ddp 的设计; | |||||
""" | |||||
def __init__(self, model: Module): | |||||
super(_DDPWrappingModel, self).__init__() | |||||
self.model = model | |||||
if hasattr(model, "train_step"): | |||||
self._train_step = model.train_step | |||||
self._train_signature_fn = None | |||||
else: | |||||
self._train_step = model | |||||
self._train_signature_fn = model.forward | |||||
if hasattr(model, "validate_step"): | |||||
self._validate_step = model.validate_step | |||||
self._validate_signature_fn = None | |||||
elif hasattr(model, "test_step"): | |||||
self._validate_step = model.test_step | |||||
self._validate_signature_fn = None | |||||
else: | |||||
self._validate_step = model | |||||
self._validate_signature_fn = model.forward | |||||
if hasattr(model, "test_step"): | |||||
self._test_step = model.test_step | |||||
self._test_signature_fn = None | |||||
elif hasattr(model, "validate_step"): | |||||
self._test_step = model.validate_step | |||||
self._test_signature_fn = None | |||||
else: | |||||
self._test_step = model | |||||
self._test_signature_fn = model.forward | |||||
def forward(self, batch, **kwargs) -> Dict: | |||||
""" | |||||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | |||||
""" | |||||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
if _forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||||
else: | |||||
return self._train_step(batch) | |||||
elif _forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||||
else: | |||||
return self._validate_step(batch) | |||||
elif _forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict): | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||||
else: | |||||
return self._test_step(batch) | |||||
elif _forward_state == ForwardState.PREDICT: | |||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||||
else: | |||||
raise NotImplementedError("You should direct a concrete mode.") | |||||
class DummyGradScaler: | |||||
""" | |||||
用于Dummy pytorch的GradScaler对象,防止重复写大量的if判断 | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def get_scale(self): | |||||
return 1.0 | |||||
def is_enabled(self): | |||||
return False | |||||
def scale(self, outputs): | |||||
return outputs | |||||
def step(self, optimizer, *args, **kwargs): | |||||
optimizer.step(*args, **kwargs) | |||||
def update(self, new_scale=None): | |||||
pass | |||||
def unscale_(self, optimizer): | |||||
pass | |||||
def load_state_dict(self, state_dict): | |||||
pass | |||||
def state_dict(self): | |||||
return {} | |||||
def _build_fp16_env(dummy=False): | |||||
if dummy: | |||||
autocast = contextlib.ExitStack | |||||
GradScaler = DummyGradScaler | |||||
else: | |||||
if not torch.cuda.is_available(): | |||||
raise RuntimeError("No cuda") | |||||
if torch.cuda.get_device_capability(0)[0] < 7: | |||||
logger.warning( | |||||
"NOTE: your device does NOT support faster training with fp16, " | |||||
"please switch to FP32 which is likely to be faster" | |||||
) | |||||
try: | |||||
from torch.cuda.amp import autocast, GradScaler | |||||
except ImportError: | |||||
raise RuntimeError("torch version too low (less than 1.6)") | |||||
return autocast, GradScaler | |||||
def replace_sampler(dataloader: "DataLoader", sampler): | |||||
""" | |||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于): | |||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | |||||
`inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader | |||||
的类,而不是直接的 DataLoader; | |||||
如果需要定制自己的 dataloader,保证以下两点: | |||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; | |||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 | |||||
来获取实际的参数的值; | |||||
""" | |||||
# 拿到实例属性; | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | |||||
# 'multiprocessing_context' 是 user-defined function; | |||||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | |||||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | |||||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | |||||
# 中寻找; | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||||
if has_variadic_kwargs: | |||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||||
del init_params["self"] | |||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||||
non_default_params = {name for name, p in init_params.items() if | |||||
name in instance_attrs and p.default != instance_attrs[name]} | |||||
# add `dataset` as it might have been replaced with `*args` | |||||
non_default_params.add("dataset") | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | |||||
required_args = { | |||||
p.name | |||||
for p in init_params.values() | |||||
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) | |||||
and p.default is p.empty | |||||
and p.name not in reconstruct_args | |||||
} | |||||
# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上; | |||||
if required_args: | |||||
required_args = sorted(required_args) | |||||
dataloader_self_name = dataloader.__class__.__name__ | |||||
raise Exception( | |||||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||||
f"The missing attributes are {required_args}. " | |||||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||||
"manually add the `DistributedSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||||
) | |||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | |||||
if not has_variadic_kwargs: | |||||
# the dataloader signature does not allow keyword arguments that need to be passed | |||||
missing_kwargs = reconstruct_args.keys() - init_params.keys() | |||||
if missing_kwargs: | |||||
missing_kwargs = sorted(missing_kwargs) | |||||
dataloader_self_name = dataloader.__class__.__name__ | |||||
raise Exception( | |||||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||||
f"The missing arguments are {missing_kwargs}. " | |||||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||||
"manually add the `DistributedSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||||
) | |||||
return type(dataloader)(**reconstruct_args) | |||||
def _dataloader_init_kwargs_resolve_sampler( | |||||
dataloader: "DataLoader", sampler: Optional["Sampler"] | |||||
) -> Dict[str, Any]: | |||||
""" | |||||
此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; | |||||
""" | |||||
batch_sampler = getattr(dataloader, "batch_sampler") | |||||
# checking the batch sampler type is different than PyTorch default. | |||||
if batch_sampler is not None and type(batch_sampler) is not BatchSampler: | |||||
batch_sampler = type(batch_sampler)( | |||||
sampler, | |||||
batch_size=batch_sampler.batch_size, | |||||
drop_last=batch_sampler.drop_last, | |||||
) | |||||
return { | |||||
"sampler": None, | |||||
"shuffle": False, | |||||
"batch_sampler": batch_sampler, | |||||
"batch_size": 1, | |||||
"drop_last": False, | |||||
} | |||||
return {"sampler": sampler, "shuffle": False, "batch_sampler": None} | |||||
def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
"""Helper function to replace current batch sampler of the dataloader by a new batch sampler. Function returns new | |||||
dataloader with new batch sampler. | |||||
Args: | |||||
dataloader: input dataloader | |||||
new_batch_sampler: new batch sampler to use | |||||
Returns: | |||||
DataLoader | |||||
""" | |||||
params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] | |||||
for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]: | |||||
if k in params_keys: | |||||
params_keys.remove(k) | |||||
params = {k: getattr(dataloader, k) for k in params_keys} | |||||
params["batch_sampler"] = new_batch_sampler | |||||
return type(dataloader)(**params) | |||||
def optimizer_state_to_device(state, device): | |||||
new_state = {} | |||||
for name, param in state.items(): | |||||
if isinstance(param, dict): | |||||
new_state[name] = optimizer_state_to_device(param, device) | |||||
elif isinstance(param, torch.Tensor): | |||||
new_state[name] = param.to(device).clone() | |||||
else: | |||||
new_state[name] = param | |||||
return new_state | |||||
@@ -0,0 +1,89 @@ | |||||
from typing import Optional | |||||
from typing import Union, List | |||||
import subprocess | |||||
from pathlib import Path | |||||
from fastNLP.core.drivers.driver import Driver | |||||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | |||||
r""" | |||||
根据输入的参数 'gpus' 的格式来决定具体的工作模式; | |||||
:param model: 运行过程中使用的具体的最原始的模型; | |||||
:param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式; | |||||
:param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释; | |||||
:param kwargs: 其余的传给 `Driver` 的参数; | |||||
""" | |||||
# 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性; | |||||
if isinstance(driver, Driver): | |||||
return driver | |||||
if driver in {"torch", "torch_ddp", "fairscale"}: | |||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||||
return initialize_torch_driver(driver, device, model, **kwargs) | |||||
elif driver in {"jittor"}: | |||||
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver | |||||
return initialize_jittor_driver(driver, device, model, **kwargs) | |||||
elif driver in {"paddle", "fleet"}: | |||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | |||||
return initialize_paddle_driver(driver, device, model, **kwargs) | |||||
else: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " | |||||
"'jittor', 'paddle', 'fleet'].") | |||||
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | |||||
""" | |||||
使用 command 通过 subprocess.Popen 开启新的进程。 | |||||
:param output_from_new_proc: 可选 ["ignore", "all", "only_error"],以上三个为特殊关键字,分别表示完全忽略拉起进程的打印输出, | |||||
only_error 表示只打印错误输出流;all 表示子进程的所有输出都打印。如果不为以上的关键字,则表示一个文件夹,将在该文件夹下建立 | |||||
两个文件,名称分别为 {rank}_std.log, {rank}_err.log 。原有的文件会被直接覆盖。 | |||||
:param command: List[str] 启动的命令 | |||||
:param env_copy: 需要注入的环境变量。 | |||||
:param rank: | |||||
:return: | |||||
""" | |||||
if output_from_new_proc == "all": | |||||
proc = subprocess.Popen(command, env=env_copy) | |||||
elif output_from_new_proc == "only_error": | |||||
proc = subprocess.Popen(command, env=env_copy, stdout=subprocess.DEVNULL) | |||||
elif output_from_new_proc == "ignore": | |||||
proc = subprocess.Popen(command, env=env_copy, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |||||
else: | |||||
assert rank is not None | |||||
std_f = open(output_from_new_proc + f'/{rank}_std.log', 'w') | |||||
err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w') | |||||
proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f) | |||||
return proc | |||||
def load_model(filepath: Union[str, Path], backend: str = "torch", **kwargs): | |||||
r""" | |||||
对应 `load_model`,用来帮助用户加载之前通过 `load_model` 所保存的模型; | |||||
:param filepath: 加载的文件的位置; | |||||
:param backend: 使用哪种 backend 来加载该 filepath, 目前支持 ["torch", "paddle", "jittor"] 。 | |||||
""" | |||||
if filepath is None: | |||||
raise ValueError("Parameter `path` can not be None.") | |||||
assert backend is not None, "Parameter `backend` can not be None." | |||||
if backend == "torch": | |||||
import torch | |||||
_res = torch.load(filepath) | |||||
return _res | |||||
elif backend == "jittor": | |||||
raise NotImplementedError | |||||
elif backend == "paddle": | |||||
raise NotImplementedError | |||||
else: | |||||
raise ValueError("Parameter `backend` could only be one of these values: ['torch', 'jittor', 'paddle']") | |||||
@@ -0,0 +1,21 @@ | |||||
__all__ = [ | |||||
'BucketSampler', | |||||
'SortedSampler', | |||||
'ConstTokenNumSampler', | |||||
'ConstantTokenNumSampler', | |||||
'UnrepeatedDistributedSampler', | |||||
'MixSampler', | |||||
'InnerSampler', | |||||
'DopedSampler', | |||||
'MixSequentialSampler', | |||||
'PollingSampler', | |||||
'ReproducibleIterator', | |||||
'RandomSampler', | |||||
'ReproducibleBatchSampler', | |||||
're_instantiate_sampler' | |||||
] | |||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler | |||||
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, ReproducibleBatchSampler, re_instantiate_sampler | |||||
@@ -0,0 +1,659 @@ | |||||
import array | |||||
import numpy as np | |||||
from typing import Union, List, Iterable, Dict | |||||
__all__ = [ | |||||
'MixSampler', | |||||
'InnerSampler', | |||||
'DopedSampler', | |||||
'MixSequentialSampler', | |||||
'PollingSampler' | |||||
] | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import SequentialSampler, Sampler | |||||
import torch | |||||
class MixSampler: | |||||
""" | |||||
mix_sampler的基类 | |||||
""" | |||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None, | |||||
ds_ratio: Union[str, List[float], Dict[str, float]] = None, | |||||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | |||||
""" | |||||
:param dataset: 实现了__getitem__和__len__的数据容器列表 | |||||
:param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset | |||||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | |||||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | |||||
""" | |||||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||||
if isinstance(dataset, Dict) and isinstance(sampler, List): | |||||
raise ValueError(f"{sampler} must be dict") | |||||
if batch_size <= 0: | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got batch_size={}".format(batch_size)) | |||||
if not isinstance(drop_last, bool): | |||||
raise ValueError("drop_last should be a boolean value, but got " | |||||
"drop_last={}".format(drop_last)) | |||||
if not isinstance(sampler, str) and (rank >= 0 or word_size >= 0): | |||||
raise ValueError("if rank>=0 and word_size>=0, sampler must be str") | |||||
if sampler is None and (word_size < 0 or rank < 0): | |||||
if isinstance(dataset, List): | |||||
self.sampler = [SequentialSampler(ds) for ds in dataset] | |||||
elif isinstance(dataset, Dict): | |||||
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||||
elif isinstance(sampler, List): | |||||
if len(sampler) != len(dataset): | |||||
raise ValueError("the length of sampler != the length of sampler") | |||||
self.sampler = sampler | |||||
elif isinstance(sampler, Dict): | |||||
self.sampler = sampler | |||||
else: | |||||
# 单卡多机情况下, sampler为None或者str且word_size>0, rank > 0 | |||||
if isinstance(sampler, str): | |||||
if sampler not in ['seq', 'rand']: | |||||
raise ValueError(f"sampler is {sampler}, but seq or rand is required") | |||||
self.sampler = sampler | |||||
# 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len | |||||
sampler_lens, total_lens, sampler_index = [], 0, [] | |||||
if isinstance(self.sampler, List): | |||||
if ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(ds_ratio, List): | |||||
if not all(item >= 0 for item in ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
if ds_ratio is None: | |||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||||
elif ds_ratio == 'pad_to_most': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif ds_ratio == 'truncate_to_least': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif isinstance(ds_ratio, Dict): | |||||
if not all(item >= 0 for item in ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()] | |||||
else: | |||||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
# sampler为str时候,初始化下移到iter方法中 | |||||
if len(sampler_lens) > 0: | |||||
sampler_index = [sampler_lens[0]] | |||||
for idx in sampler_lens[1:]: | |||||
temp = sampler_index[-1] | |||||
sampler_index.append(temp + idx) | |||||
self.batch_size = batch_size | |||||
self.drop_last = drop_last | |||||
self.ds_ratio = ds_ratio | |||||
self.rank = rank | |||||
self.word_size = word_size | |||||
self.datasets = dataset | |||||
self.num_samplers = sampler_index | |||||
self.len_samplers = total_lens | |||||
self.epoch = 0 | |||||
def __iter__(self): | |||||
pass | |||||
def __len__(self): | |||||
pass | |||||
def set_epoch(self, epoch: int) -> None: | |||||
""" | |||||
配合ddp使用, 控制随机数种子 | |||||
:param epoch: 当前的轮次 | |||||
:return: | |||||
""" | |||||
self.epoch = epoch | |||||
class InnerSampler: | |||||
""" | |||||
提供多卡情况下使用的内部sampler | |||||
""" | |||||
def __init__(self, ds_ind_list: List) -> None: | |||||
self.ds_ind_list = ds_ind_list | |||||
def __iter__(self) -> int: | |||||
for item in self.ds_ind_list: | |||||
yield item | |||||
def __len__(self) -> int: | |||||
return len(self.ds_ind_list) | |||||
class DopedSampler(MixSampler): | |||||
""" | |||||
定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 | |||||
""" | |||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None, | |||||
ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, | |||||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | |||||
super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size, | |||||
sampler=sampler, ds_ratio=ds_ratio, | |||||
drop_last=drop_last, rank=rank, word_size=word_size) | |||||
def __iter__(self) -> List[int]: | |||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||||
if isinstance(self.sampler, str): | |||||
if self.sampler == 'seq': | |||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
elif self.sampler == 'rand': | |||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(indices)) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||||
if isinstance(self.sampler, List): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(self.ds_ratio, List): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif isinstance(self.ds_ratio, Dict): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * self.ds_ratio[name]) for name, spl in self.sampler.items()] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
else: | |||||
raise ValueError("datasets must be dict or list") | |||||
# 初始化参数 | |||||
sampler_index = [sampler_lens[0]] | |||||
for idx in sampler_lens[1:]: | |||||
temp = sampler_index[-1] | |||||
sampler_index.append(temp + idx) | |||||
self.num_samplers = sampler_index | |||||
self.len_samplers = total_lens | |||||
# 每个batch的数据, 总的数据量total_index, 每个数据集的samplers | |||||
batch_idx, samplers = [], [] | |||||
# 如果单机则用所有数据,否则采用多卡 | |||||
if self.rank < 0 or self.word_size < 0: | |||||
# 根据sampler长度判断是否使用unsigned int 或者unsigned long | |||||
if self.len_samplers > 42e8: | |||||
total_index = array.array('L', list(range(self.len_samplers))) | |||||
else: | |||||
total_index = array.array('I', list(range(self.len_samplers))) | |||||
else: | |||||
if (self.len_samplers // self.word_size) > 42e8: | |||||
# 整分给每个卡的数据 | |||||
self.len_samplers = self.len_samplers - self.len_samplers % self.word_size | |||||
total_index = array.array('L', list(range(self.len_samplers))[self.rank::self.word_size]) | |||||
else: | |||||
total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size]) | |||||
# 根据sampler的类型取出每个数据集的sampler | |||||
if isinstance(self.sampler, List): | |||||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||||
samplers = [(iter(spl), idx, base_index) | |||||
for idx, (spl, base_index) in enumerate(zip(self.sampler, sampler_base_index))] | |||||
else: | |||||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||||
samplers = [(iter(spl), name, sampler_base_index[idx]) | |||||
for idx, (name, spl) in enumerate(self.sampler.items())] | |||||
# 生成随机数 | |||||
np.random.seed(self.epoch) | |||||
np.random.shuffle(total_index) | |||||
for idx in total_index: | |||||
ds_index = np.searchsorted(self.num_samplers, idx, side='right') | |||||
spl, name, base_index = samplers[ds_index] | |||||
try: | |||||
batch_idx.append(next(spl) + base_index) | |||||
except StopIteration: | |||||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | |||||
spl = iter(self.sampler[name]) | |||||
batch_idx.append(next(spl) + base_index) | |||||
samplers[name] = (spl, name, base_index) | |||||
if len(batch_idx) == self.batch_size: | |||||
yield batch_idx | |||||
batch_idx = [] | |||||
if len(batch_idx) > 0 and not self.drop_last: | |||||
yield batch_idx | |||||
def __len__(self) -> int: | |||||
# 多卡情况下 | |||||
if self.rank >= 0 and self.word_size >= 0: | |||||
# 整分给每个卡的数据 | |||||
self.len_samplers = (self.len_samplers - self.len_samplers % self.word_size) / self.word_size | |||||
if self.drop_last: | |||||
return self.len_samplers // self.batch_size | |||||
else: | |||||
return (self.len_samplers + self.batch_size - 1) // self.batch_size | |||||
class MixSequentialSampler(MixSampler): | |||||
""" | |||||
定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表顺序采样并返回index,只有处理了上一个dataset才会处理下一个。 | |||||
""" | |||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None, | |||||
ds_ratio: Union[str, List[float], Dict[str, float]] = None, | |||||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | |||||
""" | |||||
:param dataset: 实现了__getitem__和__len__的数据容器列表 | |||||
:param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset | |||||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | |||||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | |||||
""" | |||||
super(MixSequentialSampler, self).__init__(dataset=dataset, batch_size=batch_size, | |||||
sampler=sampler, ds_ratio=ds_ratio, | |||||
drop_last=drop_last, rank=rank, word_size=word_size) | |||||
def __iter__(self) -> Iterable[List[int]]: | |||||
""" | |||||
按照dataset的顺序采样,打包成一个batch后返回 | |||||
:return: | |||||
""" | |||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||||
if isinstance(self.sampler, str): | |||||
if self.sampler == 'seq': | |||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
elif self.sampler == 'rand': | |||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(indices)) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||||
if isinstance(self.sampler, List): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(self.ds_ratio, List): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif isinstance(self.ds_ratio, Dict): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * self.ds_ratio[name]) for name, spl in self.sampler.items()] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
else: | |||||
raise ValueError("datasets must be dict or list") | |||||
# 初始化参数 | |||||
sampler_index = [sampler_lens[0]] | |||||
for idx in sampler_lens[1:]: | |||||
temp = sampler_index[-1] | |||||
sampler_index.append(temp + idx) | |||||
self.num_samplers = sampler_index | |||||
self.len_samplers = total_lens | |||||
batch_idx, total_index, samplers = [], list(range(self.len_samplers)), [] | |||||
if isinstance(self.sampler, List): | |||||
if self.word_size > 0 and self.rank >= 0: | |||||
sampler_base_index = [0] + [len(spl) * self.word_size for spl in self.sampler][:-1] | |||||
else: | |||||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||||
samplers = [(iter(spl), idx, base_index) for idx, (spl, base_index) in | |||||
enumerate(zip(self.sampler, sampler_base_index))] | |||||
else: | |||||
if self.word_size > 0 and self.rank >= 0: | |||||
sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] | |||||
else: | |||||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||||
samplers = [(iter(spl), name, sampler_base_index[idx]) | |||||
for idx, (name, spl) in enumerate(self.sampler.items())] | |||||
for idx in total_index: | |||||
ds_index = np.searchsorted(self.num_samplers, idx, side='right') | |||||
spl, name, base_index = samplers[ds_index] | |||||
try: | |||||
batch_idx.append(next(spl) + base_index) | |||||
except StopIteration: | |||||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | |||||
spl = iter(self.sampler[name]) | |||||
batch_idx.append(next(spl) + base_index) | |||||
samplers[name] = (spl, name, base_index) | |||||
if len(batch_idx) == self.batch_size: | |||||
yield batch_idx | |||||
batch_idx = [] | |||||
# 当前数据集采样完,需要及时处理最后一个batch | |||||
if self.num_samplers[ds_index] == (idx + 1): | |||||
if len(batch_idx) > 0 and not self.drop_last: | |||||
yield batch_idx | |||||
batch_idx = [] | |||||
def __len__(self) -> int: | |||||
lens, index = 0, 0 | |||||
num_sampler = [] | |||||
for ds_len in self.num_samplers: | |||||
num_sampler.append(ds_len - index) | |||||
index = ds_len | |||||
for ds_len in num_sampler: | |||||
if self.drop_last: | |||||
lens += ds_len // self.batch_size | |||||
else: | |||||
lens += (ds_len + self.batch_size - 1) // self.batch_size | |||||
return lens | |||||
class PollingSampler(MixSampler): | |||||
""" | |||||
定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表轮流采样并返回index,处理了上个dataset的一个batch后会处理下一个。 | |||||
""" | |||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = 16, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None, | |||||
drop_last: bool = False, ds_ratio="pad_to_most", rank: int = -1, | |||||
word_size: int = -1) -> None: | |||||
""" | |||||
:param dataset: 实现了__getitem__和__len__的数据容器列表 | |||||
:param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset | |||||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | |||||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | |||||
:param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时, | |||||
以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样 | |||||
""" | |||||
super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size, | |||||
sampler=sampler, ds_ratio=ds_ratio, | |||||
drop_last=drop_last, rank=rank, word_size=word_size) | |||||
def __iter__(self) -> List[int]: | |||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||||
if isinstance(self.sampler, str): | |||||
if self.sampler == 'seq': | |||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
elif self.sampler == 'rand': | |||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(indices)) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||||
if isinstance(self.sampler, List): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(self.ds_ratio, List): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [max(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_len = sum([1 for _ in self.sampler.keys()]) | |||||
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||||
elif isinstance(self.ds_ratio, Dict): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * self.ds_ratio[name]) for name, spl in self.sampler.items()] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
else: | |||||
raise ValueError("datasets must be dict or list") | |||||
# 初始化参数 | |||||
sampler_index = [sampler_lens[0]] | |||||
for idx in sampler_lens[1:]: | |||||
temp = sampler_index[-1] | |||||
sampler_index.append(temp + idx) | |||||
self.num_samplers = sampler_index | |||||
self.len_samplers = total_lens | |||||
start_idx, samplers = 0, [] | |||||
if isinstance(self.sampler, List): | |||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||||
for sampler_idx, (end_idx, spl) in enumerate(zip(self.num_samplers, self.sampler)): | |||||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, sampler_idx)) | |||||
start_idx = end_idx | |||||
else: | |||||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||||
end_idx = self.num_samplers[idx] | |||||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, name)) | |||||
start_idx = end_idx | |||||
while True: | |||||
# 退出循环 | |||||
if len(samplers) == 0: | |||||
break | |||||
batch_idx, flag = [], False | |||||
ds_total_iter, ds_sampler, ds_base_idx, sampler_idx = samplers.pop(0) | |||||
for _ in range(self.batch_size): | |||||
try: | |||||
# 取出数据 | |||||
next(ds_total_iter) | |||||
# 取出真正数据, 若取完则重新初始化一个 | |||||
try: | |||||
batch_idx.append(next(ds_sampler) + ds_base_idx) | |||||
except StopIteration: | |||||
ds_sampler = iter(self.sampler[sampler_idx]) | |||||
batch_idx.append(next(ds_sampler) + ds_base_idx) | |||||
except StopIteration: | |||||
# 当前ds所有的数据集采样完毕,将其清除队列 | |||||
flag = True | |||||
# 判断是否真正解决某个数据集的采样 | |||||
if flag is False: | |||||
samplers.append((ds_total_iter, ds_sampler, ds_base_idx, sampler_idx)) | |||||
if len(batch_idx) == self.batch_size: | |||||
yield batch_idx | |||||
elif len(batch_idx) > 0 and not self.drop_last: | |||||
yield batch_idx | |||||
def __len__(self) -> int: | |||||
lens, index = 0, 0 | |||||
num_sampler = [] | |||||
for ds_len in self.num_samplers: | |||||
num_sampler.append(ds_len - index) | |||||
index = ds_len | |||||
for ds_len in num_sampler: | |||||
if self.drop_last: | |||||
lens += ds_len // self.batch_size | |||||
else: | |||||
lens += (ds_len + self.batch_size - 1) // self.batch_size | |||||
return lens | |||||
if __name__ == '__main__': | |||||
from fastNLP.core.dataset import DataSet | |||||
ds = DataSet({'x': ["x1a", "1ws2", "xa qa", "ax wq", "iu, lk"] * 101, 'y': [1, 0, 1, 0, 1] * 101}) | |||||
ds1 = DataSet({'x': ["x12a", "1wzs2", "xa xqa", "aax wq", "iau, lk"] * 101, 'y': ['1', '0', '1', '0', '1'] * 101}) | |||||
sampler = DopedSampler(dataset=[ds, ds1], batch_size=6, rank=0, word_size=-2, sampler='seq') | |||||
seqSpl = MixSequentialSampler(dataset=[ds, ds1], batch_size=6, rank=0, word_size=2, sampler='seq', drop_last=True) | |||||
polSpl = PollingSampler(dataset=[ds, ds1], batch_size=6, rank=1, word_size=2, sampler='seq', drop_last=False) | |||||
for idx, batch in enumerate(polSpl): | |||||
print(idx, batch) | |||||
# print(len(seqSpl)) |
@@ -0,0 +1,315 @@ | |||||
from typing import Dict, List | |||||
import math | |||||
import numpy as np | |||||
from array import array | |||||
from copy import deepcopy | |||||
__all__ = [ | |||||
'ReproducibleIterator', | |||||
'RandomSampler', | |||||
'ReproducibleBatchSampler', | |||||
're_instantiate_sampler' | |||||
] | |||||
def re_instantiate_sampler(sampler): | |||||
all_attributes = vars(sampler) | |||||
return type(sampler)(**all_attributes) | |||||
class ReproducibleIterator: | |||||
""" | |||||
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||||
或者 batch_sampler; | |||||
""" | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.") | |||||
def __len__(self): | |||||
raise NotImplementedError("Each specific sampler should implement its own `__len__` method.") | |||||
def __iter__(self): | |||||
raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | |||||
def state_dict(self): | |||||
raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | |||||
def load_state_dict(self, states): | |||||
raise NotImplementedError("Each specific sampler should implement its own `load_state_dict` method.") | |||||
@property | |||||
def num_left_samples(self): | |||||
raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.") | |||||
def set_epoch(self, epoch): | |||||
pass | |||||
class RandomSampler(ReproducibleIterator): | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||||
self.dataset = dataset | |||||
self.shuffle = shuffle | |||||
self.seed = seed | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||||
# 多卡的相关的参数 | |||||
self.num_replicas = kwargs.get("num_replicas", 1) | |||||
self.rank = kwargs.get("rank", 0) | |||||
self.epoch = kwargs.get("epoch", -1) | |||||
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | |||||
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | |||||
self._during_iter = kwargs.get("_during_iter", False) | |||||
def __len__(self): | |||||
""" | |||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | |||||
:return: | |||||
""" | |||||
return self.total_size//self.num_replicas | |||||
def __iter__(self): | |||||
r""" | |||||
当前使用num_consumed_samples做法会在交替使用的时候遇到问题; | |||||
Example: | |||||
>>> sampler = RandomSampler() | |||||
>>> iter1 = iter(sampler) | |||||
>>> iter2 = iter(sampler) | |||||
>>> next(iter1) | |||||
>>> next(iter2) # 当前num_consumed_samples的数量会发生变化 | |||||
""" | |||||
if self._during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self._during_iter = True | |||||
indices = self.generate_indices() | |||||
if self.pad: | |||||
# add extra samples to make it evenly divisible | |||||
padding_size = self.total_size - len(indices) | |||||
if padding_size <= len(indices): | |||||
indices += indices[:padding_size] | |||||
else: | |||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||||
else: | |||||
# remove tail of data to make it evenly divisible. | |||||
indices = indices[:self.total_size] | |||||
assert len(indices) == self.total_size | |||||
# subsample | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
assert len(indices) == self.num_left_samples | |||||
for index in indices: | |||||
self.num_consumed_samples += self.num_replicas | |||||
yield index | |||||
self._during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
def generate_indices(self) -> List[int]: | |||||
""" | |||||
生成随机序列 | |||||
:return: | |||||
""" | |||||
if self.shuffle: | |||||
indices = list(range(len(self.dataset))) | |||||
seed = self.seed + self.epoch | |||||
rng = np.random.default_rng(abs(seed)) | |||||
rng.shuffle(indices) | |||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||||
self.epoch -= 1 | |||||
else: | |||||
indices = list(range(len(self.dataset))) | |||||
return indices | |||||
def state_dict(self) -> Dict: | |||||
states = { | |||||
'seed': self.seed, | |||||
'epoch': self.epoch, | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'shuffle': self.shuffle | |||||
} | |||||
return states | |||||
def load_state_dict(self, states: Dict): | |||||
# 如果 self._during_iter 是 True,那么 data_idx 一定是 0; | |||||
assert self._during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
"during an unfinished iteration." | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
length = states['length'] | |||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | |||||
self.seed = states['seed'] | |||||
self.epoch = states['epoch'] | |||||
self.num_consumed_samples = states['num_consumed_samples'] | |||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | |||||
self.shuffle = states["shuffle"] | |||||
def set_epoch(self, epoch: int) -> None: | |||||
self.epoch = epoch | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
""" | |||||
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; | |||||
:param num_replicas: | |||||
:param rank: | |||||
:param pad: 这个 pad 的意思是指如果 sample 数量不整除 num_replicas 的时候,要不要 pad 一下,使得最终使得 replica 上 | |||||
的 sample 数量是完全一致的。 | |||||
:return: | |||||
""" | |||||
assert self._during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||||
"during an unfinished iteration." | |||||
assert num_replicas>0 and isinstance(num_replicas, int) | |||||
assert isinstance(rank, int) and 0<=rank<num_replicas | |||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||||
self.num_replicas = num_replicas | |||||
self.rank = rank | |||||
self.pad = pad | |||||
return self | |||||
@property | |||||
def total_size(self): | |||||
""" | |||||
这个变量代表的含义是当前这个sampler会最终产生出的index数量,因为replica和pad的原因,这个值可能等于、大于或者小于len(dataset) | |||||
:return: | |||||
""" | |||||
return self.num_consumed_samples + self.num_replicas*self.num_left_samples | |||||
@property | |||||
def num_left_samples(self): | |||||
""" | |||||
返回当前 iteration 还有多少个 sample 结束 | |||||
:return: | |||||
""" | |||||
num_consumed_samples = self.num_consumed_samples | |||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
class ReproducibleBatchSampler: | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||||
self.batch_sampler = batch_sampler | |||||
self.batch_size = batch_size | |||||
self.drop_last = drop_last | |||||
self.data_idx = kwargs.get("data_idx", 0) | |||||
self._index_list = kwargs.get("_index_list", self._iterate_sampler()) | |||||
self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||||
def _iterate_sampler(self): | |||||
_index_lst = [] | |||||
for idx in self.batch_sampler: | |||||
if isinstance(idx, list): | |||||
_index_lst.extend(idx) | |||||
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||||
else: | |||||
_index_lst.append(idx) | |||||
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||||
if len(_index_lst) > 4294967295: | |||||
# 注意 self._index_list 内存放的是全部数据的 index; | |||||
# unsigned long | |||||
_index_lst = array("L", _index_lst) | |||||
else: | |||||
# unsigned int | |||||
_index_lst = array("I", _index_lst) | |||||
return _index_lst | |||||
def __iter__(self): | |||||
if self.need_reinitialize: | |||||
self._index_list = self._iterate_sampler() | |||||
self.data_idx = 0 | |||||
else: | |||||
self.need_reinitialize = True | |||||
batch = [] | |||||
if self.data_idx: | |||||
index_list = self._index_list[self.data_idx:] | |||||
else: | |||||
index_list = self._index_list | |||||
for idx in index_list: | |||||
batch.append(idx) | |||||
self.data_idx += 1 | |||||
if len(batch) == self.batch_size: | |||||
yield batch | |||||
batch = [] | |||||
if len(batch) > 0 and not self.drop_last: | |||||
yield batch | |||||
def __len__(self) -> int: | |||||
if self.drop_last: | |||||
return len(self._index_list) // self.batch_size | |||||
else: | |||||
return (len(self._index_list) + self.batch_size - 1) // self.batch_size | |||||
def state_dict(self) -> Dict: | |||||
return {"index_list": deepcopy(self._index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||||
def load_state_dict(self, states: Dict): | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
_index_list = states["index_list"] | |||||
assert len(_index_list) == len(self._index_list), "The number of samples is different between the checkpoint " \ | |||||
"record and current dataset." | |||||
self._index_list = _index_list | |||||
self.data_idx = states["data_idx"] | |||||
self.need_reinitialize = False | |||||
def set_distributed(self): | |||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||||
def set_epoch(self, epoch): | |||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | |||||
self.batch_sampler.sampler.set_epoch(epoch) | |||||
@property | |||||
def batch_idx_in_epoch(self): | |||||
if self.drop_last: | |||||
return len(self._index_list) // self.batch_size - (len(self._index_list) - self.data_idx) // self.batch_size | |||||
else: | |||||
return (len(self._index_list) + self.batch_size - 1) // self.batch_size - \ | |||||
(len(self._index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||||
# todo | |||||
# class SortedSampler(ReproducibleIterator): | |||||
# def __init__(self, dataset, key): | |||||
# pass | |||||
# | |||||
# | |||||
# class BucketedSampler(ReproducibleIterator): | |||||
# def __init__(self, dataset, key): | |||||
# pass | |||||
if __name__ == "__main__": | |||||
sampler = RandomSampler(1) | |||||
print(vars(sampler)) | |||||
batch_sampler = ReproducibleBatchSampler(list(range(3)), 1, True) | |||||
print(vars(batch_sampler)) | |||||
@@ -0,0 +1,813 @@ | |||||
r""" | |||||
sampler 子类实现了 fastNLP 所需的各种采样器。 | |||||
""" | |||||
__all__ = [ | |||||
"BucketSampler", | |||||
"SortedSampler", | |||||
'ConstTokenNumSampler', | |||||
"ConstantTokenNumSampler", | |||||
"UnrepeatedDistributedSampler", | |||||
] | |||||
from itertools import chain | |||||
from typing import List, Iterable | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import SequentialSampler, Sampler, RandomSampler | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Sampler | |||||
# class DopedSampler(Sampler): | |||||
# """ | |||||
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 | |||||
# """ | |||||
# | |||||
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||||
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None, | |||||
# ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, drop_last: bool = False) -> None: | |||||
# if batch_size <= 0: | |||||
# raise ValueError("batch_size should be a positive integer value, " | |||||
# "but got batch_size={}".format(batch_size)) | |||||
# if not isinstance(drop_last, bool): | |||||
# raise ValueError("drop_last should be a boolean value, but got " | |||||
# "drop_last={}".format(drop_last)) | |||||
# self.batch_size = batch_size | |||||
# self.drop_last = drop_last | |||||
# self.ds_ratio = ds_ratio | |||||
# if sampler is None: | |||||
# if isinstance(dataset, List): | |||||
# self.sampler = [SequentialSampler(ds) for ds in dataset] | |||||
# elif isinstance(dataset, Dict): | |||||
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||||
# | |||||
# elif isinstance(sampler, List): | |||||
# if len(sampler) != len(dataset): | |||||
# raise ValueError("the length of sampler != the length of sampler") | |||||
# self.sampler = sampler | |||||
# else: | |||||
# self.sampler = sampler | |||||
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None: | |||||
# self.ds_ratio = ds_ratio | |||||
# elif isinstance(ds_ratio, List): | |||||
# if not all(item >= 0 for item in ds_ratio): | |||||
# raise ValueError("batch_size should be a positive integer value, " | |||||
# "but got batch_size={}".format(ds_ratio)) | |||||
# self.ds_ratio = ds_ratio | |||||
# else: | |||||
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None") | |||||
# | |||||
# def __iter__(self): | |||||
# samplers, index = [], 0 | |||||
# if isinstance(self.sampler, List): | |||||
# for idx, sampler in enumerate(self.sampler): | |||||
# samplers.append((iter(sampler), self.batch_size, index, 0, idx)) | |||||
# index += len(sampler) | |||||
# elif isinstance(self.sampler, Dict): | |||||
# for name, sampler in self.sampler.items(): | |||||
# samplers.append((iter(sampler), self.batch_size, index, 0, name)) | |||||
# index += len(sampler) | |||||
# | |||||
# def __len__(self): | |||||
# lens = 0 | |||||
# max_len, ds_len = 0, 0 | |||||
# if self.ds_ratio == 'truncate_to_least': | |||||
# if isinstance(self.sampler, List): | |||||
# max_len = min(len(sampler) for sampler in self.sampler) | |||||
# ds_len = len(self.sampler) | |||||
# elif isinstance(self.sampler, Dict): | |||||
# max_len = min(len(sampler) for _, sampler in self.sampler.items()) | |||||
# for _, _ in self.sampler.items(): | |||||
# ds_len += 1 | |||||
# | |||||
# elif self.ds_ratio == 'pad_to_most': | |||||
# if isinstance(self.sampler, List): | |||||
# max_len = max(len(sampler) for sampler in self.sampler) | |||||
# ds_len = len(self.sampler) | |||||
# elif isinstance(self.sampler, Dict): | |||||
# max_len = max(len(sampler) for _, sampler in self.sampler.items()) | |||||
# for _, _ in self.sampler.items(): | |||||
# ds_len += 1 | |||||
# | |||||
# if self.ds_ratio is None: | |||||
# if isinstance(self.sampler, List): | |||||
# for i in range(len(self.sampler)): | |||||
# sampler = self.sampler[i] | |||||
# if self.drop_last: | |||||
# lens += len(sampler) // self.batch_size | |||||
# else: | |||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size | |||||
# elif isinstance(self.sampler, Dict): | |||||
# for name, sampler in self.sampler.items(): | |||||
# if self.drop_last: | |||||
# lens += len(sampler) // self.batch_size | |||||
# else: | |||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size | |||||
# elif self.ds_ratio == 'truncate_to_least' or self.ds_ratio == 'pad_to_most': | |||||
# for i in range(ds_len): | |||||
# if self.drop_last: | |||||
# lens += max_len // self.batch_size | |||||
# else: | |||||
# lens += (max_len + self.batch_size - 1) // self.batch_size | |||||
# return lens | |||||
# | |||||
# def demo(self): | |||||
# indexes = np.array([0]*self.batch_size + [1]*self.batch_size + [2]*self.batch_size) | |||||
# shift = np.array([0]*self.batch_size + [len(ds1)]*self.batch_size + [len(ds1)+len(ds2)]*self.batch_size) | |||||
# buffer = np.zeros(self.batch_size*self.num_ds, dtype=int) | |||||
# select_sampler = np.random.randint(0, self.batch_size*self.num_ds, num_sample=self.batch_size) | |||||
# select_indices = buffer[select_sampler] + shift[select_sampler] | |||||
# num_1 = (indexes[select_sampler]==0).sum() | |||||
# | |||||
# class MixSequentialSampler(Sampler): | |||||
# """ | |||||
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表顺序采样并返回index,只有处理了上一个dataset才会处理下一个。 | |||||
# """ | |||||
# | |||||
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||||
# sampler: Union[List[Sampler], Dict[str, Sampler], None] = None, | |||||
# drop_last: bool = False) -> None: | |||||
# """ | |||||
# | |||||
# :param dataset: 实现了__getitem__和__len__的数据容器列表 | |||||
# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset | |||||
# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | |||||
# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | |||||
# """ | |||||
# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||||
# if isinstance(dataset, Dict) and isinstance(sampler, List): | |||||
# raise ValueError(f"{sampler} must be dict") | |||||
# | |||||
# # 判断batch_size是否大于等于0 | |||||
# if batch_size <= 0: | |||||
# raise ValueError("batch_size should be a positive integer value, " | |||||
# "but got batch_size={}".format(batch_size)) | |||||
# | |||||
# if not isinstance(drop_last, bool): | |||||
# raise ValueError("drop_last should be a boolean value, but got " | |||||
# "drop_last={}".format(drop_last)) | |||||
# self.batch_size = batch_size | |||||
# self.drop_last = drop_last | |||||
# if sampler is None: | |||||
# if isinstance(dataset, List): | |||||
# self.sampler = [SequentialSampler(ds) for ds in dataset] | |||||
# elif isinstance(dataset, Dict): | |||||
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||||
# elif isinstance(sampler, List): | |||||
# if len(sampler) != len(dataset): | |||||
# raise ValueError("the length of sampler != the length of sampler") | |||||
# self.sampler = sampler | |||||
# | |||||
# def __iter__(self) -> Iterable[List[int]]: | |||||
# """ | |||||
# 按照dataset的顺序采样,打包成一个batch后返回 | |||||
# :return: | |||||
# """ | |||||
# index = 0 | |||||
# batch = [] | |||||
# if isinstance(self. sampler, List): | |||||
# for i in range(len(self.sampler)): | |||||
# sampler = self.sampler[i] | |||||
# for idx in sampler: | |||||
# batch.append(idx + index) | |||||
# if len(batch) == self.batch_size: | |||||
# yield batch | |||||
# batch = [] | |||||
# if len(batch) > 0 and not self.drop_last: | |||||
# yield batch | |||||
# batch = [] | |||||
# index += len(sampler) | |||||
# elif isinstance(self.sampler, Dict): | |||||
# for name, sampler in self.sampler.items(): | |||||
# for idx in sampler: | |||||
# batch.append(idx + index) | |||||
# if len(batch) == self.batch_size: | |||||
# yield batch | |||||
# batch = [] | |||||
# if len(batch) > 0 and not self.drop_last: | |||||
# yield batch | |||||
# batch = [] | |||||
# index += len(sampler) | |||||
# | |||||
# def __len__(self) -> int: | |||||
# lens = 0 | |||||
# if isinstance(self.sampler, List): | |||||
# for i in range(len(self.sampler)): | |||||
# sampler = self.sampler[i] | |||||
# if self.drop_last: | |||||
# lens += len(sampler) // self.batch_size | |||||
# else: | |||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size | |||||
# elif isinstance(self.sampler, Dict): | |||||
# for _, sampler in self.sampler.items(): | |||||
# if self.drop_last: | |||||
# lens += len(sampler) // self.batch_size | |||||
# else: | |||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size | |||||
# return lens | |||||
# class PollingSampler(Sampler): | |||||
# """ | |||||
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表轮流采样并返回index,处理了上个dataset的一个batch后会处理下一个。 | |||||
# """ | |||||
# | |||||
# def __init__(self, dataset: Union[List, Dict], batch_size: int = 16, | |||||
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None, | |||||
# drop_last: bool = False, ds_ratio="pad_to_most") -> None: | |||||
# """ | |||||
# | |||||
# :param dataset: 实现了__getitem__和__len__的数据容器列表 | |||||
# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset | |||||
# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | |||||
# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | |||||
# :param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时, | |||||
# 以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样 | |||||
# """ | |||||
# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||||
# if isinstance(dataset, Dict) and isinstance(sampler, List): | |||||
# raise ValueError(f"{sampler} must be dict") | |||||
# if isinstance(dataset, List) and isinstance(sampler, Dict): | |||||
# raise ValueError(f"{sampler} must be list") | |||||
# # 判断batch_size是否大于等于0 | |||||
# if batch_size <= 0: | |||||
# raise ValueError("batch_size should be a positive integer value, " | |||||
# "but got batch_size={}".format(batch_size)) | |||||
# | |||||
# if not isinstance(drop_last, bool): | |||||
# raise ValueError("drop_last should be a boolean value, but got " | |||||
# "drop_last={}".format(drop_last)) | |||||
# | |||||
# self.batch_size = batch_size | |||||
# self.drop_last = drop_last | |||||
# if sampler is None: | |||||
# if isinstance(dataset, List): | |||||
# self.sampler = [SequentialSampler(ds) for ds in dataset] | |||||
# elif isinstance(dataset, Dict): | |||||
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||||
# | |||||
# elif isinstance(sampler, List): | |||||
# if len(sampler) != len(dataset): | |||||
# raise ValueError("the length of sampler != the length of sampler") | |||||
# self.sampler = sampler | |||||
# else: | |||||
# self.sampler = sampler | |||||
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None: | |||||
# self.ds_ratio = ds_ratio | |||||
# else: | |||||
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None") | |||||
# | |||||
# def __iter__(self) -> Iterable[List[int]]: | |||||
# # index是数据集下标基址, pointer指向数据集列表的某个数据集 | |||||
# index, pointer, samplers, flag = 0, 0, [], False | |||||
# | |||||
# if isinstance(self.sampler, List): | |||||
# for idx, sampler in enumerate(self.sampler): | |||||
# samplers.append((iter(sampler), self.batch_size, index, 0, idx)) | |||||
# index += len(sampler) | |||||
# elif isinstance(self.sampler, Dict): | |||||
# for name, sampler in self.sampler.items(): | |||||
# samplers.append((iter(sampler), self.batch_size, index, 0, name)) | |||||
# index += len(sampler) | |||||
# if self.ds_ratio == 'pad_to_most': | |||||
# if isinstance(self.sampler, List): | |||||
# limit_len = max(len(ds) for ds in self.sampler) | |||||
# else: | |||||
# limit_len = max(len(ds) for _, ds in self.sampler.items()) | |||||
# elif self.ds_ratio == 'truncate_to_least': | |||||
# if isinstance(self.sampler, List): | |||||
# limit_len = min(len(ds) for ds in self.sampler) | |||||
# else: | |||||
# limit_len = min(len(ds) for _, ds in self.sampler.items()) | |||||
# else: | |||||
# limit_len = 0 | |||||
# # 最后一个批次的大小 | |||||
# last_batch_size = limit_len % self.batch_size | |||||
# | |||||
# while True: | |||||
# # 全部采样完,退出 | |||||
# if len(samplers) == 0: | |||||
# break | |||||
# batch, flag = [], False | |||||
# # sampler_len代表已经取出来的数据个数 | |||||
# sampler, batch_size, index, sampler_len, name = samplers.pop(0) | |||||
# for _ in range(batch_size): | |||||
# try: | |||||
# batch.append(index + next(sampler)) | |||||
# sampler_len += 1 | |||||
# except StopIteration: | |||||
# flag = True | |||||
# # ds_ratio为None,第一种情况,删除掉采样完的数据即可。 | |||||
# if self.ds_ratio == 'pad_to_most' and sampler_len < limit_len: | |||||
# # 重置sampler,并取足一个batch数据 | |||||
# sampler = iter(self.sampler[name]) | |||||
# # 由于batch_size一定小于等于ds的长度,故能够取足一个batch_size的数据 | |||||
# for _ in range(batch_size-len(batch)): | |||||
# batch.append(next(sampler) + index) | |||||
# sampler_len += 1 | |||||
# break | |||||
# | |||||
# # ds_ratio不为None情况 | |||||
# # 两种情况会触发一下逻辑:1.truncate_to_least时,最短的数据集最后一个batch大小不等于batch_size时, | |||||
# # 其他较长的数据集的最后一个batch长度会较长;2. pad_to_most,最长的数据集最后一个batch不等于batch_size时,较短数据集最后一个 | |||||
# # batch长度会较长 | |||||
# if limit_len != 0 and limit_len < sampler_len: | |||||
# batch = batch[:last_batch_size] | |||||
# # ds_ratio为任意情况下, 没有取完所有数据,则添加到队列尾部 | |||||
# elif (limit_len == 0 and flag == False) or limit_len > sampler_len: | |||||
# samplers.append((sampler, batch_size, index, sampler_len, name)) | |||||
# if len(batch) == batch_size: | |||||
# yield batch | |||||
# elif len(batch) > 0 and not self.drop_last: | |||||
# yield batch | |||||
# | |||||
# def __len__(self) -> int: | |||||
# lens = 0 | |||||
# max_len, ds_len = 0, 0 | |||||
# if self.ds_ratio == 'truncate_to_least': | |||||
# if isinstance(self.sampler, List): | |||||
# max_len = min(len(sampler) for sampler in self.sampler) | |||||
# ds_len = len(self.sampler) | |||||
# elif isinstance(self.sampler, Dict): | |||||
# max_len = min(len(sampler) for _, sampler in self.sampler.items()) | |||||
# for _, _ in self.sampler.items(): | |||||
# ds_len += 1 | |||||
# | |||||
# elif self.ds_ratio == 'pad_to_most': | |||||
# if isinstance(self.sampler, List): | |||||
# max_len = max(len(sampler) for sampler in self.sampler) | |||||
# ds_len = len(self.sampler) | |||||
# elif isinstance(self.sampler, Dict): | |||||
# max_len = max(len(sampler) for _, sampler in self.sampler.items()) | |||||
# for _, _ in self.sampler.items(): | |||||
# ds_len += 1 | |||||
# if self.ds_ratio is None: | |||||
# if isinstance(self.sampler, List): | |||||
# for i in range(len(self.sampler)): | |||||
# sampler = self.sampler[i] | |||||
# if self.drop_last: | |||||
# lens += len(sampler) // self.batch_size | |||||
# else: | |||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size | |||||
# elif isinstance(self.sampler, Dict): | |||||
# for name, sampler in self.sampler.items(): | |||||
# if self.drop_last: | |||||
# lens += len(sampler) // self.batch_size | |||||
# else: | |||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size | |||||
# else: | |||||
# for i in range(ds_len): | |||||
# if self.drop_last: | |||||
# lens += max_len // self.batch_size | |||||
# else: | |||||
# lens += (max_len + self.batch_size - 1) // self.batch_size | |||||
# return lens | |||||
class BucketSampler(Sampler): | |||||
r""" | |||||
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | |||||
""" | |||||
def __init__(self, dataset, num_buckets=10, batch_size=None, seq_len_field_name='seq_len', drop_last=False) -> None: | |||||
r""" | |||||
:param int num_buckets: bucket的数量 | |||||
:param int batch_size: batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非 | |||||
Trainer/Tester场景使用,需要显示传递该值 | |||||
:param str seq_len_field_name: 对应序列长度的 `field` 的名字 | |||||
""" | |||||
self.dataset = dataset | |||||
self.num_buckets = num_buckets | |||||
self.batch_size = batch_size | |||||
self.seq_len_field_name = seq_len_field_name | |||||
def set_batch_size(self, batch_size) -> None: | |||||
r""" | |||||
:param int batch_size: 每个batch的大小 | |||||
:return: | |||||
""" | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
if self.batch_size is None: | |||||
raise RuntimeError("batch_size is None.") | |||||
seq_lens = self.dataset.get_all_fields()[self.seq_len_field_name].content | |||||
total_sample_num = len(seq_lens) | |||||
bucket_indexes = [] | |||||
assert total_sample_num >= self.num_buckets, "The number of samples is smaller than the number of buckets." | |||||
num_sample_per_bucket = total_sample_num // self.num_buckets | |||||
for i in range(self.num_buckets): | |||||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) | |||||
bucket_indexes[-1][1] = total_sample_num | |||||
sorted_seq_lens = list(sorted([(idx, seq_len) for | |||||
idx, seq_len in zip(range(total_sample_num), seq_lens)], | |||||
key=lambda x: x[1])) | |||||
batchs = [] | |||||
left_init_indexes = [] | |||||
for b_idx in range(self.num_buckets): | |||||
start_idx = bucket_indexes[b_idx][0] | |||||
end_idx = bucket_indexes[b_idx][1] | |||||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx] | |||||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens]) | |||||
num_batch_per_bucket = len(left_init_indexes) // self.batch_size | |||||
np.random.shuffle(left_init_indexes) | |||||
for i in range(num_batch_per_bucket): | |||||
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size]) | |||||
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:] | |||||
if (left_init_indexes) != 0: | |||||
batchs.append(left_init_indexes) | |||||
np.random.shuffle(batchs) | |||||
return chain(*batchs) | |||||
class ConstTokenNumSampler(Sampler): | |||||
""" | |||||
尽量保证每个batch的输入token数量是接近的。 | |||||
""" | |||||
def __init__(self, dataset, seq_len_field_name: List[int], max_token: int = 4096, max_sentence: int = -1, | |||||
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None: | |||||
""" | |||||
:param dataset: | |||||
:param List[int] seq_len_field_name: 哪个field指示的sample的长度 | |||||
:param int max_token: 每个batch的最大的token数量 | |||||
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 | |||||
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 | |||||
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 | |||||
""" | |||||
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1 | |||||
self.dataset = dataset | |||||
self.seq_len_field_name = seq_len_field_name | |||||
self.num_bucket = num_bucket | |||||
self.max_token = max_token | |||||
self._max_sentence = max_sentence | |||||
self.need_be_multiple_of = need_be_multiple_of | |||||
assert len(self.dataset) > self.num_bucket, "The number of samples should be larger than buckets." | |||||
seq_len = self.dataset.get_field(self.seq_len_field_name) | |||||
self.seq_len = seq_len | |||||
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] | |||||
seq_len_indice.sort(key=lambda x: x[0]) | |||||
indice_in_buckets = [] | |||||
if self.num_bucket > 0: | |||||
sample_per_bucket = len(seq_len_indice) // self.num_bucket | |||||
i = 0 | |||||
while len(indice_in_buckets) < len(seq_len_indice): | |||||
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket]) | |||||
i += 1 | |||||
else: | |||||
indice_in_buckets = [seq_len_indice] | |||||
self.indice_in_buckets = indice_in_buckets | |||||
self.get_new_order() | |||||
@property | |||||
def max_sentence(self): | |||||
if self._max_sentence < 1: | |||||
return 100000000 | |||||
return self._max_sentence | |||||
@max_sentence.setter | |||||
def max_sentence(self, max_sentence): | |||||
self._max_sentence = max_sentence | |||||
def get_new_order(self) -> None: | |||||
np.random.shuffle(self.indice_in_buckets) | |||||
for bucket in self.indice_in_buckets: | |||||
np.random.shuffle(bucket) | |||||
indices = list(chain(*self.indice_in_buckets)) | |||||
batches = [] | |||||
cur_max_len = 0 | |||||
batch = [] | |||||
for length, i in indices: | |||||
max_len = max(length, cur_max_len) | |||||
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
cur_max_len = length | |||||
if left_sample != 0: | |||||
add_samples = add_samples[:-left_sample] | |||||
batch = batch[-left_sample:] | |||||
cur_max_len = max(cur_max_len, max(batch)) | |||||
else: | |||||
batch = [] | |||||
if len(add_samples) == 0: | |||||
raise RuntimeError( | |||||
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") | |||||
batches.append(add_samples) | |||||
else: | |||||
cur_max_len = max_len | |||||
batch.append(i) | |||||
if batch: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
if left_sample != 0: | |||||
add_samples = add_samples[:-left_sample].copy() | |||||
if add_samples: | |||||
batches.append(add_samples) | |||||
np.random.shuffle(batches) | |||||
self.batches = batches | |||||
def __iter__(self) -> Iterable[int]: | |||||
for batch in self.batches: | |||||
yield batch | |||||
self.get_new_order() | |||||
def __len__(self): | |||||
return len(self.batches) | |||||
class ConstantTokenNumSampler: | |||||
""" | |||||
尽量保证每个batch的输入token数量是接近的。 | |||||
""" | |||||
def __init__(self, seq_len, max_token: List[int] = 4096, max_sentence: int = -1, | |||||
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None: | |||||
""" | |||||
:param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入 | |||||
:param int max_token: 每个batch的最大的token数量 | |||||
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 | |||||
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 | |||||
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 | |||||
""" | |||||
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1 | |||||
assert len(seq_len) > num_bucket, "The number of samples should be larger than buckets." | |||||
self.seq_len = seq_len | |||||
self.max_token = max_token | |||||
self._max_sentence = max_sentence | |||||
self.need_be_multiple_of = need_be_multiple_of | |||||
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] | |||||
seq_len_indice.sort(key=lambda x: x[0]) | |||||
indice_in_buckets = [] | |||||
if num_bucket > 0: | |||||
sample_per_bucket = len(seq_len_indice) // num_bucket | |||||
i = 0 | |||||
while len(indice_in_buckets) < len(seq_len_indice): | |||||
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket]) | |||||
i += 1 | |||||
else: | |||||
indice_in_buckets = [seq_len_indice] | |||||
self.indice_in_buckets = indice_in_buckets | |||||
self.get_new_order() | |||||
@property | |||||
def max_sentence(self): | |||||
if self._max_sentence < 1: | |||||
return 100000000 | |||||
return self._max_sentence | |||||
@max_sentence.setter | |||||
def max_sentence(self, max_sentence): | |||||
self._max_sentence = max_sentence | |||||
def get_new_order(self) -> None: | |||||
np.random.shuffle(self.indice_in_buckets) | |||||
for bucket in self.indice_in_buckets: | |||||
np.random.shuffle(bucket) | |||||
indices = list(chain(*self.indice_in_buckets)) | |||||
batches = [] | |||||
cur_max_len = 0 | |||||
batch = [] | |||||
for length, i in indices: | |||||
max_len = max(length, cur_max_len) | |||||
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
cur_max_len = length | |||||
if left_sample != 0: | |||||
add_samples = add_samples[:-left_sample] | |||||
batch = batch[-left_sample:] | |||||
cur_max_len = max(cur_max_len, max(batch)) | |||||
else: | |||||
batch = [] | |||||
if len(add_samples) == 0: | |||||
raise RuntimeError( | |||||
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") | |||||
batches.append(add_samples) | |||||
else: | |||||
cur_max_len = max_len | |||||
batch.append(i) | |||||
if batch: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
if left_sample != 0: | |||||
add_samples = add_samples[:-left_sample].copy() | |||||
if add_samples: | |||||
batches.append(add_samples) | |||||
np.random.shuffle(batches) | |||||
self.batches = batches | |||||
def __iter__(self) -> Iterable[int]: | |||||
for batch in self.batches: | |||||
yield batch | |||||
self.get_new_order() | |||||
def __len__(self): | |||||
return len(self.batches) | |||||
class SortedSampler(Sampler): | |||||
r""" | |||||
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) | |||||
""" | |||||
def __init__(self, dataset, seq_len_field_name: str = 'seq_len', descending: bool = True) -> None: | |||||
""" | |||||
:param str seq_len_field_name: 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是 | |||||
数字,则使用该field的长度进行排序 | |||||
:param bool descending: 是否降序排列 | |||||
""" | |||||
self.dataset = dataset | |||||
self.seq_len_field_name = seq_len_field_name | |||||
self.descending = descending | |||||
def __iter__(self) -> Iterable[int]: | |||||
seq_lens = self.dataset.get_field(self.seq_len_field_name).content | |||||
try: | |||||
seq_lens = list(map(len, seq_lens)) | |||||
except: | |||||
pass | |||||
orders = np.argsort(seq_lens).tolist() # 从小到大的顺序 | |||||
if self.descending: | |||||
orders = orders[::-1] | |||||
for order in orders: | |||||
yield order | |||||
def simple_sort_bucketing(lengths): | |||||
r""" | |||||
:param lengths: list of int, the lengths of all examples. | |||||
:return data: 2-level list | |||||
:: | |||||
[ | |||||
[index_11, index_12, ...], # bucket 1 | |||||
[index_21, index_22, ...], # bucket 2 | |||||
... | |||||
] | |||||
""" | |||||
lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)] | |||||
sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1]) | |||||
# TODO: need to return buckets | |||||
return [idx for idx, _ in sorted_lengths] | |||||
def k_means_1d(x, k, max_iter=100): | |||||
r"""Perform k-means on 1-D data. | |||||
:param x: list of int, representing points in 1-D. | |||||
:param k: the number of clusters required. | |||||
:param max_iter: maximum iteration | |||||
:return centroids: numpy array, centroids of the k clusters | |||||
assignment: numpy array, 1-D, the bucket id assigned to each example. | |||||
""" | |||||
sorted_x = sorted(list(set(x))) | |||||
x = np.array(x) | |||||
if len(sorted_x) < k: | |||||
raise ValueError("too few buckets") | |||||
gap = len(sorted_x) / k | |||||
centroids = np.array([sorted_x[int(x * gap)] for x in range(k)]) | |||||
assign = None | |||||
for i in range(max_iter): | |||||
# Cluster Assignment step | |||||
assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x]) | |||||
# Move centroids step | |||||
new_centroids = np.array([x[assign == k].mean() for k in range(k)]) | |||||
if (new_centroids == centroids).all(): | |||||
centroids = new_centroids | |||||
break | |||||
centroids = new_centroids | |||||
return np.array(centroids), assign | |||||
def k_means_bucketing(lengths, buckets): | |||||
r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths. | |||||
:param lengths: list of int, the length of all samples. | |||||
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | |||||
threshold for each bucket (This is usually None.). | |||||
:return data: 2-level list | |||||
:: | |||||
[ | |||||
[index_11, index_12, ...], # bucket 1 | |||||
[index_21, index_22, ...], # bucket 2 | |||||
... | |||||
] | |||||
""" | |||||
bucket_data = [[] for _ in buckets] | |||||
num_buckets = len(buckets) | |||||
_, assignments = k_means_1d(lengths, num_buckets) | |||||
for idx, bucket_id in enumerate(assignments): | |||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | |||||
bucket_data[bucket_id].append(idx) | |||||
return bucket_data | |||||
class UnrepeatedDistributedSampler: | |||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0): | |||||
""" | |||||
考虑在多卡evaluate的场景下,不能重复sample。 | |||||
:param dataset: | |||||
:param shuffle: | |||||
:param seed: | |||||
""" | |||||
self.dataset = dataset | |||||
self.shuffle = shuffle | |||||
self.seed = seed | |||||
# 多卡的相关的参数 | |||||
self.num_replicas = 1 | |||||
self.rank = 0 | |||||
self.epoch = -1 | |||||
def __len__(self): | |||||
""" | |||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | |||||
:return: | |||||
""" | |||||
num_common = len(self.dataset)//self.num_replicas | |||||
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||||
return self.num_samples | |||||
def __iter__(self): | |||||
r""" | |||||
当前使用num_consumed_samples做法会在交替使用的时候遇到问题; | |||||
Example: | |||||
>>> sampler = RandomSampler() | |||||
>>> iter1 = iter(sampler) | |||||
>>> iter2 = iter(sampler) | |||||
>>> next(iter1) | |||||
>>> next(iter2) # 当前num_consumed_samples的数量会发生变化 | |||||
""" | |||||
indices = self.generate_indices() | |||||
# subsample | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
assert len(indices) == len(self) | |||||
for index in indices: | |||||
yield index | |||||
def generate_indices(self) -> List[int]: | |||||
""" | |||||
生成随机序列 | |||||
:return: | |||||
""" | |||||
if self.shuffle: | |||||
indices = list(range(len(self.dataset))) | |||||
seed = self.seed + self.epoch | |||||
rng = np.random.default_rng(abs(seed)) | |||||
rng.shuffle(indices) | |||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||||
self.epoch -= 1 | |||||
else: | |||||
indices = list(range(len(self.dataset))) | |||||
return indices | |||||
def set_epoch(self, epoch: int) -> None: | |||||
self.epoch = epoch | |||||
def set_distributed(self, num_replicas, rank): | |||||
""" | |||||
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; | |||||
:param num_replicas: | |||||
:param rank: | |||||
:return: | |||||
""" | |||||
assert num_replicas>0 and isinstance(num_replicas, int) | |||||
assert isinstance(rank, int) and 0<=rank<num_replicas | |||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||||
self.num_replicas = num_replicas | |||||
self.rank = rank | |||||
return self |
@@ -0,0 +1,157 @@ | |||||
import pytest | |||||
from functools import reduce | |||||
from fastNLP.core.callbacks.callback_events import Filter | |||||
class TestFilter: | |||||
def test_params_check(self): | |||||
# 顺利通过 | |||||
_filter1 = Filter(every=10) | |||||
_filter2 = Filter(once=10) | |||||
_filter3 = Filter(filter_fn=lambda: None) | |||||
# 触发 ValueError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter4 = Filter() | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter." | |||||
# 触发 ValueError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter5 = Filter(every=10, once=10) | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "These three values should be only set one." | |||||
# 触发 TypeError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter6 = Filter(every="heihei") | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "Argument every should be integer and greater than zero" | |||||
# 触发 TypeError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter7 = Filter(once="heihei") | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "Argument once should be integer and positive" | |||||
# 触发 TypeError | |||||
with pytest.raises(TypeError) as e: | |||||
_filter7 = Filter(filter_fn="heihei") | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "Argument event_filter should be a callable" | |||||
def test_every_filter(self): | |||||
# every = 10 | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w-1 for w in range(10, 101, 10)] | |||||
# every = 1 | |||||
@Filter(every=1) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == list(range(100)) | |||||
def test_once_filter(self): | |||||
# once = 10 | |||||
@Filter(once=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [9] | |||||
def test_filter_fn(self): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
@Filter(filter_fn=filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
def test_extract_filter_from_fn(self): | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_filter_num_called = [] | |||||
_filter_num_executed = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
_filter = _fn.__fastNLP_filter__ | |||||
_filter_num_called.append(_filter.num_called) | |||||
_filter_num_executed.append(_filter.num_executed) | |||||
assert _filter_num_called == list(range(1, 101)) | |||||
assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10] | |||||
def _fn(data): | |||||
return data | |||||
assert not hasattr(_fn, "__fastNLP_filter__") | |||||
def test_filter_state_dict(self): | |||||
# every = 10 | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(50): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(10, 51, 10)] | |||||
# 保存状态 | |||||
state = _fn.__fastNLP_filter__.state_dict() | |||||
# 加载状态 | |||||
_fn.__fastNLP_filter__.load_state_dict(state) | |||||
_res = [] | |||||
for i in range(50, 100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(60, 101, 10)] | |||||
@@ -0,0 +1,717 @@ | |||||
import os | |||||
import pytest | |||||
from typing import Any | |||||
from dataclasses import dataclass | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from pathlib import Path | |||||
import re | |||||
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.log import logger | |||||
@dataclass | |||||
class ArgMaxDatasetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
data_num: int = 100 | |||||
seed: int = 0 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
validate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
@pytest.fixture(scope="module", params=[0], autouse=True) | |||||
def model_and_optimizers(request): | |||||
trainer_params = TrainerParameters() | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=ArgMaxDatasetConfig.num_labels, | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
dataset = TorchArgMaxDatset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||||
data_num=ArgMaxDatasetConfig.data_num, | |||||
seed=ArgMaxDatasetConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer_params.train_dataloader = _dataloader | |||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | |||||
return trainer_params | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | |||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||||
@magic_argv_env_context | |||||
def test_model_checkpoint_callback_1( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
version, | |||||
only_state_dict | |||||
): | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
if version == 0: | |||||
callbacks = [ | |||||
CheckpointCallback( | |||||
monitor="acc", | |||||
save_folder=path, | |||||
save_every_n_epochs=1, | |||||
save_every_n_global_batches=123, # 避免和 epoch 的保存重复; | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
CheckpointCallback( | |||||
monitor="acc", | |||||
save_folder=path, | |||||
save_every_n_epochs=3, | |||||
save_every_n_global_batches=None, | |||||
save_topk=2, | |||||
save_last=True, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
try: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=10, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
if version == 0: | |||||
if driver == "torch": | |||||
assert "epoch_10-global_batch_250-acc" in all_saved_model_paths | |||||
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] | |||||
assert len(all_saved_model_paths) == 12 | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "epoch_6-global_batch_78-acc" in all_saved_model_paths | |||||
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] | |||||
assert len(all_saved_model_paths) == 11 | |||||
all_state_dicts = [epoch_save_path, step_save_path] | |||||
elif version == 1: | |||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||||
if driver == "torch": | |||||
assert "epoch_9-global_batch_225-acc" in all_saved_model_paths | |||||
assert "last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | |||||
for each_folder_name in all_saved_model_paths: | |||||
each_folder_name = pattern.findall(each_folder_name) | |||||
if len(each_folder_name) != 0: | |||||
aLL_topk_folders.append(each_folder_name[0]) | |||||
assert len(aLL_topk_folders) == 2 | |||||
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"] | |||||
last_save_path = all_saved_model_paths["last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||||
assert len(all_saved_model_paths) == 6 | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "epoch_9-global_batch_117-acc" in all_saved_model_paths | |||||
assert "last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | |||||
for each_folder_name in all_saved_model_paths: | |||||
each_folder_name = pattern.findall(each_folder_name) | |||||
if len(each_folder_name) != 0: | |||||
aLL_topk_folders.append(each_folder_name[0]) | |||||
assert len(aLL_topk_folders) == 2 | |||||
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"] | |||||
last_save_path = all_saved_model_paths["last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||||
assert len(all_saved_model_paths) == 6 | |||||
all_state_dicts = [epoch_save_path, last_save_path, topk_save_path] | |||||
for folder in all_state_dicts: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=2, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
# pass | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("only_state_dict", [True]) | |||||
@magic_argv_env_context | |||||
def test_model_checkpoint_callback_2( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
only_state_dict | |||||
): | |||||
path = Path.cwd().joinpath("test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
from fastNLP.core.callbacks.callback_events import Events | |||||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||||
def raise_exception(trainer): | |||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||||
raise NotImplementedError | |||||
callbacks = [ | |||||
CheckpointCallback( | |||||
monitor="acc1", | |||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_global_batches=None, | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=NotImplementedError, | |||||
only_state_dict=only_state_dict | |||||
), | |||||
] | |||||
try: | |||||
with pytest.raises(NotImplementedError): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=10, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
if FASTNLP_DISTRIBUTED_CHECK in os.environ: | |||||
os.environ.pop(FASTNLP_DISTRIBUTED_CHECK) | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
if driver == "torch": | |||||
assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"] | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"] | |||||
assert len(all_saved_model_paths) == 1 | |||||
all_state_dicts = [exception_model_path] | |||||
for folder in all_state_dicts: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver="torch", | |||||
device=4, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=2, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
# pass | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | |||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||||
@magic_argv_env_context | |||||
def test_trainer_checkpoint_callback_1( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
version, | |||||
only_state_dict | |||||
): | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
if version == 0: | |||||
callbacks = [ | |||||
CheckpointCallback( | |||||
monitor="acc", | |||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | |||||
save_every_n_epochs=7, | |||||
save_every_n_global_batches=123, # 避免和 epoch 的保存重复; | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
CheckpointCallback( | |||||
monitor="acc", | |||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_global_batches=None, | |||||
save_topk=2, | |||||
save_last=True, | |||||
save_on_exception=None, | |||||
only_state_dict=only_state_dict | |||||
) | |||||
] | |||||
try: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=10, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
if version == 0: | |||||
if driver == "torch": | |||||
assert "epoch_7-global_batch_175-acc" in all_saved_model_paths | |||||
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] | |||||
assert len(all_saved_model_paths) == 3 | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "epoch_7-global_batch_91-acc" in all_saved_model_paths | |||||
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"] | |||||
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] | |||||
assert len(all_saved_model_paths) == 2 | |||||
all_state_dicts = [epoch_save_path, step_save_path] | |||||
elif version == 1: | |||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
if driver == "torch": | |||||
assert "last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | |||||
for each_folder_name in all_saved_model_paths: | |||||
each_folder_name = pattern.findall(each_folder_name) | |||||
if len(each_folder_name) != 0: | |||||
aLL_topk_folders.append(each_folder_name[0]) | |||||
assert len(aLL_topk_folders) == 2 | |||||
last_save_path = all_saved_model_paths["last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||||
assert len(all_saved_model_paths) == 3 | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | |||||
for each_folder_name in all_saved_model_paths: | |||||
each_folder_name = pattern.findall(each_folder_name) | |||||
if len(each_folder_name) != 0: | |||||
aLL_topk_folders.append(each_folder_name[0]) | |||||
assert len(aLL_topk_folders) == 2 | |||||
last_save_path = all_saved_model_paths["last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||||
assert len(all_saved_model_paths) == 3 | |||||
all_state_dicts = [last_save_path, topk_save_path] | |||||
for folder in all_state_dicts: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=13, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.load(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
pass | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | |||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | |||||
@magic_argv_env_context | |||||
def test_trainer_checkpoint_callback_2( | |||||
driver, | |||||
device, | |||||
version | |||||
): | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
import transformers | |||||
import torch | |||||
from torchmetrics import Accuracy | |||||
from transformers import AutoModelForSequenceClassification | |||||
from fastNLP import Trainer | |||||
from torch.optim import AdamW | |||||
from torch.utils.data import DataLoader, Dataset | |||||
from fastNLP.core.utils.utils import dataclass_to_dict | |||||
logger.info(f"transformer version: {transformers.__version__}") | |||||
task = "mrpc" | |||||
model_checkpoint = "distilbert-base-uncased" | |||||
## Loading the dataset | |||||
from datasets import load_dataset | |||||
actual_task = "mnli" if task == "mnli-mm" else task | |||||
dataset = load_dataset("glue", actual_task) | |||||
# Preprocessing the data | |||||
from transformers import AutoTokenizer | |||||
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) | |||||
task_to_keys = { | |||||
"cola": ("sentence", None), | |||||
"mnli": ("premise", "hypothesis"), | |||||
"mnli-mm": ("premise", "hypothesis"), | |||||
"mrpc": ("sentence1", "sentence2"), | |||||
"qnli": ("question", "sentence"), | |||||
"qqp": ("question1", "question2"), | |||||
"rte": ("sentence1", "sentence2"), | |||||
"sst2": ("sentence", None), | |||||
"stsb": ("sentence1", "sentence2"), | |||||
"wnli": ("sentence1", "sentence2"), | |||||
} | |||||
sentence1_key, sentence2_key = task_to_keys[task] | |||||
if sentence2_key is None: | |||||
print(f"Sentence: {dataset['train'][0][sentence1_key]}") | |||||
else: | |||||
print(f"Sentence 1: {dataset['train'][0][sentence1_key]}") | |||||
print(f"Sentence 2: {dataset['train'][0][sentence2_key]}") | |||||
def preprocess_function(examples): | |||||
if sentence2_key is None: | |||||
return tokenizer(examples[sentence1_key], truncation=True) | |||||
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True) | |||||
encoded_dataset = dataset.map(preprocess_function, batched=True) | |||||
## Fine-tuning the model | |||||
num_labels = 3 if task.startswith("mnli") else 1 if task == "stsb" else 2 | |||||
distilbert_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels) | |||||
class TestDistilBertDataset(Dataset): | |||||
def __init__(self, dataset): | |||||
super(TestDistilBertDataset, self).__init__() | |||||
self._dataset = dataset | |||||
def __len__(self): | |||||
return len(self._dataset) | |||||
def __getitem__(self, item): | |||||
_data = self._dataset[item] | |||||
return _data["input_ids"], _data["attention_mask"], [ | |||||
_data["label"]] # , _data["sentence1"], _data["sentence2"] | |||||
def test_bert_collate_fn(batch): | |||||
input_ids, atten_mask, labels = [], [], [] | |||||
max_length = [0] * 3 | |||||
for each_item in batch: | |||||
input_ids.append(each_item[0]) | |||||
max_length[0] = max(max_length[0], len(each_item[0])) | |||||
atten_mask.append(each_item[1]) | |||||
max_length[1] = max(max_length[1], len(each_item[1])) | |||||
labels.append(each_item[2]) | |||||
max_length[2] = max(max_length[2], len(each_item[2])) | |||||
for i in range(3): | |||||
each = (input_ids, atten_mask, labels)[i] | |||||
for item in each: | |||||
item.extend([0] * (max_length[i] - len(item))) | |||||
return {"input_ids": torch.cat([torch.tensor([item]) for item in input_ids], dim=0), | |||||
"attention_mask": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0), | |||||
"labels": torch.cat([torch.tensor(item) for item in labels], dim=0)} | |||||
test_bert_dataset_train = TestDistilBertDataset(encoded_dataset["train"]) | |||||
test_bert_dataloader_train = DataLoader(dataset=test_bert_dataset_train, batch_size=32, shuffle=True, | |||||
collate_fn=test_bert_collate_fn) | |||||
test_bert_dataset_validate = TestDistilBertDataset(encoded_dataset["test"]) | |||||
test_bert_dataloader_validate = DataLoader(dataset=test_bert_dataset_validate, batch_size=32, shuffle=False, | |||||
collate_fn=test_bert_collate_fn) | |||||
def bert_input_mapping(data): | |||||
data["target"] = data["labels"] | |||||
return data | |||||
def bert_output_mapping(data): | |||||
data = dataclass_to_dict(data) | |||||
data["preds"] = torch.max(data["logits"], dim=-1)[1] | |||||
# data["target"] = data["labels"] | |||||
del data["logits"] | |||||
del data["hidden_states"] | |||||
del data["attentions"] | |||||
return data | |||||
test_bert_optimizers = AdamW(params=distilbert_model.parameters(), lr=5e-5) | |||||
test_bert_model = distilbert_model | |||||
acc = Accuracy() | |||||
def model_save_fn(folder): | |||||
test_bert_model.save_pretrained(folder) | |||||
def model_load_fn(folder): | |||||
test_bert_model.from_pretrained(folder) | |||||
if version == 0: | |||||
callbacks = [ | |||||
CheckpointCallback( | |||||
monitor="acc", | |||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_global_batches=50, | |||||
save_topk=None, | |||||
save_last=False, | |||||
save_on_exception=None, | |||||
model_save_fn=model_save_fn | |||||
) | |||||
] | |||||
elif version == 1: | |||||
callbacks = [ | |||||
CheckpointCallback( | |||||
monitor="acc", | |||||
is_trainer_checkpoint=True, | |||||
save_folder=path, | |||||
save_every_n_epochs=None, | |||||
save_every_n_global_batches=None, | |||||
save_topk=1, | |||||
save_last=True, | |||||
save_on_exception=None, | |||||
model_save_fn=model_save_fn | |||||
) | |||||
] | |||||
try: | |||||
trainer = Trainer( | |||||
model=test_bert_model, | |||||
driver=driver, | |||||
device=device, | |||||
n_epochs=2, | |||||
train_dataloader=test_bert_dataloader_train, | |||||
optimizers=test_bert_optimizers, | |||||
validate_dataloaders=test_bert_dataloader_validate, | |||||
input_mapping=bert_input_mapping, | |||||
output_mapping=bert_output_mapping, | |||||
metrics={"acc": acc}, | |||||
callbacks=callbacks | |||||
) | |||||
trainer.run() | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
if version == 0: | |||||
if driver == "torch": | |||||
assert "epoch_1-global_batch_200-acc" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"] | |||||
assert len(all_saved_model_paths) == 4 | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "epoch_1-global_batch_100-acc" in all_saved_model_paths | |||||
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"] | |||||
assert len(all_saved_model_paths) == 2 | |||||
all_state_dicts = [epoch_save_path] | |||||
elif version == 1: | |||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") | |||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
if driver == "torch": | |||||
assert "last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | |||||
for each_folder_name in all_saved_model_paths: | |||||
each_folder_name = pattern.findall(each_folder_name) | |||||
if len(each_folder_name) != 0: | |||||
aLL_topk_folders.append(each_folder_name[0]) | |||||
assert len(aLL_topk_folders) == 1 | |||||
last_save_path = all_saved_model_paths["last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||||
assert len(all_saved_model_paths) == 2 | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "last" in all_saved_model_paths | |||||
aLL_topk_folders = [] | |||||
for each_folder_name in all_saved_model_paths: | |||||
each_folder_name = pattern.findall(each_folder_name) | |||||
if len(each_folder_name) != 0: | |||||
aLL_topk_folders.append(each_folder_name[0]) | |||||
assert len(aLL_topk_folders) == 1 | |||||
last_save_path = all_saved_model_paths["last"] | |||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||||
assert len(all_saved_model_paths) == 2 | |||||
all_state_dicts = [last_save_path, topk_save_path] | |||||
for folder in all_state_dicts: | |||||
trainer = Trainer( | |||||
model=test_bert_model, | |||||
driver=driver, | |||||
device=device, | |||||
n_epochs=3, | |||||
train_dataloader=test_bert_dataloader_train, | |||||
optimizers=test_bert_optimizers, | |||||
validate_dataloaders=test_bert_dataloader_validate, | |||||
input_mapping=bert_input_mapping, | |||||
output_mapping=bert_output_mapping, | |||||
metrics={"acc": acc}, | |||||
) | |||||
trainer.load(folder, model_load_fn=model_load_fn) | |||||
trainer.run() | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
# pass | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -0,0 +1,119 @@ | |||||
import os | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
from torch import optim | |||||
import torch.distributed as dist | |||||
import pytest | |||||
from dataclasses import dataclass | |||||
from typing import Any | |||||
import numpy as np | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback | |||||
from fastNLP.core import Evaluator | |||||
from fastNLP.core.utils.utils import safe_rm | |||||
from fastNLP.core.drivers.torch_driver import TorchSingleDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
@dataclass | |||||
class ArgMaxDatasetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
data_num: int = 100 | |||||
seed: int = 0 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
validate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
@pytest.fixture(scope="module", params=[0], autouse=True) | |||||
def model_and_optimizers(request): | |||||
trainer_params = TrainerParameters() | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=ArgMaxDatasetConfig.num_labels, | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01) | |||||
dataset = TorchArgMaxDatset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||||
data_num=ArgMaxDatasetConfig.data_num, | |||||
seed=ArgMaxDatasetConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer_params.train_dataloader = _dataloader | |||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | |||||
return trainer_params | |||||
# pytest test_load_best_model_callback_torch.py::test_load_best_model_callback -s | |||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("save_folder", ['save_models', None]) | |||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||||
@magic_argv_env_context | |||||
def test_load_best_model_callback( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
save_folder, | |||||
only_state_dict | |||||
): | |||||
callbacks = [LoadBestModelCallback(monitor='acc')] | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=3, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run(num_eval_sanity_batch=0) | |||||
driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | |||||
evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, | |||||
dataloaders={'dl1': model_and_optimizers.validate_dataloaders}, | |||||
metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | |||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | |||||
progress_bar='rich', use_dist_sampler=False) | |||||
results = evaluator.run() | |||||
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1']) | |||||
if save_folder: | |||||
safe_rm(save_folder) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -0,0 +1,43 @@ | |||||
import pytest | |||||
from tests.helpers.utils import Capturing | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from fastNLP.core.log.logger import logger | |||||
def test_get_monitor_value(): | |||||
logger.set_stdout(stdout='raw') | |||||
# 测试完全匹配 | |||||
res = {'f1': 0.2, 'acc#rec': 0.3} | |||||
with Capturing() as output: | |||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | |||||
assert monitor == 'f1' and value==0.2 | |||||
assert 'We can not find' not in output[0] | |||||
# 测试可以匹配,且选择更靠前的 | |||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | |||||
with Capturing() as output: | |||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | |||||
assert monitor=='acc#f1' and value==0.2 | |||||
assert 'We can not find' in output[0] | |||||
# 测试monitor匹配不上,使用real_monitor | |||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | |||||
with Capturing() as output: | |||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res) | |||||
assert monitor=='acc#rec' and value==0.3 | |||||
assert 'We can not find' not in output[0] | |||||
# 测试monitor/real_monitor匹配不上, 重新选择 | |||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | |||||
with Capturing() as output: | |||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) | |||||
assert monitor=='acc#f1' and value==0.2 | |||||
assert 'We can not find' in output[0] | |||||
# 测试partial的位置 | |||||
res = {"acc#acc": 0.52, "loss#loss": 2} | |||||
with Capturing() as output: | |||||
monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) | |||||
assert monitor=='loss#loss' and value==2 | |||||
assert 'We can not find' in output[0] |
@@ -0,0 +1,120 @@ | |||||
""" | |||||
python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_test_distributed_launch_torch_1.py | |||||
""" | |||||
import argparse | |||||
import os | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||||
import sys | |||||
path = os.path.abspath(__file__) | |||||
folders = path.split(os.sep) | |||||
for folder in list(folders[::-1]): | |||||
if 'fastnlp' not in folder.lower(): | |||||
folders.pop(-1) | |||||
else: | |||||
break | |||||
path = os.sep.join(folders) | |||||
sys.path.extend([path, os.path.join(path, 'fastNLP')]) | |||||
import torch | |||||
from torch.nn.parallel import DistributedDataParallel | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from dataclasses import dataclass | |||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
@dataclass | |||||
class NormalClassificationTrainTorchConfig: | |||||
num_labels: int = 2 | |||||
feature_dimension: int = 3 | |||||
each_label_data: int = 100 | |||||
seed: int = 0 | |||||
n_epochs: int = 10 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
driver: str = "torch" | |||||
device: int = 7 | |||||
local_rank = int(os.environ["LOCAL_RANK"]) | |||||
local_device = torch.device(f"cuda:{local_rank}") | |||||
torch.cuda.set_device(local_device) | |||||
model = TorchNormalModel_Classification_2( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||||
) | |||||
model.to(local_device) | |||||
dist.init_process_group(backend='nccl', world_size=2, rank=local_rank) | |||||
model = DistributedDataParallel(model, device_ids=[local_device.index], output_device=local_device) | |||||
dist.barrier() | |||||
optimizers = SGD(model.parameters(), lr=0.001) | |||||
dataset = TorchNormalDataset_Classification( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||||
seed=NormalClassificationTrainTorchConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = _dataloader | |||||
validate_dataloaders = _dataloader | |||||
metrics = {"acc": Accuracy()} | |||||
def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
accumulation_steps, | |||||
fp16 | |||||
): | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver="torch_ddp", | |||||
device=None, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
validate_dataloaders=validate_dataloaders, | |||||
metrics=metrics, | |||||
n_epochs=2, | |||||
data_device=local_device, | |||||
progress_bar='rich', | |||||
accumulation_steps=accumulation_steps, | |||||
fp16=fp16, | |||||
) | |||||
trainer.run() | |||||
dist.barrier() | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser(description='Input trainer parameters.') | |||||
parser.add_argument('-v', '--version', type=int, default=0, help="choose one test to run") | |||||
args = parser.parse_args() | |||||
if args.version == 0: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=1, fp16=False) | |||||
elif args.version == 1: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=3, fp16=False) | |||||
elif args.version == 2: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=1, fp16=True) | |||||
elif args.version == 3: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=3, fp16=True) |
@@ -0,0 +1,110 @@ | |||||
""" | |||||
python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_test_distributed_launch_torch_2.py | |||||
""" | |||||
import argparse | |||||
import os | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||||
import sys | |||||
path = os.path.abspath(__file__) | |||||
folders = path.split(os.sep) | |||||
for folder in list(folders[::-1]): | |||||
if 'fastnlp' not in folder.lower(): | |||||
folders.pop(-1) | |||||
else: | |||||
break | |||||
path = os.sep.join(folders) | |||||
sys.path.extend([path, os.path.join(path, 'fastNLP')]) | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from dataclasses import dataclass | |||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
@dataclass | |||||
class NormalClassificationTrainTorchConfig: | |||||
num_labels: int = 2 | |||||
feature_dimension: int = 3 | |||||
each_label_data: int = 100 | |||||
seed: int = 0 | |||||
n_epochs: int = 10 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
driver: str = "torch" | |||||
device: int = 7 | |||||
model = TorchNormalModel_Classification_1( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||||
) | |||||
optimizers = SGD(model.parameters(), lr=0.001) | |||||
dataset = TorchNormalDataset_Classification( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||||
seed=NormalClassificationTrainTorchConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = _dataloader | |||||
validate_dataloaders = _dataloader | |||||
metrics = {"acc": Accuracy()} | |||||
def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
accumulation_steps, | |||||
fp16 | |||||
): | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver="torch_ddp", | |||||
device=None, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
validate_dataloaders=validate_dataloaders, | |||||
metrics=metrics, | |||||
n_epochs=2, | |||||
progress_bar='rich', | |||||
accumulation_steps=accumulation_steps, | |||||
fp16=fp16, | |||||
) | |||||
trainer.run() | |||||
dist.barrier() | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser(description='Input trainer parameters.') | |||||
parser.add_argument('-v', '--version', type=int, default=0, help="choose one test to run") | |||||
args = parser.parse_args() | |||||
if args.version == 0: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=1, fp16=False) | |||||
elif args.version == 1: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=3, fp16=False) | |||||
elif args.version == 2: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=1, fp16=True) | |||||
elif args.version == 3: | |||||
_test_trainer_torch_with_evaluator_fp16_accumulation_steps(accumulation_steps=3, fp16=True) | |||||
@@ -0,0 +1,104 @@ | |||||
import pytest | |||||
from typing import Any | |||||
from dataclasses import dataclass | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from torchmetrics import Accuracy | |||||
import torch.distributed as dist | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.callbacks.callback_events import Events | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback | |||||
from tests.helpers.utils import magic_argv_env_context, Capturing | |||||
@dataclass | |||||
class NormalClassificationTrainTorchConfig: | |||||
num_labels: int = 2 | |||||
feature_dimension: int = 3 | |||||
each_label_data: int = 100 | |||||
seed: int = 0 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
validate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
@pytest.fixture(scope="module", autouse=True) | |||||
def model_and_optimizers(): | |||||
trainer_params = TrainerParameters() | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
dataset = TorchNormalDataset_Classification( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||||
seed=NormalClassificationTrainTorchConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer_params.train_dataloader = _dataloader | |||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | |||||
return trainer_params | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | |||||
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | |||||
@magic_argv_env_context | |||||
def test_trainer_event_trigger( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs=2, | |||||
): | |||||
with pytest.raises(Exception): | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
for name, member in Events.__members__.items(): | |||||
assert member.value in output[0] | |||||
@@ -0,0 +1,174 @@ | |||||
""" | |||||
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; | |||||
""" | |||||
import pytest | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
import torch.distributed as dist | |||||
from dataclasses import dataclass | |||||
from typing import Any | |||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDatset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
@dataclass | |||||
class NormalClassificationTrainTorchConfig: | |||||
num_labels: int = 2 | |||||
feature_dimension: int = 3 | |||||
each_label_data: int = 100 | |||||
seed: int = 0 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@dataclass | |||||
class ArgMaxDatasetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
data_num: int = 100 | |||||
seed: int = 0 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
validate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
@pytest.fixture(scope="module", params=[1], autouse=True) | |||||
def model_and_optimizers(request): | |||||
trainer_params = TrainerParameters() | |||||
if request.param == 0: | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
dataset = TorchNormalDataset_Classification( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||||
seed=NormalClassificationTrainTorchConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer_params.train_dataloader = _dataloader | |||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | |||||
elif request.param == 1: | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=ArgMaxDatasetConfig.num_labels, | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
dataset = TorchArgMaxDatset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||||
data_num=ArgMaxDatasetConfig.data_num, | |||||
seed=ArgMaxDatasetConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer_params.train_dataloader = _dataloader | |||||
trainer_params.validate_dataloaders = _dataloader | |||||
trainer_params.metrics = {"acc": Accuracy()} | |||||
return trainer_params | |||||
# 测试一下普通的情况; | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | |||||
@magic_argv_env_context | |||||
def test_trainer_torch_with_evaluator( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs=10, | |||||
): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]]) | |||||
@pytest.mark.parametrize("fp16", [True, False]) | |||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||||
@magic_argv_env_context | |||||
def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
fp16, | |||||
accumulation_steps, | |||||
n_epochs=6, | |||||
): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
fp16=fp16, | |||||
accumulation_steps=accumulation_steps, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -0,0 +1,319 @@ | |||||
import os.path | |||||
import subprocess | |||||
import sys | |||||
import pytest | |||||
import torch.distributed as dist | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from dataclasses import dataclass | |||||
from typing import Any | |||||
from pathlib import Path | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | |||||
from tests.helpers.utils import magic_argv_env_context, Capturing | |||||
from fastNLP.core import synchronize_safe_rm | |||||
@dataclass | |||||
class NormalClassificationTrainTorchConfig: | |||||
num_labels: int = 2 | |||||
feature_dimension: int = 3 | |||||
each_label_data: int = 100 | |||||
seed: int = 0 | |||||
n_epochs: int = 10 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
driver: str = "torch" | |||||
device: int = 7 | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
validate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
@pytest.fixture(scope="function", params=[0], autouse=True) | |||||
def model_and_optimizers(request): | |||||
trainer_params = TrainerParameters() | |||||
if request.param == 0: | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
dataset = TorchNormalDataset_Classification( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||||
seed=NormalClassificationTrainTorchConfig.seed | |||||
) | |||||
trainer_params.train_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer_params.validate_dataloaders = None | |||||
trainer_params.input_mapping = None | |||||
trainer_params.output_mapping = None | |||||
# elif request.param == 1: | |||||
# model = | |||||
return trainer_params | |||||
# 测试一下 cpu; | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | |||||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||||
@magic_argv_env_context | |||||
def test_trainer_torch_without_evaluator( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs=10, | |||||
): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", 4), | |||||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||||
@pytest.mark.parametrize("fp16", [False, True]) | |||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||||
@magic_argv_env_context | |||||
def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
fp16, | |||||
accumulation_steps, | |||||
n_epochs=10, | |||||
): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
fp16=fp16, | |||||
accumulation_steps=accumulation_steps, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
# 测试 accumulation_steps; | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), ("torch", [4, 5])]) | |||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||||
@magic_argv_env_context | |||||
def test_trainer_torch_without_evaluator_accumulation_steps( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
accumulation_steps, | |||||
n_epochs=2, | |||||
): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=[RecordAccumulationStepsCallback_Torch(accumulation_steps)], | |||||
accumulation_steps=accumulation_steps | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.parametrize("driver,device", [("torch", [6, 7])]) | |||||
@pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | |||||
@magic_argv_env_context | |||||
def test_trainer_output_from_new_proc( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
output_from_new_proc, | |||||
n_epochs=2, | |||||
): | |||||
std_msg = "test std msg trainer, std std std" | |||||
err_msg = "test err msg trainer, err err, err" | |||||
from fastNLP.core.log.logger import logger | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
output_from_new_proc=output_from_new_proc | |||||
) | |||||
if trainer.driver.get_local_rank() != 0: | |||||
logger.warning(std_msg) | |||||
sys.stderr.write(err_msg) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
if output_from_new_proc == "all": | |||||
if trainer.driver.get_local_rank() != 0: | |||||
assert std_msg in output[0] | |||||
assert err_msg in output[0] | |||||
elif output_from_new_proc == "ignore": | |||||
if trainer.driver.get_local_rank() != 0: | |||||
assert std_msg not in output[0] | |||||
assert err_msg not in output[0] | |||||
elif output_from_new_proc == "only_error": | |||||
if trainer.driver.get_local_rank() != 0: | |||||
assert std_msg not in output[0] | |||||
assert err_msg in output[0] | |||||
else: | |||||
std_path = Path(os.path.abspath(output_from_new_proc)).joinpath(f"{trainer.driver.get_local_rank()}_std.log") | |||||
assert std_path.exists() | |||||
err_path = Path(os.path.abspath(output_from_new_proc)).joinpath(f"{trainer.driver.get_local_rank()}_err.log") | |||||
assert err_path.exists() | |||||
path = Path(os.path.abspath(output_from_new_proc)) | |||||
synchronize_safe_rm(path) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) | |||||
@pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | |||||
@magic_argv_env_context | |||||
def test_trainer_on_exception( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
cur_rank, | |||||
n_epochs=2, | |||||
): | |||||
from fastNLP.core.callbacks.callback_events import Events | |||||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||||
def raise_exception(trainer): | |||||
if trainer.driver.get_local_rank() == cur_rank: | |||||
raise NotImplementedError | |||||
with pytest.raises(NotImplementedError): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||||
@magic_argv_env_context | |||||
def test_torch_distributed_launch_1(version): | |||||
""" | |||||
测试用户自己在外面初始化 ddp; | |||||
""" | |||||
from fastNLP.core.drivers.torch_driver.ddp import find_free_network_port | |||||
path = Path(os.path.abspath(__file__)).parent | |||||
command = ["python", "-m", "torch.distributed.launch", "--nproc_per_node", "2", "--master_port", find_free_network_port(), | |||||
f"{path.joinpath('_test_distributed_launch_torch_1.py')}", "-v", f"{version}"] | |||||
subprocess.check_call(command) | |||||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||||
@magic_argv_env_context | |||||
def test_torch_distributed_launch_2(version): | |||||
""" | |||||
测试用户自己不初始化 ddp,但是使用 torch.distributed.launch 启动; | |||||
""" | |||||
from fastNLP.core.drivers.torch_driver.ddp import find_free_network_port | |||||
path = Path(os.path.abspath(__file__)).parent | |||||
command = ["python", "-m", "torch.distributed.launch", "--nproc_per_node", "2", "--master_port", find_free_network_port(), | |||||
f"{path.joinpath('_test_distributed_launch_torch_2.py')}", "-v", f"{version}"] | |||||
subprocess.check_call(command) | |||||
@@ -0,0 +1,60 @@ | |||||
from functools import reduce | |||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; | |||||
from tests.helpers.datasets.normal_data import NormalIterator | |||||
class Test_WrapDataLoader: | |||||
def test_normal_generator(self): | |||||
all_sanity_batches = [4, 20, 100] | |||||
for sanity_batches in all_sanity_batches: | |||||
data = NormalIterator(num_of_data=1000) | |||||
wrapper = _TruncatedDataLoader(num_batches=sanity_batches) | |||||
dataloader = iter(wrapper(dataloader=data)) | |||||
mark = 0 | |||||
while True: | |||||
try: | |||||
_data = next(dataloader) | |||||
except StopIteration: | |||||
break | |||||
mark += 1 | |||||
assert mark == sanity_batches | |||||
def test_torch_dataloader(self): | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
from torch.utils.data import DataLoader | |||||
bses = [8, 16, 40] | |||||
all_sanity_batches = [4, 7, 10] | |||||
for bs in bses: | |||||
for sanity_batches in all_sanity_batches: | |||||
dataset = TorchNormalDataset(num_of_data=1000) | |||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | |||||
wrapper = _TruncatedDataLoader(num_batches=sanity_batches) | |||||
dataloader = wrapper(dataloader) | |||||
dataloader = iter(dataloader) | |||||
all_supposed_running_data_num = 0 | |||||
while True: | |||||
try: | |||||
_data = next(dataloader) | |||||
except StopIteration: | |||||
break | |||||
all_supposed_running_data_num += _data.shape[0] | |||||
assert all_supposed_running_data_num == bs * sanity_batches | |||||
def test_len(self): | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
from torch.utils.data import DataLoader | |||||
bses = [8, 16, 40] | |||||
all_sanity_batches = [4, 7, 10] | |||||
length = [] | |||||
for bs in bses: | |||||
for sanity_batches in all_sanity_batches: | |||||
dataset = TorchNormalDataset(num_of_data=1000) | |||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | |||||
wrapper = _TruncatedDataLoader(num_batches=sanity_batches) | |||||
dataloader = wrapper(dataloader) | |||||
length.append(len(dataloader)) | |||||
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) |
@@ -0,0 +1,132 @@ | |||||
import os | |||||
import torch | |||||
import torch.distributed as dist | |||||
import numpy as np | |||||
# print(isinstance((1,), tuple)) | |||||
# exit() | |||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object | |||||
from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | |||||
def test_convert_to_tensors(): | |||||
local_rank = 0 | |||||
obj = { | |||||
'tensor': torch.full(size=(2,), fill_value=local_rank), | |||||
'numpy': np.full(shape=(1,), fill_value=local_rank), | |||||
'bool': local_rank % 2 == 0, | |||||
'float': local_rank + 0.1, | |||||
'int': local_rank, | |||||
'dict': { | |||||
'rank': local_rank | |||||
}, | |||||
'list': [local_rank] * 2, | |||||
'str': 'xxx' | |||||
} | |||||
data = convert_to_tensors(obj) | |||||
assert len(data) == len(obj) | |||||
assert (data['tensor'] == obj['tensor']).sum() == 2 | |||||
for name in ['list', 'str']: | |||||
assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \ | |||||
isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1 | |||||
for name in ['numpy', 'bool', 'float', 'int']: | |||||
assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1 | |||||
assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1 | |||||
@magic_argv_env_context | |||||
def test_fastnlp_torch_all_gather(): | |||||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||||
os.environ['MASTER_PORT'] = '29500' | |||||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: | |||||
os.environ['LOCAL_RANK'] = '0' | |||||
os.environ['RANK'] = '0' | |||||
os.environ['WORLD_SIZE'] = '2' | |||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | |||||
torch.distributed.init_process_group(backend='nccl') | |||||
torch.distributed.barrier() | |||||
local_rank = int(os.environ['LOCAL_RANK']) | |||||
torch.cuda.set_device(local_rank) | |||||
obj = { | |||||
'tensor': torch.full(size=(2,), fill_value=local_rank).cuda(), | |||||
'numpy': np.full(shape=(2, ), fill_value=local_rank), | |||||
'bool': local_rank%2==0, | |||||
'float': local_rank + 0.1, | |||||
'int': local_rank, | |||||
'dict': { | |||||
'rank': local_rank | |||||
}, | |||||
'list': [local_rank]*2, | |||||
'str': f'{local_rank}', | |||||
'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), | |||||
torch.full(size=(2,), fill_value=local_rank).cuda()] | |||||
} | |||||
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||||
world_size = int(os.environ['WORLD_SIZE']) | |||||
assert len(data) == world_size | |||||
for i in range(world_size): | |||||
assert (data[i]['tensor']==i).sum()==world_size | |||||
assert data[i]['numpy'][0]==i | |||||
assert data[i]['bool']==(i%2==0) | |||||
assert np.allclose(data[i]['float'], i+0.1) | |||||
assert data[i]['int'] == i | |||||
assert data[i]['dict']['rank'] == i | |||||
assert data[i]['list'][0] == i | |||||
assert data[i]['str'] == f'{i}' | |||||
assert data[i]['tensors'][0][0] == i | |||||
for obj in [1, True, 'xxx']: | |||||
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||||
assert len(data)==world_size | |||||
assert data[0]==data[1] | |||||
@magic_argv_env_context | |||||
def test_fastnlp_torch_broadcast_object(): | |||||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||||
os.environ['MASTER_PORT'] = '29500' | |||||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: | |||||
os.environ['LOCAL_RANK'] = '0' | |||||
os.environ['RANK'] = '0' | |||||
os.environ['WORLD_SIZE'] = '2' | |||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | |||||
torch.distributed.init_process_group(backend='nccl') | |||||
torch.distributed.barrier() | |||||
local_rank = int(os.environ['LOCAL_RANK']) | |||||
torch.cuda.set_device(local_rank) | |||||
if os.environ['LOCAL_RANK']=="0": | |||||
obj = { | |||||
'tensor': torch.full(size=(2,), fill_value=local_rank).cuda(), | |||||
'numpy': np.full(shape=(2, ), fill_value=local_rank), | |||||
'bool': local_rank%2==0, | |||||
'float': local_rank + 0.1, | |||||
'int': local_rank, | |||||
'dict': { | |||||
'rank': local_rank | |||||
}, | |||||
'list': [local_rank]*2, | |||||
'str': f'{local_rank}', | |||||
'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), | |||||
torch.full(size=(2,), fill_value=local_rank).cuda()] | |||||
} | |||||
else: | |||||
obj = None | |||||
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) | |||||
i = 0 | |||||
assert data['tensor'][0]==0 | |||||
assert data['numpy'][0]==0 | |||||
assert data['bool']==(i%2==0) | |||||
assert np.allclose(data['float'], i+0.1) | |||||
assert data['int'] == i | |||||
assert data['dict']['rank'] == i | |||||
assert data['list'][0] == i | |||||
assert data['str'] == f'{i}' | |||||
assert data['tensors'][0][0] == i | |||||
for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: | |||||
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) | |||||
assert int(data)==0 |
@@ -0,0 +1,79 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
# 框架无关的一些接口测试 | |||||
""" | |||||
模拟 | |||||
同一个dl,同时传入trainer和evaluator, | |||||
(1)在训练到一半进行evaluate,需要保证trainer中dl的sampler状态不受影响 | |||||
(2)evaluate设置新的set_distributed不改变原有trainer中的evaluate | |||||
""" | |||||
class SequenceDataSet: | |||||
def __init__(self, num_samples): | |||||
self.data = list(range(num_samples)) | |||||
def __getitem__(self, item): | |||||
return self.data[item] | |||||
def __len__(self): | |||||
return len(self.data) | |||||
def check_replace_sampler(driver): | |||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | |||||
# reproducible 是 True 和 False | |||||
assert driver.is_distributed() is False, "This test only for non distributed sampler." | |||||
ds = SequenceDataSet(10) | |||||
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) | |||||
dl1 = driver.replace_sampler(dataloader, dist_sampler='dist', reproducible=True) | |||||
# 迭代两个 batch | |||||
already_seen_idx = set() | |||||
for idx, batch in enumerate(dl1): | |||||
already_seen_idx.update(batch) | |||||
if idx > 1: | |||||
sampler_states = dataloader.sampler.state_dict() | |||||
break | |||||
# 再对原来的dataloader进行迭代,应该不影响 dl1 ,即 dl1 应该继续输出剩下的,而不会重复 | |||||
for idx, batch in enumerate(dataloader): | |||||
pass | |||||
left_idxes = set() | |||||
for idx, batch in enumerate(dl1): | |||||
for b in batch: | |||||
assert b not in already_seen_idx | |||||
left_idxes.update(batch) | |||||
if not driver.is_distributed(): | |||||
# 如果不是分布式的话,应该是等于整个数据的 | |||||
assert len(left_idxes)+len(already_seen_idx) == len(ds) | |||||
# 重新加载,应该是可以输出刚才完全一样的 | |||||
dl1.sampler.load_state_dict(sampler_states) | |||||
for idx, batch in enumerate(dl1): | |||||
for b in batch: | |||||
assert b not in already_seen_idx | |||||
assert b in left_idxes | |||||
@@ -0,0 +1,35 @@ | |||||
from torch.utils.data.sampler import SequentialSampler, RandomSampler | |||||
from fastNLP.core.samplers.sampler import ReproduceSampler | |||||
from tests.helpers.datasets.normal_data import NormalIterator | |||||
class TestReproduceSampler: | |||||
def test_sequentialsampler(self): | |||||
normal_iterator = NormalIterator(num_of_data=20) | |||||
sequential_sampler = SequentialSampler(normal_iterator) | |||||
reproduce_sampler = ReproduceSampler(sequential_sampler) | |||||
# iter_seq_sampler = iter(sequential_sampler) | |||||
# for each in iter_seq_sampler: | |||||
# print(each) | |||||
iter_reproduce_sampler = iter(reproduce_sampler) | |||||
forward_step = 3 | |||||
for _ in range(forward_step): | |||||
next(iter_reproduce_sampler) | |||||
state = reproduce_sampler.save_state() | |||||
assert state["current_batch_idx"] == forward_step | |||||
new_repro_sampler = ReproduceSampler(sequential_sampler) | |||||
assert new_repro_sampler.save_state()["current_batch_idx"] == 0 | |||||
new_repro_sampler.load_state(state) | |||||
iter_new_repro_sampler = iter(new_repro_sampler) | |||||
new_index_list = [] | |||||
for each in iter_new_repro_sampler: | |||||
new_index_list.append(each) | |||||
assert new_index_list == list(range(3, 20)) | |||||
@@ -1,300 +0,0 @@ | |||||
import os | |||||
import tempfile | |||||
import datetime | |||||
from pathlib import Path | |||||
import logging | |||||
import re | |||||
from fastNLP.core.envs.env import FASTNLP_LAUNCH_TIME | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | |||||
# 测试 TorchDDPDriver; | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_1(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd() | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath, mode="w") | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(filepath) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_2(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
origin_path = Path.cwd() | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_3(): | |||||
""" | |||||
path = None; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
handler = logger.add_file() | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
# print(f"\nrank: {driver.get_local_rank()} line, {line}\n") | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(file) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_4(): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd().joinpath("not_existed") | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
class TestLogger: | |||||
msg = 'some test log msg' | |||||
def test_add_file_1(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(self.msg) | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_2(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
origin_path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
path = path.joinpath('log.txt') | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
with open(path, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(origin_path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_3(self): | |||||
""" | |||||
测试 path 是 None; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.add_file() | |||||
logger.info(self.msg) | |||||
path = Path.cwd() | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
file.unlink() | |||||
logger.removeHandler(handler) | |||||
def test_add_file_4(self): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_stdout(self, capsys): | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.set_stdout(stdout="raw") | |||||
logger.info(self.msg) | |||||
logger.debug('aabbc') | |||||
captured = capsys.readouterr() | |||||
assert "some test log msg\n" == captured.out | |||||
logger.removeHandler(handler) | |||||
@@ -0,0 +1,508 @@ | |||||
import unittest | |||||
from itertools import product | |||||
import numpy as np | |||||
from functools import partial | |||||
from array import array | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, ReproducibleBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
class TestRandomSamplerYh(unittest.TestCase): | |||||
def test_init(self): | |||||
# 测试能否正确初始化 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
sampler = RandomSampler(dataset) | |||||
for i in sampler: | |||||
pass | |||||
def test_during_iter(self): | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
sampler = RandomSampler(dataset) | |||||
for i in sampler: | |||||
with self.assertRaises(AssertionError): | |||||
sampler.set_distributed(1, 0) | |||||
break | |||||
# should not raise | |||||
for i in sampler: | |||||
pass | |||||
sampler.set_distributed(1, 0) | |||||
def test_set_distributed(self): | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
sampler = RandomSampler(dataset, shuffle=False) | |||||
sampler.set_distributed(num_replicas=2, rank=0, pad=False) | |||||
self.assertEqual(len(sampler), 50) | |||||
count = 0 | |||||
for i in sampler: | |||||
self.assertEqual(i%2, 0) | |||||
count += 1 | |||||
self.assertEqual(count, 50) | |||||
sampler.set_distributed(num_replicas=2, rank=1, pad=False) | |||||
self.assertEqual(len(sampler), 50) | |||||
count = 0 | |||||
for i in sampler: | |||||
self.assertEqual(i%2, 1) | |||||
count += 1 | |||||
self.assertEqual(count, 50) | |||||
dataset = TorchNormalDataset(num_of_data=101) | |||||
sampler = RandomSampler(dataset, shuffle=False) | |||||
sampler.set_distributed(num_replicas=2, rank=0, pad=True) | |||||
self.assertEqual(len(sampler), 51) | |||||
count = 0 | |||||
for i in sampler: | |||||
self.assertEqual(i%2, 0) | |||||
count += 1 | |||||
self.assertEqual(count, 51) | |||||
sampler.set_distributed(num_replicas=2, rank=1, pad=True) | |||||
self.assertEqual(len(sampler), 51) | |||||
count = 0 | |||||
for i in sampler: | |||||
if i!=0: | |||||
self.assertEqual(i%2, 1) | |||||
count += 1 | |||||
self.assertEqual(count, 51) | |||||
def test_state_dict_check_length(self): | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
sampler = RandomSampler(dataset, shuffle=False) | |||||
states = sampler.state_dict() | |||||
new_ds = TorchNormalDataset(num_of_data=10) | |||||
with self.assertRaises(AssertionError): | |||||
new_sampler = RandomSampler(new_ds) | |||||
new_sampler.load_state_dict(states) | |||||
new_ds = TorchNormalDataset(num_of_data=100) | |||||
new_sampler = RandomSampler(new_ds) | |||||
new_sampler.load_state_dict(states) | |||||
def test_state_dict(self): | |||||
num_samples = 100 | |||||
dataset = TorchNormalDataset(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
lst = [0]+np.random.randint(1, num_samples, size=3).tolist() | |||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||||
lst): | |||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
self.assertEqual(len(already_numbers), num_consumed_samples) | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, already_numbers) | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, other_rank_number) | |||||
other_rank_number.add(i) | |||||
self.assertNotIn(i, already_numbers) | |||||
count += 1 | |||||
def test_state_dict_2(self): | |||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||||
num_samples = 100 | |||||
dataset = TorchNormalDataset(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() | |||||
# lst = [30] | |||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||||
lst): | |||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||||
already_numbers = set() | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
self.assertEqual(len(already_numbers), num_consumed_samples*2) | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, already_numbers) | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||||
count = 0 | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, other_rank_number) | |||||
other_rank_number.add(i) | |||||
self.assertNotIn(i, already_numbers) | |||||
count += 1 | |||||
class TestRandomSampler(unittest.TestCase): | |||||
# 测试单卡; | |||||
def test_seed_work_when_shuffle_is_true(self): | |||||
data_length = 100 | |||||
torch_normal_data = TorchNormalDataset(num_of_data=data_length) | |||||
for shuffle in [True, False]: | |||||
iterable = RandomSampler(dataset=torch_normal_data, shuffle=shuffle) | |||||
# 迭代一些数据,但是不迭代完; | |||||
iterable.set_epoch(1) | |||||
iterator = iter(iterable) | |||||
pre_data = [] | |||||
forward_steps = 30 | |||||
for _ in range(forward_steps): | |||||
pre_data.append(next(iterator)) | |||||
# 看重新生成迭代器是否能够完全重置状态; | |||||
iterator = iter(iterable) | |||||
res = [] | |||||
for _ in range(forward_steps): | |||||
res.append(next(iterator)) | |||||
assert pre_data == res | |||||
# 测试断点重训; | |||||
# 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的; | |||||
def test_2(self): | |||||
data_length = 100 | |||||
torch_normal_data = TorchNormalDataset(num_of_data=data_length) | |||||
random_sampler_1 = RandomSampler(dataset=torch_normal_data, shuffle=True) | |||||
iterator = iter(random_sampler_1) | |||||
# 第一轮 | |||||
random_sampler_1.set_epoch(0) | |||||
first_epoch = [] | |||||
forward_steps = 30 | |||||
for _ in range(forward_steps): | |||||
first_epoch.append(next(iterator)) | |||||
# 先提前保存断点重训的结果; | |||||
state = random_sampler_1.state_dict() | |||||
# 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确; | |||||
first_left_data = [] | |||||
while True: | |||||
try: | |||||
first_left_data.append(next(iterator)) | |||||
except StopIteration: | |||||
break | |||||
# 第二轮 | |||||
random_sampler_1.set_epoch(1) | |||||
iterator = iter(random_sampler_1) | |||||
second_epoch = [] | |||||
for _ in range(forward_steps): | |||||
second_epoch.append(next(iterator)) | |||||
assert first_epoch != second_epoch | |||||
# 重新加载第一轮的状态,查看断点重训是否正确; | |||||
random_sampler_2 = RandomSampler(dataset=torch_normal_data, shuffle=True) | |||||
random_sampler_2.load_state_dict(state) | |||||
random_sampler_2.set_epoch(0) | |||||
iterator = iter(random_sampler_2) | |||||
re_first_epoch = [] | |||||
while True: | |||||
try: | |||||
re_first_epoch.append(next(iterator)) | |||||
except StopIteration: | |||||
break | |||||
assert re_first_epoch == first_left_data | |||||
# 查看第二轮的结果是否也是和第一次的第二轮完全一致; | |||||
random_sampler_2.set_epoch(1) | |||||
iterator = iter(random_sampler_2) | |||||
re_second_epoch = [] | |||||
for _ in range(forward_steps): | |||||
re_second_epoch.append(next(iterator)) | |||||
assert re_second_epoch == second_epoch | |||||
# 多卡; | |||||
# 如果一个 sampler 还没有迭代完,我们又直接 iter(sampler) 那么是否正确(应当生成一个全新的 sampler)? | |||||
def test_3(self): | |||||
data_length = 100 | |||||
torch_normal_data = TorchNormalDataset(num_of_data=data_length) | |||||
random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=False) | |||||
random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True) | |||||
iterable_items = [random_sampler_1, random_sampler_2] | |||||
world_size = 3 | |||||
for pad in {True, False}: | |||||
for iterable in iterable_items: | |||||
for rank in range(world_size): | |||||
each_rank_iterable = iterable() | |||||
each_rank_iterable.set_epoch(0) | |||||
each_rank_iterable.set_distributed(num_replicas=world_size, rank=rank, pad=pad) | |||||
# 迭代一些数据,但是不迭代完; | |||||
iterator = iter(each_rank_iterable) | |||||
pre_data = [] | |||||
forward_steps = 10 | |||||
for _ in range(forward_steps): | |||||
pre_data.append(next(iterator)) | |||||
# 看重新生成迭代器是否能够完全重置状态; | |||||
iterator = iter(each_rank_iterable) | |||||
res = [] | |||||
for _ in range(forward_steps): | |||||
res.append(next(iterator)) | |||||
assert res == pre_data | |||||
# 测试断点重训; | |||||
# 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的; | |||||
def test_4(self): | |||||
data_length = 100 | |||||
torch_normal_data = TorchNormalDataset(num_of_data=data_length) | |||||
random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True) | |||||
world_size_1 = 2 | |||||
forward_steps = 10 | |||||
for pad in {True, False}: | |||||
all_rank_state = {} | |||||
all_rank_first_left_data = {} | |||||
all_rank_second_epoch = {} | |||||
for rank in range(world_size_1): | |||||
each_rank_iterable = random_sampler_1() | |||||
each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad) | |||||
iterator = iter(each_rank_iterable) | |||||
# 第一轮 | |||||
each_rank_iterable.set_epoch(0) | |||||
first_epoch = [] | |||||
for _ in range(forward_steps): | |||||
first_epoch.append(next(iterator)) | |||||
# 先提前保存断点重训的结果; | |||||
all_rank_state[rank] = each_rank_iterable.state_dict() | |||||
# 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确; | |||||
first_left_data = [] | |||||
while True: | |||||
try: | |||||
first_left_data.append(next(iterator)) | |||||
except StopIteration: | |||||
break | |||||
all_rank_first_left_data[rank] = first_left_data | |||||
# 第二轮 | |||||
each_rank_iterable.set_epoch(1) | |||||
iterator = iter(each_rank_iterable) | |||||
second_epoch = [] | |||||
for _ in range(forward_steps): | |||||
second_epoch.append(next(iterator)) | |||||
all_rank_second_epoch[rank] = second_epoch | |||||
assert first_epoch != second_epoch | |||||
# 重新加载第一轮的状态,查看断点重训是否正确; | |||||
random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True) | |||||
for rank in range(world_size_1): | |||||
each_rank_iterable = random_sampler_2() | |||||
each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad) | |||||
each_rank_iterable.load_state_dict(all_rank_state[rank]) | |||||
each_rank_iterable.set_epoch(0) | |||||
iterator = iter(each_rank_iterable) | |||||
re_first_epoch = [] | |||||
while True: | |||||
try: | |||||
re_first_epoch.append(next(iterator)) | |||||
except StopIteration: | |||||
break | |||||
assert re_first_epoch == all_rank_first_left_data[rank] | |||||
# 查看第二轮的结果是否也是和第一次的第二轮完全一致; | |||||
each_rank_iterable.set_epoch(1) | |||||
iterator = iter(each_rank_iterable) | |||||
re_second_epoch = [] | |||||
for _ in range(forward_steps): | |||||
re_second_epoch.append(next(iterator)) | |||||
assert re_second_epoch == all_rank_second_epoch[rank] | |||||
# todo 测试 ddp 时 world_size 改变的断点重训; | |||||
def test_5(self): | |||||
... | |||||
class TestReproducibleBatchSampler: | |||||
def test_torch_dataloader_1(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
"sampler_type": "ReproducibleBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.append(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
def test_3(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader, RandomSampler, BatchSampler | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
for idx, data in enumerate(dataloader): | |||||
if idx > 3: | |||||
break | |||||
iterator = iter(dataloader) | |||||
for each in iterator: | |||||
pass |
@@ -0,0 +1,38 @@ | |||||
import unittest | |||||
import random | |||||
from fastNLP.core.samplers import SequentialSampler, RandomSampler, BucketSampler | |||||
from fastNLP.core.dataset import DataSet | |||||
from array import array | |||||
import torch | |||||
from fastNLP.core.samplers.sampler import ReproduceBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
class SamplerTest(unittest.TestCase): | |||||
def test_sequentialsampler(self): | |||||
ds = DataSet({'x': [1, 2, 3, 4] * 10}) | |||||
sqspl = SequentialSampler(ds) | |||||
for idx, inst in enumerate(sqspl): | |||||
self.assertEqual(idx, inst) | |||||
def test_randomsampler(self): | |||||
ds = DataSet({'x': [1, 2, 3, 4] * 10}) | |||||
rdspl = RandomSampler(ds) | |||||
ans = [ds[i] for i in rdspl] | |||||
self.assertEqual(len(ans), len(ds)) | |||||
def test_bucketsampler(self): | |||||
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10}) | |||||
sampler = BucketSampler(data_set, num_buckets=3, batch_size=16, seq_len_field_name="seq_len") | |||||
@@ -0,0 +1,7 @@ | |||||
class Demo: | |||||
def __init__(self): | |||||
pass | |||||
def demo(self): | |||||
b = 1 | |||||
return b |
@@ -0,0 +1,8 @@ | |||||
class Demo: | |||||
def __init__(self): | |||||
self.b = 1 | |||||
def demo(self): | |||||
b = 1 | |||||
return b | |||||
@@ -0,0 +1,304 @@ | |||||
import time | |||||
import os | |||||
import pytest | |||||
from subprocess import Popen, PIPE | |||||
from io import StringIO | |||||
import sys | |||||
from fastNLP.core.utils.cache_results import cache_results | |||||
from tests.helpers.common.utils import check_time_elapse | |||||
from fastNLP.core import synchronize_safe_rm | |||||
def get_subprocess_results(cmd): | |||||
pipe = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) | |||||
output, err = pipe.communicate() | |||||
if output: | |||||
output = output.decode('utf8') | |||||
else: | |||||
output = '' | |||||
if err: | |||||
err = err.decode('utf8') | |||||
else: | |||||
err = '' | |||||
res = output + err | |||||
return res | |||||
class Capturing(list): | |||||
# 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面 | |||||
def __enter__(self): | |||||
self._stdout = sys.stdout | |||||
self._stderr = sys.stderr | |||||
sys.stdout = self._stringio = StringIO() | |||||
sys.stderr = self._stringioerr = StringIO() | |||||
return self | |||||
def __exit__(self, *args): | |||||
self.append(self._stringio.getvalue() + self._stringioerr.getvalue()) | |||||
del self._stringio, self._stringioerr # free up some memory | |||||
sys.stdout = self._stdout | |||||
sys.stderr = self._stderr | |||||
class TestCacheResults: | |||||
def test_cache_save(self): | |||||
cache_fp = 'demo.pkl' | |||||
try: | |||||
@cache_results(cache_fp) | |||||
def demo(): | |||||
time.sleep(1) | |||||
return 1 | |||||
res = demo() | |||||
with check_time_elapse(1, op='lt'): | |||||
res = demo() | |||||
finally: | |||||
synchronize_safe_rm(cache_fp) | |||||
def test_cache_save_refresh(self): | |||||
cache_fp = 'demo.pkl' | |||||
try: | |||||
@cache_results(cache_fp, _refresh=True) | |||||
def demo(): | |||||
time.sleep(1.5) | |||||
return 1 | |||||
res = demo() | |||||
with check_time_elapse(1, op='ge'): | |||||
res = demo() | |||||
finally: | |||||
synchronize_safe_rm(cache_fp) | |||||
def test_cache_no_func_change(self): | |||||
cache_fp = os.path.abspath('demo.pkl') | |||||
try: | |||||
@cache_results(cache_fp) | |||||
def demo(): | |||||
time.sleep(2) | |||||
return 1 | |||||
with check_time_elapse(1, op='gt'): | |||||
res = demo() | |||||
@cache_results(cache_fp) | |||||
def demo(): | |||||
time.sleep(2) | |||||
return 1 | |||||
with check_time_elapse(1, op='lt'): | |||||
res = demo() | |||||
finally: | |||||
synchronize_safe_rm('demo.pkl') | |||||
def test_cache_func_change(self, capsys): | |||||
cache_fp = 'demo.pkl' | |||||
try: | |||||
@cache_results(cache_fp) | |||||
def demo(): | |||||
time.sleep(2) | |||||
return 1 | |||||
with check_time_elapse(1, op='gt'): | |||||
res = demo() | |||||
@cache_results(cache_fp) | |||||
def demo(): | |||||
time.sleep(1) | |||||
return 1 | |||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo() | |||||
assert 'is different from its last cache' in output[0] | |||||
# 关闭check_hash应该不warning的 | |||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo(_check_hash=0) | |||||
assert 'is different from its last cache' not in output[0] | |||||
finally: | |||||
synchronize_safe_rm('demo.pkl') | |||||
def test_cache_check_hash(self): | |||||
cache_fp = 'demo.pkl' | |||||
try: | |||||
@cache_results(cache_fp, _check_hash=False) | |||||
def demo(): | |||||
time.sleep(2) | |||||
return 1 | |||||
with check_time_elapse(1, op='gt'): | |||||
res = demo() | |||||
@cache_results(cache_fp, _check_hash=False) | |||||
def demo(): | |||||
time.sleep(1) | |||||
return 1 | |||||
# 默认不会check | |||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo() | |||||
assert 'is different from its last cache' not in output[0] | |||||
# check也可以 | |||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo(_check_hash=True) | |||||
assert 'is different from its last cache' in output[0] | |||||
finally: | |||||
synchronize_safe_rm('demo.pkl') | |||||
# 外部 function 改变也会 导致改变 | |||||
def test_refer_fun_change(self): | |||||
cache_fp = 'demo.pkl' | |||||
test_type = 'func_refer_fun_change' | |||||
try: | |||||
with check_time_elapse(3, op='gt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
# 引用的function没有变化 | |||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'Read cache from' in res | |||||
assert 'is different from its last cache' not in res | |||||
# 引用的function有变化 | |||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'is different from its last cache' in res | |||||
finally: | |||||
synchronize_safe_rm(cache_fp) | |||||
# 外部 method 改变也会 导致改变 | |||||
def test_refer_class_method_change(self): | |||||
cache_fp = 'demo.pkl' | |||||
test_type = 'refer_class_method_change' | |||||
try: | |||||
with check_time_elapse(3, op='gt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
# 引用的class没有变化 | |||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'Read cache from' in res | |||||
assert 'is different from its last cache' not in res | |||||
# 引用的class有变化 | |||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'is different from its last cache' in res | |||||
finally: | |||||
synchronize_safe_rm(cache_fp) | |||||
def test_duplicate_keyword(self): | |||||
with pytest.raises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_verbose(a, _verbose): | |||||
pass | |||||
func_verbose(0, 1) | |||||
with pytest.raises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_cache(a, _cache_fp): | |||||
pass | |||||
func_cache(1, 2) | |||||
with pytest.raises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_refresh(a, _refresh): | |||||
pass | |||||
func_refresh(1, 2) | |||||
with pytest.raises(RuntimeError): | |||||
@cache_results(None) | |||||
def func_refresh(a, _check_hash): | |||||
pass | |||||
func_refresh(1, 2) | |||||
def test_create_cache_dir(self): | |||||
@cache_results('demo/demo.pkl') | |||||
def cache(): | |||||
return 1, 2 | |||||
try: | |||||
results = cache() | |||||
assert (1, 2) == results | |||||
finally: | |||||
synchronize_safe_rm('demo/') | |||||
def test_result_none_error(self): | |||||
@cache_results('demo.pkl') | |||||
def cache(): | |||||
pass | |||||
try: | |||||
with pytest.raises(RuntimeError): | |||||
results = cache() | |||||
finally: | |||||
synchronize_safe_rm('demo.pkl') | |||||
if __name__ == '__main__': | |||||
import argparse | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--test_type', type=str, default='refer_class_method_change') | |||||
parser.add_argument('--turn', type=int, default=1) | |||||
parser.add_argument('--cache_fp', type=str, default='demo.pkl') | |||||
args = parser.parse_args() | |||||
test_type = args.test_type | |||||
cache_fp = args.cache_fp | |||||
turn = args.turn | |||||
if test_type == 'func_refer_fun_change': | |||||
if turn == 0: | |||||
def demo(): | |||||
b = 1 | |||||
return b | |||||
else: | |||||
def demo(): | |||||
b = 2 | |||||
return b | |||||
@cache_results(cache_fp) | |||||
def demo_refer_other_func(): | |||||
time.sleep(3) | |||||
b = demo() | |||||
return b | |||||
res = demo_refer_other_func() | |||||
if test_type == 'refer_class_method_change': | |||||
print(f"Turn:{turn}") | |||||
if turn == 0: | |||||
from helper_for_cache_results_1 import Demo | |||||
else: | |||||
from helper_for_cache_results_2 import Demo | |||||
demo = Demo() | |||||
# import pdb | |||||
# pdb.set_trace() | |||||
@cache_results(cache_fp) | |||||
def demo_func(): | |||||
time.sleep(3) | |||||
b = demo.demo() | |||||
return b | |||||
res = demo_func() | |||||
@@ -0,0 +1,91 @@ | |||||
import os | |||||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call | |||||
from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context | |||||
@rank_zero_call | |||||
def write_something(): | |||||
print(os.environ.get('RANK', '0')*5, flush=True) | |||||
def write_other_thing(): | |||||
print(os.environ.get('RANK', '0')*5, flush=True) | |||||
class PaddleTest: | |||||
# @x54-729 | |||||
def test_rank_zero_call(self): | |||||
pass | |||||
def test_all_rank_run(self): | |||||
pass | |||||
class JittorTest: | |||||
# @x54-729 | |||||
def test_rank_zero_call(self): | |||||
pass | |||||
def test_all_rank_run(self): | |||||
pass | |||||
class TestTorch: | |||||
@magic_argv_env_context | |||||
def test_rank_zero_call(self): | |||||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||||
os.environ['MASTER_PORT'] = '29500' | |||||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: | |||||
os.environ['LOCAL_RANK'] = '0' | |||||
os.environ['RANK'] = '0' | |||||
os.environ['WORLD_SIZE'] = '2' | |||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | |||||
with Capturing() as output: | |||||
write_something() | |||||
output = output[0] | |||||
if os.environ['LOCAL_RANK'] == '0': | |||||
assert '00000' in output and '11111' not in output | |||||
else: | |||||
assert '00000' not in output and '11111' not in output | |||||
with Capturing() as output: | |||||
rank_zero_call(write_other_thing)() | |||||
output = output[0] | |||||
if os.environ['LOCAL_RANK'] == '0': | |||||
assert '00000' in output and '11111' not in output | |||||
else: | |||||
assert '00000' not in output and '11111' not in output | |||||
@magic_argv_env_context | |||||
def test_all_rank_run(self): | |||||
os.environ['MASTER_ADDR'] = '127.0.0.1' | |||||
os.environ['MASTER_PORT'] = '29500' | |||||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ: | |||||
os.environ['LOCAL_RANK'] = '0' | |||||
os.environ['RANK'] = '0' | |||||
os.environ['WORLD_SIZE'] = '2' | |||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all') | |||||
# torch.distributed.init_process_group(backend='nccl') | |||||
# torch.distributed.barrier() | |||||
with all_rank_call(): | |||||
with Capturing(no_del=True) as output: | |||||
write_something() | |||||
output = output[0] | |||||
if os.environ['LOCAL_RANK'] == '0': | |||||
assert '00000' in output | |||||
else: | |||||
assert '11111' in output | |||||
with all_rank_call(): | |||||
with Capturing(no_del=True) as output: | |||||
rank_zero_call(write_other_thing)() | |||||
output = output[0] | |||||
if os.environ['LOCAL_RANK'] == '0': | |||||
assert '00000' in output | |||||
else: | |||||
assert '11111' in output |
@@ -0,0 +1,200 @@ | |||||
import unittest | |||||
import paddle | |||||
from fastNLP.core.utils.paddle_utils import paddle_to, paddle_move_data_to_device | |||||
############################################################################ | |||||
# | |||||
# 测试仅将单个paddle张量迁移到指定设备 | |||||
# | |||||
############################################################################ | |||||
class PaddleToDeviceTestCase(unittest.TestCase): | |||||
def test_case(self): | |||||
tensor = paddle.rand((4, 5)) | |||||
res = paddle_to(tensor, "gpu") | |||||
self.assertTrue(res.place.is_gpu_place()) | |||||
self.assertEqual(res.place.gpu_device_id(), 0) | |||||
res = paddle_to(tensor, "cpu") | |||||
self.assertTrue(res.place.is_cpu_place()) | |||||
res = paddle_to(tensor, "gpu:2") | |||||
self.assertTrue(res.place.is_gpu_place()) | |||||
self.assertEqual(res.place.gpu_device_id(), 2) | |||||
res = paddle_to(tensor, "gpu:1") | |||||
self.assertTrue(res.place.is_gpu_place()) | |||||
self.assertEqual(res.place.gpu_device_id(), 1) | |||||
############################################################################ | |||||
# | |||||
# 测试将参数中包含的所有paddle张量迁移到指定设备 | |||||
# | |||||
############################################################################ | |||||
class PaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
def check_gpu(self, tensor, idx): | |||||
""" | |||||
检查张量是否在指定的设备上的工具函数 | |||||
""" | |||||
self.assertTrue(tensor.place.is_gpu_place()) | |||||
self.assertEqual(tensor.place.gpu_device_id(), idx) | |||||
def check_cpu(self, tensor): | |||||
""" | |||||
检查张量是否在cpu上的工具函数 | |||||
""" | |||||
self.assertTrue(tensor.place.is_cpu_place()) | |||||
def test_tensor_transfer(self): | |||||
""" | |||||
测试单个张量的迁移 | |||||
""" | |||||
paddle_tensor = paddle.rand((3, 4, 5)).cpu() | |||||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device=None) | |||||
self.check_cpu(res) | |||||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) | |||||
self.check_gpu(res, 0) | |||||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None) | |||||
self.check_gpu(res, 1) | |||||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device="cpu") | |||||
self.check_gpu(res, 0) | |||||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") | |||||
self.check_gpu(res, 0) | |||||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:1") | |||||
self.check_gpu(res, 1) | |||||
def test_list_transfer(self): | |||||
""" | |||||
测试张量列表的迁移 | |||||
""" | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||||
res = paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_cpu(r) | |||||
res = paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_gpu(r, 0) | |||||
res = paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
def test_tensor_tuple_transfer(self): | |||||
""" | |||||
测试张量元组的迁移 | |||||
""" | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||||
paddle_tuple = tuple(paddle_list) | |||||
res = paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_cpu(r) | |||||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_gpu(r, 0) | |||||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
def test_dict_transfer(self): | |||||
""" | |||||
测试字典结构的迁移 | |||||
""" | |||||
paddle_dict = { | |||||
"tensor": paddle.rand((3, 4)), | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||||
"tensor": paddle.rand((3, 4)) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | |||||
self.assertIsInstance(res, dict) | |||||
self.check_gpu(res["tensor"], 0) | |||||
self.assertIsInstance(res["list"], list) | |||||
for t in res["list"]: | |||||
self.check_gpu(t, 0) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | |||||
self.check_gpu(t, 0) | |||||
self.check_gpu(res["dict"]["tensor"], 0) | |||||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device="cpu") | |||||
self.assertIsInstance(res, dict) | |||||
self.check_gpu(res["tensor"], 0) | |||||
self.assertIsInstance(res["list"], list) | |||||
for t in res["list"]: | |||||
self.check_gpu(t, 0) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | |||||
self.check_gpu(t, 0) | |||||
self.check_gpu(res["dict"]["tensor"], 0) | |||||
res = paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | |||||
self.assertIsInstance(res, dict) | |||||
self.check_gpu(res["tensor"], 1) | |||||
self.assertIsInstance(res["list"], list) | |||||
for t in res["list"]: | |||||
self.check_gpu(t, 1) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | |||||
self.check_gpu(t, 1) | |||||
self.check_gpu(res["dict"]["tensor"], 1) | |||||
res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | |||||
self.assertIsInstance(res, dict) | |||||
self.check_cpu(res["tensor"]) | |||||
self.assertIsInstance(res["list"], list) | |||||
for t in res["list"]: | |||||
self.check_cpu(t) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | |||||
self.check_cpu(t) | |||||
self.check_cpu(res["dict"]["tensor"]) |
@@ -0,0 +1,205 @@ | |||||
import unittest | |||||
import paddle | |||||
import torch | |||||
from fastNLP.core.utils.torch_paddle_utils import torch_paddle_move_data_to_device | |||||
############################################################################ | |||||
# | |||||
# 测试将参数中包含的所有torch和paddle张量迁移到指定设备 | |||||
# | |||||
############################################################################ | |||||
class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
def check_gpu(self, tensor, idx): | |||||
""" | |||||
检查张量是否在指定显卡上的工具函数 | |||||
""" | |||||
if isinstance(tensor, paddle.Tensor): | |||||
self.assertTrue(tensor.place.is_gpu_place()) | |||||
self.assertEqual(tensor.place.gpu_device_id(), idx) | |||||
elif isinstance(tensor, torch.Tensor): | |||||
self.assertTrue(tensor.is_cuda) | |||||
self.assertEqual(tensor.device.index, idx) | |||||
def check_cpu(self, tensor): | |||||
if isinstance(tensor, paddle.Tensor): | |||||
self.assertTrue(tensor.place.is_cpu_place()) | |||||
elif isinstance(tensor, torch.Tensor): | |||||
self.assertFalse(tensor.is_cuda) | |||||
def test_tensor_transfer(self): | |||||
""" | |||||
测试迁移单个张量 | |||||
""" | |||||
paddle_tensor = paddle.rand((3, 4, 5)).cpu() | |||||
res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device=None) | |||||
self.check_cpu(res) | |||||
res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) | |||||
self.check_gpu(res, 0) | |||||
res = torch_paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None) | |||||
self.check_gpu(res, 1) | |||||
res = torch_paddle_move_data_to_device(paddle_tensor, device="cuda:0", data_device="cpu") | |||||
self.check_gpu(res, 0) | |||||
res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") | |||||
self.check_gpu(res, 0) | |||||
res = torch_paddle_move_data_to_device(paddle_tensor, device=None, data_device="cuda:1") | |||||
self.check_gpu(res, 1) | |||||
torch_tensor = torch.rand(3, 4, 5) | |||||
res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device=None) | |||||
self.check_cpu(res) | |||||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device=None) | |||||
print(res.device) | |||||
self.check_gpu(res, 0) | |||||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:1", data_device=None) | |||||
self.check_gpu(res, 1) | |||||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device="cpu") | |||||
self.check_gpu(res, 0) | |||||
res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:0") | |||||
self.check_gpu(res, 0) | |||||
res = torch_paddle_move_data_to_device(torch_tensor, device=None, data_device="gpu:1") | |||||
self.check_gpu(res, 1) | |||||
def test_list_transfer(self): | |||||
""" | |||||
测试迁移张量的列表 | |||||
""" | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)] | |||||
res = torch_paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
res = torch_paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_cpu(r) | |||||
res = torch_paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_gpu(r, 0) | |||||
res = torch_paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | |||||
self.assertIsInstance(res, list) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
def test_tensor_tuple_transfer(self): | |||||
""" | |||||
测试迁移张量的元组 | |||||
""" | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + [torch.rand((6, 4, 2)) for i in range(5)] | |||||
paddle_tuple = tuple(paddle_list) | |||||
res = torch_paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
res = torch_paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_cpu(r) | |||||
res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_gpu(r, 0) | |||||
res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | |||||
self.assertIsInstance(res, tuple) | |||||
for r in res: | |||||
self.check_gpu(r, 1) | |||||
def test_dict_transfer(self): | |||||
""" | |||||
测试迁移复杂的字典结构 | |||||
""" | |||||
paddle_dict = { | |||||
"torch_tensor": torch.rand((3, 4)), | |||||
"torch_list": [torch.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)], | |||||
"torch_tensor": torch.rand((3, 4)), | |||||
"paddle_tensor": paddle.rand((3, 4)) | |||||
}, | |||||
"paddle_tensor": paddle.rand((3, 4)), | |||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)] , | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = torch_paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | |||||
self.assertIsInstance(res, dict) | |||||
self.check_gpu(res["torch_tensor"], 0) | |||||
self.check_gpu(res["paddle_tensor"], 0) | |||||
self.assertIsInstance(res["torch_list"], list) | |||||
for t in res["torch_list"]: | |||||
self.check_gpu(t, 0) | |||||
self.assertIsInstance(res["list"], list) | |||||
for t in res["list"]: | |||||
self.check_gpu(t, 0) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | |||||
self.check_gpu(t, 0) | |||||
self.check_gpu(res["dict"]["torch_tensor"], 0) | |||||
self.check_gpu(res["dict"]["paddle_tensor"], 0) | |||||
res = torch_paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | |||||
self.assertIsInstance(res, dict) | |||||
self.check_gpu(res["torch_tensor"], 1) | |||||
self.check_gpu(res["paddle_tensor"], 1) | |||||
self.assertIsInstance(res["torch_list"], list) | |||||
for t in res["torch_list"]: | |||||
self.check_gpu(t, 1) | |||||
self.assertIsInstance(res["list"], list) | |||||
for t in res["list"]: | |||||
self.check_gpu(t, 1) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | |||||
self.check_gpu(t, 1) | |||||
self.check_gpu(res["dict"]["torch_tensor"], 1) | |||||
self.check_gpu(res["dict"]["paddle_tensor"], 1) | |||||
res = torch_paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | |||||
self.assertIsInstance(res, dict) | |||||
self.check_cpu(res["torch_tensor"]) | |||||
self.check_cpu(res["paddle_tensor"]) | |||||
self.assertIsInstance(res["torch_list"], list) | |||||
for t in res["torch_list"]: | |||||
self.check_cpu(t) | |||||
self.assertIsInstance(res["list"], list) | |||||
for t in res["list"]: | |||||
self.check_cpu(t) | |||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | |||||
self.check_cpu(t) | |||||
self.check_cpu(res["dict"]["torch_tensor"]) | |||||
self.check_cpu(res["dict"]["paddle_tensor"]) |
@@ -0,0 +1,122 @@ | |||||
""" | |||||
该模块用于实现一些帮助我们在测试的 callback 类; | |||||
""" | |||||
from fastNLP.core.callbacks.callback import Callback | |||||
class RecordLossCallback(Callback): | |||||
""" | |||||
通过该 callback 来测试模型的训练是否基本正常; | |||||
""" | |||||
def __init__(self, loss_threshold: float): | |||||
self.loss = None | |||||
self.loss_threshold = loss_threshold | |||||
self.loss_begin_value = None | |||||
def on_before_backward(self, trainer, outputs): | |||||
loss = trainer.extract_loss_from_outputs(outputs) | |||||
loss = trainer.driver.tensor_to_numeric(loss) | |||||
self.loss = loss | |||||
if self.loss_begin_value is None: | |||||
self.loss_begin_value = loss | |||||
def on_train_end(self, trainer): | |||||
assert self.loss < self.loss_begin_value | |||||
if self.loss_threshold is not None: | |||||
assert self.loss < self.loss_threshold | |||||
class RecordMetricCallback(Callback): | |||||
""" | |||||
通过该 callback 来测试带有 metrics 的 Trainer 是否训练测试正确; | |||||
""" | |||||
def __init__(self, monitor: str, metric_threshold: float, larger_better: bool): | |||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
self.metric = None | |||||
self.metric_threshold = metric_threshold | |||||
self.metric_begin_value = None | |||||
def on_validate_end(self, trainer, results): | |||||
self.metric = results[self.monitor] | |||||
if self.metric_begin_value is None: | |||||
self.metric_begin_value = self.metric | |||||
def on_train_end(self, trainer): | |||||
if self.larger_better: | |||||
assert self.metric >= self.metric_begin_value | |||||
assert self.metric > self.metric_threshold | |||||
else: | |||||
assert self.metric <= self.metric_begin_value | |||||
assert self.metric < self.metric_threshold | |||||
class RecordTrainerEventTriggerCallback(Callback): | |||||
""" | |||||
测试每一个 callback 是否在 trainer 中都得到了调用; | |||||
""" | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
print("on_after_trainer_initialized") | |||||
def on_sanity_check_begin(self, trainer): | |||||
print("on_sanity_check_begin") | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
print("on_sanity_check_end") | |||||
def on_train_begin(self, trainer): | |||||
print("on_train_begin") | |||||
def on_train_end(self, trainer): | |||||
print("on_train_end") | |||||
def on_train_epoch_begin(self, trainer): | |||||
if trainer.current_epoch_idx >= 1: | |||||
# 触发 on_exception; | |||||
raise Exception | |||||
print("on_train_epoch_begin") | |||||
def on_train_epoch_end(self, trainer): | |||||
print("on_train_epoch_end") | |||||
def on_fetch_data_begin(self, trainer): | |||||
print("on_fetch_data_begin") | |||||
def on_fetch_data_end(self, trainer): | |||||
print("on_fetch_data_end") | |||||
def on_train_batch_begin(self, trainer, batch, indices=None): | |||||
print("on_train_batch_begin") | |||||
def on_train_batch_end(self, trainer): | |||||
print("on_train_batch_end") | |||||
def on_exception(self, trainer, exception): | |||||
print("on_exception") | |||||
def on_before_backward(self, trainer, outputs): | |||||
print("on_before_backward") | |||||
def on_after_backward(self, trainer): | |||||
print("on_after_backward") | |||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
print("on_before_optimizer_step") | |||||
def on_before_zero_grad(self, trainer, optimizers): | |||||
print("on_before_zero_grad") | |||||
def on_validate_begin(self, trainer): | |||||
print("on_validate_begin") | |||||
def on_validate_end(self, trainer, results): | |||||
print("on_validate_end") | |||||
@@ -0,0 +1,56 @@ | |||||
import torch | |||||
from copy import deepcopy | |||||
from fastNLP.core.callbacks.callback import Callback | |||||
class RecordAccumulationStepsCallback_Torch(Callback): | |||||
""" | |||||
通过该 callback 来测试 Trainer 的 accumulation_steps 是否实现正确; | |||||
1. 在每一个 batch 检验模型是否正确地得到了更新(只有每隔 accumulation_steps 模型的参数才应该改变); | |||||
2. 检验 optimizer 的参数是否只在正确的时刻进行了清零; | |||||
""" | |||||
def __init__(self, accumulation_steps: int): | |||||
self.accumulation_steps = accumulation_steps | |||||
self.last_batch_params = None | |||||
self.equal = 0 | |||||
def on_train_batch_end(self, trainer): | |||||
# 注意这里的 trainer.global_forward_steps 的值比 trainer 上一次调用 batch_step_fn 的值大一; | |||||
if trainer.global_forward_batches % trainer.accumulation_steps == 0: | |||||
# 模型的参数应该与上一个 batch 不同; | |||||
cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach()) | |||||
if self.last_batch_params is not None: | |||||
assert not cur_batch_params.equal(self.last_batch_params) | |||||
if cur_batch_params.equal(self.last_batch_params): | |||||
self.equal += 1 | |||||
# optimizer 的梯度应该得到了清零; | |||||
optimizers = trainer.driver.optimizers | |||||
for optimizer in optimizers: | |||||
param_groups = optimizer.param_groups | |||||
for group in param_groups: | |||||
for p in group['params']: | |||||
assert p.grad is None or p.grad.equal(torch.zeros_like(p.grad)) | |||||
else: | |||||
# 模型的参数应该与上一个 batch 相同; | |||||
cur_batch_params = deepcopy(next(trainer.driver.unwrap_model().parameters()).cpu().detach()) | |||||
if self.last_batch_params is not None: | |||||
assert cur_batch_params.equal(self.last_batch_params) | |||||
# optimizer 的梯度不应该得到了清零; | |||||
optimizers = trainer.driver.optimizers | |||||
for optimizer in optimizers: | |||||
param_groups = optimizer.param_groups | |||||
for group in param_groups: | |||||
for p in group['params']: | |||||
assert p.grad is not None and not p.grad.equal(torch.zeros_like(p.grad)) | |||||
self.last_batch_params = cur_batch_params | |||||
def on_train_end(self, trainer): | |||||
print(f"\n equal num: {self.equal}.\n") | |||||
print(f"\ntotal_batch_num: {trainer.total_batches}.\n") |
@@ -0,0 +1,33 @@ | |||||
import time | |||||
from contextlib import contextmanager | |||||
@contextmanager | |||||
def check_time_elapse(seconds, op='lt'): | |||||
""" | |||||
检测某一段程序所花费的时间,是否 op 给定的seconds | |||||
:param int seconds: | |||||
:param str op: | |||||
:return: | |||||
""" | |||||
start = time.time() | |||||
yield | |||||
end = time.time() | |||||
if op == 'lt': | |||||
assert end-start < seconds | |||||
elif op == 'gt': | |||||
assert end-start > seconds | |||||
elif op == 'eq': | |||||
assert end - start == seconds | |||||
elif op == 'le': | |||||
assert end - start <= seconds | |||||
elif op == 'ge': | |||||
assert end - start >= seconds | |||||
else: | |||||
raise ValueError("Only supports lt,gt,eq,le,ge.") | |||||
@@ -0,0 +1,18 @@ | |||||
class NormalIterator: | |||||
def __init__(self, num_of_data=1000): | |||||
self._num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
self._index = 0 | |||||
def __iter__(self): | |||||
return self | |||||
def __next__(self): | |||||
if self._index >= self._num_of_data: | |||||
raise StopIteration | |||||
_data = self._data[self._index] | |||||
self._index += 1 | |||||
return self._data | |||||
def __len__(self): | |||||
return self._num_of_data |
@@ -0,0 +1,54 @@ | |||||
import paddle | |||||
from paddle.io import Dataset | |||||
import numpy as np | |||||
class PaddleNormalDataset(Dataset): | |||||
def __init__(self, num_of_data=1000): | |||||
self.num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
def __len__(self): | |||||
return self.num_of_data | |||||
def __getitem__(self, item): | |||||
return self._data[item] | |||||
class PaddleRandomDataset(Dataset): | |||||
def __init__(self, num_of_data=1000, features=64, labels=10): | |||||
self.num_of_data = num_of_data | |||||
self.x = [ | |||||
paddle.rand((features,)) | |||||
for i in range(num_of_data) | |||||
] | |||||
self.y = [ | |||||
paddle.rand((labels,)) | |||||
for i in range(num_of_data) | |||||
] | |||||
def __len__(self): | |||||
return self.num_of_data | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} | |||||
class PaddleDataset_MNIST(Dataset): | |||||
def __init__(self, mode="train"): | |||||
self.dataset = [ | |||||
( | |||||
np.array(img).astype('float32').reshape(-1), | |||||
label | |||||
) for img, label in paddle.vision.datasets.MNIST(mode=mode) | |||||
] | |||||
def __getitem__(self, idx): | |||||
return {"x": self.dataset[idx][0], "y": self.dataset[idx][1]} | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
@@ -0,0 +1,68 @@ | |||||
import torch | |||||
from functools import reduce | |||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler | |||||
from torch.utils.data.sampler import SequentialSampler, BatchSampler | |||||
class TorchNormalDataset(Dataset): | |||||
def __init__(self, num_of_data=1000): | |||||
self.num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
def __len__(self): | |||||
return self.num_of_data | |||||
def __getitem__(self, item): | |||||
return self._data[item] | |||||
# 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据; | |||||
class TorchNormalDataset_Classification(Dataset): | |||||
def __init__(self, num_labels, feature_dimension=2, each_label_data=1000, seed=0): | |||||
self.num_labels = num_labels | |||||
self.feature_dimension = feature_dimension | |||||
self.each_label_data = each_label_data | |||||
self.seed = seed | |||||
torch.manual_seed(seed) | |||||
self.x_center = torch.randint(low=-100, high=100, size=[num_labels, feature_dimension]) | |||||
random_shuffle = torch.randn([num_labels, each_label_data, feature_dimension]) / 10 | |||||
self.x = self.x_center.unsqueeze(1).expand(num_labels, each_label_data, feature_dimension) + random_shuffle | |||||
self.x = self.x.view(num_labels * each_label_data, feature_dimension) | |||||
self.y = reduce(lambda x, y: x+y, [[i] * each_label_data for i in range(num_labels)]) | |||||
def __len__(self): | |||||
return self.num_labels * self.each_label_data | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} | |||||
class TorchArgMaxDatset(Dataset): | |||||
def __init__(self, feature_dimension=10, data_num=1000, seed=0): | |||||
self.num_labels = feature_dimension | |||||
self.feature_dimension = feature_dimension | |||||
self.data_num = data_num | |||||
self.seed = seed | |||||
g = torch.Generator() | |||||
g.manual_seed(1000) | |||||
self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float() | |||||
self.y = torch.max(self.x, dim=-1)[1] | |||||
def __len__(self): | |||||
return self.data_num | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} | |||||
if __name__ == "__main__": | |||||
a = TorchNormalDataset_Classification(2, each_label_data=4) | |||||
print(a.x) | |||||
print(a.y) | |||||
print(a[0]) | |||||
@@ -0,0 +1,32 @@ | |||||
import paddle | |||||
import paddle.nn as nn | |||||
class PaddleNormalModel_Classification(paddle.nn.Layer): | |||||
""" | |||||
基础的paddle分类模型 | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(PaddleNormalModel_Classification, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def forward(self, x): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return x | |||||
def train_step(self, x, y): | |||||
x = self(x) | |||||
return {"loss": self.loss_fn(x, y)} | |||||
def validate_step(self, x, y): | |||||
x = self(x) | |||||
return {"pred": x, "target": y.reshape((-1,))} |
@@ -0,0 +1,65 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
# 1. 最为基础的分类模型 | |||||
class TorchNormalModel_Classification_1(nn.Module): | |||||
""" | |||||
单独实现 train_step 和 evaluate_step; | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(TorchNormalModel_Classification_1, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def forward(self, x): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return x | |||||
def train_step(self, x, y): | |||||
x = self(x) | |||||
return {"loss": self.loss_fn(x, y)} | |||||
def validate_step(self, x, y): | |||||
""" | |||||
如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"}; | |||||
""" | |||||
x = self(x) | |||||
x = torch.max(x, dim=-1)[1] | |||||
return {"preds": x, "target": y} | |||||
class TorchNormalModel_Classification_2(nn.Module): | |||||
""" | |||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(TorchNormalModel_Classification_2, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def forward(self, x, y): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
loss = self.loss_fn(x, y) | |||||
x = torch.max(x, dim=-1)[1] | |||||
return {"loss": loss, "preds": x, "target": y} | |||||
@@ -0,0 +1,123 @@ | |||||
import os | |||||
import sys | |||||
import __main__ | |||||
from functools import wraps | |||||
import inspect | |||||
from inspect import ismethod | |||||
import functools | |||||
from copy import deepcopy | |||||
from io import StringIO | |||||
import time | |||||
import numpy as np | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||||
from fastNLP.core.drivers.utils import distributed_open_proc | |||||
def get_class_that_defined_method(meth): | |||||
if isinstance(meth, functools.partial): | |||||
return get_class_that_defined_method(meth.func) | |||||
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)): | |||||
for cls in inspect.getmro(meth.__self__.__class__): | |||||
if meth.__name__ in cls.__dict__: | |||||
return cls | |||||
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing | |||||
if inspect.isfunction(meth): | |||||
cls = getattr(inspect.getmodule(meth), | |||||
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0], | |||||
None) | |||||
if isinstance(cls, type): | |||||
return cls | |||||
return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||||
def magic_argv_env_context(fn): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs): | |||||
command = deepcopy(sys.argv) | |||||
env = deepcopy(os.environ.copy()) | |||||
used_args = [] | |||||
for each_arg in sys.argv[1:]: | |||||
if "test" not in each_arg: | |||||
used_args.append(each_arg) | |||||
pytest_current_test = os.environ.get('PYTEST_CURRENT_TEST') | |||||
try: | |||||
l_index = pytest_current_test.index("[") | |||||
r_index = pytest_current_test.index("]") | |||||
subtest = pytest_current_test[l_index: r_index + 1] | |||||
except: | |||||
subtest = "" | |||||
if not ismethod(fn) and get_class_that_defined_method(fn) is None: | |||||
sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{fn.__name__}{subtest}"] + used_args | |||||
else: | |||||
sys.argv = [sys.argv[0], f"{os.path.abspath(sys.modules[fn.__module__].__file__)}::{get_class_that_defined_method(fn).__name__}::{fn.__name__}{subtest}"] + used_args | |||||
res = fn(*args, **kwargs) | |||||
sys.argv = deepcopy(command) | |||||
os.environ = env | |||||
return res | |||||
return wrapper | |||||
class Capturing(list): | |||||
# 用来捕获当前环境中的stdout和stderr,会将其中stderr的输出拼接在stdout的输出后面 | |||||
""" | |||||
使用例子 | |||||
with Capturing() as output: | |||||
do_something | |||||
assert 'xxx' in output[0] | |||||
""" | |||||
def __init__(self, no_del=False): | |||||
# 如果no_del为True,则不会删除_stringio,和_stringioerr | |||||
super().__init__() | |||||
self.no_del = no_del | |||||
def __enter__(self): | |||||
self._stdout = sys.stdout | |||||
self._stderr = sys.stderr | |||||
sys.stdout = self._stringio = StringIO() | |||||
sys.stderr = self._stringioerr = StringIO() | |||||
return self | |||||
def __exit__(self, *args): | |||||
self.append(self._stringio.getvalue() + self._stringioerr.getvalue()) | |||||
if not self.no_del: | |||||
del self._stringio, self._stringioerr # free up some memory | |||||
sys.stdout = self._stdout | |||||
sys.stderr = self._stderr | |||||
def re_run_current_cmd_for_torch(num_procs, output_from_new_proc='ignore'): | |||||
# Script called as `python a/b/c.py` | |||||
if int(os.environ.get('LOCAL_RANK', '0')) == 0: | |||||
if __main__.__spec__ is None: # pragma: no-cover | |||||
# pull out the commands used to run the script and resolve the abs file path | |||||
command = sys.argv | |||||
command[0] = os.path.abspath(command[0]) | |||||
# use the same python interpreter and actually running | |||||
command = [sys.executable] + command | |||||
# Script called as `python -m a.b.c` | |||||
else: | |||||
command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:] | |||||
for rank in range(1, num_procs+1): | |||||
env_copy = os.environ.copy() | |||||
env_copy["LOCAL_RANK"] = f"{rank}" | |||||
env_copy['WOLRD_SIZE'] = f'{num_procs+1}' | |||||
env_copy['RANK'] = f'{rank}' | |||||
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; | |||||
env_copy[FASTNLP_GLOBAL_RANK] = str(rank) | |||||
proc = distributed_open_proc(output_from_new_proc, command, env_copy, None) | |||||
delay = np.random.uniform(1, 5, 1)[0] | |||||
time.sleep(delay) |