@@ -46,9 +46,11 @@ __all__ = [ | |||||
'TorchDataLoader', | 'TorchDataLoader', | ||||
'PaddleDataLoader', | 'PaddleDataLoader', | ||||
'JittorDataLoader', | 'JittorDataLoader', | ||||
'OneflowDataLoader', | |||||
'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'prepare_torch_dataloader', | 'prepare_torch_dataloader', | ||||
'prepare_oneflow_dataloader', | |||||
"prepare_dataloader", | "prepare_dataloader", | ||||
# dataset | # dataset | ||||
@@ -63,6 +65,8 @@ __all__ = [ | |||||
"PaddleFleetDriver", | "PaddleFleetDriver", | ||||
"JittorSingleDriver", | "JittorSingleDriver", | ||||
"JittorMPIDriver", | "JittorMPIDriver", | ||||
"OneflowSingleDriver", | |||||
"OneflowDDPDriver", | |||||
# log | # log | ||||
"logger", | "logger", | ||||
@@ -18,7 +18,7 @@ from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, Mappi | |||||
NestedMappingPackerUnpacker | NestedMappingPackerUnpacker | ||||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | ||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'oneflow', 'numpy', 'raw', 'auto', None] | |||||
# 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行 | # 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行 | ||||
AUTO_BACKEND_MAPPING = {'jittor': 'numpy'} | AUTO_BACKEND_MAPPING = {'jittor': 'numpy'} | ||||
@@ -103,7 +103,7 @@ class Collator: | |||||
Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应 | Collator 在第一次进行 pad 的时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应 | ||||
的 Padder 给对应的 field 。 | 的 Padder 给对应的 field 。 | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','oneflow','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | ||||
的数据返回一定是 list 。 | 的数据返回一定是 list 。 | ||||
""" | """ | ||||
@@ -200,8 +200,8 @@ class Collator: | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | ||||
无意义。 | 无意义。 | ||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | ||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var oneflow.Tensor 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | ||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | ||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
@@ -275,7 +275,7 @@ class Collator: | |||||
""" | """ | ||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | 设置可以 pad 的 field 默认 pad 为什么类型的 tensor | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','oneflow','numpy','raw', 'auto', None], | |||||
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 | 若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -10,6 +10,7 @@ from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPad | |||||
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | ||||
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | ||||
from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder | from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder | ||||
from .oneflow_padder import OneflowTensorPadder, OneflowSequencePadder, OneflowNumberPadder | |||||
from .exceptions import * | from .exceptions import * | ||||
@@ -91,6 +92,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'oneflow': | |||||
return OneflowNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | ||||
@@ -105,6 +108,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'oneflow': | |||||
return OneflowSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | ||||
@@ -121,6 +126,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'oneflow': | |||||
return OneflowTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | ||||
@@ -18,9 +18,9 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | """ | ||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | ||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | """ | ||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
@@ -0,0 +1,204 @@ | |||||
__all__ = [ | |||||
'OneflowNumberPadder', | |||||
'OneflowSequencePadder', | |||||
'OneflowTensorPadder' | |||||
] | |||||
from inspect import isclass | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
numpy_to_oneflow_dtype_dict = { | |||||
np.bool_: oneflow.bool, | |||||
np.uint8: oneflow.uint8, | |||||
np.int8: oneflow.int8, | |||||
np.int32: oneflow.int32, | |||||
np.int64: oneflow.int64, | |||||
np.float16: oneflow.float16, | |||||
np.float32: oneflow.float32, | |||||
np.float64: oneflow.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||||
} | |||||
number_to_oneflow_dtype_dict = { | |||||
float: oneflow.float32, # 因为 oneflow.tensor([1], dtype=float)是oneflow.float64 | |||||
int: oneflow.int64, | |||||
bool: oneflow.bool | |||||
} | |||||
from .padder import Padder | |||||
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||||
from .exceptions import * | |||||
def is_oneflow_tensor(dtype): | |||||
""" | |||||
判断是否为 oneflow 的 tensor | |||||
:param dtype 数据的 dtype 类型 | |||||
""" | |||||
if not isclass(dtype) and isinstance(dtype, oneflow.dtype): | |||||
return True | |||||
return False | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | |||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | |||||
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_oneflow_tensor(ele_dtype))): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"or numpy numbers or oneflow.Tensor but get `{ele_dtype}`.") | |||||
if dtype is not None: | |||||
if not (is_oneflow_tensor(dtype) or is_number(dtype)): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or oneflow.dtype but get `{dtype}`.") | |||||
dtype = number_to_oneflow_dtype_dict.get(dtype, dtype) | |||||
else: | |||||
if ele_dtype is not None: | |||||
if (is_number(ele_dtype) or is_oneflow_tensor(ele_dtype)): | |||||
ele_dtype = number_to_oneflow_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_oneflow_dtype_dict.get(ele_dtype) | |||||
return dtype | |||||
class OneflowNumberPadder(Padder): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 oneflow.Tensor([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 oneflow.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 oneflow.long, oneflow.float32, int, float 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
return oneflow.tensor(batch_field, dtype=dtype) | |||||
class OneflowSequencePadder(Padder): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 oneflow.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 oneflow.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 oneflow.long, oneflow.float32, int, float 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
tensor = get_padded_oneflow_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||||
return tensor | |||||
class OneflowTensorPadder(Padder): | |||||
""" | |||||
目前支持 [oneflow.tensor([3, 2], oneflow.tensor([1])] 类似的。若内部元素不为 oneflow.tensor ,则必须含有 tolist() 方法。 | |||||
>>> OneflowTensorPadder.pad([np.array([3, 4]), np.array([1])], pad_val=-100) | |||||
[[ 3. 4.] | |||||
[ 1. -100.]] | |||||
>>> OneflowTensorPadder.pad([oneflow.LongTensor([3, 4]), oneflow.LongTensor([1])], pad_val=-100) | |||||
tensor([[ 3, 4], | |||||
[ 1, -100]]) | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 oneflow.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 oneflow.long, oneflow.float32, int, float 等 | |||||
""" | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val=0, dtype=None): | |||||
device = None | |||||
try: | |||||
if not isinstance(batch_field[0], oneflow.Tensor): | |||||
batch_field = [oneflow.tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||||
else: | |||||
batch_field = [field.to(dtype) for field in batch_field] | |||||
device = batch_field[0].device | |||||
if dtype is None: | |||||
dtype = batch_field[0].dtype | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a oneflow.Tensor (it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
shapes = [field.shape for field in batch_field] | |||||
if len(batch_field) < 2: | |||||
max_shape = [len(batch_field)] + list(shapes[0]) | |||||
else: | |||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||||
tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device) | |||||
for i, field in enumerate(batch_field): | |||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||||
tensor[slices] = field | |||||
return tensor | |||||
def fill_tensor(batch_field, padded_batch, dtype): | |||||
""" | |||||
将 batch_field 中的值填入到 tensor 中。 | |||||
:param batch_field: 需要填充进入 array 中的内容 | |||||
:param padded_batch: 待填充的 tensor | |||||
:param dtype: 数据的类别 | |||||
:return: | |||||
""" | |||||
if padded_batch.ndim == 2: | |||||
for i, content_i in enumerate(batch_field): | |||||
padded_batch[i, :len(content_i)] = oneflow.tensor(content_i, dtype=dtype) | |||||
elif padded_batch.ndim == 3: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
padded_batch[i, j, :len(content_ii)] = oneflow.tensor(content_ii, dtype=dtype) | |||||
elif padded_batch.ndim == 4: | |||||
try: # 应该是图像,所以直接应该就 ok 了。 | |||||
padded_batch = oneflow.tensor(batch_field) | |||||
except: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
for k, content_iii in enumerate(content_ii): | |||||
padded_batch[i, j, k, :len(content_iii)] = oneflow.tensor(content_iii, dtype=dtype) | |||||
elif padded_batch.ndim == 1: | |||||
padded_batch[:] = oneflow.tensor(batch_field, dtype=dtype) | |||||
else: | |||||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
"report.") | |||||
return padded_batch | |||||
def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0): | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> oneflow.LongTensor([[1, 2], [3, 0]]) | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
shapes = get_shape(batch_field) | |||||
tensor = oneflow.full(shapes, dtype=dtype, value=pad_val) | |||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||||
return tensor |
@@ -13,9 +13,9 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | """ | ||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | ||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | """ | ||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
@@ -38,7 +38,7 @@ def is_torch_tensor(dtype): | |||||
""" | """ | ||||
判断是否为 torch 的 tensor | 判断是否为 torch 的 tensor | ||||
:param dtype 数据的 dtype 类型 | |||||
:param dtype: 数据的 dtype 类型 | |||||
""" | """ | ||||
if not isclass(dtype) and isinstance(dtype, torch.dtype): | if not isclass(dtype) and isinstance(dtype, torch.dtype): | ||||
return True | return True | ||||
@@ -49,9 +49,9 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
""" | """ | ||||
用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | 用于检测数据的 dtype 类型, 根据内部和外部数据判断。 | ||||
:param ele_dtype 内部数据的类型 | |||||
:param dtype 数据外部类型 | |||||
:param class_name 类的名称 | |||||
:param ele_dtype: 内部数据的类型 | |||||
:param dtype: 数据外部类型 | |||||
:param class_name: 类的名称 | |||||
""" | """ | ||||
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | ||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
@@ -3,9 +3,11 @@ __all__ = [ | |||||
'TorchDataLoader', | 'TorchDataLoader', | ||||
'PaddleDataLoader', | 'PaddleDataLoader', | ||||
'JittorDataLoader', | 'JittorDataLoader', | ||||
'OneflowDataLoader', | |||||
'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
'prepare_torch_dataloader', | 'prepare_torch_dataloader', | ||||
'prepare_oneflow_dataloader', | |||||
"prepare_dataloader", | "prepare_dataloader", | ||||
@@ -15,5 +17,6 @@ __all__ = [ | |||||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | ||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | ||||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | ||||
from .oneflow_dataloader import OneflowDataLoader, prepare_oneflow_dataloader | |||||
from .prepare_dataloader import prepare_dataloader | from .prepare_dataloader import prepare_dataloader | ||||
from .utils import OverfitDataLoader | from .utils import OverfitDataLoader |
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
"OneflowDataLoader", | |||||
"prepare_oneflow_dataloader", | |||||
] | |||||
from .fdl import OneflowDataLoader, prepare_oneflow_dataloader |
@@ -0,0 +1,353 @@ | |||||
__all__ = [ | |||||
'OneflowDataLoader', | |||||
'prepare_oneflow_dataloader' | |||||
] | |||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List, Any | |||||
from abc import ABC | |||||
from copy import deepcopy | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||||
from ..utils import _match_param | |||||
from ..utils import HasLenGetitemType | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
from oneflow.utils.data import DataLoader, Sampler, Dataset | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||||
class _FDataSet: | |||||
""" | |||||
提供给 ``OneflowDataLoader`` 使用的 warp 类,其功能是对 dataset 进行封装,wrap 修改 dataset 的 __getitem__ 函数,增加返回 | |||||
数据的下标 idx 。 | |||||
..note:: | |||||
需要注意的是传入 ``__init__`` 的 dataset 需要实现 __getattribute__ 方法才能在 _FDataset 实例化对象中调用 dataset 的方法 | |||||
""" | |||||
def __init__(self, dataset) -> None: | |||||
self.dataset = dataset | |||||
def __getitem__(self, item: Union[int, list]) -> Tuple: | |||||
return (item, self.dataset[item]) | |||||
def __getattr__(self, item): | |||||
try: | |||||
return self.dataset.__getattribute__(item) | |||||
except AttributeError as e: | |||||
raise e | |||||
def __len__(self) -> int: | |||||
return len(self.dataset) | |||||
class OneflowDataLoader(DataLoader): | |||||
""" | |||||
提供给 ``oneflow`` 框架使用的 ``DataLoader`` 函数,``OneflowDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, | |||||
若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。 | |||||
具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]`` | |||||
三种取值。 | |||||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | |||||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* callate_fn 为 ``None`` 时, ``OneflowDataLoadr`` 默认使用 oneflow DataLoader 自带的 collate_fn | |||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
""" | |||||
def __init__(self, dataset, batch_size: int = 16, | |||||
shuffle: bool = False, sampler = None, batch_sampler = None, | |||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | |||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||||
persistent_workers: bool = False, **kwargs) -> None: | |||||
""" | |||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,``OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param timeout: 子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
:param multiprocessing_context: 多进程的上下文环境 | |||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个 ``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
:param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | |||||
if isinstance(dataset, DataSet) and collate_fn is None: | |||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||||
if not isinstance(dataset, _FDataSet): | |||||
dataset = _FDataSet(dataset) | |||||
if num_workers>0 and multiprocessing_context is None: | |||||
multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 | |||||
if batch_sampler is not None: | |||||
batch_size = 1 | |||||
shuffle = False | |||||
sampler = None | |||||
elif sampler is None: | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
shuffle = False | |||||
if isinstance(collate_fn, str): | |||||
if collate_fn == 'auto': | |||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||||
collate_fn = deepcopy(dataset.dataset.collator) | |||||
collate_fn.set_backend(backend="oneflow") | |||||
else: | |||||
collate_fn = Collator(backend="oneflow") | |||||
else: | |||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||||
dl_kwargs = _match_param(OneflowDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__) | |||||
if dl_kwargs is None: | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers) | |||||
else: | |||||
super().__init__(**dl_kwargs) | |||||
self.cur_batch_indices = None | |||||
def __iter__(self): | |||||
self.collate_fn = indice_collate_wrapper(self.collate_fn) | |||||
for indices, data in super().__iter__(): | |||||
self.cur_batch_indices = indices | |||||
yield data | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> Collator: | |||||
""" | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var, oneflow.Tensor 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator | |||||
""" | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return collator | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
def _get_collator(self): | |||||
""" | |||||
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||||
:return: | |||||
""" | |||||
collator = None | |||||
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||||
collator = self.collate_fn.__wrapped__ | |||||
elif isinstance(self.collate_fn, Collator): | |||||
collator = self.collate_fn | |||||
return collator | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Example:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
""" | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_ignore(*field_names) | |||||
return collator | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def get_batch_indices(self) -> List[int]: | |||||
""" | |||||
获取当前 ``batch`` 中每条数据对应的索引。 | |||||
:return: 当前 ``batch`` 数据的索引; | |||||
""" | |||||
return self.cur_batch_indices | |||||
def prepare_oneflow_dataloader(ds_or_db, | |||||
batch_size: int = 16, | |||||
shuffle: bool = None, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | |||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||||
persistent_workers: bool = False, | |||||
non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
non_train_batch_size: int = None) \ | |||||
-> Union[OneflowDataLoader, Dict[str, OneflowDataLoader]]: | |||||
""" | |||||
``prepare_oneflow_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``OneflowDataloader``对象, 详见 :class:`~fastNLP.OneflowDataLoader`。 | |||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle, Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_oneflow_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||||
帮你实例化一个 ``OneflowDataLoader`` 对象并返回该对象。 详见:class:`~fastNLP.core.dataloaders.OneflowDataLoader`。 | |||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_oneflow_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||||
来创建不同的 ``OneflowDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_oneflow_dataloader`` 默认该 value 为 train 数据集, | |||||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
最终根据 ``key: OneflowDataLoader`` 组成 ``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_oneflow_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | |||||
``OneflowDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_oneflow_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: OneflowDataLoader`` 组成 | |||||
``Dict[key, OneflowDataLoader]`` 的字典返回。 | |||||
:param ds_or_db: 可以有以下三种取值, | |||||
* ds_or_db 为 :class:`~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典 | |||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值为 ``Dict[str, OneflowDataLoader]`` 的字典 | |||||
* ds_or_db 为实现了 __getitem__() 和 __len__() 的对象 ,返回值为:class:`~fastNLP.OneflowDataLoader` | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param non_train_batch_size: 非 'train' 数据集的 ``OneflowDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
:param shuffle: 是否打乱数据集, 默认为 ``None``, 如果传入的 ``ds_or_db`` 可以判断出哪个是 'train' 则设置其 shuffle 为 True , | |||||
其它的为 False 。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``OneflowDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
* callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
``OneflowDataLoader`` 调用默认的 oneflow 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
* callate_fn 为 ``'auto'`` 时,`OneflowDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
此时可以配套使用 ``OneflowDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``OneflowDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``OneflowDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param timeout: 子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
:param multiprocessing_context: 多进程的上下文环境 | |||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed`` | |||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
:param persistent_workers: 如果其为 ``True``, ``OneflowDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
""" | |||||
from fastNLP.io import DataBundle | |||||
if isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | |||||
for name, ds in ds_or_db.iter_datasets(): | |||||
if 'train' in name: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
else: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, Mapping): | |||||
dl_bundle = {} | |||||
for name, ds in ds_or_db.items(): | |||||
if 'train' in name: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=True if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
else: | |||||
dl_bundle[name] = OneflowDataLoader(dataset=ds, | |||||
batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=False if shuffle is None else shuffle, | |||||
sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers, | |||||
) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, HasLenGetitemType): | |||||
dl = OneflowDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||||
shuffle=False if shuffle is None else shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
) | |||||
return dl | |||||
else: | |||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or mapping!") |
@@ -9,6 +9,7 @@ import sys | |||||
from .torch_dataloader import prepare_torch_dataloader | from .torch_dataloader import prepare_torch_dataloader | ||||
from .paddle_dataloader import prepare_paddle_dataloader | from .paddle_dataloader import prepare_paddle_dataloader | ||||
from .jittor_dataloader import prepare_jittor_dataloader | from .jittor_dataloader import prepare_jittor_dataloader | ||||
from .oneflow_dataloader import prepare_oneflow_dataloader | |||||
from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | ||||
from ..log import logger | from ..log import logger | ||||
@@ -37,7 +38,7 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop | |||||
* 为 ``Callable`` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 | * 为 ``Callable`` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 | ||||
* 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | * 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | ||||
:param num_workers: 使用多少进程进行数据的 fetch 。 | :param num_workers: 使用多少进程进行数据的 fetch 。 | ||||
:param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 | |||||
:param backend: 当前支持 ``["auto", "torch", "paddle", "jittor", "oneflow"]`` 四种类型。 | |||||
* 为 ``auto`` 时,首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 | * 为 ``auto`` 时,首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 | ||||
``sys.modules`` 中已经 import 的 ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 | ``sys.modules`` 中已经 import 的 ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 | ||||
@@ -45,6 +46,7 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop | |||||
* 为 ``torch`` 时,使用 :func:`~fastNLP.prepare_torch_dataloader` 。 | * 为 ``torch`` 时,使用 :func:`~fastNLP.prepare_torch_dataloader` 。 | ||||
* 为 ``paddle`` 时,使用 :func:`~fastNLP.prepare_paddle_dataloader` 。 | * 为 ``paddle`` 时,使用 :func:`~fastNLP.prepare_paddle_dataloader` 。 | ||||
* 为 ``jittor`` 时,使用 :func:`~fastNLP.prepare_jittor_dataloader` 。 | * 为 ``jittor`` 时,使用 :func:`~fastNLP.prepare_jittor_dataloader` 。 | ||||
* 为 ``oneflow`` 时,使用 :func:`~fastNLP.prepare_oneflow_dataloader` 。 | |||||
:return | :return | ||||
""" | """ | ||||
@@ -61,6 +63,10 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = None, drop | |||||
prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, | prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, | ||||
num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, | num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, | ||||
drop_last=drop_last) | drop_last=drop_last) | ||||
elif backend == 'oneflow': | |||||
return prepare_oneflow_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, | |||||
num_workers=num_workers, shuffle=shuffle, sampler=None, | |||||
batch_size=batch_size) | |||||
else: | else: | ||||
raise ValueError(f"Currently we do not support backend:{backend}.") | raise ValueError(f"Currently we do not support backend:{backend}.") | ||||
@@ -1,22 +1,27 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Driver', | 'Driver', | ||||
'TorchDriver', | 'TorchDriver', | ||||
"TorchSingleDriver", | |||||
"TorchDDPDriver", | |||||
"PaddleDriver", | |||||
"PaddleSingleDriver", | |||||
"PaddleFleetDriver", | |||||
"JittorDriver", | |||||
"JittorSingleDriver", | |||||
"JittorMPIDriver", | |||||
'TorchSingleDriver', | |||||
'TorchDDPDriver', | |||||
'PaddleDriver', | |||||
'PaddleSingleDriver', | |||||
'PaddleFleetDriver', | |||||
'JittorDriver', | |||||
'JittorSingleDriver', | |||||
'JittorMPIDriver', | |||||
'OneflowDriver', | |||||
'OneflowSingleDriver', | |||||
'OneflowDDPDriver', | |||||
'torch_seed_everything', | 'torch_seed_everything', | ||||
'paddle_seed_everything', | 'paddle_seed_everything', | ||||
'oneflow_seed_everything', | |||||
'optimizer_state_to_device' | 'optimizer_state_to_device' | ||||
] | ] | ||||
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device | from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device | ||||
from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver | from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver | ||||
from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything | from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything | ||||
from .oneflow_driver import OneflowDriver, OneflowSingleDriver, OneflowDDPDriver, oneflow_seed_everything | |||||
from .driver import Driver | from .driver import Driver | ||||
@@ -1,7 +1,7 @@ | |||||
from typing import Union, Optional, List | from typing import Union, Optional, List | ||||
from .driver import Driver | from .driver import Driver | ||||
from ..utils import is_torch_module, is_paddle_module, is_jittor_module | |||||
from ..utils import is_torch_module, is_paddle_module, is_jittor_module, is_oneflow_module | |||||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | ||||
@@ -25,6 +25,8 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
driver = "paddle" | driver = "paddle" | ||||
elif is_jittor_module(model): | elif is_jittor_module(model): | ||||
driver = "jittor" | driver = "jittor" | ||||
elif is_oneflow_module(model): | |||||
driver = "oneflow" | |||||
else: | else: | ||||
raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") | raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") | ||||
@@ -37,6 +39,9 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
elif driver in {"paddle"}: | elif driver in {"paddle"}: | ||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | ||||
return initialize_paddle_driver(driver, device, model, **kwargs) | return initialize_paddle_driver(driver, device, model, **kwargs) | ||||
elif driver in {"oneflow"}: | |||||
from fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver import initialize_oneflow_driver | |||||
return initialize_oneflow_driver(driver, device, model, **kwargs) | |||||
else: | else: | ||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', " | raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', " | ||||
"'jittor', 'paddle'].") | |||||
"'jittor', 'paddle', 'oneflow'].") |
@@ -0,0 +1,18 @@ | |||||
__all__ = [ | |||||
"OneflowDDPDriver", | |||||
"OneflowSingleDriver", | |||||
"OneflowDriver", | |||||
"oneflow_seed_everything", | |||||
"optimizer_state_to_device" | |||||
] | |||||
from .ddp import OneflowDDPDriver | |||||
from .single_device import OneflowSingleDriver | |||||
from .oneflow_driver import OneflowDriver | |||||
from .utils import oneflow_seed_everything, optimizer_state_to_device | |||||
@@ -0,0 +1,323 @@ | |||||
import os | |||||
from typing import List, Optional, Union, Dict | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
import oneflow.env as dist_env | |||||
from oneflow.nn.parallel import DistributedDataParallel | |||||
from oneflow.utils.data import BatchSampler | |||||
__all__ = [ | |||||
"OneflowDDPDriver" | |||||
] | |||||
from .oneflow_driver import OneflowDriver | |||||
from fastNLP.core.drivers.oneflow_driver.utils import ( | |||||
replace_sampler, | |||||
replace_batch_sampler | |||||
) | |||||
from fastNLP.core.utils import check_user_specific_params | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, \ | |||||
ReproducibleBatchSampler, \ | |||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from fastNLP.envs import FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather, fastnlp_oneflow_broadcast_object | |||||
from .utils import _check_dataloader_args_for_distributed | |||||
class OneflowDDPDriver(OneflowDriver): | |||||
r""" | |||||
``OneflowDDPDriver`` 实现了动态图下使用 ``DistributedDataParallel`` 进行的数据并行分布式训练。 | |||||
.. note:: | |||||
您在绝大多数情况下不需要自己使用到该类,通过向 ``Trainer`` 传入正确的参数,您可以方便快速地部署您的分布式训练; | |||||
``OneflowDDPDriver`` 目前支持两种启动方式: | |||||
1. 用户不做任何处理,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||||
2. 用户将模型通过 ``DistributedDataParallel`` 处理后,通过运行 ``python -m oneflow.distributed.launch --nproc_per_node 2 train.py`` 启动; | |||||
注意多机的启动强制要求用户在每一台机器上使用 ``python -m oneflow.distributed.launch`` 启动;因此我们不会在 ``OneflowDDPDriver`` 中保存 | |||||
任何当前有多少台机器的信息; | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param parallel_device: 该参数无效,**FastNLP** 会自动获取当前进程的设备; | |||||
:param fp16: 是否开启 fp16 训练;目前该参数无效; | |||||
:param oneflow_kwargs: | |||||
* *ddp_kwargs* -- 用于 ``DistributedDataParallel`` 的其它参数,详情可查阅 **oneflow** 的官方文档; | |||||
""" | |||||
def __init__( | |||||
self, | |||||
model, | |||||
parallel_device: Optional["oneflow.device"], | |||||
fp16: bool = False, | |||||
oneflow_kwargs: Dict = {}, | |||||
**kwargs | |||||
): | |||||
super(OneflowDDPDriver, self).__init__(model, fp16=fp16, oneflow_kwargs=oneflow_kwargs, **kwargs) | |||||
# oneflow 会自己初始化通信组,因此 parallel_device 实际上不起作用,可以通过 current_device 获取设备 | |||||
self.model_device = oneflow.device("cuda", oneflow.cuda.current_device()) | |||||
self._data_device = self.model_device | |||||
self.global_rank = int(os.environ["RANK"]) | |||||
self.world_size = int(os.environ["WORLD_SIZE"]) | |||||
self._ddp_kwargs = self._oneflow_kwargs.get("ddp_kwargs", {}) | |||||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) | |||||
if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | |||||
logger.info("Notice your model has buffers and you are using `OneflowDDPDriver`, but you do not set " | |||||
"'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" | |||||
" to 'False' to avoid redundant data communication between different processes.") | |||||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | |||||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | |||||
if self.output_from_new_proc not in {"all", "ignore", "only_error"}: | |||||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
self._has_ddpwrapped = False# hasattr(model, ) | |||||
def setup(self): | |||||
r""" | |||||
将模型用 ``DistributedDataParallel`` 进行处理; | |||||
""" | |||||
if self._has_setup: | |||||
return | |||||
self._has_setup = True | |||||
self.configure_ddp() | |||||
self.barrier() | |||||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | |||||
# self._pids = [oneflow.tensor(0, dtype=oneflow.int).to(self.data_device) for _ in range(dist_env.get_world_size())] | |||||
# comm.all_gather(self._pids, oneflow.tensor(os.getpid(), dtype=oneflow.int).to(self.data_device)) | |||||
# local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None | |||||
# if local_world_size is None: | |||||
# local_world_size = oneflow.tensor(int(os.environ.get("LOCAL_RANK")), dtype=oneflow.int).to(self.data_device) | |||||
# comm.all_reduce(local_world_size, op=dist_env.ReduceOp.MAX) | |||||
# local_world_size = local_world_size.tolist() + 1 | |||||
# node_rank = self.global_rank // local_world_size | |||||
# self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] | |||||
# self._pids = self.tensor_to_numeric(self._pids) | |||||
def configure_ddp(self): | |||||
if not hasattr(self.model, "_ddp_state_for_reversed_params"): | |||||
self.model.to(self.model_device) | |||||
self.model = DistributedDataParallel( | |||||
# 注意这里的 self.model_device 是 `oneflow.device` type,因此 self.model_device.index; | |||||
self.model, | |||||
**self._ddp_kwargs | |||||
) | |||||
self._has_ddpwrapped = True | |||||
@property | |||||
def master_address(self) -> str: | |||||
return os.environ.get("MASTER_ADDR") | |||||
@property | |||||
def master_port(self) -> str: | |||||
return os.environ.get("MASTER_PORT") | |||||
@property | |||||
def world_size(self) -> int: | |||||
return self._world_size | |||||
@world_size.setter | |||||
def world_size(self, size: int): | |||||
self._world_size = size | |||||
@property | |||||
def global_rank(self) -> int: | |||||
return self._global_rank | |||||
@global_rank.setter | |||||
def global_rank(self, rank: int) -> None: | |||||
self._global_rank = rank | |||||
@property | |||||
def local_rank(self) -> int: # 这个不会受到 all_rank_call_context 的影响 | |||||
return int(os.environ.get("LOCAL_RANK", 0)) | |||||
@property | |||||
def data_device(self): | |||||
return self._data_device | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]] = None, | |||||
reproducible: bool = False): | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 OneflowDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
# trainer, evaluator | |||||
if dist is None: | |||||
if reproducible: | |||||
raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") | |||||
else: | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler)) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler)) | |||||
return dataloader | |||||
# trainer | |||||
elif dist == "dist": | |||||
args = self.get_dataloader_args(dataloader) | |||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
batch_sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
_check_dataloader_args_for_distributed(args, controller="Trainer") | |||||
sampler = RandomSampler( | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||||
) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, sampler) | |||||
# evaluator | |||||
elif dist == "unrepeatdist": | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||||
_check_dataloader_args_for_distributed(args, controller="Evaluator") | |||||
sampler = UnrepeatedSequentialSampler( | |||||
dataset=args.dataset | |||||
) | |||||
else: | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank | |||||
) | |||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | |||||
raise ValueError( | |||||
"Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||||
def is_global_zero(self): | |||||
r""" | |||||
:return: 返回当前的进程是否在全局上是进程 0 ; | |||||
""" | |||||
return self.global_rank == 0 | |||||
def get_model_no_sync_context(self): | |||||
r""" | |||||
:return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;该功能暂时无效,返回一个空的上下文环境; | |||||
""" | |||||
# TODO 暂时没有在 oneflow 中找到类似的功能; | |||||
from fastNLP.core.utils import nullcontext | |||||
return nullcontext | |||||
return self.model.no_sync | |||||
def unwrap_model(self): | |||||
r""" | |||||
:return: 返回原始模型; | |||||
""" | |||||
return self.model | |||||
def get_local_rank(self) -> int: | |||||
r""" | |||||
:return: 返回当前进程局部的进程编号; | |||||
""" | |||||
return self.local_rank | |||||
def barrier(self): | |||||
r""" | |||||
通过使用该函数来使得各个进程之间同步操作; | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | |||||
comm.barrier() | |||||
def is_distributed(self): | |||||
r""" | |||||
:return: 返回当前使用的 driver 是否是分布式的 driver,对于 ``OneflowDDPDriver`` 来说,该函数一定返回 ``True``; | |||||
""" | |||||
return True | |||||
def broadcast_object(self, obj, src: int = 0, **kwargs): | |||||
r""" | |||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||||
:param int src: source 的 global rank 。 | |||||
:param int dst: target 的 global rank,可以是多个目标 rank | |||||
:param group: 所属的 group | |||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | |||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 | |||||
return | |||||
return fastnlp_oneflow_broadcast_object(obj, src, device=self.data_device) | |||||
def all_gather(self, obj) -> List: | |||||
r""" | |||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||||
pickle 进行序列化,接收到之后再反序列化。 | |||||
example:: | |||||
obj = { | |||||
'a': [1, 1], | |||||
'b': [[1, 2], [1, 2]], | |||||
'c': { | |||||
'd': [1, 2] | |||||
} | |||||
} | |||||
-> | |||||
[ | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
] | |||||
:param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 | |||||
:param group: | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 | |||||
return [obj] | |||||
return fastnlp_oneflow_all_gather(obj) |
@@ -0,0 +1,306 @@ | |||||
import io | |||||
import pickle | |||||
import os | |||||
from typing import Any, List | |||||
from fastNLP.core.utils import apply_to_collection, get_oneflow_device | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.envs.env import FASTNLP_NO_SYNC | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
import oneflow.env as dist_env | |||||
PROTOCOL_VERSION = 1 | |||||
def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
if dst == my_rank: | |||||
if not gather_list: | |||||
raise ValueError( | |||||
"Argument ``gather_list`` must be specified on destination rank." | |||||
) | |||||
elif gather_list: | |||||
raise ValueError( | |||||
"Argument ``gather_list`` must NOT be specified " | |||||
"on non-destination ranks." | |||||
) | |||||
obj = {"protocol_version": PROTOCOL_VERSION, "data": obj} | |||||
pickled_bytes = pickle.dumps(obj) | |||||
def fastnlp_oneflow_gather_object(obj, dst=0): | |||||
""" | |||||
从其它 rank gather 东西到 dst rank 。 | |||||
Example:: | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> fastnlp_oneflow_gather_object( | |||||
gather_objects[dist.get_rank()], | |||||
output if dist.get_rank() == 0 else None, | |||||
dst=0 | |||||
) | |||||
>>> # On rank 0 | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象 | |||||
:param dst: 目标的 rank 。 | |||||
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: | |||||
return [obj] | |||||
if dist_env.get_rank() == dst: | |||||
object_gather_list = [None for _ in range(dist_env.get_world_size())] | |||||
else: | |||||
object_gather_list = None | |||||
# Ensure object_gather_list is specified appopriately. | |||||
my_rank = dist_env.get_rank() | |||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list) | |||||
# 防止 unpickle 的时候出现在了发送的 gpu 上。 | |||||
obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) | |||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
current_device = oneflow.device("cuda") | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist_env.get_world_size() | |||||
object_sizes_tensor = oneflow.zeros(group_size, dtype=oneflow.long, device=current_device) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes. An all-gather is needed here despite this being a | |||||
# gather, since each rank needs to broadcast a tensor of the same (maximal) | |||||
# size. | |||||
comm.all_gather(object_size_list, local_size) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor = input_tensor.reshape(max_object_size) | |||||
# Avoid populating output tensors if the result won't be gathered on this rank. | |||||
if my_rank == dst: | |||||
coalesced_output_tensor = oneflow.empty( | |||||
max_object_size * group_size, dtype=oneflow.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
# All ranks call gather with equal-sized tensors. | |||||
comm.gather( | |||||
input_tensor, | |||||
gather_list=output_tensors if my_rank == dst else None, | |||||
dst=dst, | |||||
) | |||||
if my_rank != dst: | |||||
return | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(oneflow.uint8) # type: ignore[call-overload] | |||||
tensor_size = object_size_list[i] | |||||
object_gather_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
def _object_to_tensor(obj, device=None): | |||||
f = io.BytesIO() | |||||
obj = {"protocol_version": PROTOCOL_VERSION, "data": obj} | |||||
pickled_bytes = pickle.dumps(obj) | |||||
byte_tensor = oneflow.ByteTensor(list(pickled_bytes)) | |||||
local_size = oneflow.LongTensor([byte_tensor.numel()]) | |||||
if device is not None: | |||||
byte_tensor = byte_tensor.to(device) | |||||
local_size = local_size.to(device) | |||||
return byte_tensor, local_size | |||||
def _tensor_to_object(tensor, tensor_size): | |||||
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] | |||||
res = pickle.loads(buf) | |||||
assert res["protocol_version"] == PROTOCOL_VERSION | |||||
return res["data"] | |||||
def send_recv_object(obj, src, cur_rank, device): | |||||
r""" | |||||
oneflow 中的单点对多点的分发函数; | |||||
例如将进程 0 上的对象 object 分发到其它进程上; | |||||
Example:: | |||||
cur_rank = int(os.environ.get('LOCAL_RANK', 0)) | |||||
# 拿到 local_device | |||||
send_recv_object(object, 0, cur_rank, local_device) | |||||
:param obj: 一个可以序列化的 python 对象; | |||||
:param src: 从哪一个 rank 上发送到其它 rank; | |||||
:param cur_rank: 当前的进程的 rank 序号; | |||||
:param device: 当前的进程所在的设备; | |||||
:param group: 通信组,默认为 None; | |||||
:param tag: 将发送与远程接收匹配的标记; | |||||
:return: | |||||
""" | |||||
# src rank send to all other ranks | |||||
size = oneflow.LongTensor([0]).to(device) | |||||
if cur_rank == src: | |||||
world_size = dist_env.get_world_size() | |||||
tensor, size = _object_to_tensor(obj) | |||||
tensor = tensor.to(device) | |||||
size = size.to(device) | |||||
# 首先同步 obj 的 size 的信息; | |||||
comm.broadcast(size, src) | |||||
for subrank in range(world_size): | |||||
if subrank != src: | |||||
comm.send(tensor=tensor, dst=subrank) | |||||
else: | |||||
comm.broadcast(size, src) | |||||
tensor = oneflow.ByteTensor([0] * size).to(device) | |||||
comm.recv(tensor=tensor, src=src) | |||||
return _tensor_to_object(tensor.cpu(), size) | |||||
def _to_device(tensor, device): | |||||
return tensor.contiguous().to(device) | |||||
def fastnlp_oneflow_all_gather(obj: Any, device=None) ->List: | |||||
""" | |||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||||
example:: | |||||
obj = { | |||||
'a': [1, 1], | |||||
'b': [[1, 2], [1, 2]], | |||||
'c': { | |||||
'd': [1, 2] | |||||
} | |||||
} | |||||
-> | |||||
[ | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||||
] | |||||
:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 | |||||
序列化之后进行传输。 | |||||
:param device: 当前该参数无意义。 | |||||
:param group: | |||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: | |||||
return [obj] | |||||
if isinstance(obj, oneflow.Tensor): | |||||
objs = [oneflow.zeros_like(obj) for _ in range(dist_env.get_world_size())] | |||||
comm.all_gather(objs, obj) | |||||
else: | |||||
objs = [None for _ in range(dist_env.get_world_size())] | |||||
# 防止 unpickle 的时候弄到发送的 gpu 上了 | |||||
obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) | |||||
all_gather_object(objs, obj) | |||||
return objs | |||||
def fastnlp_oneflow_broadcast_object(obj, src, device=None): | |||||
""" | |||||
将 src 上的 obj 对象广播到其它 rank 上。 | |||||
:param obj: 需要发送的对象 | |||||
:param src: 从哪里发出。 | |||||
:param device: | |||||
:param group: 属于哪个通信 group | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: | |||||
if src == dist_env.get_rank(): | |||||
return obj | |||||
else: | |||||
return None | |||||
cur_rank = dist_env.get_rank() | |||||
if cur_rank == src: | |||||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||||
obj = apply_to_collection(obj, oneflow.Tensor, _to_device, device=oneflow.device("cpu")) | |||||
if device is None: | |||||
device = oneflow.cuda.current_device() | |||||
device = get_oneflow_device(device) | |||||
if cur_rank == src: | |||||
tensor, size = _object_to_tensor(obj, device=device) | |||||
else: | |||||
size = oneflow.LongTensor([0]).to(device) | |||||
comm.broadcast(size, src=src) | |||||
if cur_rank != src: | |||||
tensor = oneflow.empty( | |||||
size.int().item(), # type: ignore[arg-type] | |||||
dtype=oneflow.uint8, | |||||
device=device | |||||
) | |||||
comm.broadcast(tensor, src=src) | |||||
return _tensor_to_object(tensor, tensor_size=size.item()) | |||||
def all_gather_object(object_list, obj): | |||||
""" | |||||
Example:: | |||||
>>> # Note: Process group initialization omitted on each rank. | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
:param object_list: | |||||
:param obj: | |||||
:param group: | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, "0")) == 2: | |||||
return [obj] | |||||
current_device = get_oneflow_device(oneflow.cuda.current_device()) | |||||
input_tensor, local_size = _object_to_tensor(obj, device=current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist_env.get_world_size() | |||||
object_sizes_tensor = oneflow.zeros( | |||||
group_size, dtype=oneflow.long, device=current_device | |||||
) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes | |||||
comm.all_gather(object_size_list, local_size) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor = input_tensor.reshape(max_object_size) | |||||
coalesced_output_tensor = oneflow.empty( | |||||
max_object_size * group_size, dtype=oneflow.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
comm.all_gather(output_tensors, input_tensor) | |||||
# Deserialize outputs back to object. | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(oneflow.uint8) | |||||
if tensor.device != oneflow.device("cpu"): | |||||
tensor = tensor.cpu() | |||||
tensor_size = object_size_list[i] | |||||
object_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
return object_list |
@@ -0,0 +1,70 @@ | |||||
import os | |||||
from typing import Optional, Union, List, Sequence | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from .oneflow_driver import OneflowDriver | |||||
from .single_device import OneflowSingleDriver | |||||
from .ddp import OneflowDDPDriver | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | |||||
__all__ = [] | |||||
def initialize_oneflow_driver(driver: str, device: Optional[Union[str, "oneflow.device", int, List[int]]], | |||||
model: "oneflow.nn.Module", **kwargs) -> OneflowDriver: | |||||
r""" | |||||
用来根据参数 ``driver` 和 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去; | |||||
:param driver: 该参数的值应为以下之一:``["oneflow"]``; | |||||
:param device: 该参数的格式与 ``Trainer`` 对参数 ``device`` 的要求一致; | |||||
:param model: 训练或者评测的具体的模型; | |||||
:return: 返回一个 :class:`~fastNLP.core.OneflowSingleDriver` 或 :class:`~fastNLP.core.OneflowDDPDriver` 实例; | |||||
""" | |||||
# world_size 和 rank | |||||
if FASTNLP_BACKEND_LAUNCH in os.environ: | |||||
if device is not None: | |||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `oneflow.distributed.launch` to pull " | |||||
"up your script. ", once=True) | |||||
return OneflowDDPDriver(model, None, **kwargs) | |||||
if driver not in {"oneflow"}: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['oneflow'].") | |||||
_could_use_device_num = oneflow.cuda.device_count() | |||||
if isinstance(device, str): | |||||
device = oneflow.device(device) | |||||
elif isinstance(device, int): | |||||
if device < 0: | |||||
if device != -1: | |||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | |||||
device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)] | |||||
elif device >= _could_use_device_num: | |||||
print(device, _could_use_device_num) | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
else: | |||||
device = oneflow.device(f"cuda:{device}") | |||||
elif isinstance(device, Sequence): | |||||
device = list(set(device)) | |||||
for each in device: | |||||
if not isinstance(each, int): | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | |||||
elif each < 0: | |||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | |||||
elif each >= _could_use_device_num: | |||||
raise ValueError(f"When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | |||||
f" the available gpu number:{_could_use_device_num}.") | |||||
device = [oneflow.device(f"cuda:{w}") for w in device] | |||||
elif device is not None and not isinstance(device, oneflow.device): | |||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||||
if driver == "oneflow": # single, ddp, 直接启动。 | |||||
if not isinstance(device, List): | |||||
return OneflowSingleDriver(model, device, **kwargs) | |||||
else: | |||||
raise RuntimeError("If you want to run distributed training, please use " | |||||
"'python -m oneflow.distributed.launch xxx.py'.") | |||||
return OneflowDDPDriver(model, device, **kwargs) |
@@ -0,0 +1,445 @@ | |||||
import os | |||||
from typing import Union, Dict, Optional, Callable, Tuple | |||||
from functools import partial | |||||
import numpy as np | |||||
import random | |||||
from dataclasses import dataclass | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from pathlib import Path | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.utils.data import DataLoader, Sampler, BatchSampler, Dataset | |||||
from oneflow.optim import Optimizer | |||||
from oneflow.utils.data import RandomSampler as OneflowRandomSampler | |||||
_reduces = { | |||||
"sum": oneflow.sum, | |||||
"min": oneflow.min, | |||||
"max": oneflow.max, | |||||
"mean": oneflow.mean | |||||
} | |||||
__all__ = [ | |||||
"OneflowDriver" | |||||
] | |||||
from .utils import optimizer_state_to_device, DummyGradScaler | |||||
from fastNLP.core.drivers.driver import Driver | |||||
from fastNLP.core.utils.utils import _get_fun_msg, nullcontext | |||||
from fastNLP.core.utils import apply_to_collection, oneflow_move_data_to_device, auto_param_call | |||||
from fastNLP.envs import rank_zero_call | |||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | |||||
from fastNLP.core.dataloaders import OverfitDataLoader | |||||
class OneflowDriver(Driver): | |||||
r""" | |||||
专属于 ``oneflow`` 的 ``driver``,是 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver`` 的父类; | |||||
.. warning:: | |||||
您不应当直接初始化该类,然后传入给 ``Trainer``,换句话说,您应当使用该类的子类 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver``,而不是 | |||||
该类本身; | |||||
.. note:: | |||||
您可以在使用 ``OneflowSingleDriver`` 和 ``OneflowDDPDriver`` 时使用 ``OneflowDriver`` 提供的接口; | |||||
""" | |||||
def __init__(self, model, fp16: Optional[bool] = False, oneflow_kwargs: Dict = {}, **kwargs): | |||||
super(OneflowDriver, self).__init__(model) | |||||
""" 进行 fp16 的设置 """ | |||||
self._oneflow_kwargs = oneflow_kwargs | |||||
self.fp16 = fp16 | |||||
if fp16: | |||||
logger.warn("OneflowDriver of eager mode dose not support fp16 now.``") | |||||
# self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not self.fp16) | |||||
# self.grad_scaler = _grad_scaler(**self._oneflow_kwargs.get("gradscaler_kwargs", {})) | |||||
self.auto_cast = nullcontext | |||||
self.grad_scaler = DummyGradScaler() | |||||
self.set_grad_to_none = self._oneflow_kwargs.get("set_grad_to_none") | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
def zero_grad(self): | |||||
for optimizer in self.optimizers: | |||||
optimizer.zero_grad(self.set_grad_to_none) | |||||
def backward(self, loss): | |||||
loss.backward() | |||||
# self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def check_dataloader_legality(self, dataloader): | |||||
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader): | |||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||||
if len(dataloader) == 0: | |||||
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it " | |||||
"may cause some unexpected exceptions.", once=True) | |||||
@staticmethod | |||||
def _check_optimizer_legality(optimizers): | |||||
for each_optimizer in optimizers: | |||||
if not isinstance(each_optimizer, Optimizer): | |||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " | |||||
f"not {type(each_optimizer)}.") | |||||
@staticmethod | |||||
def tensor_to_numeric(tensor, reduce: str = None): | |||||
r""" | |||||
将 ``oneflow.Tensor`` 转换成 python 中的数值类型; | |||||
:param tensor: ``oneflow.Tensor``; | |||||
:param reduce: 当 tensor 是一个多数值的张量时,应当使用何种归一化操作来转换成单一数值,应当为以下类型之一:``['max', 'min', 'sum', 'mean']``; | |||||
:return: 返回一个单一数值,其数值类型是 python 中的基本的数值类型,例如 ``int,float`` 等; | |||||
""" | |||||
if tensor is None: | |||||
return None | |||||
def _translate(_data): | |||||
if _data.numel() == 1: | |||||
return _data.item() | |||||
if reduce is None: | |||||
return _data.tolist() | |||||
return _reduces[reduce](_data).item() | |||||
return apply_to_collection( | |||||
data=tensor, | |||||
dtype=oneflow.Tensor, | |||||
function=_translate | |||||
) | |||||
def set_model_mode(self, mode: str): | |||||
r""" | |||||
设置模型的状态是 ``train`` 还是 ``eval``; | |||||
:param mode: ``'train'`` 或 ``'eval'``; | |||||
""" | |||||
assert mode in {"train", "eval"} | |||||
getattr(self.model, mode)() | |||||
@rank_zero_call | |||||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): | |||||
""" | |||||
保存当前 driver 的模型到 folder 下。 | |||||
:param filepath: 保存到哪个文件夹; | |||||
:param only_state_dict: 是否只保存权重;如果使用 ``DistributedDataParallel`` 启动分布式训练的话,该参数只能为 ``True``; | |||||
:return: | |||||
""" | |||||
model = self.unwrap_model() | |||||
if not only_state_dict and self.is_distributed(): | |||||
logger.warn("`Cannot save ddp model directly, we will save its state_dict for you.") | |||||
only_state_dict = True | |||||
if only_state_dict: | |||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
oneflow.save(states, filepath) | |||||
else: | |||||
if self.model_device is not None: | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, oneflow.device("cpu")) | |||||
oneflow.save(model, filepath) | |||||
if not self.is_distributed(): | |||||
self.move_model_to_device(model, self.model_device) | |||||
else: | |||||
oneflow.save(model, filepath) | |||||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): | |||||
""" | |||||
从 folder 中加载权重并赋值到当前 driver 的模型上。 | |||||
:param filepath: 加载权重或模型的路径 | |||||
:param load_state_dict: 保存的内容是否只是权重。 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
model = self.unwrap_model() | |||||
res = oneflow.load(filepath) | |||||
if isinstance(res, dict) and only_state_dict is False: | |||||
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " | |||||
f"`only_state_dict=True`") | |||||
elif not isinstance(res, dict) and only_state_dict is True: | |||||
logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " | |||||
f"`only_state_dict=False`") | |||||
if not isinstance(res, dict): | |||||
res = res.state_dict() | |||||
model.load_state_dict(res) | |||||
@rank_zero_call | |||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||||
# 1. sampler 的状态; | |||||
num_consumed_batches = states.pop("num_consumed_batches") | |||||
states["sampler_states"] = self.get_sampler_state(dataloader, num_consumed_batches) | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | |||||
if not os.path.exists(folder): | |||||
os.mkdir(folder) | |||||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||||
self.save_model(model_path, only_state_dict=only_state_dict) | |||||
# 3. 保存 optimizers 的状态; | |||||
states["optimizers_state_dict"] = self.get_optimizer_state() | |||||
logger.debug("Save optimizer state dict.") | |||||
# # 4. 保存fp16的状态 | |||||
# if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
# grad_scaler_state_dict = self.grad_scaler.state_dict() | |||||
# states['grad_scaler_state_dict'] = grad_scaler_state_dict | |||||
oneflow.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
def get_sampler_state(self, dataloader, num_consumed_batches): | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | |||||
sampler_states = sampler.state_dict() | |||||
if dataloader_args.batch_size is not None: | |||||
sampler_states["num_consumed_samples"] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
* num_consumed_batches | |||||
else: | |||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on sampler's " | |||||
"`num_consumed_samples`, it may cause missing some samples when reload.") | |||||
else: | |||||
raise RuntimeError("The sampler has no `state_dict()` method, fastNLP cannot save the training " | |||||
"state.") | |||||
return sampler_states | |||||
def load_sampler_state(self, dataloader, sampler_states): | |||||
states = {} | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | |||||
elif isinstance(dataloader_args.sampler, OneflowRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
logger.debug("Replace oneflow RandomSampler into fastNLP RandomSampler.") | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our" | |||||
"`ReproducibleSampler`.") | |||||
else: | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(sampler_states) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
return states | |||||
def get_optimizer_state(self): | |||||
optimizers_state_dict = {} | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: oneflow.optim.Optimizer = self.optimizers[i] | |||||
optimizer_state = optimizer.state_dict() | |||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], oneflow.device("cpu")) | |||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
return optimizers_state_dict | |||||
def load_optimizer_state(self, states): | |||||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
f"checkpoint it is:{len(states)}" | |||||
for i in range(len(self.optimizers)): | |||||
optimizer: oneflow.optim.Optimizer = self.optimizers[i] | |||||
optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
logger.debug("Load optimizer state dict.") | |||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = oneflow.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||||
# 1. 加载 optimizers 的状态; | |||||
optimizers_state_dict = states.pop("optimizers_state_dict") | |||||
self.load_optimizer_state(optimizers_state_dict) | |||||
# 2. 加载模型状态; | |||||
if should_load_model: | |||||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||||
# # 3. 加载 fp16 的状态 | |||||
# if "grad_scaler_state_dict" in states: | |||||
# grad_scaler_state_dict = states.pop("grad_scaler_state_dict") | |||||
# if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
# self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||||
# logger.debug("Load grad_scaler state dict...") | |||||
# elif not isinstance(self.grad_scaler, DummyGradScaler): | |||||
# logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||||
# f"the training process may be unstable.") | |||||
# 4. 恢复 sampler 的状态; | |||||
sampler_states = states.pop("sampler_states") | |||||
states_ret = self.load_sampler_state(dataloader, sampler_states) | |||||
states.update(states_ret) | |||||
return states | |||||
def get_evaluate_context(self): | |||||
r""" | |||||
:return: 返回 ``oneflow.no_grad`` 这个 context; | |||||
""" | |||||
return oneflow.no_grad | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
else: | |||||
return fn(batch) | |||||
def get_model_call_fn(self, fn: str) -> Tuple: | |||||
if hasattr(self.model, fn): | |||||
fn = getattr(self.model, fn) | |||||
if not callable(fn): | |||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") | |||||
logger.debug(f"Use {_get_fun_msg(fn, with_fp=False)}...") | |||||
return fn, None | |||||
elif fn in {"train_step", "evaluate_step"}: | |||||
logger.debug(f"Use {_get_fun_msg(self.model.forward, with_fp=False)}...") | |||||
return self.model, self.model.forward | |||||
else: | |||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | |||||
@staticmethod | |||||
def move_model_to_device(model: "oneflow.nn.Module", device: "oneflow.device"): | |||||
r""" | |||||
将模型迁移到对应的设备上; | |||||
""" | |||||
if device is not None: | |||||
model.to(device) | |||||
def move_data_to_device(self, batch): | |||||
""" | |||||
将一个 batch 的数据迁移到对应的设备上; | |||||
:param batch: 一个 batch 的数据,可以是 ``list、dict`` 等; | |||||
:return: | |||||
""" | |||||
return oneflow_move_data_to_device(batch, self.data_device) | |||||
@staticmethod | |||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
process_seed = oneflow.initial_seed() | |||||
base_seed = process_seed - worker_id | |||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||||
np.random.seed(ss.generate_state(4)) | |||||
oneflow_ss, stdlib_ss = ss.spawn(2) | |||||
oneflow.manual_seed(oneflow_ss.generate_state(1, dtype=np.uint64)[0]) | |||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||||
random.seed(stdlib_seed) | |||||
def set_deterministic_dataloader(self, dataloader: "DataLoader"): | |||||
if dataloader.worker_init_fn is None: | |||||
dataloader.worker_init_fn = partial(self.worker_init_function, | |||||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) | |||||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx: int): | |||||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||||
if callable(getattr(dataloader.sampler, "set_epoch", None)): | |||||
dataloader.sampler.set_epoch(cur_epoch_idx) | |||||
@staticmethod | |||||
def get_dataloader_args(dataloader: "DataLoader"): | |||||
""" | |||||
获取 dataloader 的 shuffle 和 drop_last 属性; | |||||
""" | |||||
@dataclass | |||||
class Res: | |||||
dataset: Optional[Dataset] = None | |||||
batch_sampler: Optional[BatchSampler] = None | |||||
sampler: Optional[Sampler] = None | |||||
batch_size: Optional[int] = None | |||||
shuffle: Optional[bool] = None | |||||
drop_last: Optional[bool] = None | |||||
res = Res() | |||||
# oneflow 的 DataLoader 一定会有 dataset 属性; | |||||
res.dataset = dataloader.dataset | |||||
# dataloader 使用的是 sampler; | |||||
if dataloader.batch_sampler is None: | |||||
res.sampler = dataloader.sampler | |||||
res.batch_size = 1 | |||||
res.shuffle = True if isinstance(dataloader.sampler, RandomSampler) else False | |||||
res.drop_last = False | |||||
# dataloader 使用的是 batch_sampler; | |||||
else: | |||||
res.batch_sampler = dataloader.batch_sampler | |||||
if hasattr(dataloader.batch_sampler, "batch_size"): | |||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; | |||||
else: | |||||
dataloader_iter = iter(dataloader) | |||||
pre_sample = next(dataloader_iter) | |||||
res.batch_size = pre_sample.shape[0] | |||||
if hasattr(dataloader.batch_sampler, "sampler"): | |||||
res.sampler = dataloader.batch_sampler.sampler | |||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(dataloader.batch_sampler.sampler, OneflowRandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
# ReproduceBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | |||||
batch_sampler = dataloader.batch_sampler.batch_sampler | |||||
res.sampler = batch_sampler.sampler | |||||
if hasattr(batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(batch_sampler.sampler, OneflowRandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
else: | |||||
# 如果 dataloader.batch_sampler 没有 sampler 这个属性,那么说明其使用的是自己的 batch_sampler,且没有 "sampler" 属性; | |||||
# 这种情况下 DataLoader 会自己初始化一个 sampler;我们因此将这个默认初始化的 sampler 挂载到 res 上; | |||||
res.sampler = dataloader.sampler | |||||
res.shuffle = False | |||||
if hasattr(dataloader.batch_sampler, "drop_last"): | |||||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last") | |||||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; | |||||
else: | |||||
res.drop_last = False | |||||
return res |
@@ -0,0 +1,114 @@ | |||||
import os | |||||
from typing import Dict, Union | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.utils.data import SequentialSampler as OneflowSequentialSampler | |||||
from oneflow.utils.data import BatchSampler as OneflowBatchSampler | |||||
__all__ = [ | |||||
"OneflowSingleDriver" | |||||
] | |||||
from .oneflow_driver import OneflowDriver | |||||
from fastNLP.core.drivers.oneflow_driver.utils import replace_sampler, replace_batch_sampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||||
ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | |||||
from fastNLP.core.log import logger | |||||
class OneflowSingleDriver(OneflowDriver): | |||||
r""" | |||||
用于执行 ``oneflow`` 动态图 cpu 和 单卡 gpu 运算的 ``driver``; | |||||
:param model: 传入给 ``Trainer`` 的 ``model`` 参数; | |||||
:param device: oneflow.device,当前进程所使用的设备; | |||||
:param fp16: 是否开启 fp16;目前动态图的单卡下该参数无效; | |||||
:param oneflow_kwargs: | |||||
""" | |||||
def __init__(self, model, device: "oneflow.device", fp16: bool = False, oneflow_kwargs: Dict = {}, **kwargs): | |||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) | |||||
if cuda_visible_devices == "": | |||||
device = oneflow.device("cpu") | |||||
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" | |||||
"use `cpu` instead of `gpu` device.") | |||||
super(OneflowSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | |||||
if device is None: | |||||
logger.debug("device is not set, fastNLP will try to automatically get it.") | |||||
try: | |||||
device = next(model.parameters()).device | |||||
assert isinstance(device, oneflow.device) | |||||
except: | |||||
raise ValueError("fastNLP cannot get device automatically, please set device explicitly.") | |||||
self.model_device = device | |||||
self.local_rank = 0 | |||||
self.global_rank = 0 | |||||
self.world_size = 1 | |||||
def setup(self): | |||||
r""" | |||||
将模型迁移到相应的设备上; | |||||
""" | |||||
if self.model_device is not None: | |||||
self.model.to(self.model_device) | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||||
reproducible: bool = False): | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, dist) | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | |||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
if reproducible: | |||||
if type(args.batch_sampler) is OneflowBatchSampler: | |||||
if type(args.sampler) is OneflowSequentialSampler: | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace oneflow SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | |||||
return dataloader | |||||
def unwrap_model(self): | |||||
r""" | |||||
:return: 返回模型 | |||||
""" | |||||
return self.model | |||||
@property | |||||
def data_device(self): | |||||
r""" | |||||
:return: 数据和模型所在的设备; | |||||
""" | |||||
return self.model_device | |||||
def is_distributed(self): | |||||
r""" | |||||
:return: 返回当前使用的 driver 是否是分布式的 driver,在 ``OneflowSingleDriver`` 中返回 ``False``; | |||||
""" | |||||
return False |
@@ -0,0 +1,292 @@ | |||||
import os | |||||
from typing import Any, Dict, Optional | |||||
from enum import IntEnum | |||||
import contextlib | |||||
import random | |||||
import numpy as np | |||||
import inspect | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.envs.utils import get_global_seed | |||||
from fastNLP.envs import ( | |||||
get_global_rank, | |||||
FASTNLP_BACKEND_LAUNCH, | |||||
FASTNLP_GLOBAL_SEED, | |||||
) | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.utils import auto_param_call | |||||
from fastNLP.core.log import logger | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.nn import Module | |||||
from oneflow.utils.data import DataLoader | |||||
from oneflow.utils.data import RandomSampler as oneflowRandomSampler | |||||
from oneflow.utils.data import SequentialSampler as oneflowSequentialSampler | |||||
from oneflow.utils.data import BatchSampler as oneflowBatchSampler | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
__all__ = [ | |||||
'oneflow_seed_everything', | |||||
'optimizer_state_to_device' | |||||
] | |||||
def oneflow_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int: | |||||
r""" | |||||
为 **oneflow**、**numpy**、**python.random** 伪随机数生成器设置种子。 | |||||
:param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。 | |||||
:param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。 | |||||
当设置为 ``True`` 时,**FastNLP** 会将种子加上当前的 ``global_rank``。 | |||||
""" | |||||
max_seed_value = np.iinfo(np.uint32).max | |||||
min_seed_value = np.iinfo(np.uint32).min | |||||
if seed is None: | |||||
if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1": | |||||
seed = 42 | |||||
else: | |||||
seed = get_global_seed() | |||||
logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.") | |||||
if not isinstance(seed, int): | |||||
seed = int(seed) | |||||
if not (min_seed_value <= seed <= max_seed_value): | |||||
logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.") | |||||
seed %= max_seed_value | |||||
os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}" | |||||
if add_global_rank_to_seed: | |||||
seed += get_global_rank() | |||||
random.seed(seed) | |||||
np.random.seed(seed) | |||||
oneflow.manual_seed(seed) | |||||
oneflow.cuda.manual_seed_all(seed) | |||||
return seed | |||||
class ForwardState(IntEnum): | |||||
TRAIN = 0 | |||||
VALIDATE = 1 | |||||
TEST = 2 | |||||
PREDICT = 3 | |||||
class _DDPWrappingModel(Module): | |||||
""" | |||||
该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; | |||||
之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; | |||||
另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等; | |||||
然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 | |||||
`model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; | |||||
因此出于以上考虑,我们实现了这一函数; | |||||
对于更详细的解释,可以参考 'pytorch_lightning' 的 ddp 的设计; | |||||
""" | |||||
def __init__(self, model: Module): | |||||
super(_DDPWrappingModel, self).__init__() | |||||
self.model = model | |||||
def forward(self, batch, **kwargs) -> Dict: | |||||
""" | |||||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | |||||
""" | |||||
fn = kwargs.pop("fastnlp_fn") | |||||
signature_fn = kwargs.pop("fastnlp_signature_fn") | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | |||||
else: | |||||
return fn(batch) | |||||
class DummyGradScaler: | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def get_scale(self): | |||||
return 1.0 | |||||
def is_enabled(self): | |||||
return False | |||||
def scale(self, outputs): | |||||
return outputs | |||||
def step(self, optimizer, *args, **kwargs): | |||||
optimizer.step(*args, **kwargs) | |||||
def update(self, new_scale=None): | |||||
pass | |||||
def unscale_(self, optimizer): | |||||
pass | |||||
def load_state_dict(self, state_dict): | |||||
pass | |||||
def state_dict(self): | |||||
return {} | |||||
def _build_fp16_env(dummy=False): | |||||
return | |||||
if dummy: | |||||
autocast = contextlib.ExitStack | |||||
GradScaler = DummyGradScaler | |||||
else: | |||||
if not oneflow.cuda.is_available(): | |||||
raise RuntimeError("Oneflow is not installed in gpu version, please use device='cpu'.") | |||||
if oneflow.cuda.get_device_capability(0)[0] < 7: | |||||
logger.rank_zero_warning( | |||||
"NOTE: your device does NOT support faster training with fp16, " | |||||
"please switch to FP32 which is likely to be faster" | |||||
) | |||||
try: | |||||
from oneflow.amp import GradScaler | |||||
from oneflow.cuda.amp import autocast, GradScaler | |||||
except ImportError: | |||||
raise RuntimeError("torch version too low (less than 1.6)") | |||||
return autocast, GradScaler | |||||
def replace_sampler(dataloader: "DataLoader", sampler): | |||||
r""" | |||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于): | |||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | |||||
`inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader | |||||
的类,而不是直接的 DataLoader; | |||||
如果需要定制自己的 dataloader,保证以下两点: | |||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; | |||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 | |||||
来获取实际的参数的值; | |||||
""" | |||||
# 拿到实例属性; | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | |||||
# 'multiprocessing_context' 是 user-defined function; | |||||
if getattr(dataloader, 'multiprocessing_context', None) is not None: | |||||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
# 防止用户的 DataLoader 是继承了 oneflow 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||||
if has_variadic_kwargs and isinstance(dataloader, DataLoader): | |||||
# 防止用户写入了 super().__init__(**kwargs) | |||||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | |||||
if key not in init_params and key != 'self': | |||||
init_params[key] = value | |||||
# 如果初始化dataloader所使用的参数不是默认值,那么我们需要将其记录下来用于重新初始化时设置; | |||||
non_default_params = {name for name, p in init_params.items() if | |||||
name in instance_attrs and p.default != instance_attrs[name]} | |||||
# add `dataset` as it might have been replaced with `*args` | |||||
non_default_params.add("dataset") | |||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||||
if isinstance(dataloader, DataLoader): | |||||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||||
batch_sampler = getattr(dataloader, "batch_sampler") | |||||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||||
raise RuntimeError("It should not be running here, please report a bug to us.") | |||||
required_args = { | |||||
p.name | |||||
for p in init_params.values() | |||||
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) | |||||
and p.default is p.empty | |||||
and p.name not in reconstruct_args | |||||
} | |||||
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 | |||||
if required_args: | |||||
required_args = sorted(required_args) | |||||
dataloader_self_name = dataloader.__class__.__name__ | |||||
raise Exception( | |||||
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " | |||||
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " | |||||
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " | |||||
f"`{dataloader_self_name}`'s attribute." | |||||
) | |||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | |||||
if not has_variadic_kwargs: | |||||
# the dataloader signature does not allow keyword arguments that need to be passed | |||||
missing_kwargs = reconstruct_args.keys() - init_params.keys() | |||||
if missing_kwargs: | |||||
missing_kwargs = sorted(missing_kwargs) | |||||
dataloader_self_name = dataloader.__class__.__name__ | |||||
raise Exception( | |||||
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." | |||||
) | |||||
# 如果没有kwargs,则保证一下只传入需要的参数 | |||||
if not isinstance(dataloader, DataLoader): | |||||
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} | |||||
return type(dataloader)(**reconstruct_args) | |||||
def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
r""" | |||||
替换一个 dataloader 的 batch_sampler; | |||||
""" | |||||
params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] | |||||
for k in ["batch_size", "sampler", "drop_last", "batch_sampler", "dataset_kind"]: | |||||
if k in params_keys: | |||||
params_keys.remove(k) | |||||
params = {k: getattr(dataloader, k) for k in params_keys} | |||||
params["batch_sampler"] = new_batch_sampler | |||||
if not isinstance(dataloader, DataLoader): | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||||
if not has_variadic_kwargs: | |||||
params = {key:value for key,value in params.items() if key in init_params} | |||||
return type(dataloader)(**params) | |||||
def optimizer_state_to_device(state, device): | |||||
r""" | |||||
将一个 ``optimizer`` 的 ``state_dict`` 迁移到对应的设备; | |||||
:param state: ``optimzier.state_dict()``; | |||||
:param device: 要迁移到的目的设备; | |||||
:return: 返回迁移后的新的 state_dict; | |||||
""" | |||||
new_state = {} | |||||
for name, param in state.items(): | |||||
if isinstance(param, dict): | |||||
new_state[name] = optimizer_state_to_device(param, device) | |||||
elif isinstance(param, oneflow.Tensor): | |||||
new_state[name] = param.to(device).clone() | |||||
else: | |||||
new_state[name] = param | |||||
return new_state | |||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||||
if type(args.batch_sampler) is not oneflowBatchSampler or (type(args.sampler) not in {oneflowRandomSampler, | |||||
oneflowSequentialSampler}): | |||||
mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||||
f"``{substitution}``. The customized sampler should set for distributed running " | |||||
f"before initializing ``{controller}`` , and then set the " | |||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") |
@@ -8,6 +8,7 @@ from .backend import Backend | |||||
from .torch_backend.backend import TorchBackend | from .torch_backend.backend import TorchBackend | ||||
from .paddle_backend.backend import PaddleBackend | from .paddle_backend.backend import PaddleBackend | ||||
from .jittor_backend.backend import JittorBackend | from .jittor_backend.backend import JittorBackend | ||||
from .oneflow_backend.backend import OneflowBackend | |||||
class AutoBackend(Backend): | class AutoBackend(Backend): | ||||
@@ -52,6 +53,8 @@ class AutoBackend(Backend): | |||||
self.__class__ = PaddleBackend | self.__class__ = PaddleBackend | ||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
self.__class__ = JittorBackend | self.__class__ = JittorBackend | ||||
elif backend == 'oneflow': | |||||
self.__class__ = OneflowBackend | |||||
elif backend is None: | elif backend is None: | ||||
# 不用做任何事情就可以初始化了 | # 不用做任何事情就可以初始化了 | ||||
pass | pass | ||||
@@ -0,0 +1,130 @@ | |||||
from typing import List | |||||
import numpy as np | |||||
from fastNLP.core.metrics.backend import Backend | |||||
from fastNLP.core.metrics.utils import AggregateMethodError | |||||
from fastNLP.core.utils import is_in_oneflow_dist | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
__all__ = [] | |||||
class OneflowBackend(Backend): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self._specified = True | |||||
def aggregate(self, tensor, method: str): | |||||
""" | |||||
聚集结果,并根据 method 计算后,返回结果 | |||||
:param tensor: 需要聚合的张量 | |||||
:param method: 聚合的方法, 目前支持 ``['sum', 'mean', 'max', 'mix']``: | |||||
* method 为 ``'sum'`` 时, 会将多张卡上聚合结果在维度为 `0` 上 累加起来。 | |||||
* method 为 ``'mean'`` 时,会将多张卡上聚合结果在维度为 `0` 上取平均值。 | |||||
* method 为 ``'max'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最大值。 | |||||
* method 为 ``'mix'`` 时,会将多张卡上聚合结果在维度为 `0` 上取最小值。 | |||||
""" | |||||
if isinstance(tensor, oneflow.Tensor): | |||||
# TODO 暂时没有找到 oneflow 中检测是否初始化了分布式环境的方法 | |||||
if is_in_oneflow_dist(): | |||||
if method is None: | |||||
raise AggregateMethodError(should_have_aggregate_method=True) | |||||
tensor = self.all_gather_object(tensor) | |||||
if isinstance(tensor[0], oneflow.Tensor): | |||||
tensor = oneflow.stack(tensor) | |||||
# 第一步, aggregate结果 | |||||
if method == 'sum': | |||||
tensor = oneflow.sum(tensor, dim=0) | |||||
elif method == 'mean': | |||||
tensor = oneflow.mean(tensor, dim=0) | |||||
elif method == 'max': | |||||
tensor, _ = oneflow.max(tensor, dim=0) | |||||
elif method == 'min': | |||||
tensor, _ = oneflow.min(tensor, dim=0) | |||||
else: | |||||
raise AggregateMethodError(should_have_aggregate_method=False) | |||||
return tensor | |||||
def create_tensor(self, value: float): | |||||
""" | |||||
创建 tensor,并且填入 value 作为值 | |||||
:param value: 创建张量的初始值 | |||||
""" | |||||
tensor = oneflow.ones(1).fill_(value) | |||||
return tensor | |||||
def fill_value(self, tensor, value: float): | |||||
""" | |||||
将 tensor 的值设置为 value | |||||
:param tensor: 传入的张量 | |||||
:param value: 需要 fill 的值。 | |||||
""" | |||||
tensor.fill_(value) | |||||
return tensor | |||||
def get_scalar(self, tensor) -> float: | |||||
""" | |||||
获取 tensor 的 scalar 值 | |||||
:param tensor: 传入的张量 | |||||
""" | |||||
return tensor.item() | |||||
def tensor2numpy(self, tensor) -> np.array: | |||||
""" | |||||
将 tensor 转为 numpy 值, 主要是在 metric 计算中使用 | |||||
:param tensor: 传入的张量 | |||||
""" | |||||
if isinstance(tensor, oneflow.Tensor): | |||||
return tensor.cpu().detach().numpy() | |||||
elif isinstance(tensor, np.ndarray): | |||||
return tensor | |||||
elif isinstance(tensor, (float, int)): | |||||
return tensor | |||||
else: | |||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||||
@staticmethod | |||||
def is_distributed() -> bool: | |||||
""" | |||||
判断是否为 ddp 状态 | |||||
:return: | |||||
""" | |||||
return is_in_oneflow_dist() | |||||
def move_tensor_to_device(self, tensor, device): | |||||
""" | |||||
将张量移到设备上 | |||||
:param tensor: 需要移动的张量 | |||||
:param device: 设备名, 一般为 "cpu", "cuda:0"等字符串 | |||||
""" | |||||
return tensor.to(device) | |||||
def all_gather_object(self, obj, group=None) -> List: | |||||
""" | |||||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||||
:param obj: | |||||
:param group: | |||||
""" | |||||
if self.is_distributed(): | |||||
obj_list = fastnlp_oneflow_all_gather(obj) | |||||
return obj_list | |||||
return [obj] | |||||
@@ -14,6 +14,10 @@ __all__ = [ | |||||
'f_rich_progress', | 'f_rich_progress', | ||||
'torch_move_data_to_device', | 'torch_move_data_to_device', | ||||
'is_torch_module', | 'is_torch_module', | ||||
'get_oneflow_device', | |||||
'oneflow_move_data_to_device', | |||||
'is_oneflow_module', | |||||
'is_in_oneflow_dist', | |||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
'auto_param_call', | 'auto_param_call', | ||||
'check_user_specific_params', | 'check_user_specific_params', | ||||
@@ -36,6 +40,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module | is_in_fnlp_paddle_dist, is_in_paddle_launch_dist, is_paddle_module | ||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
from .torch_utils import torch_move_data_to_device, is_torch_module | from .torch_utils import torch_move_data_to_device, is_torch_module | ||||
from .oneflow_utils import oneflow_move_data_to_device, is_oneflow_module, is_in_oneflow_dist, get_oneflow_device | |||||
from .utils import * | from .utils import * | ||||
from .tqdm_progress import f_tqdm_progress | from .tqdm_progress import f_tqdm_progress | ||||
from .seq_len_to_mask import seq_len_to_mask | from .seq_len_to_mask import seq_len_to_mask | ||||
@@ -0,0 +1,69 @@ | |||||
import os | |||||
from typing import Any, Union, Optional | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
__all__ = [ | |||||
'get_oneflow_device' | |||||
'oneflow_move_data_to_device', | |||||
'is_oneflow_module', | |||||
'is_in_oneflow_dist', | |||||
] | |||||
from .utils import apply_to_collection | |||||
def get_oneflow_device(device): | |||||
""" | |||||
构造一个 :class:`oneflow.device` 实例并返回。 | |||||
:param device: 字符串或 gpu 编号 | |||||
:return: :class:`oneflow.device` | |||||
""" | |||||
if isinstance(device, oneflow.device): | |||||
return device | |||||
if isinstance(device, int): | |||||
return oneflow.device("cuda", device) | |||||
if isinstance(device, str): | |||||
return oneflow.device(device) | |||||
raise RuntimeError(f"Cannot get `oneflow.device` from {device}.") | |||||
def oneflow_move_data_to_device(batch: Any, device: Optional[Union[str, "oneflow.device"]] = None) -> Any: | |||||
r""" | |||||
在 **oneflow** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||||
:param batch: 需要迁移的数据; | |||||
:param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作; | |||||
:return: 迁移到新设备上的数据集合; | |||||
""" | |||||
if device is None: | |||||
return batch | |||||
def batch_to(data: Any) -> Any: | |||||
data_output = data.to(device) | |||||
if data_output is not None: | |||||
return data_output | |||||
# user wrongly implemented the `TransferableDataType` and forgot to return `self`. | |||||
return data | |||||
return apply_to_collection(batch, dtype=oneflow.Tensor, function=batch_to) | |||||
def is_oneflow_module(model) -> bool: | |||||
""" | |||||
判断传入的 ``model`` 是否是 :class:`oneflow.nn.Module` 类型 | |||||
:param model: 模型; | |||||
:return: 当前模型是否为 ``oneflow`` 的模型; | |||||
""" | |||||
try: | |||||
return isinstance(model, oneflow.nn.Module) | |||||
except BaseException: | |||||
return False | |||||
def is_in_oneflow_dist() -> bool: | |||||
""" | |||||
判断是否处于 **oneflow** 分布式的进程下。 | |||||
""" | |||||
return "GLOG_log_dir" in os.environ |
@@ -22,5 +22,6 @@ _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and | |||||
_NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import | _NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import | ||||
_NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import | _NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import | ||||
_NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import | _NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import | ||||
_NEED_IMPORT_ONEFLOW = _module_available("oneflow") and 'oneflow' in need_import | |||||
_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") | _TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") |
@@ -8,7 +8,7 @@ from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VIS | |||||
from fastNLP.envs.utils import _module_available, get_gpu_count | from fastNLP.envs.utils import _module_available, get_gpu_count | ||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor', 'oneflow'] | |||||
def _set_backend(): | def _set_backend(): | ||||
@@ -145,6 +145,9 @@ def set_env(global_seed=None): | |||||
if backend == 'torch': | if backend == 'torch': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
if backend == 'oneflow': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
def dump_fastnlp_backend(default:bool = False, backend=None): | def dump_fastnlp_backend(default:bool = False, backend=None): | ||||
""" | """ | ||||
@@ -50,6 +50,15 @@ def set_env_on_import_jittor(): | |||||
if 'log_silent' not in os.environ: | if 'log_silent' not in os.environ: | ||||
os.environ['log_silent'] = '1' | os.environ['log_silent'] = '1' | ||||
def set_env_on_import_oneflow(): | |||||
if 'GLOG_log_dir' in os.environ: | |||||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ['RANK'] | |||||
if int(os.environ.get(FASTNLP_REMOVE_LOCAL_RANK, 1)): | |||||
remove_local_rank_in_argv() | |||||
if 'GLOG_log_dir' in os.environ and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
os.environ[FASTNLP_BACKEND_LAUNCH] = '1' | |||||
def set_env_on_import(): | def set_env_on_import(): | ||||
""" | """ | ||||
@@ -61,6 +70,7 @@ def set_env_on_import(): | |||||
set_env_on_import_torch() | set_env_on_import_torch() | ||||
set_env_on_import_paddle() | set_env_on_import_paddle() | ||||
set_env_on_import_jittor() | set_env_on_import_jittor() | ||||
set_env_on_import_oneflow() | |||||
# fastNLP 内部使用的一些变量 | # fastNLP 内部使用的一些变量 | ||||
if FASTNLP_LAUNCH_TIME not in os.environ: | if FASTNLP_LAUNCH_TIME not in os.environ: | ||||
@@ -3,7 +3,7 @@ import numpy as np | |||||
from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \ | from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \ | ||||
_get_element_shape_dtype | _get_element_shape_dtype | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR, _NEED_IMPORT_ONEFLOW | |||||
def test_get_element_shape_dtype(): | def test_get_element_shape_dtype(): | ||||
@@ -14,10 +14,11 @@ def test_get_element_shape_dtype(): | |||||
catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))]) | catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))]) | ||||
# @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) | |||||
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'paddle']) | |||||
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'paddle', 'jittor', 'oneflow']) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
@pytest.mark.jittor | |||||
@pytest.mark.oneflow | |||||
def test_get_padder_run(backend): | def test_get_padder_run(backend): | ||||
if not _NEED_IMPORT_TORCH and backend == 'torch': | if not _NEED_IMPORT_TORCH and backend == 'torch': | ||||
pytest.skip("No torch") | pytest.skip("No torch") | ||||
@@ -25,6 +26,8 @@ def test_get_padder_run(backend): | |||||
pytest.skip("No paddle") | pytest.skip("No paddle") | ||||
if not _NEED_IMPORT_JITTOR and backend == 'jittor': | if not _NEED_IMPORT_JITTOR and backend == 'jittor': | ||||
pytest.skip("No jittor") | pytest.skip("No jittor") | ||||
if not _NEED_IMPORT_ONEFLOW and backend == 'oneflow': | |||||
pytest.skip("No oneflow") | |||||
batch_field = [1, 2, 3] | batch_field = [1, 2, 3] | ||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | ||||
@@ -163,3 +166,57 @@ def test_torch_padder(): | |||||
assert isinstance(pad_batch, np.ndarray) | assert isinstance(pad_batch, np.ndarray) | ||||
assert np.shape(pad_batch) == (3, 3, 3) | assert np.shape(pad_batch) == (3, 3, 3) | ||||
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 | assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 | ||||
@pytest.mark.oneflow | |||||
def test_oneflow_padder(): | |||||
if not _NEED_IMPORT_ONEFLOW: | |||||
pytest.skip("No oneflow.") | |||||
import oneflow | |||||
backend = 'oneflow' | |||||
target_type = oneflow.Tensor | |||||
batch_field = [1, 2, 3] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert (pad_batch == oneflow.LongTensor(batch_field)).sum()==len(batch_field) | |||||
batch_field = [[1], [2, 2], [3, 3, 3]] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3) | |||||
assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==3 | |||||
batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,3))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3, 3) | |||||
assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==9 | |||||
batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3, 3) | |||||
assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==12 | |||||
batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,))] | |||||
with pytest.raises(InconsistencyError): | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
# 可以是 numpy.ndarray | |||||
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3, 3) | |||||
assert (pad_batch == oneflow.zeros(pad_batch.shape)).sum()==12 | |||||
# 测试 to numpy | |||||
batch_field = [oneflow.ones((3,3)), oneflow.ones((2,3)), oneflow.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, np.ndarray) | |||||
assert np.shape(pad_batch) == (3, 3, 3) | |||||
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 |
@@ -0,0 +1,105 @@ | |||||
import numpy as np | |||||
import pytest | |||||
from fastNLP.core.collators.padders.oneflow_padder import OneflowTensorPadder, OneflowSequencePadder, OneflowNumberPadder | |||||
from fastNLP.core.collators.padders.exceptions import DtypeError | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
@pytest.mark.oneflow | |||||
class TestOneflowNumberPadder: | |||||
def test_run(self): | |||||
padder = OneflowNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [1, 2, 3] | |||||
t_a = padder(a) | |||||
assert isinstance(t_a, oneflow.Tensor) | |||||
assert (t_a == oneflow.LongTensor(a)).sum() == 3 | |||||
@pytest.mark.oneflow | |||||
class TestOneflowSequencePadder: | |||||
def test_run(self): | |||||
padder = OneflowSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [[1, 2, 3], [3]] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, oneflow.Tensor) | |||||
assert tuple(shape) == (2, 3) | |||||
b = oneflow.LongTensor([[1, 2, 3], [3, -1, -1]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
def test_dtype_check(self): | |||||
padder = OneflowSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | |||||
padder = OneflowSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
padder = OneflowSequencePadder(pad_val=-1, ele_dtype=oneflow.long, dtype=int) | |||||
padder = OneflowSequencePadder(pad_val=-1, ele_dtype=np.int8, dtype=None) | |||||
a = padder([[1], [2, 322]]) | |||||
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | |||||
padder = OneflowSequencePadder(pad_val=-1, ele_dtype=np.zeros(2).dtype, dtype=None) | |||||
@pytest.mark.oneflow | |||||
class TestOneflowTensorPadder: | |||||
def test_run(self): | |||||
padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.zeros(3).dtype, dtype=int) | |||||
a = [oneflow.zeros(3), oneflow.zeros(2), oneflow.zeros(0)] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, oneflow.Tensor) | |||||
assert tuple(shape) == (3, 3) | |||||
b = oneflow.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
a = [oneflow.zeros((3, 2)), oneflow.zeros((2, 2)), oneflow.zeros((1, 2))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, oneflow.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = oneflow.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[0, 0], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
a = [oneflow.zeros((3, 2)), oneflow.zeros((2, 2)), oneflow.zeros((1, 1))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, oneflow.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = oneflow.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.zeros(3).dtype, dtype=int) | |||||
a = [oneflow.zeros((3, 2)), oneflow.zeros((2, 2)), oneflow.zeros((1, 0))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, oneflow.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = oneflow.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.zeros(3).dtype, dtype=None) | |||||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, oneflow.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = oneflow.FloatTensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
def test_dtype_check(self): | |||||
padder = OneflowTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | |||||
padder = OneflowTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
padder = OneflowTensorPadder(pad_val=-1, ele_dtype=oneflow.long, dtype=int) | |||||
padder = OneflowTensorPadder(pad_val=-1, ele_dtype=int, dtype=oneflow.long) | |||||
@@ -2,7 +2,7 @@ | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR, _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.core.collators.collator import Collator | from fastNLP.core.collators.collator import Collator | ||||
from ...helpers.utils import Capturing | from ...helpers.utils import Capturing | ||||
@@ -14,6 +14,10 @@ def _assert_equal(d1, d2): | |||||
if 'float64' in str(d2.dtype): | if 'float64' in str(d2.dtype): | ||||
print(d2.dtype) | print(d2.dtype) | ||||
assert (d1 == d2).all().item() | assert (d1 == d2).all().item() | ||||
if 'oneflow' in str(type(d1)): | |||||
if 'float64' in str(d2.dtype): | |||||
print(d2.dtype) | |||||
assert (d1 == d2).all().item() | |||||
else: | else: | ||||
assert all(d1 == d2) | assert all(d1 == d2) | ||||
except TypeError: | except TypeError: | ||||
@@ -43,9 +47,9 @@ def findListDiff(d1, d2): | |||||
class TestCollator: | class TestCollator: | ||||
@pytest.mark.torch | |||||
def test_run(self): | |||||
dict_batch = [{ | |||||
@staticmethod | |||||
def setup_class(cls): | |||||
cls.dict_batch = [{ | |||||
'str': '1', | 'str': '1', | ||||
'lst_str': ['1'], | 'lst_str': ['1'], | ||||
'int': 1, | 'int': 1, | ||||
@@ -75,17 +79,21 @@ class TestCollator: | |||||
} | } | ||||
] | ] | ||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
cls.list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
def test_run_traw(self): | |||||
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | ||||
collator = Collator(backend='raw') | collator = Collator(backend='raw') | ||||
assert raw_pad_batch == collator(dict_batch) | |||||
assert raw_pad_batch == collator(self.dict_batch) | |||||
collator = Collator(backend='raw') | collator = Collator(backend='raw') | ||||
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | ||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | ||||
[{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
findListDiff(raw_pad_lst, collator(self.list_batch)) | |||||
def test_run_numpy(self): | |||||
collator = Collator(backend='numpy') | collator = Collator(backend='numpy') | ||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), | numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), | ||||
@@ -94,36 +102,60 @@ class TestCollator: | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), | 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), | ||||
'b': np.array([[1, 2], [1, 2]])}} | 'b': np.array([[1, 2], [1, 2]])}} | ||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
findDictDiff(numpy_pad_batch, collator(self.dict_batch)) | |||||
collator = Collator(backend='numpy') | collator = Collator(backend='numpy') | ||||
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), | numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), | ||||
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | ||||
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), | np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), | ||||
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | ||||
[{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
findListDiff(numpy_pad_lst, collator(list_batch)) | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
collator = Collator(backend='torch') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), | |||||
'lst_int': torch.LongTensor([[1, 0], [1, 2]]), | |||||
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
'float': torch.FloatTensor([1.1, 2.1]), | |||||
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), | |||||
'numpy': torch.FloatTensor([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), | |||||
'b': torch.LongTensor( | |||||
[[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='torch') | |||||
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), | |||||
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), | |||||
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(torch_pad_lst, collator(list_batch)) | |||||
findListDiff(numpy_pad_lst, collator(self.list_batch)) | |||||
@pytest.mark.torch | |||||
def test_run_torch(self): | |||||
import torch | |||||
collator = Collator(backend='torch') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), | |||||
'lst_int': torch.LongTensor([[1, 0], [1, 2]]), | |||||
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
'float': torch.FloatTensor([1.1, 2.1]), | |||||
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), | |||||
'numpy': torch.FloatTensor([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), | |||||
'b': torch.LongTensor( | |||||
[[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(self.dict_batch)) | |||||
collator = Collator(backend='torch') | |||||
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), | |||||
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), | |||||
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(torch_pad_lst, collator(self.list_batch)) | |||||
@pytest.mark.oneflow | |||||
def test_run_oneflow(self): | |||||
import oneflow | |||||
collator = Collator(backend='oneflow') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': oneflow.LongTensor([1, 2]), | |||||
'lst_int': oneflow.LongTensor([[1, 0], [1, 2]]), | |||||
'nest_lst_int': oneflow.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
'float': oneflow.FloatTensor([1.1, 2.1]), | |||||
'lst_float': oneflow.FloatTensor([[1.1], [2.1]]), 'bool': oneflow.BoolTensor([True, False]), | |||||
'numpy': oneflow.FloatTensor([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': oneflow.LongTensor([1, 2]), | |||||
'b': oneflow.LongTensor( | |||||
[[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(self.dict_batch)) | |||||
collator = Collator(backend='oneflow') | |||||
oneflow_pad_lst = [['1', '2'], [['1'], ['2', '2']], oneflow.LongTensor([1, 2]), oneflow.LongTensor([[1, 0], [2, 2]]), | |||||
oneflow.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
oneflow.FloatTensor([1.1, 2.1]), oneflow.FloatTensor([[1.1], [2.1]]), oneflow.BoolTensor([True, False]), | |||||
oneflow.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(oneflow_pad_lst, collator(self.list_batch)) | |||||
def test_pad(self): | def test_pad(self): | ||||
dict_batch = [{ | dict_batch = [{ | ||||
@@ -366,6 +398,46 @@ def test_torch_dl(): | |||||
with pytest.raises(KeyError): | with pytest.raises(KeyError): | ||||
dl.set_pad('i', pad_val=None) | dl.set_pad('i', pad_val=None) | ||||
@pytest.mark.oneflow | |||||
def test_oneflow_dl(): | |||||
from fastNLP import OneflowDataLoader | |||||
from fastNLP import DataSet | |||||
import numpy as np | |||||
import oneflow | |||||
ds = DataSet({ | |||||
'x': [1, 2], 'y': [[1,2], [3]], 'z':[np.ones((1, 2)), np.ones((2, 3))], | |||||
'i': [{'j': [1, 2]}, {'j': [3]}], 'j': ['a', 'b'] | |||||
}) | |||||
dl = OneflowDataLoader(ds, batch_size=2) | |||||
batch = next(iter(dl)) | |||||
assert 'x' in batch and 'y' in batch and 'z' in batch and 'i' in batch and 'j' in batch | |||||
assert batch['z'].dtype == oneflow.float32 | |||||
assert isinstance(batch['j'], list) | |||||
assert batch['i']['j'].dtype, oneflow.long | |||||
dl.set_ignore('x') | |||||
batch = next(iter(dl)) | |||||
assert 'x' not in batch and 'y' in batch and 'z' in batch | |||||
dl.set_pad('y', pad_val=None) | |||||
batch = next(iter(dl)) | |||||
assert 'x' not in batch and 'y' in batch and 'z' in batch | |||||
assert isinstance(batch['y'], list) | |||||
assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad | |||||
dl.set_pad(('i', 'j'), pad_val=None) | |||||
batch = next(iter(dl)) | |||||
assert 'x' not in batch and 'y' in batch and 'z' in batch | |||||
assert isinstance(batch['y'], list) | |||||
assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad | |||||
assert isinstance(batch['i']['j'], list) | |||||
assert len(batch['i']['j'][0])!=len(batch['i']['j'][1]) # 没有 pad | |||||
with pytest.raises(KeyError): | |||||
dl.set_pad('i', pad_val=None) | |||||
def test_compare_tuple(): | def test_compare_tuple(): | ||||
from fastNLP.core.collators.collator import _compare_tuple | from fastNLP.core.collators.collator import _compare_tuple | ||||
@@ -0,0 +1,96 @@ | |||||
""" | |||||
测试 oneflow 动态图的多卡训练:: | |||||
>>> # 不使用 DistributedDataParallel 包裹的情况 | |||||
>>> python -m oneflow.distributed.launch --nproc_per_node 2 _test_trainer_oneflow.py | |||||
>>> # 使用 DistributedDataParallel 包裹的情况 | |||||
>>> python -m oneflow.distributed.launch --nproc_per_node 2 _test_trainer_oneflow.py -w | |||||
""" | |||||
import sys | |||||
sys.path.append("../../../") | |||||
import os | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.nn.parallel import DistributedDataParallel | |||||
from oneflow.optim import Adam | |||||
from oneflow.utils.data import DataLoader | |||||
from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 | |||||
from tests.helpers.datasets.oneflow_data import OneflowArgMaxDataset | |||||
@dataclass | |||||
class TrainOneflowConfig: | |||||
num_labels: int = 3 | |||||
feature_dimension: int = 3 | |||||
batch_size: int = 2 | |||||
shuffle: bool = True | |||||
evaluate_every = 2 | |||||
def test_trainer_oneflow( | |||||
callbacks, | |||||
wrapped=False, | |||||
n_epochs=2, | |||||
): | |||||
model = OneflowNormalModel_Classification_1( | |||||
num_labels=TrainOneflowConfig.num_labels, | |||||
feature_dimension=TrainOneflowConfig.feature_dimension | |||||
) | |||||
optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=OneflowArgMaxDataset(20, TrainOneflowConfig.feature_dimension), | |||||
batch_size=TrainOneflowConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=OneflowArgMaxDataset(12, TrainOneflowConfig.feature_dimension), | |||||
batch_size=TrainOneflowConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
evaluate_dataloaders = val_dataloader | |||||
evaluate_every = TrainOneflowConfig.evaluate_every | |||||
metrics = {"acc": Accuracy()} | |||||
if wrapped: | |||||
model.to(int(os.environ["LOCAL_RANK"])) | |||||
model = DistributedDataParallel(model) | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver="oneflow", | |||||
device=0, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
evaluate_every=evaluate_every, | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument( | |||||
"-w", | |||||
"--wrapped", | |||||
default=False, | |||||
action="store_true", | |||||
help="Use DistributedDataParallal to wrap model first.", | |||||
) | |||||
args = parser.parse_args() | |||||
callbacks = [] | |||||
test_trainer_oneflow(callbacks, args.wrapped) |
@@ -0,0 +1,70 @@ | |||||
import os | |||||
import pytest | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
from oneflow.optim import Adam | |||||
from oneflow.utils.data import DataLoader | |||||
from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 | |||||
from tests.helpers.datasets.oneflow_data import OneflowArgMaxDataset | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
@dataclass | |||||
class TrainOneflowConfig: | |||||
num_labels: int = 3 | |||||
feature_dimension: int = 3 | |||||
batch_size: int = 2 | |||||
shuffle: bool = True | |||||
evaluate_every = 2 | |||||
@pytest.mark.parametrize("device", ["cpu", 1]) | |||||
@pytest.mark.parametrize("callbacks", [[]]) | |||||
@pytest.mark.oneflow | |||||
@magic_argv_env_context | |||||
def test_trainer_oneflow( | |||||
device, | |||||
callbacks, | |||||
n_epochs=2, | |||||
): | |||||
model = OneflowNormalModel_Classification_1( | |||||
num_labels=TrainOneflowConfig.num_labels, | |||||
feature_dimension=TrainOneflowConfig.feature_dimension | |||||
) | |||||
optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=OneflowArgMaxDataset(20, TrainOneflowConfig.feature_dimension), | |||||
batch_size=TrainOneflowConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=OneflowArgMaxDataset(12, TrainOneflowConfig.feature_dimension), | |||||
batch_size=TrainOneflowConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
evaluate_dataloaders = val_dataloader | |||||
evaluate_every = TrainOneflowConfig.evaluate_every | |||||
metrics = {"acc": Accuracy()} | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver="oneflow", | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
evaluate_every=evaluate_every, | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
trainer.run() |
@@ -0,0 +1,169 @@ | |||||
import pytest | |||||
from fastNLP.core.dataloaders.oneflow_dataloader import OneflowDataLoader, prepare_oneflow_dataloader | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from tests.helpers.utils import Capturing, recover_logger | |||||
from fastNLP import logger | |||||
import numpy as np | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
@pytest.mark.oneflow | |||||
class TestFdl: | |||||
def test_init_v1(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
fdl = OneflowDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||||
# for batch in fdl: | |||||
# print(batch) | |||||
fdl1 = OneflowDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||||
# for batch in fdl1: | |||||
# print(batch) | |||||
def test_set_padding(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
fdl = OneflowDataLoader(ds, batch_size=3) | |||||
fdl.set_pad("x", -1) | |||||
for batch in fdl: | |||||
assert batch['x'].shape == oneflow.Size([3, 4]) | |||||
def test_get_batch_indices(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
fdl = OneflowDataLoader(ds, batch_size=3, shuffle=True) | |||||
for batch in fdl: | |||||
assert len(fdl.get_batch_indices()) == 3 | |||||
def test_other_dataset(self): | |||||
import numpy as np | |||||
class _DataSet: | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, item): | |||||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||||
def __len__(self): | |||||
return 10 | |||||
def __getattribute__(self, item): | |||||
return object.__getattribute__(self, item) | |||||
dataset = _DataSet() | |||||
dl = OneflowDataLoader(dataset, batch_size=2, shuffle=True) | |||||
# dl.set_inputs('data', 'labels') | |||||
# dl.set_pad_val('labels', val=None) | |||||
for batch in dl: | |||||
assert batch[0].shape == oneflow.Size([2, 5]) | |||||
assert batch[1].shape == oneflow.Size([2, 2, 3]) | |||||
def test_default_collate_fn(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
with pytest.raises(ValueError): | |||||
fdl = OneflowDataLoader(ds, batch_size=3, collate_fn=None) | |||||
import numpy as np | |||||
class _DataSet: | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, item): | |||||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||||
def __len__(self): | |||||
return 10 | |||||
fdl = OneflowDataLoader(_DataSet(), batch_size=3, collate_fn=None, drop_last=True) | |||||
for batch in fdl: | |||||
assert batch[0].shape == oneflow.Size([3, 5]) | |||||
def test_my_collate_fn(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
def collate_fn(batch): | |||||
res = {'x': [], 'y': []} | |||||
for ins in batch: | |||||
res['x'].append(ins['x']) | |||||
res['y'].append(ins['y']) | |||||
return res | |||||
fdl = OneflowDataLoader(ds, collate_fn=collate_fn, batch_size=3, drop_last=True) | |||||
for batch in fdl: | |||||
assert batch['x'] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] | |||||
assert batch['y'] == [1, 0, 1] | |||||
def test_prepare_oneflow_dataloader(self): | |||||
# 测试 fastNLP 的 dataset | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dl = prepare_oneflow_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl, OneflowDataLoader) | |||||
ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | |||||
dl_bundle = prepare_oneflow_dataloader(dbl) | |||||
assert isinstance(dl_bundle['train'], OneflowDataLoader) | |||||
assert isinstance(dl_bundle['val'], OneflowDataLoader) | |||||
ds_dict = {'train_1': ds, 'val': ds1} | |||||
dl_dict = prepare_oneflow_dataloader(ds_dict) | |||||
assert isinstance(dl_dict['train_1'], OneflowDataLoader) | |||||
assert isinstance(dl_dict['val'], OneflowDataLoader) | |||||
# 测试其他 dataset | |||||
class _DataSet: | |||||
def __init__(self): | |||||
pass | |||||
def __getitem__(self, item): | |||||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||||
def __len__(self): | |||||
return 10 | |||||
def __getattribute__(self, item): | |||||
return object.__getattribute__(self, item) | |||||
ds2 = _DataSet() | |||||
dl1 = prepare_oneflow_dataloader(ds2, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl1, OneflowDataLoader) | |||||
ds3 = _DataSet() | |||||
dbl1 = DataBundle(datasets={'train': ds2, 'val': ds3}) | |||||
dl_bundle1 = prepare_oneflow_dataloader(dbl1) | |||||
assert isinstance(dl_bundle1['train'], OneflowDataLoader) | |||||
assert isinstance(dl_bundle1['val'], OneflowDataLoader) | |||||
ds_dict1 = {'train_1': ds2, 'val': ds3} | |||||
dl_dict1 = prepare_oneflow_dataloader(ds_dict1) | |||||
assert isinstance(dl_dict1['train_1'], OneflowDataLoader) | |||||
assert isinstance(dl_dict1['val'], OneflowDataLoader) | |||||
ds = [[1, [1]], [2, [2, 2]]] | |||||
dl = prepare_oneflow_dataloader(ds, batch_size=2) | |||||
for batch in dl: | |||||
assert (batch[0] == oneflow.LongTensor([1, 2])).sum()==2 | |||||
assert (batch[1] == oneflow.LongTensor([[1, 0], [2, 2]])).sum()==4 | |||||
# sequence = [ds, ds1] | |||||
# seq_ds = prepare_oneflow_dataloader(sequence) | |||||
# assert isinstance(seq_ds[0], OneflowDataLoader) | |||||
# assert isinstance(seq_ds[1], OneflowDataLoader) | |||||
def test_get_backend(self): | |||||
from fastNLP.core.collators import Collator | |||||
from oneflow.utils.data import DataLoader, Dataset | |||||
class MyDatset(DataSet): | |||||
def __len__(self): | |||||
return 1000 | |||||
def __getitem__(self, item): | |||||
return [[1, 0], [1], [1, 2, 4]], [1, 0] | |||||
collate_batch = Collator(backend='auto') | |||||
dl = DataLoader(MyDatset(), collate_fn=collate_batch) | |||||
for batch in dl: | |||||
print(batch) |
@@ -0,0 +1,78 @@ | |||||
import oneflow | |||||
from oneflow import nn | |||||
from oneflow.utils.data import DataLoader, Dataset | |||||
from oneflow.nn.parallel import DistributedDataParallel as ddp | |||||
import os | |||||
# print(oneflow.ones(3,4).device) | |||||
# print(oneflow.rand(3,4).device) | |||||
# exit(0) | |||||
# PLACEMENT = oneflow.placement("cuda", [0,1]) | |||||
# S0 = oneflow.sbp.split(0) | |||||
# B = oneflow.sbp.broadcast | |||||
print(oneflow.cuda.current_device()) | |||||
exit(0) | |||||
class OneflowArgMaxDataset(Dataset): | |||||
def __init__(self, feature_dimension=10, data_num=1000, seed=0): | |||||
self.num_labels = feature_dimension | |||||
self.feature_dimension = feature_dimension | |||||
self.data_num = data_num | |||||
self.seed = seed | |||||
g = oneflow.Generator() | |||||
g.manual_seed(1000) | |||||
self.x = oneflow.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float() | |||||
self.y = oneflow.max(self.x, dim=-1)[1] | |||||
def __len__(self): | |||||
return self.data_num | |||||
def __getitem__(self, item): | |||||
return self.x[item], self.y[item] | |||||
class Model(nn.Module): | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(Model, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||||
def forward(self, x): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return x | |||||
dataset = OneflowArgMaxDataset(10, 100) | |||||
model = Model(10, 10) | |||||
loss_func = nn.CrossEntropyLoss() | |||||
optimizer = oneflow.optim.Adam(model.parameters(), 0.001) | |||||
dataloader = oneflow.utils.data.DataLoader(dataset, batch_size=32) | |||||
device = "cuda" | |||||
model.to(device) | |||||
# model = ddp(model) | |||||
loss_func.to(device) | |||||
# model = model.to_global(PLACEMENT, B) | |||||
for i in range(2): | |||||
for i, (x, y) in enumerate(dataloader): | |||||
if i % 2 != oneflow.env.get_rank(): | |||||
continue | |||||
x = x.to(device) | |||||
y = y.to(device) | |||||
# x = x.to_global(PLACEMENT, S0) | |||||
# y = y.to_global(PLACEMENT, S0) | |||||
output = model(x) | |||||
loss = loss_func(output, y) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
oneflow.save(model, "ttt") | |||||
print("end.") | |||||
# python -m oneflow.distributed.launch --nproc_per_node 2 dist.py | |||||
@@ -0,0 +1,948 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append("../../../../") | |||||
import pytest | |||||
from pathlib import Path | |||||
from fastNLP.core.drivers.oneflow_driver.ddp import OneflowDDPDriver | |||||
from fastNLP import prepare_oneflow_dataloader | |||||
from fastNLP.core.samplers import ( | |||||
RandomSampler, | |||||
UnrepeatedSampler, | |||||
BucketedBatchSampler, | |||||
UnrepeatedRandomSampler, | |||||
UnrepeatedSequentialSampler, | |||||
) | |||||
from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 | |||||
from tests.helpers.datasets.oneflow_data import OneflowNormalDataset, OneflowNormalXYDataset | |||||
from tests.helpers.utils import recover_logger | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP import logger | |||||
from fastNLP.core.drivers.oneflow_driver.dist_utils import fastnlp_oneflow_all_gather | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
import oneflow.env as dist_env | |||||
from oneflow.utils.data import DataLoader, BatchSampler | |||||
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"): | |||||
oneflow_model = OneflowNormalModel_Classification_1(labels, features) | |||||
oneflow_opt = oneflow.optim.Adam(params=oneflow_model.parameters(), lr=0.01) | |||||
device = [oneflow.device("cuda", i) for i in device] | |||||
driver = OneflowDDPDriver( | |||||
model=oneflow_model, | |||||
parallel_device=device, | |||||
fp16=fp16, | |||||
output_from_new_proc=output_from_new_proc | |||||
) | |||||
driver.set_optimizers(oneflow_opt) | |||||
driver.setup() | |||||
return driver | |||||
def dataloader_with_bucketedbatchsampler(dataset, length, batch_size, shuffle, drop_last): | |||||
""" | |||||
建立一个 batch_sampler 为 BucketedBatchSampler 的 dataloader | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=BucketedBatchSampler( | |||||
dataset, | |||||
length, | |||||
batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
), | |||||
) | |||||
return dataloader | |||||
def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0, unrepeated=False): | |||||
""" | |||||
建立一个 sampler 为 RandomSampler 的 dataloader | |||||
""" | |||||
if unrepeated: | |||||
sampler = UnrepeatedRandomSampler(dataset, shuffle, seed) | |||||
else: | |||||
sampler = RandomSampler(dataset, shuffle, seed=seed) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=sampler, | |||||
drop_last=drop_last, | |||||
batch_size=batch_size | |||||
) | |||||
return dataloader | |||||
############################################################################ | |||||
# | |||||
# 测试 OneflowDDPDriver 的一些函数 | |||||
# | |||||
############################################################################ | |||||
@pytest.mark.oneflow | |||||
class TestDDPDriverFunction: | |||||
""" | |||||
测试 OneflowDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | |||||
""" | |||||
def test_simple_functions(self): | |||||
""" | |||||
简单测试多个函数 | |||||
""" | |||||
driver = generate_driver(10, 10) | |||||
""" | |||||
测试 move_data_to_device 函数。 | |||||
""" | |||||
driver.move_data_to_device(oneflow.rand((32, 64))) | |||||
comm.barrier() | |||||
""" | |||||
测试 is_distributed 函数 | |||||
""" | |||||
assert driver.is_distributed() == True | |||||
comm.barrier() | |||||
""" | |||||
测试 get_no_sync_context 函数 | |||||
""" | |||||
res = driver.get_model_no_sync_context() | |||||
comm.barrier() | |||||
""" | |||||
测试 is_global_zero 函数 | |||||
""" | |||||
driver.is_global_zero() | |||||
comm.barrier() | |||||
""" | |||||
测试 unwrap_model 函数 | |||||
""" | |||||
driver.unwrap_model() | |||||
comm.barrier() | |||||
""" | |||||
测试 get_local_rank 函数 | |||||
""" | |||||
driver.get_local_rank() | |||||
comm.barrier() | |||||
""" | |||||
测试 all_gather 函数 | |||||
详细的测试在 test_dist_utils.py 中完成 | |||||
""" | |||||
obj = { | |||||
"rank": driver.global_rank | |||||
} | |||||
obj_list = driver.all_gather(obj) | |||||
for i, res in enumerate(obj_list): | |||||
assert res["rank"] == i | |||||
""" | |||||
测试 broadcast_object 函数 | |||||
详细的函数在 test_dist_utils.py 中完成 | |||||
""" | |||||
if driver.global_rank == 0: | |||||
obj = { | |||||
"rank": driver.global_rank | |||||
} | |||||
else: | |||||
obj = None | |||||
res = driver.broadcast_object(obj, src=0) | |||||
assert res["rank"] == 0 | |||||
############################################################################ | |||||
# | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | |||||
############################################################################ | |||||
@pytest.mark.oneflow | |||||
class TestSetDistReproDataloader: | |||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.device = [0, 1] | |||||
def setup_method(self): | |||||
self.dataset = OneflowNormalDataset(100) | |||||
""" | |||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||||
此时对应 driver.load_checkpoint 中的情况 | |||||
""" | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | |||||
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | |||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler is batch_sampler | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | |||||
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | |||||
sampler = RandomSampler(self.dataset, shuffle=shuffle) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is sampler | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
comm.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | |||||
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 | |||||
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 | |||||
是否重新实例化 dataloader | |||||
""" | |||||
def test_with_dist_none_reproducible_true(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | |||||
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
with pytest.raises(RuntimeError): | |||||
# 应当抛出 RuntimeError | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler | |||||
时的表现 | |||||
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler | |||||
和原 dataloader 相同 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank, | |||||
pad=True | |||||
) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
self.check_distributed_sampler(dataloader.batch_sampler) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 | |||||
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 | |||||
batch_sampler.sampler 和原 dataloader 相同 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank | |||||
) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.drop_last == False | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | |||||
此时直接返回原来的 dataloader,不做任何处理。 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert replaced_loader is dataloader | |||||
comm.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader( | |||||
dataset=self.dataset, | |||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | |||||
) | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 | |||||
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
comm.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
comm.barrier() | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 | |||||
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
comm.barrier() | |||||
def check_distributed_sampler(self, sampler): | |||||
""" | |||||
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 | |||||
""" | |||||
assert sampler.num_replicas == dist_env.get_world_size() | |||||
assert sampler.rank == dist_env.get_rank() | |||||
if not isinstance(sampler, UnrepeatedSampler): | |||||
assert sampler.pad == True | |||||
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): | |||||
""" | |||||
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
num_replicas = len(self.device) | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | |||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.set_epoch(4) | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_idx.update(batch.tolist()) | |||||
comm.barrier() | |||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
# 重新加载,应该可以输出剩下的内容,且对于 OneflowNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||||
# 重新改造 dataloader | |||||
new_loader = dataloader_with_bucketedbatchsampler( | |||||
replaced_loader.dataset, | |||||
length=replaced_loader.dataset._data, | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=False, | |||||
) | |||||
new_loader.batch_sampler.set_distributed( | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank, | |||||
pad=True | |||||
) | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||||
new_loader.batch_sampler.set_epoch(4) | |||||
else: | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||||
# 重新构造 dataloader | |||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) | |||||
new_loader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank | |||||
) | |||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
new_loader.batch_sampler.sampler.set_epoch(4) | |||||
for idx, batch in enumerate(new_loader): | |||||
left_idxes.update(batch.tolist()) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | |||||
############################################################################ | |||||
# | |||||
# 测试 save 和 load 相关的功能 | |||||
# | |||||
############################################################################ | |||||
@pytest.mark.oneflow | |||||
class TestSaveLoad: | |||||
""" | |||||
测试多卡情况下 save 和 load 相关函数的表现 | |||||
""" | |||||
def setup_method(self): | |||||
self.dataset = OneflowNormalXYDataset(100) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_model(self, only_state_dict): | |||||
""" | |||||
测试 save_model 和 load_model 函数 | |||||
""" | |||||
try: | |||||
path = "model" | |||||
dataloader = DataLoader(self.dataset, batch_size=2) | |||||
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1) | |||||
driver1.save_model(path, only_state_dict) | |||||
# 同步 | |||||
comm.barrier() | |||||
driver2.load_model(path, only_state_dict) | |||||
for idx, batch in enumerate(dataloader): | |||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert oneflow.all(res1["preds"] == res2["preds"]) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
@pytest.mark.parametrize("device", ([[0,1]])) | |||||
def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
num_replicas = len(device) | |||||
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \ | |||||
generate_driver(20, 1, device=device, fp16=False) | |||||
dataloader = dataloader_with_bucketedbatchsampler( | |||||
self.dataset, | |||||
length=[10 for i in range(len(self.dataset))], | |||||
batch_size=4, | |||||
shuffle=True, | |||||
drop_last=False | |||||
) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=driver1.world_size, | |||||
rank=driver1.global_rank, | |||||
pad=True | |||||
) | |||||
num_consumed_batches = 4 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
driver1.set_sampler_epoch(dataloader, 4) | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
# 同步 | |||||
comm.barrier() | |||||
# 保存状态 | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
comm.barrier() | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_bucketedbatchsampler( | |||||
self.dataset, | |||||
length=[10 for i in range(len(self.dataset))], | |||||
batch_size=2, | |||||
shuffle=True, | |||||
drop_last=False | |||||
) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=driver2.world_size, | |||||
rank=driver2.global_rank, | |||||
pad=True | |||||
) | |||||
comm.barrier() | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
comm.barrier() | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||||
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas | |||||
# # 3. 检查 fp16 是否被加载 | |||||
# if fp16: | |||||
# assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
driver2.set_sampler_epoch(replaced_loader, 4) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert oneflow.all(res1["preds"] == res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | |||||
comm.barrier() | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
@pytest.mark.parametrize("device", ([[0,1]])) | |||||
def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "checkpoints/" | |||||
num_replicas = len(device) | |||||
driver1 = generate_driver(20, 1, device=device, fp16=fp16) | |||||
driver2 = generate_driver(20, 1, device=device, fp16=False) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=driver1.world_size, | |||||
rank=driver1.global_rank, | |||||
pad=True | |||||
) | |||||
num_consumed_batches = 4 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
driver1.set_sampler_epoch(dataloader, 4) | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
# 同步 | |||||
comm.barrier() | |||||
# 保存状态 | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
comm.barrier() # 等待save成功 | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=driver2.world_size, | |||||
rank=driver2.global_rank, | |||||
pad=True | |||||
) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
if os.environ['FASTNLP_GLOBAL_RANK'] == '0': | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas | |||||
# # 3. 检查 fp16 是否被加载 | |||||
# if fp16: | |||||
# assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
driver2.set_sampler_epoch(replaced_loader, 4) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert oneflow.all(res1["preds"] == res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible=True): | |||||
try: | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
dataset = OneflowNormalXYDataset(num_samples) | |||||
dl = prepare_oneflow_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
device = [oneflow.device("cuda", i) for i in [0, 1]] | |||||
driver = OneflowDDPDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
flags.append(batch['x'].size(0) == batch_size) | |||||
data.extend(batch['x'].reshape(-1).tolist()) | |||||
_num_samples = num_samples//2 | |||||
if drop_last and _num_samples%batch_size != 0: | |||||
assert len(data)!=_num_samples | |||||
assert all(flags) == True | |||||
elif _num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == _num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)-1): | |||||
assert data[i]>data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)-1): | |||||
flags.append(data[i]>data[i-1]) | |||||
assert all(flags) is False | |||||
datas = fastnlp_oneflow_all_gather(data) | |||||
if drop_last: | |||||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||||
else: | |||||
assert len(set(datas[0] + datas[1])) == num_samples | |||||
finally: | |||||
pass | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=True): | |||||
try: | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 200 | |||||
num_device = 2 | |||||
dataset = OneflowNormalXYDataset(num_samples) | |||||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||||
shuffle=shuffle, num_batch_per_bucket=2) | |||||
dl = prepare_oneflow_dataloader(dataset, batch_sampler=sampler) | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
device = [oneflow.device("cuda", i) for i in [0, 1]] | |||||
driver = OneflowDDPDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
d = batch['x'].reshape(-1).tolist() | |||||
diff = max(d) - min(d) | |||||
assert diff<batch_size*2*2*2 | |||||
data.extend(d) | |||||
flags.append(len(d)==batch_size) | |||||
_num_samples = num_samples//num_device | |||||
if drop_last and _num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif _num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == _num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)-1): | |||||
assert data[i]<data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)-1): | |||||
flags.append(data[i]<data[i-1]) | |||||
assert all(flags) is False | |||||
datas = fastnlp_oneflow_all_gather(data) | |||||
if drop_last: | |||||
assert len(set(datas[0] + datas[1])) == num_samples-_num_samples%batch_size*2 | |||||
else: | |||||
assert len(set(datas[0] + datas[1])) == num_samples | |||||
finally: | |||||
pass | |||||
@pytest.mark.oneflow | |||||
@recover_logger | |||||
@pytest.mark.parametrize("inherit", ([True, False])) | |||||
def test_customized_batch_sampler_dataloader(inherit): | |||||
try: | |||||
logger.set_stdout('raw', level='info') | |||||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||||
num_samples = 10 | |||||
dataset = OneflowNormalXYDataset(num_samples) | |||||
if inherit: | |||||
class BatchSampler(oneflow.utils.data.BatchSampler): | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
for i in range(len(self)): | |||||
start = i * self.batch_size | |||||
end = (i + 1) * self.batch_size | |||||
return indices[start:end] | |||||
def __len__(self): | |||||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||||
else: | |||||
class BatchSampler: | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
for i in range(len(self)): | |||||
start = i * self.batch_size | |||||
end = (i + 1) * self.batch_size | |||||
return indices[start:end] | |||||
def __len__(self): | |||||
return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||||
dl = prepare_oneflow_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4)) | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
device = [oneflow.device("cuda", i) for i in [0, 1]] | |||||
driver = OneflowDDPDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
# TODO 这里需要raise | |||||
with pytest.raises(TypeError): | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||||
finally: | |||||
pass | |||||
@pytest.mark.oneflow | |||||
@recover_logger | |||||
@pytest.mark.parametrize("inherit", ([True, False])) | |||||
def test_customized_sampler_dataloader(inherit): | |||||
try: | |||||
logger.set_stdout('raw', level='info') | |||||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||||
num_samples = 10 | |||||
dataset = OneflowNormalXYDataset(num_samples) | |||||
if inherit: | |||||
class Sampler(oneflow.utils.data.RandomSampler): | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
return iter(indices) | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
else: | |||||
class Sampler: | |||||
def __init__(self, dataset, batch_size): | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
def __iter__(self): | |||||
indices = list(range(len(dataset))) | |||||
return iter(indices) | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
dl = prepare_oneflow_dataloader(dataset, sampler=Sampler(dataset, batch_size=4)) | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
device = [oneflow.device("cuda", i) for i in [0, 1]] | |||||
driver = OneflowDDPDriver(model, parallel_device=device) | |||||
driver.setup() | |||||
# TODO 这里需要raise | |||||
with pytest.raises(TypeError): | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||||
finally: | |||||
pass | |||||
if __name__ == "__main__": | |||||
# python -m oneflow.distributed.launch --nproc_per_node 2 test_ddp.py | |||||
from tests.helpers.utils import run_pytest | |||||
run_pytest(sys.argv) |
@@ -0,0 +1,157 @@ | |||||
import sys | |||||
sys.path.append("../../../../") | |||||
import os | |||||
import pytest | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from fastNLP.core.drivers.oneflow_driver.dist_utils import ( | |||||
_tensor_to_object, | |||||
_object_to_tensor, | |||||
fastnlp_oneflow_all_gather, | |||||
fastnlp_oneflow_broadcast_object, | |||||
) | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
import oneflow.comm as comm | |||||
# @pytest.mark.oneflow | |||||
# class TestDistUtilsTools: | |||||
# """ | |||||
# 测试一些工具函数 | |||||
# """ | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("device", (["cpu", int(os.getenv("LOCAL_RANK", "0"))])) | |||||
def test_tensor_object_transfer_tensor(device): | |||||
""" | |||||
测试 _tensor_to_object 和 _object_to_tensor 二者的结果能否互相转换 | |||||
""" | |||||
# 张量 | |||||
oneflow_tensor = oneflow.rand(3, 4, 5) | |||||
obj_tensor, size = _object_to_tensor(oneflow_tensor, device=device) | |||||
res = _tensor_to_object(obj_tensor, size) | |||||
assert oneflow.all(res == oneflow_tensor) | |||||
# 列表 | |||||
oneflow_list = [oneflow.rand(6, 4, 2) for i in range(10)] | |||||
obj_tensor, size = _object_to_tensor(oneflow_list, device=device) | |||||
res = _tensor_to_object(obj_tensor, size) | |||||
assert isinstance(res, list) | |||||
for before, after in zip(oneflow_list, res): | |||||
assert oneflow.all(after == before) | |||||
# 元组 | |||||
oneflow_list = [oneflow.rand(6, 4, 2) for i in range(10)] | |||||
oneflow_tuple = tuple(oneflow_list) | |||||
obj_tensor, size = _object_to_tensor(oneflow_tuple, device=device) | |||||
res = _tensor_to_object(obj_tensor, size) | |||||
assert isinstance(res, tuple) | |||||
for before, after in zip(oneflow_list, res): | |||||
assert oneflow.all(after == before) | |||||
# 字典 | |||||
oneflow_dict = { | |||||
"tensor": oneflow.rand(3, 4), | |||||
"list": [oneflow.rand(6, 4, 2) for i in range(10)], | |||||
"dict":{ | |||||
"list": [oneflow.rand(6, 4, 2) for i in range(10)], | |||||
"tensor": oneflow.rand(3, 4) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
obj_tensor, size = _object_to_tensor(oneflow_dict, device=device) | |||||
res = _tensor_to_object(obj_tensor, size) | |||||
assert isinstance(res, dict) | |||||
assert oneflow.all(res["tensor"] == oneflow_dict["tensor"]) | |||||
assert isinstance(res["list"], list) | |||||
for before, after in zip(oneflow_dict["list"], res["list"]): | |||||
assert oneflow.all(after == before) | |||||
assert isinstance(res["dict"], dict) | |||||
assert oneflow.all(res["dict"]["tensor"] == oneflow_dict["dict"]["tensor"]) | |||||
for before, after in zip(oneflow_dict["dict"]["list"], res["dict"]["list"]): | |||||
assert oneflow.all(after == before) | |||||
assert res["int"] == oneflow_dict["int"] | |||||
assert res["string"] == oneflow_dict["string"] | |||||
@pytest.mark.oneflow | |||||
def test_fastnlp_oneflow_all_gather(): | |||||
local_rank = int(os.environ["LOCAL_RANK"]) | |||||
obj = { | |||||
"tensor": oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(), | |||||
"numpy": np.full(shape=(2, ), fill_value=local_rank), | |||||
"bool": local_rank % 2 == 0, | |||||
"float": local_rank + 0.1, | |||||
"int": local_rank, | |||||
"dict": { | |||||
"rank": local_rank | |||||
}, | |||||
"list": [local_rank]*2, | |||||
"str": f"{local_rank}", | |||||
"tensors": [oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(), | |||||
oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda()] | |||||
} | |||||
data = fastnlp_oneflow_all_gather(obj) | |||||
world_size = int(os.environ["WORLD_SIZE"]) | |||||
assert len(data) == world_size | |||||
for i in range(world_size): | |||||
assert (data[i]["tensor"] == i).sum() == world_size | |||||
assert data[i]["numpy"][0] == i | |||||
assert data[i]["bool"] == (i % 2 == 0) | |||||
assert np.allclose(data[i]["float"], i + 0.1) | |||||
assert data[i]["int"] == i | |||||
assert data[i]["dict"]["rank"] == i | |||||
assert data[i]["list"][0] == i | |||||
assert data[i]["str"] == f"{i}" | |||||
assert data[i]["tensors"][0][0] == i | |||||
for obj in [1, True, "xxx"]: | |||||
data = fastnlp_oneflow_all_gather(obj) | |||||
assert len(data) == world_size | |||||
assert data[0] == data[1] | |||||
@pytest.mark.oneflow | |||||
def test_fastnlp_oneflow_broadcast_object(): | |||||
local_rank = int(os.environ["LOCAL_RANK"]) | |||||
if os.environ["LOCAL_RANK"] == "0": | |||||
obj = { | |||||
"tensor": oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(), | |||||
"numpy": np.full(shape=(2, ), fill_value=local_rank, dtype=int), | |||||
"bool": local_rank % 2 == 0, | |||||
"float": local_rank + 0.1, | |||||
"int": local_rank, | |||||
"dict": { | |||||
"rank": local_rank | |||||
}, | |||||
"list": [local_rank] * 2, | |||||
"str": f"{local_rank}", | |||||
"tensors": [oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(), | |||||
oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda()] | |||||
} | |||||
else: | |||||
obj = None | |||||
# device=oneflow.cuda.current_devuce | |||||
data = fastnlp_oneflow_broadcast_object(obj, src=0, device=local_rank) | |||||
i = 0 | |||||
assert data["tensor"][0] == 0 | |||||
assert data["numpy"][0] == 0 | |||||
assert data["bool"] == (i % 2 == 0) | |||||
assert np.allclose(data["float"], i + 0.1) | |||||
assert data["int"] == i | |||||
assert data["dict"]["rank"] == i | |||||
assert data["list"][0] == i | |||||
assert data["str"] == f"{i}" | |||||
assert data["tensors"][0][0] == i | |||||
for obj in [local_rank, bool(local_rank== 1), str(local_rank)]: | |||||
data = fastnlp_oneflow_broadcast_object(obj, src=0, device=local_rank) | |||||
assert int(data) == 0 | |||||
if __name__ == "__main__": | |||||
# python -m oneflow.distributed.launch --nproc_per_node 2 test_dist_utils.py | |||||
pytest.main([ | |||||
f'{__file__}' | |||||
]) |
@@ -0,0 +1,76 @@ | |||||
import pytest | |||||
from fastNLP.core.drivers import OneflowSingleDriver, OneflowDDPDriver | |||||
from fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver import initialize_oneflow_driver | |||||
from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow import device as oneflowdevice | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as oneflowdevice | |||||
@pytest.mark.oneflow | |||||
def test_incorrect_driver(): | |||||
model = OneflowNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_oneflow_driver("paddle", 0, model) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
["cpu", "cuda:0", 0, oneflowdevice("cuda:0")] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["oneflow"] | |||||
) | |||||
def test_get_single_device(driver, device): | |||||
""" | |||||
测试正常情况下初始化OneflowSingleDriver的情况 | |||||
""" | |||||
model = OneflowNormalModel_Classification_1(20, 10) | |||||
driver = initialize_oneflow_driver(driver, device, model) | |||||
assert isinstance(driver, OneflowSingleDriver) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[[0, 1], -1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["oneflow"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp(driver, device): | |||||
""" | |||||
测试 ddp 多卡的初始化情况 | |||||
""" | |||||
model = OneflowNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(RuntimeError): | |||||
driver = initialize_oneflow_driver(driver, device, model) | |||||
# assert isinstance(driver, OneflowDDPDriver) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[-2, [0, 20, 3], [-2], 20] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["oneflow"] | |||||
) | |||||
def test_device_out_of_range(driver, device): | |||||
""" | |||||
测试传入的device超过范围的情况 | |||||
""" | |||||
model = OneflowNormalModel_Classification_1(20, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_oneflow_driver(driver, device, model) |
@@ -0,0 +1,790 @@ | |||||
import pytest | |||||
from pathlib import Path | |||||
from fastNLP.core.drivers.oneflow_driver.single_device import OneflowSingleDriver | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1 | |||||
from tests.helpers.datasets.oneflow_data import OneflowNormalDataset, OneflowNormalXYDataset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW, _NEED_IMPORT_TORCH | |||||
from fastNLP import prepare_oneflow_dataloader, BucketedBatchSampler | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.utils.data import DataLoader, BatchSampler | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | |||||
""" | |||||
建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader | |||||
""" | |||||
if shuffle: | |||||
sampler = oneflow.utils.data.RandomSampler(dataset) | |||||
else: | |||||
sampler = oneflow.utils.data.SequentialSampler(dataset) | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler( | |||||
sampler, batch_size=batch_size, drop_last=drop_last | |||||
), | |||||
batch_size=batch_size, | |||||
drop_last=drop_last, | |||||
), | |||||
) | |||||
return dataloader | |||||
def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0): | |||||
""" | |||||
建立一个 sampler 为 RandomSampler 的 dataloader | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, shuffle, seed=seed), | |||||
drop_last=drop_last, | |||||
batch_size=batch_size | |||||
) | |||||
return dataloader | |||||
############################################################################ | |||||
# | |||||
# 测试基类 OneflowDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
class TestOneflowDriverFunctions: | |||||
""" | |||||
使用 OneflowSingleDriver 测试基类的函数 | |||||
""" | |||||
@classmethod | |||||
def setup_class(self): | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
self.driver = OneflowSingleDriver(model, device="cpu") | |||||
@pytest.mark.oneflow | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试对合法 optimizers 的检查 | |||||
""" | |||||
# 单个 optimizer | |||||
optimizer = oneflow.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
# 列表 | |||||
optimizers = [ | |||||
oneflow.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
@pytest.mark.torchoneflow | |||||
def test_invalid_optimizers(self): | |||||
""" | |||||
测试传入非法的 optimizers | |||||
""" | |||||
optimizer = torch.optim.Adam( | |||||
params=TorchNormalModel_Classification_1(10, 32).parameters(), | |||||
lr=0.01, | |||||
) | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizers = [ | |||||
torch.optim.Adam( | |||||
params=TorchNormalModel_Classification_1(10, 32).parameters(), | |||||
lr=0.01, | |||||
) | |||||
] | |||||
with pytest.raises(TypeError): | |||||
self.driver.set_optimizers(optimizers) | |||||
@pytest.mark.oneflow | |||||
def test_check_dataloader_legality(self): | |||||
""" | |||||
测试 check_dataloader_legality 函数的表现 | |||||
""" | |||||
dataloader = DataLoader(OneflowNormalDataset()) | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
@pytest.mark.torchoneflow | |||||
def test_check_dataloader_legality_invalid(self): | |||||
""" | |||||
测试 check_dataloader_legality 函数传入其他类型的表现 | |||||
""" | |||||
# 创建 torch 的 dataloader | |||||
dataloader = torch.utils.data.DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(TypeError): | |||||
self.driver.check_dataloader_legality(dataloader) | |||||
@pytest.mark.oneflow | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试 tensor_to_numeric 函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = oneflow.tensor(3) | |||||
res = OneflowSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == 3 | |||||
tensor = oneflow.rand(3, 4) | |||||
res = OneflowSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == tensor.tolist() | |||||
# 张量list | |||||
tensor_list = [oneflow.rand(6, 4, 2) for i in range(10)] | |||||
res = OneflowSingleDriver.tensor_to_numeric(tensor_list) | |||||
assert isinstance(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
assert res == tensor_list | |||||
# 张量tuple | |||||
tensor_tuple = tuple([oneflow.rand(6, 4, 2) for i in range(10)]) | |||||
res = OneflowSingleDriver.tensor_to_numeric(tensor_tuple) | |||||
assert isinstance(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
assert res == tensor_tuple | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": oneflow.rand(3, 4), | |||||
"list": [oneflow.rand(6, 4, 2) for i in range(10)], | |||||
"dict":{ | |||||
"list": [oneflow.rand(6, 4, 2) for i in range(10)], | |||||
"tensor": oneflow.rand(3, 4) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = OneflowSingleDriver.tensor_to_numeric(tensor_dict) | |||||
assert isinstance(res, dict) | |||||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||||
assert isinstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
assert r == d.tolist() | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
assert r == d.tolist() | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||||
@pytest.mark.oneflow | |||||
def test_tensor_to_numeric_reduce(self): | |||||
tensor = oneflow.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | |||||
res_max = OneflowSingleDriver.tensor_to_numeric(tensor, reduce="max") | |||||
res_min = OneflowSingleDriver.tensor_to_numeric(tensor, reduce="min") | |||||
res_sum = OneflowSingleDriver.tensor_to_numeric(tensor, reduce="sum") | |||||
res_mean = OneflowSingleDriver.tensor_to_numeric(tensor, reduce="mean") | |||||
assert res_max == 6 | |||||
assert res_min == 1 | |||||
assert res_sum == 21 | |||||
assert res_mean == 3.5 | |||||
@pytest.mark.oneflow | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试set_model_mode函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
assert self.driver.model.training | |||||
self.driver.set_model_mode("eval") | |||||
assert not self.driver.model.training | |||||
# 应该报错 | |||||
with pytest.raises(AssertionError): | |||||
self.driver.set_model_mode("test") | |||||
@pytest.mark.oneflow | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
OneflowSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
assert self.driver.model.linear1.weight.device.type == "cpu" | |||||
@pytest.mark.oneflow | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
OneflowSingleDriver.move_model_to_device(self.driver.model, "cuda") | |||||
assert self.driver.model.linear1.weight.device.type == "cuda" | |||||
assert self.driver.model.linear1.weight.device.index == 0 | |||||
@pytest.mark.oneflow | |||||
def test_worker_init_function(self): | |||||
""" | |||||
测试worker_init_function | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
OneflowSingleDriver.worker_init_function(0) | |||||
@pytest.mark.oneflow | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试set_deterministic_dataloader | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(OneflowNormalDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
@pytest.mark.oneflow | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(OneflowNormalDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试正常情况下 get_dataloader_args 的表现 | |||||
""" | |||||
dataloader = DataLoader( | |||||
OneflowNormalDataset(), | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
) | |||||
res = OneflowSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, OneflowNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, oneflow.utils.data.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, oneflow.utils.data.SequentialSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 batch_sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = OneflowNormalDataset() | |||||
dataloader = dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last) | |||||
res = OneflowSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, OneflowNormalDataset) | |||||
assert isinstance(res.batch_sampler, ReproduceBatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, oneflow.utils.data.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, oneflow.utils.data.SequentialSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = OneflowNormalDataset() | |||||
dataloader = dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last) | |||||
res = OneflowSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, OneflowNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
assert isinstance(res.sampler, RandomSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
############################################################################ | |||||
# | |||||
# 测试 OneflowSingleDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
@pytest.mark.oneflow | |||||
class TestSingleDeviceFunction: | |||||
""" | |||||
测试其它函数的测试例 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
model = OneflowNormalModel_Classification_1(10, 784) | |||||
cls.driver = OneflowSingleDriver(model, device="cpu") | |||||
def test_unwrap_model(self): | |||||
""" | |||||
测试能否运行 | |||||
""" | |||||
res = self.driver.unwrap_model() | |||||
assert res is self.driver.model | |||||
def test_is_distributed(self): | |||||
assert self.driver.is_distributed() == False | |||||
def test_move_data_to_device(self): | |||||
self.driver.move_data_to_device(oneflow.rand(32, 64)) | |||||
############################################################################ | |||||
# | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | |||||
############################################################################ | |||||
@pytest.mark.oneflow | |||||
class TestSetDistReproDataloader: | |||||
""" | |||||
专门测试 set_dist_repro_dataloader 函数的类 | |||||
""" | |||||
def setup_method(self): | |||||
self.dataset = OneflowNormalDataset(20) | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
self.driver = OneflowSingleDriver(model, device="cpu") | |||||
def test_with_reproducible_false(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||||
当dist为字符串时,此时应该返回原来的 dataloader | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert replaced_loader is dataloader | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
def test_with_reproducible_true(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||||
当 dist 为字符串时,此时应该返回新的 dataloader;如果 shuffle 为 False,则只会替换 sampler; | |||||
否则将会替换 BatchSampler | |||||
TODO: | |||||
在 Sampler 的参数不是默认的情况下会替换 batch_sampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||||
assert not (replaced_loader is dataloader) | |||||
# 替换 sampler | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
if shuffle: | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler.sampler, oneflow.utils.data.RandomSampler) | |||||
else: | |||||
assert isinstance(replaced_loader.batch_sampler, oneflow.utils.data.BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||||
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) | |||||
dist = RandomSampler(self.dataset, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is dist | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = dataloader_with_randombatchsampler(self.dataset, 4, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): | |||||
""" | |||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
replaced_loader.batch_sampler.set_epoch(3) | |||||
else: | |||||
replaced_loader.batch_sampler.sampler.set_epoch(3) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_idx.update(batch.tolist()) | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
# 重新加载,应该可以输出剩下的内容,且对于 OneflowNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
# 重新改造 dataloader | |||||
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||||
new_loader.batch_sampler.set_epoch(3) | |||||
else: | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
# 重新构造 dataloader | |||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
new_loader.batch_sampler.sampler.set_epoch(3) | |||||
for idx, batch in enumerate(new_loader): | |||||
left_idxes.update(batch.tolist()) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||||
############################################################################ | |||||
# | |||||
# 测试 save 和 load 相关的功能 | |||||
# | |||||
############################################################################ | |||||
def generate_random_driver(labels, features, fp16=False, device="cpu"): | |||||
""" | |||||
生成driver | |||||
""" | |||||
model = OneflowNormalModel_Classification_1(labels, features) | |||||
opt = oneflow.optim.Adam(params=model.parameters(), lr=0.01) | |||||
driver = OneflowSingleDriver(model, device=device, fp16=fp16) | |||||
driver.set_optimizers(opt) | |||||
driver.setup() | |||||
return driver | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_model(only_state_dict): | |||||
""" | |||||
测试 save_model 和 load_model 函数 | |||||
""" | |||||
try: | |||||
path = "model" | |||||
dataset = OneflowNormalXYDataset(20) | |||||
dataloader = DataLoader(dataset, batch_size=4) | |||||
driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1) | |||||
driver1.save_model(path, only_state_dict) | |||||
driver2.load_model(path, only_state_dict) | |||||
for batch in dataloader: | |||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert oneflow.all(res1["preds"] == res2["preds"]) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
dataset = OneflowNormalXYDataset(20) | |||||
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | |||||
driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda") | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
driver1.set_sampler_epoch(dataloader, 3) | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | |||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | |||||
# # 3. 检查 fp16 是否被加载 | |||||
# if fp16: | |||||
# assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
driver1.set_sampler_epoch(replaced_loader, 3) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
batch = driver2.move_data_to_device(batch) | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert oneflow.all(res1["preds"] == res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) | |||||
assert len(left_x_batches | already_seen_x_set) == len(dataset) | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda") | |||||
dataset = OneflowNormalXYDataset(40) | |||||
dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
driver1.set_sampler_epoch(dataloader, 3) | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist()) | |||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist()) | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# # 3. 检查 fp16 是否被加载 | |||||
# if fp16: | |||||
# assert not isinstance(driver2.grad_scaler, oneflow.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
# set epoch | |||||
driver2.set_sampler_epoch(replaced_loader, 3) | |||||
for idx, batch in enumerate(replaced_loader): | |||||
batch = driver2.move_data_to_device(batch) | |||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist()) | |||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert oneflow.all(res1["preds"] == res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) | |||||
assert len(left_x_batches | already_seen_x_set) == len(dataset) | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
@pytest.mark.parametrize("reproducible", ([True, False])) | |||||
def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible): | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 100 | |||||
dataset = OneflowNormalXYDataset(num_samples) | |||||
dl = prepare_oneflow_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
driver = OneflowSingleDriver(model, device="cpu") | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
flags.append(batch['x'].size(0) == batch_size) | |||||
data.extend(batch['x'].reshape(-1).tolist()) | |||||
if drop_last and num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)): | |||||
assert data[i]>data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)): | |||||
flags.append(data[i]>data[i-1]) | |||||
assert all(flags) is False | |||||
@pytest.mark.oneflow | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) | |||||
@pytest.mark.parametrize("drop_last", ([True, False])) | |||||
@pytest.mark.parametrize("reproducible", ([True, False])) | |||||
def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible): | |||||
# 需要检验一下 set_dist_repro_dataloader 没有修改参数 | |||||
num_samples = 100 | |||||
dataset = OneflowNormalXYDataset(num_samples) | |||||
sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||||
shuffle=shuffle, num_batch_per_bucket=2) | |||||
dl = prepare_oneflow_dataloader(dataset, batch_sampler=sampler) | |||||
model = OneflowNormalModel_Classification_1(10, 32) | |||||
driver = OneflowSingleDriver(model, device="cpu") | |||||
dl = driver.set_dist_repro_dataloader(dataloader=dl, reproducible=reproducible) | |||||
data = [] | |||||
flags = [] | |||||
for batch in dl: | |||||
d = batch['x'].reshape(-1).tolist() | |||||
diff = max(d) - min(d) | |||||
assert diff<batch_size*2 | |||||
data.extend(d) | |||||
flags.append(len(d)==batch_size) | |||||
if drop_last and num_samples%batch_size != 0: | |||||
assert len(data)!=num_samples | |||||
assert all(flags) == True | |||||
elif num_samples%batch_size!=0: | |||||
assert flags[-1] is False | |||||
else: | |||||
assert len(data) == num_samples | |||||
if not shuffle: | |||||
for i in range(1, len(data)): | |||||
assert data[i]<data[i-1] | |||||
else: | |||||
flags = [] | |||||
for i in range(1, len(data)): | |||||
flags.append(data[i]<data[i-1]) | |||||
assert all(flags) is False | |||||
@@ -0,0 +1,39 @@ | |||||
import pytest | |||||
from fastNLP.core.drivers.oneflow_driver.utils import ( | |||||
replace_batch_sampler, | |||||
replace_sampler, | |||||
) | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
from tests.helpers.datasets.oneflow_data import OneflowNormalDataset | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
from oneflow.utils.data import DataLoader, BatchSampler | |||||
@pytest.mark.oneflow | |||||
def test_replace_batch_sampler(): | |||||
dataset = OneflowNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, OneflowNormalDataset) | |||||
assert len(replaced_loader.dataset) == len(dataset) | |||||
assert replaced_loader.batch_sampler.batch_size == 16 | |||||
@pytest.mark.oneflow | |||||
def test_replace_sampler(): | |||||
dataset = OneflowNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
sampler = RandomSampler(dataset) | |||||
replaced_loader = replace_sampler(dataloader, sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
@@ -0,0 +1,55 @@ | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.utils.data import Dataset | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||||
class OneflowNormalDataset(Dataset): | |||||
def __init__(self, num_of_data=1000): | |||||
self.num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
def __len__(self): | |||||
return self.num_of_data | |||||
def __getitem__(self, item): | |||||
return self._data[item] | |||||
class OneflowNormalXYDataset(Dataset): | |||||
""" | |||||
可以被输入到分类模型中的普通数据集 | |||||
""" | |||||
def __init__(self, num_of_data=1000): | |||||
self.num_of_data = num_of_data | |||||
self._data = list(range(num_of_data)) | |||||
def __len__(self): | |||||
return self.num_of_data | |||||
def __getitem__(self, item): | |||||
return { | |||||
"x": oneflow.tensor([self._data[item]], dtype=oneflow.float), | |||||
"y": oneflow.tensor([self._data[item]], dtype=oneflow.float) | |||||
} | |||||
class OneflowArgMaxDataset(Dataset): | |||||
def __init__(self, data_num=1000, feature_dimension=10, seed=0): | |||||
self.num_labels = feature_dimension | |||||
self.feature_dimension = feature_dimension | |||||
self.data_num = data_num | |||||
self.seed = seed | |||||
g = oneflow.Generator() | |||||
g.manual_seed(1000) | |||||
self.x = oneflow.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float() | |||||
self.y = oneflow.max(self.x, dim=-1)[1] | |||||
def __len__(self): | |||||
return self.data_num | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} |
@@ -0,0 +1,43 @@ | |||||
from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW | |||||
if _NEED_IMPORT_ONEFLOW: | |||||
import oneflow | |||||
from oneflow.nn import Module | |||||
import oneflow.nn as nn | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
# 1. 最为基础的分类模型 | |||||
class OneflowNormalModel_Classification_1(Module): | |||||
""" | |||||
单独实现 train_step 和 evaluate_step; | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(OneflowNormalModel_Classification_1, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def forward(self, x): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return x | |||||
def train_step(self, x, y): | |||||
x = self(x) | |||||
return {"loss": self.loss_fn(x, y)} | |||||
def evaluate_step(self, x, y): | |||||
""" | |||||
如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"}; | |||||
""" | |||||
x = self(x) | |||||
x = oneflow.max(x, dim=-1)[1] | |||||
return {"pred": x, "target": y} |
@@ -8,6 +8,7 @@ from io import StringIO | |||||
import time | import time | ||||
import signal | import signal | ||||
import pytest | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.utils.utils import get_class_that_defined_method | from fastNLP.core.utils.utils import get_class_that_defined_method | ||||
@@ -147,3 +148,42 @@ def re_run_current_cmd_for_torch(num_procs, output_from_new_proc='ignore'): | |||||
delay = np.random.uniform(1, 5, 1)[0] | delay = np.random.uniform(1, 5, 1)[0] | ||||
time.sleep(delay) | time.sleep(delay) | ||||
def re_run_current_cmd_for_oneflow(num_procs, output_from_new_proc='ignore'): | |||||
# 实际上逻辑和 torch 一样,只是为了区分不同框架所以独立出来 | |||||
# Script called as `python a/b/c.py` | |||||
if int(os.environ.get('LOCAL_RANK', '0')) == 0: | |||||
if __main__.__spec__ is None: # pragma: no-cover | |||||
# pull out the commands used to run the script and resolve the abs file path | |||||
command = sys.argv | |||||
command[0] = os.path.abspath(command[0]) | |||||
# use the same python interpreter and actually running | |||||
command = [sys.executable] + command | |||||
# Script called as `python -m a.b.c` | |||||
else: | |||||
command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:] | |||||
for rank in range(1, num_procs+1): | |||||
env_copy = os.environ.copy() | |||||
env_copy["LOCAL_RANK"] = f"{rank}" | |||||
env_copy['WOLRD_SIZE'] = f'{num_procs+1}' | |||||
env_copy['RANK'] = f'{rank}' | |||||
env_copy["GLOG_log_dir"] = os.path.join( | |||||
os.getcwd(), f"oneflow_rank_{rank}" | |||||
) | |||||
os.makedirs(env_copy["GLOG_log_dir"], exist_ok=True) | |||||
# 如果是多机,一定需要用户自己拉起,因此我们自己使用 open_subprocesses 开启的进程的 FASTNLP_GLOBAL_RANK 一定是 LOCAL_RANK; | |||||
env_copy[FASTNLP_GLOBAL_RANK] = str(rank) | |||||
proc = distributed_open_proc(output_from_new_proc, command, env_copy, rank) | |||||
delay = np.random.uniform(1, 5, 1)[0] | |||||
time.sleep(delay) | |||||
def run_pytest(argv): | |||||
cmd = argv[0] | |||||
for i in range(1, len(argv)): | |||||
cmd += "::" + argv[i] | |||||
pytest.main([cmd]) |
@@ -2,7 +2,9 @@ | |||||
markers = | markers = | ||||
torch | torch | ||||
paddle | paddle | ||||
oneflow | |||||
paddledist | paddledist | ||||
jittor | jittor | ||||
torchpaddle | torchpaddle | ||||
torchjittor | |||||
torchjittor | |||||
torchoneflow |