@@ -1,5 +1,4 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'AutoCollator', | |||||
'Collator' | 'Collator' | ||||
] | ] | ||||
from .collator import AutoCollator, Collator | |||||
from .collator import Collator |
@@ -1,386 +1,648 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'AutoCollator', | |||||
'Collator', | |||||
'Collator' | |||||
] | ] | ||||
from typing import List, Union, Dict, Callable, Sequence, Mapping | |||||
import os | |||||
import sys | |||||
import inspect | |||||
from abc import ABCMeta, abstractmethod | |||||
from typing import Any, Dict, List, Callable, Union, Tuple | |||||
from numbers import Number | |||||
import warnings | |||||
from fastNLP.core.log import logger | |||||
from .padders.get_padder import get_padder | |||||
import numpy as np | |||||
import re | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||||
pack_batch_sequence | |||||
if _NEED_IMPORT_PADDLE: | |||||
import paddle | |||||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
class ApplyResultException(Exception): | |||||
def __init__(self, msg, index=None): | |||||
super().__init__(msg) | |||||
self.msg = msg | |||||
self.index = index # 标示在哪个数据遭遇到问题了 | |||||
class SetInputOrTargetException(Exception): | |||||
def __init__(self, msg, index=None, field_name=None): | |||||
super().__init__(msg) | |||||
self.msg = msg | |||||
self.index = index # 标示在哪个数据遭遇到问题了 | |||||
self.field_name = field_name # 标示当前 field 的名称 | |||||
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||||
r""" | |||||
识别cell的类别与dimension的数量 | |||||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||||
:param cell: | |||||
:param dim: | |||||
:return: | |||||
def _get_backend() -> str: | |||||
""" | """ | ||||
if isinstance(cell, (str, Number, np.bool_)): | |||||
if hasattr(cell, 'dtype'): | |||||
return cell.dtype.type, dim | |||||
return type(cell), dim | |||||
elif isinstance(cell, list): | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i, j in res]) | |||||
dims = set([j for i, j in res]) | |||||
if len(types) > 1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
elif len(types) == 0: | |||||
raise SetInputOrTargetException("Empty value encountered.") | |||||
if len(dims) > 1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
elif isinstance(cell, torch.Tensor): | |||||
return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 | |||||
elif isinstance(cell, paddle.Tensor): | |||||
return cell.dtype, cell.dim() + dim | |||||
elif isinstance(cell, np.ndarray): | |||||
if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 | |||||
return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 | |||||
# 否则需要继续往下 iterate | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i, j in res]) | |||||
dims = set([j for i, j in res]) | |||||
if len(types) > 1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
elif len(types) == 0: | |||||
raise SetInputOrTargetException("Empty value encountered.") | |||||
if len(dims) > 1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
else: # 包含 tuple, set, dict 以及其它的类型 | |||||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||||
def _get_ds_type_dim(ds: dict): | |||||
# 获取数据集第一行的 field 内部函数的类型和维度 | |||||
field_dtype, field_dim = {}, {} | |||||
for field_name, field_content in ds.items(): | |||||
type_0, dim_0 = _get_ele_type_and_dim(field_content) | |||||
field_dtype[field_name], field_dim[field_name] = type_0, dim_0 | |||||
return field_dtype, field_dim | |||||
class Collator(metaclass=ABCMeta): | |||||
r""" | |||||
辅助DataLoader管理collate_fn的类 | |||||
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | |||||
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | |||||
某个 backend 的 dataloader 。 | |||||
(2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。 | |||||
如果都没有找到则返回 numpy 。 | |||||
:return: | |||||
""" | """ | ||||
def _check_module(module): | |||||
""" | |||||
检查该 module 是否含有 某个 backend 的特征 | |||||
def __init__(self): | |||||
super(Collator, self).__init__() | |||||
self.collate_fn = [] | |||||
@abstractmethod | |||||
def __call__(self, ins_lst: List) -> Any: | |||||
raise NotImplementedError | |||||
@abstractmethod | |||||
def set_pad_val(self, *field_names: str, value=0): | |||||
raise NotImplementedError | |||||
:param module: module 对象 | |||||
:return: | |||||
""" | |||||
catch_backend = [] | |||||
try: | |||||
file = module.__file__ | |||||
for backend in CHECK_BACKEND: | |||||
if f'{os.sep}site-packages{os.sep}{backend}' in file: | |||||
catch_backend = [backend, file] | |||||
except: | |||||
pass | |||||
return catch_backend | |||||
currentframe = inspect.currentframe() | |||||
# 方式(1) | |||||
catch_backend = [] | |||||
for i in range(100): | |||||
currentframe = currentframe.f_back | |||||
if currentframe is not None: | |||||
module = inspect.getmodule(currentframe) | |||||
if module is not None: | |||||
catch_backend = _check_module(module) | |||||
if len(catch_backend): # 主要捕获到一个就结束吧 | |||||
break | |||||
else: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
# 方式 (2) | |||||
for backend in CHECK_BACKEND: | |||||
if backend in sys.modules: | |||||
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") | |||||
return backend | |||||
for key, module in sys.modules.items(): | |||||
catch_backend = _check_module(module) | |||||
if catch_backend: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
return 'numpy' | |||||
class Collator: | |||||
def __init__(self, backend='auto'): | |||||
""" | |||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||||
class _MultiCollator: | |||||
""" | |||||
管理所有collator的容器, | |||||
遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 | |||||
""" | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | |||||
""" | |||||
self.unpack_batch_func = None | |||||
self.pack_batch_func = None | |||||
self.ignore_fields = set() | |||||
self.padders = {} | |||||
self.input_fields = {} | |||||
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 | |||||
self.set_backend(backend) | |||||
def __call__(self, batch)->Union[List, Dict]: | |||||
""" | |||||
batch可能存在三种可能性 | |||||
List[Dict], List[List], List[Sample] | |||||
def __init__(self, collate_fns: Union[Callable, List[Callable], None]): | |||||
第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 | |||||
第二步:使用每个 field 各自的 padder 进行 pad 。 | |||||
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 | |||||
if collate_fns is None: | |||||
collate_fns = [] | |||||
第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 | |||||
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample | |||||
的类别。 | |||||
第一次调用会根据当前 field 决定对应的 Padder 。 | |||||
if isinstance(collate_fns, Callable): | |||||
collate_fns = [collate_fns] | |||||
""" | |||||
if self.unpack_batch_func is None: | |||||
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 | |||||
if self.batch_data_type is None: | |||||
if isinstance(batch[0], Mapping): | |||||
self.batch_data_type = 'd' | |||||
elif isinstance(batch[0], Sequence): # 这里存在误判的风险 | |||||
self.batch_data_type = 'l' | |||||
else: | |||||
self.batch_data_type = 's' | |||||
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||||
f"is `{self.batch_data_type}`.") | |||||
if self.batch_data_type == 's': | |||||
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 | |||||
self.pack_batch_func = lambda x: x['_single'] | |||||
elif self.batch_data_type == 'l': | |||||
self.unpack_batch_func = unpack_batch_sequence | |||||
self.pack_batch_func = pack_batch_sequence | |||||
elif self.batch_data_type == 'd': | |||||
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} | |||||
self.unpack_batch_func = unpack_batch_nested_mapping | |||||
self.pack_batch_func = pack_batch_nested_mapping | |||||
else: | |||||
self.unpack_batch_func = unpack_batch_mapping | |||||
self.pack_batch_func = lambda x:x | |||||
self._collators: list = collate_fns | |||||
if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | |||||
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | |||||
else: | |||||
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 | |||||
pad_batch = {} | |||||
if len(self.padders)==0: # 第一次运行,准备 padder | |||||
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 | |||||
self.backend = _get_backend() | |||||
for key in unpack_batch.keys(): | |||||
if key not in self.input_fields and key not in self.ignore_fields: | |||||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | |||||
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': | |||||
self.input_fields[key]['backend'] = self.backend | |||||
for field_name, setting in self.input_fields.items(): | |||||
pad_fn = setting.get('pad_fn', None) | |||||
if callable(pad_fn): | |||||
padder = pad_fn | |||||
else: | |||||
backend = self.backend if setting['backend'] == 'auto' else setting['backend'] | |||||
batch_field = unpack_batch.get(field_name) | |||||
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | |||||
dtype=setting['dtype'], backend=backend, | |||||
field_name=field_name) | |||||
self.padders[field_name] = padder | |||||
if self.batch_data_type == 'l': | |||||
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | |||||
for key, padder in self.padders.items(): | |||||
batch = unpack_batch.get(key) | |||||
pad_batch[key] = padder(batch) | |||||
return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', | |||||
pad_fn:Callable=None) -> "Collator": | |||||
""" | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | |||||
self.padders.clear() # 重新生成 | |||||
if self.batch_data_type is not None: | |||||
if self.batch_data_type == 's': | |||||
logger.debug("Set as single field mode.") | |||||
self.input_fields.clear() | |||||
elif self.batch_data_type == 'd': | |||||
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ | |||||
f"index, but other field is set as dict mode." | |||||
elif self.batch_data_type == 'l': | |||||
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||||
f"field name is {field_name}." | |||||
if field_name == '_single': | |||||
self.batch_data_type = 's' | |||||
elif isinstance(field_name, str) and sequence_idx_str.match(field_name): | |||||
self.batch_data_type = 'l' | |||||
else: | |||||
self.batch_data_type = 'd' | |||||
def __call__(self, ins_lst) -> Dict: | |||||
out, list_out = {}, [] | |||||
for idx, _collate_fn in enumerate(self._collators): | |||||
res = _collate_fn(ins_lst) | |||||
if isinstance(res, Dict): | |||||
out.update(res) | |||||
else: | |||||
list_out.append(res) | |||||
# else: | |||||
# raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") | |||||
if len(out) > 0 and len(list_out) > 0: | |||||
raise ValueError("the return of collate_fns is not the same, must be dict or list") | |||||
if len(list_out) == 1: | |||||
list_out = list_out[-1] | |||||
# print(list_out) | |||||
return out if len(out) > 0 else list_out | |||||
if field_name in self.ignore_fields: | |||||
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") | |||||
if backend is None: | |||||
backend = self.backend | |||||
else: | |||||
assert backend in SUPPORTED_BACKENDS | |||||
def get_collators(self): | |||||
return self._collators | |||||
self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} | |||||
def add_collator(self, collator: Callable): | |||||
self._collators.append(collator) | |||||
return self | |||||
def set_as_numpy(self, as_numpy: bool): | |||||
def set_backend(self, backend:str): | |||||
""" | """ | ||||
存在AutoCollator时,as_numpy控制其返回值的类型 | |||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | |||||
:param as_numpy: | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | |||||
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
for collator in self._collators: | |||||
if isinstance(collator, AutoCollator): | |||||
collator.set_as_numpy(as_numpy) | |||||
return self | |||||
assert backend in SUPPORTED_BACKENDS | |||||
self.padders.clear() | |||||
self.backend = backend | |||||
def set_pad_val(self, *field_names, val=0): | |||||
def set_ignore(self, *field_names) -> "Collator": | |||||
""" | """ | ||||
存在AutoCollator时,设置field_name的padding值 | |||||
:param field_names: 数据集的field名 | |||||
:param val: padding的值 | |||||
:return: | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
flag = True | |||||
for collator in self._collators: | |||||
if isinstance(collator, AutoCollator): | |||||
collator.set_pad_val(*field_names, val=val) | |||||
flag = False | |||||
if flag: | |||||
warnings.warn("AutoCollator is remove, set_padding is unavailable!!") | |||||
for field_name in field_names: | |||||
if field_name in self.input_fields: | |||||
self.input_fields.pop(field_name) | |||||
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | |||||
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 | |||||
self.ignore_fields.add(field_name) | |||||
return self | return self | ||||
def set_input(self, *field_names): | |||||
""" | |||||
设置AutoCollator需要的field_names,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
flag = True | |||||
for collator in self._collators: | |||||
if isinstance(collator, AutoCollator): | |||||
collator.set_input(*field_names) | |||||
flag = False | |||||
if flag: | |||||
warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||||
return self | |||||
class AutoCollator(Collator): | |||||
def __init__(self, as_numpy: bool): | |||||
super(AutoCollator, self).__init__() | |||||
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||||
self.need_inputs = set() # 需要的 field name | |||||
self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||||
self.field_dims = None # 每列数据单元维度 | |||||
self.as_numpy = as_numpy | |||||
def __call__(self, ins_lst: List[Dict]) -> dict: | |||||
if len(self.need_inputs) == 0: | |||||
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||||
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||||
# 第一种情况,设置了 set_input 的值 | |||||
# 第二种情况, 根据数据的类型的判断是否 padding | |||||
if self.field_dtypes is None and self.field_dims is None: | |||||
field_dtypes, field_dims = {}, {} | |||||
for key, value in ins_lst[0].items(): | |||||
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||||
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||||
self.field_dtypes = field_dtypes | |||||
self.field_dims = field_dims | |||||
pack_ins_lst, pad_ins_lst = {field_name: [] | |||||
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||||
# 将 list 列表内数据按列名打包 | |||||
for per_ins in ins_lst: | |||||
for field_name, _field_content in per_ins.items(): | |||||
if field_name in self.need_inputs: | |||||
pack_ins_lst[field_name].append(_field_content) | |||||
pad_field_kv = {field_name: 0 for field_name in self.need_inputs} | |||||
pad_field_kv.update(self.pad_field_value) | |||||
self.pad_field_value = pad_field_kv | |||||
if len(self.pad_field_value.keys()) > 0: | |||||
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||||
non_pad_field_names = [] | |||||
for k, v in self.pad_field_value.items(): | |||||
if v is None: | |||||
non_pad_field_names.append(k) | |||||
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||||
for field_name in non_pad_field_names: | |||||
field_array = pack_ins_lst.pop(field_name) | |||||
pad_ins_lst[field_name] = np.array(field_array) | |||||
for field_name, field_array in pack_ins_lst.items(): | |||||
content = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||||
self.field_dims[field_name], | |||||
self.pad_field_value[field_name], | |||||
as_numpy=self.as_numpy) | |||||
pad_ins_lst[field_name] = content | |||||
# else: | |||||
# # 取出每列的数据,根据类型判断是否能 pad | |||||
# for field_name, field_array in pack_ins_lst.items(): | |||||
# pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||||
# self.field_dims[field_name], | |||||
# pad_val=0, as_numpy=self.as_numpy) | |||||
# pad_ins_lst[field_name] = pad_field_array | |||||
return pad_ins_lst | |||||
def set_pad_val(self, *field_names, val=0): | |||||
for field_name in field_names: | |||||
self.pad_field_value[field_name] = val | |||||
def set_as_numpy(self, as_numpy: bool): | |||||
self.as_numpy = as_numpy | |||||
def set_input(self, *field_names): | |||||
for field_name in field_names: | |||||
self.need_inputs.add(field_name) | |||||
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||||
if field_type: | |||||
# 不处理, 返回 np.array 类型 | |||||
if field_dim > 3: | |||||
return np.array(content) | |||||
# 元素类型为数值类型 np.int64, np.float64, int, float 等 | |||||
if isinstance(field_type, type) and \ | |||||
(issubclass(field_type, np.number) or issubclass(field_type, Number)): | |||||
if field_dim == 0: | |||||
array = np.array(content, dtype=field_type) | |||||
elif field_dim == 1: | |||||
max_len = max(map(len, content)) | |||||
array = np.full((len(content), max_len), pad_val, dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
array[i, :len(content_i)] = content_i | |||||
elif field_dim == 2: | |||||
max_len = max(map(len, content)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in content]) | |||||
array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
for j, content_ii in enumerate(content_i): | |||||
array[i, j, :len(content_ii)] = content_ii | |||||
else: | |||||
shape = np.shape(content) | |||||
if len(shape) == 4: # 说明各 dimension 是相同的大小 | |||||
array = np.array(content, dtype=field_type) | |||||
else: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
if as_numpy is False: | |||||
array = torch.tensor(array) | |||||
return array | |||||
# 元素类型为数值类型 torch.float 等 | |||||
elif str(field_type).startswith('torch'): | |||||
if field_dim == 0: | |||||
tensor = torch.tensor(content).to(field_type) | |||||
elif field_dim == 1: | |||||
max_len = max(map(len, content)) | |||||
tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||||
elif field_dim == 2: | |||||
max_len = max(map(len, content)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in content]) | |||||
tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||||
dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
for j, content_ii in enumerate(content_i): | |||||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||||
else: | |||||
shapes = set([np.shape(content_i) for content_i in content]) | |||||
if len(shapes) > 1: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
shape = shapes.pop() | |||||
if len(shape) == 3: | |||||
tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, | |||||
dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
tensor[i] = content_i.clone().detach().to(field_type) | |||||
else: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return tensor | |||||
# TODO 增加jittor/paddle? | |||||
elif str(field_type).startswith('paddle'): | |||||
if field_dim == 0: | |||||
tensor = paddle.Tensor(content).to(field_type) | |||||
elif field_dim == 1: | |||||
max_len = max(map(len, content)) | |||||
tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||||
elif field_dim == 2: | |||||
max_len = max(map(len, content)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in content]) | |||||
tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||||
dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
for j, content_ii in enumerate(content_i): | |||||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||||
else: | |||||
shapes = set([np.shape(content_i) for content_i in content]) | |||||
if len(shapes) > 1: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
shape = shapes.pop() | |||||
if len(shape) == 3: | |||||
tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, | |||||
dtype=field_type) | |||||
for i, content_i in enumerate(content): | |||||
tensor[i] = content_i.clone().detach().to(field_type) | |||||
else: | |||||
raise RuntimeError( | |||||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return tensor | |||||
else: | |||||
return np.array(content) # 不进行任何操作 | |||||
else: | |||||
return np.array(content) | |||||
# | |||||
# from abc import ABCMeta, abstractmethod | |||||
# from typing import Any, Dict, List, Callable, Union, Tuple | |||||
# from numbers import Number | |||||
# import warnings | |||||
# | |||||
# import numpy as np | |||||
# | |||||
# from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
# | |||||
# if _NEED_IMPORT_PADDLE: | |||||
# import paddle | |||||
# | |||||
# if _NEED_IMPORT_TORCH: | |||||
# import torch | |||||
# | |||||
# | |||||
# class ApplyResultException(Exception): | |||||
# def __init__(self, msg, index=None): | |||||
# super().__init__(msg) | |||||
# self.msg = msg | |||||
# self.index = index # 标示在哪个数据遭遇到问题了 | |||||
# | |||||
# | |||||
# class SetInputOrTargetException(Exception): | |||||
# def __init__(self, msg, index=None, field_name=None): | |||||
# super().__init__(msg) | |||||
# self.msg = msg | |||||
# self.index = index # 标示在哪个数据遭遇到问题了 | |||||
# self.field_name = field_name # 标示当前 field 的名称 | |||||
# | |||||
# | |||||
# def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||||
# r""" | |||||
# 识别cell的类别与dimension的数量 | |||||
# | |||||
# numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||||
# :param cell: | |||||
# :param dim: | |||||
# :return: | |||||
# """ | |||||
# if isinstance(cell, (str, Number, np.bool_)): | |||||
# if hasattr(cell, 'dtype'): | |||||
# return cell.dtype.type, dim | |||||
# return type(cell), dim | |||||
# | |||||
# elif isinstance(cell, list): | |||||
# dim += 1 | |||||
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
# types = set([i for i, j in res]) | |||||
# dims = set([j for i, j in res]) | |||||
# if len(types) > 1: | |||||
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
# elif len(types) == 0: | |||||
# raise SetInputOrTargetException("Empty value encountered.") | |||||
# if len(dims) > 1: | |||||
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
# return types.pop(), dims.pop() | |||||
# | |||||
# elif isinstance(cell, torch.Tensor): | |||||
# return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 | |||||
# | |||||
# elif isinstance(cell, paddle.Tensor): | |||||
# return cell.dtype, cell.dim() + dim | |||||
# | |||||
# elif isinstance(cell, np.ndarray): | |||||
# if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 | |||||
# return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 | |||||
# # 否则需要继续往下 iterate | |||||
# dim += 1 | |||||
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
# types = set([i for i, j in res]) | |||||
# dims = set([j for i, j in res]) | |||||
# if len(types) > 1: | |||||
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
# elif len(types) == 0: | |||||
# raise SetInputOrTargetException("Empty value encountered.") | |||||
# if len(dims) > 1: | |||||
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
# return types.pop(), dims.pop() | |||||
# | |||||
# else: # 包含 tuple, set, dict 以及其它的类型 | |||||
# raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||||
# | |||||
# | |||||
# def _get_ds_type_dim(ds: dict): | |||||
# # 获取数据集第一行的 field 内部函数的类型和维度 | |||||
# field_dtype, field_dim = {}, {} | |||||
# for field_name, field_content in ds.items(): | |||||
# type_0, dim_0 = _get_ele_type_and_dim(field_content) | |||||
# field_dtype[field_name], field_dim[field_name] = type_0, dim_0 | |||||
# return field_dtype, field_dim | |||||
# | |||||
# | |||||
# class Collator(metaclass=ABCMeta): | |||||
# r""" | |||||
# 辅助DataLoader管理collate_fn的类 | |||||
# | |||||
# """ | |||||
# | |||||
# def __init__(self): | |||||
# super(Collator, self).__init__() | |||||
# self.collate_fn = [] | |||||
# | |||||
# @abstractmethod | |||||
# def __call__(self, ins_lst: List) -> Any: | |||||
# raise NotImplementedError | |||||
# | |||||
# @abstractmethod | |||||
# def set_pad_val(self, *field_names: str, value=0): | |||||
# raise NotImplementedError | |||||
# | |||||
# | |||||
# class _MultiCollator: | |||||
# """ | |||||
# 管理所有collator的容器, | |||||
# 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 | |||||
# """ | |||||
# | |||||
# def __init__(self, collate_fns: Union[Callable, List[Callable], None]): | |||||
# | |||||
# if collate_fns is None: | |||||
# collate_fns = [] | |||||
# | |||||
# if isinstance(collate_fns, Callable): | |||||
# collate_fns = [collate_fns] | |||||
# | |||||
# self._collators: list = collate_fns | |||||
# | |||||
# def __call__(self, ins_lst) -> Dict: | |||||
# out, list_out = {}, [] | |||||
# for idx, _collate_fn in enumerate(self._collators): | |||||
# res = _collate_fn(ins_lst) | |||||
# if isinstance(res, Dict): | |||||
# out.update(res) | |||||
# else: | |||||
# list_out.append(res) | |||||
# # else: | |||||
# # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") | |||||
# if len(out) > 0 and len(list_out) > 0: | |||||
# raise ValueError("the return of collate_fns is not the same, must be dict or list") | |||||
# if len(list_out) == 1: | |||||
# list_out = list_out[-1] | |||||
# # print(list_out) | |||||
# return out if len(out) > 0 else list_out | |||||
# | |||||
# def get_collators(self): | |||||
# return self._collators | |||||
# | |||||
# def add_collator(self, collator: Callable): | |||||
# self._collators.append(collator) | |||||
# | |||||
# def set_as_numpy(self, as_numpy: bool): | |||||
# """ | |||||
# 存在AutoCollator时,as_numpy控制其返回值的类型 | |||||
# | |||||
# :param as_numpy: | |||||
# :return: | |||||
# """ | |||||
# for collator in self._collators: | |||||
# if isinstance(collator, AutoCollator): | |||||
# collator.set_as_numpy(as_numpy) | |||||
# return self | |||||
# | |||||
# def set_pad_val(self, *field_names, val=0): | |||||
# """ | |||||
# 存在AutoCollator时,设置field_name的padding值 | |||||
# | |||||
# :param field_names: 数据集的field名 | |||||
# :param val: padding的值 | |||||
# :return: | |||||
# """ | |||||
# flag = True | |||||
# for collator in self._collators: | |||||
# if isinstance(collator, AutoCollator): | |||||
# collator.set_pad_val(*field_names, val=val) | |||||
# flag = False | |||||
# if flag: | |||||
# warnings.warn("AutoCollator is remove, set_padding is unavailable!!") | |||||
# return self | |||||
# | |||||
# def set_input(self, *field_names): | |||||
# """ | |||||
# 设置AutoCollator需要的field_names,未被设置默认过滤掉 | |||||
# | |||||
# :param field_names: | |||||
# :return: | |||||
# """ | |||||
# flag = True | |||||
# for collator in self._collators: | |||||
# if isinstance(collator, AutoCollator): | |||||
# collator.set_input(*field_names) | |||||
# flag = False | |||||
# if flag: | |||||
# warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||||
# return self | |||||
# | |||||
# | |||||
# class AutoCollator(Collator): | |||||
# | |||||
# def __init__(self, as_numpy: bool): | |||||
# super(AutoCollator, self).__init__() | |||||
# self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||||
# self.need_inputs = set() # 需要的 field name | |||||
# self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||||
# self.field_dims = None # 每列数据单元维度 | |||||
# self.as_numpy = as_numpy | |||||
# | |||||
# def __call__(self, ins_lst: List[Dict]) -> dict: | |||||
# if len(self.need_inputs) == 0: | |||||
# raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||||
# # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||||
# | |||||
# # 第一种情况,设置了 set_input 的值 | |||||
# # 第二种情况, 根据数据的类型的判断是否 padding | |||||
# if self.field_dtypes is None and self.field_dims is None: | |||||
# field_dtypes, field_dims = {}, {} | |||||
# for key, value in ins_lst[0].items(): | |||||
# if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||||
# field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||||
# self.field_dtypes = field_dtypes | |||||
# self.field_dims = field_dims | |||||
# | |||||
# pack_ins_lst, pad_ins_lst = {field_name: [] | |||||
# for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||||
# # 将 list 列表内数据按列名打包 | |||||
# for per_ins in ins_lst: | |||||
# for field_name, _field_content in per_ins.items(): | |||||
# if field_name in self.need_inputs: | |||||
# pack_ins_lst[field_name].append(_field_content) | |||||
# | |||||
# pad_field_kv = {field_name: 0 for field_name in self.need_inputs} | |||||
# pad_field_kv.update(self.pad_field_value) | |||||
# self.pad_field_value = pad_field_kv | |||||
# | |||||
# if len(self.pad_field_value.keys()) > 0: | |||||
# # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||||
# non_pad_field_names = [] | |||||
# for k, v in self.pad_field_value.items(): | |||||
# if v is None: | |||||
# non_pad_field_names.append(k) | |||||
# | |||||
# # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||||
# for field_name in non_pad_field_names: | |||||
# field_array = pack_ins_lst.pop(field_name) | |||||
# pad_ins_lst[field_name] = np.array(field_array) | |||||
# | |||||
# for field_name, field_array in pack_ins_lst.items(): | |||||
# content = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||||
# self.field_dims[field_name], | |||||
# self.pad_field_value[field_name], | |||||
# as_numpy=self.as_numpy) | |||||
# pad_ins_lst[field_name] = content | |||||
# | |||||
# # else: | |||||
# # # 取出每列的数据,根据类型判断是否能 pad | |||||
# # for field_name, field_array in pack_ins_lst.items(): | |||||
# # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||||
# # self.field_dims[field_name], | |||||
# # pad_val=0, as_numpy=self.as_numpy) | |||||
# # pad_ins_lst[field_name] = pad_field_array | |||||
# | |||||
# return pad_ins_lst | |||||
# | |||||
# def set_pad_val(self, *field_names, val=0): | |||||
# for field_name in field_names: | |||||
# self.pad_field_value[field_name] = val | |||||
# | |||||
# def set_as_numpy(self, as_numpy: bool): | |||||
# self.as_numpy = as_numpy | |||||
# | |||||
# def set_input(self, *field_names): | |||||
# for field_name in field_names: | |||||
# self.need_inputs.add(field_name) | |||||
# | |||||
# | |||||
# def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||||
# | |||||
# if field_type: | |||||
# # 不处理, 返回 np.array 类型 | |||||
# if field_dim > 3: | |||||
# return np.array(content) | |||||
# # 元素类型为数值类型 np.int64, np.float64, int, float 等 | |||||
# if isinstance(field_type, type) and \ | |||||
# (issubclass(field_type, np.number) or issubclass(field_type, Number)): | |||||
# if field_dim == 0: | |||||
# array = np.array(content, dtype=field_type) | |||||
# elif field_dim == 1: | |||||
# max_len = max(map(len, content)) | |||||
# array = np.full((len(content), max_len), pad_val, dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# array[i, :len(content_i)] = content_i | |||||
# elif field_dim == 2: | |||||
# max_len = max(map(len, content)) | |||||
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
# content_i in content]) | |||||
# array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# for j, content_ii in enumerate(content_i): | |||||
# array[i, j, :len(content_ii)] = content_ii | |||||
# else: | |||||
# shape = np.shape(content) | |||||
# if len(shape) == 4: # 说明各 dimension 是相同的大小 | |||||
# array = np.array(content, dtype=field_type) | |||||
# else: | |||||
# raise RuntimeError( | |||||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
# if as_numpy is False: | |||||
# array = torch.tensor(array) | |||||
# return array | |||||
# # 元素类型为数值类型 torch.float 等 | |||||
# elif str(field_type).startswith('torch'): | |||||
# if field_dim == 0: | |||||
# tensor = torch.tensor(content).to(field_type) | |||||
# elif field_dim == 1: | |||||
# max_len = max(map(len, content)) | |||||
# tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# tensor[i, :len(content_i)] = content_i.clone().detach() | |||||
# elif field_dim == 2: | |||||
# max_len = max(map(len, content)) | |||||
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
# content_i in content]) | |||||
# tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||||
# dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# for j, content_ii in enumerate(content_i): | |||||
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||||
# else: | |||||
# shapes = set([np.shape(content_i) for content_i in content]) | |||||
# if len(shapes) > 1: | |||||
# raise RuntimeError( | |||||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
# shape = shapes.pop() | |||||
# if len(shape) == 3: | |||||
# tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, | |||||
# dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# tensor[i] = content_i.clone().detach().to(field_type) | |||||
# else: | |||||
# raise RuntimeError( | |||||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
# return tensor | |||||
# # TODO 增加jittor/paddle? | |||||
# elif str(field_type).startswith('paddle'): | |||||
# if field_dim == 0: | |||||
# tensor = paddle.Tensor(content).to(field_type) | |||||
# elif field_dim == 1: | |||||
# max_len = max(map(len, content)) | |||||
# tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# tensor[i, :len(content_i)] = content_i.clone().detach() | |||||
# elif field_dim == 2: | |||||
# max_len = max(map(len, content)) | |||||
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
# content_i in content]) | |||||
# tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||||
# dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# for j, content_ii in enumerate(content_i): | |||||
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||||
# else: | |||||
# shapes = set([np.shape(content_i) for content_i in content]) | |||||
# if len(shapes) > 1: | |||||
# raise RuntimeError( | |||||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
# shape = shapes.pop() | |||||
# if len(shape) == 3: | |||||
# tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, | |||||
# dtype=field_type) | |||||
# for i, content_i in enumerate(content): | |||||
# tensor[i] = content_i.clone().detach().to(field_type) | |||||
# else: | |||||
# raise RuntimeError( | |||||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
# return tensor | |||||
# | |||||
# else: | |||||
# return np.array(content) # 不进行任何操作 | |||||
# else: | |||||
# return np.array(content) |
@@ -1,253 +0,0 @@ | |||||
from typing import List, Union, Dict, Callable, Sequence, Mapping | |||||
import os | |||||
import sys | |||||
import inspect | |||||
from fastNLP.core.log import logger | |||||
from .padders.get_padder import get_padder | |||||
import re | |||||
from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||||
pack_batch_sequence | |||||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | |||||
def _get_backend() -> str: | |||||
""" | |||||
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | |||||
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | |||||
某个 backend 的 dataloader 。 | |||||
(2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。 | |||||
如果都没有找到则返回 numpy 。 | |||||
:return: | |||||
""" | |||||
def _check_module(module): | |||||
""" | |||||
检查该 module 是否含有 某个 backend 的特征 | |||||
:param module: module 对象 | |||||
:return: | |||||
""" | |||||
catch_backend = [] | |||||
try: | |||||
file = module.__file__ | |||||
for backend in CHECK_BACKEND: | |||||
if f'{os.sep}site-packages{os.sep}{backend}' in file: | |||||
catch_backend = [backend, file] | |||||
except: | |||||
pass | |||||
return catch_backend | |||||
currentframe = inspect.currentframe() | |||||
# 方式(1) | |||||
catch_backend = [] | |||||
for i in range(100): | |||||
currentframe = currentframe.f_back | |||||
if currentframe is not None: | |||||
module = inspect.getmodule(currentframe) | |||||
if module is not None: | |||||
catch_backend = _check_module(module) | |||||
if len(catch_backend): # 主要捕获到一个就结束吧 | |||||
break | |||||
else: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
# 方式 (2) | |||||
for key, module in sys.modules.items(): | |||||
catch_backend = _check_module(module) | |||||
if catch_backend: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
return 'numpy' | |||||
class Collator: | |||||
def __init__(self, backend='auto'): | |||||
""" | |||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | |||||
""" | |||||
self.unpack_batch_func = None | |||||
self.pack_batch_func = None | |||||
self.ignore_fields = set() | |||||
self.padders = {} | |||||
self.input_fields = {} | |||||
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 | |||||
self.set_backend(backend) | |||||
def __call__(self, batch)->Union[List, Dict]: | |||||
""" | |||||
batch可能存在三种可能性 | |||||
List[Dict], List[List], List[Sample] | |||||
第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 | |||||
第二步:使用每个 field 各自的 padder 进行 pad 。 | |||||
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 | |||||
第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 | |||||
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample | |||||
的类别。 | |||||
第一次调用会根据当前 field 决定对应的 Padder 。 | |||||
""" | |||||
if self.unpack_batch_func is None: | |||||
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 | |||||
if self.batch_data_type is None: | |||||
if isinstance(batch[0], Mapping): | |||||
self.batch_data_type = 'd' | |||||
elif isinstance(batch[0], Sequence): # 这里存在误判的风险 | |||||
self.batch_data_type = 'l' | |||||
else: | |||||
self.batch_data_type = 's' | |||||
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||||
f"is `{self.batch_data_type}`.") | |||||
if self.batch_data_type == 's': | |||||
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 | |||||
self.pack_batch_func = lambda x: x['_single'] | |||||
elif self.batch_data_type == 'l': | |||||
self.unpack_batch_func = unpack_batch_sequence | |||||
self.pack_batch_func = pack_batch_sequence | |||||
elif self.batch_data_type == 'd': | |||||
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} | |||||
self.unpack_batch_func = unpack_batch_nested_mapping | |||||
self.pack_batch_func = pack_batch_nested_mapping | |||||
else: | |||||
self.unpack_batch_func = unpack_batch_mapping | |||||
self.pack_batch_func = lambda x:x | |||||
if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | |||||
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | |||||
else: | |||||
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 | |||||
pad_batch = {} | |||||
if len(self.padders)==0: # 第一次运行,准备 padder | |||||
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 | |||||
self.backend = _get_backend() | |||||
for key in unpack_batch.keys(): | |||||
if key not in self.input_fields and key not in self.ignore_fields: | |||||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | |||||
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': | |||||
self.input_fields[key]['backend'] = self.backend | |||||
for field_name, setting in self.input_fields.items(): | |||||
pad_fn = setting.get('pad_fn', None) | |||||
if callable(pad_fn): | |||||
padder = pad_fn | |||||
else: | |||||
backend = self.backend if setting['backend'] == 'auto' else setting['backend'] | |||||
batch_field = unpack_batch.get(field_name) | |||||
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | |||||
dtype=setting['dtype'], backend=backend, | |||||
field_name=field_name) | |||||
self.padders[field_name] = padder | |||||
if self.batch_data_type == 'l': | |||||
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | |||||
for key, padder in self.padders.items(): | |||||
batch = unpack_batch.get(key) | |||||
pad_batch[key] = padder(batch) | |||||
return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', | |||||
pad_fn:Callable=None) -> "Collator": | |||||
""" | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | |||||
self.padders.clear() # 重新生成 | |||||
if self.batch_data_type is not None: | |||||
if self.batch_data_type == 's': | |||||
logger.debug("Set as single field mode.") | |||||
self.input_fields.clear() | |||||
elif self.batch_data_type == 'd': | |||||
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ | |||||
f"index, but other field is set as dict mode." | |||||
elif self.batch_data_type == 'l': | |||||
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||||
f"field name is {field_name}." | |||||
if field_name == '_single': | |||||
self.batch_data_type = 's' | |||||
elif isinstance(field_name, str) and sequence_idx_str.match(field_name): | |||||
self.batch_data_type = 'l' | |||||
else: | |||||
self.batch_data_type = 'd' | |||||
if field_name in self.ignore_fields: | |||||
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") | |||||
if backend is None: | |||||
backend = self.backend | |||||
else: | |||||
assert backend in SUPPORTED_BACKENDS | |||||
self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} | |||||
return self | |||||
def set_backend(self, backend:str): | |||||
""" | |||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | |||||
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 | |||||
:return: | |||||
""" | |||||
assert backend in SUPPORTED_BACKENDS | |||||
self.padders.clear() | |||||
self.backend = backend | |||||
def set_ignore(self, *field_names) -> "Collator": | |||||
""" | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
""" | |||||
for field_name in field_names: | |||||
if field_name in self.input_fields: | |||||
self.input_fields.pop(field_name) | |||||
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | |||||
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 | |||||
self.ignore_fields.add(field_name) | |||||
return self | |||||
@@ -13,6 +13,7 @@ from .padder import Padder, NullPadder | |||||
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder | from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder | ||||
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | ||||
from .raw_padder import RawNumberPadder, RawSequencePadder | from .raw_padder import RawNumberPadder, RawSequencePadder | ||||
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | |||||
from .exceptions import * | from .exceptions import * | ||||
@@ -27,7 +28,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
:param field_name: 方便报错的。 | :param field_name: 方便报错的。 | ||||
:return: | :return: | ||||
""" | """ | ||||
logger.debug(f"The content in the field:`{field_name}` is:\n"+str(batch_field)) | |||||
logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) | |||||
if pad_val is None: | if pad_val is None: | ||||
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | ||||
return NullPadder() | return NullPadder() | ||||
@@ -89,6 +91,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | |||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | ||||
if backend == 'raw': | if backend == 'raw': | ||||
@@ -97,12 +101,16 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | |||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if depth == 1 and shape_len != 0: | if depth == 1 and shape_len != 0: | ||||
if backend == 'numpy': | if backend == 'numpy': | ||||
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | |||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if shape_len != 0 and depth>1: | if shape_len != 0 and depth>1: | ||||
msg = "Does not support pad tensor under nested list. If you need this, please report." | msg = "Does not support pad tensor under nested list. If you need this, please report." | ||||
@@ -0,0 +1,178 @@ | |||||
__all__ = [ | |||||
"PaddleNumberPadder", | |||||
"PaddleTensorPadder", | |||||
"PaddleSequencePadder" | |||||
] | |||||
from inspect import isclass | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | |||||
import paddle | |||||
numpy_to_paddle_dtype_dict = { | |||||
np.bool_: 'bool', | |||||
np.uint8: 'uint8', | |||||
np.int8: "int8", | |||||
np.int16: "int16", | |||||
np.int32: "int32", | |||||
np.int64: "int64", | |||||
np.float16: "float16", | |||||
np.float32: 'float32', | |||||
np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||||
np.complex64: 'complex64', | |||||
np.complex128: "complex128" | |||||
} | |||||
number_to_paddle_dtype_dict = { | |||||
float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64 | |||||
int: 'int64', | |||||
bool: 'bool' | |||||
} | |||||
from .padder import Padder | |||||
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||||
from .exceptions import * | |||||
def is_paddle_tensor(dtype): | |||||
if not isclass(dtype) and isinstance(dtype, paddle.dtype): | |||||
return True | |||||
return False | |||||
def is_paddle_dtype_str(dtype): | |||||
try: | |||||
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | |||||
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | |||||
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8', | |||||
u'int16', u'int32', u'int64', u'uint8', u'complex64', | |||||
u'complex128'}: | |||||
return True | |||||
except: | |||||
pass | |||||
return False | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") | |||||
if dtype is not None: | |||||
if not (is_paddle_tensor(dtype) or is_number(dtype) or is_paddle_dtype_str(dtype)): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or paddle.dtype but get `{dtype}`.") | |||||
dtype = number_to_paddle_dtype_dict.get(dtype, dtype) | |||||
else: | |||||
if (is_number(ele_dtype) or is_paddle_tensor(ele_dtype)): | |||||
ele_dtype = number_to_paddle_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) | |||||
else: | |||||
dtype = ele_dtype | |||||
return dtype | |||||
class PaddleNumberPadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
return paddle.to_tensor(batch_field, dtype=dtype) | |||||
class PaddleSequencePadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||||
return tensor | |||||
class PaddleTensorPadder(Padder): | |||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
""" | |||||
目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的 | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
shapes = [field.shape for field in batch_field] | |||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||||
if isinstance(dtype, np.dtype): | |||||
print(dtype) | |||||
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | |||||
for i, field in enumerate(batch_field): | |||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||||
if isinstance(field, np.ndarray): | |||||
field = paddle.to_tensor(field) | |||||
tensor[slices] = field | |||||
return tensor | |||||
def fill_tensor(batch_field, padded_batch, dtype): | |||||
""" | |||||
将 batch_field 中的值填入到 tensor 中。 | |||||
:param batch_field: 需要填充进入 array 中的内容 | |||||
:param padded_batch: 待填充的 tensor | |||||
:param dtype: 数据的类别 | |||||
:return: | |||||
""" | |||||
if padded_batch.ndim == 2: | |||||
for i, content_i in enumerate(batch_field): | |||||
padded_batch[i, :len(content_i)] = paddle.to_tensor(content_i, dtype=dtype) | |||||
elif padded_batch.ndim == 3: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
padded_batch[i, j, :len(content_ii)] = paddle.to_tensor(content_ii, dtype=dtype) | |||||
elif padded_batch.ndim == 4: | |||||
try: # 应该是图像,所以直接应该就 ok 了。 | |||||
padded_batch = np.array(batch_field) | |||||
except: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
for k, content_iii in enumerate(content_ii): | |||||
padded_batch[i, j, k, :len(content_iii)] = paddle.to_tensor(content_iii, dtype=dtype) | |||||
elif padded_batch.ndim == 1: | |||||
padded_batch[:] = paddle.to_tensor(batch_field, dtype=dtype) | |||||
else: | |||||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
"report.") | |||||
return padded_batch | |||||
def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> paddle.LongTensor([[1, 2], [3, 0]]) | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
shapes = get_shape(batch_field) | |||||
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) | |||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||||
return tensor |
@@ -440,6 +440,7 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) | _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) | ||||
_own_callbacks.extend(self._custom_callbacks[None]) | _own_callbacks.extend(self._custom_callbacks[None]) | ||||
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().") | |||||
self._custom_callbacks[None] = [] | self._custom_callbacks[None] = [] | ||||
if self.marker is not None: | if self.marker is not None: | ||||
if len(self._custom_callbacks[self.marker]) == 0: | if len(self._custom_callbacks[self.marker]) == 0: | ||||
@@ -3,17 +3,18 @@ __all__ = [ | |||||
'prepare_jittor_dataloader' | 'prepare_jittor_dataloader' | ||||
] | ] | ||||
from typing import Callable, Optional, List | |||||
from typing import Callable, Optional, List, Union | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
from jittor.dataset.utils import collate_batch | from jittor.dataset.utils import collate_batch | ||||
from jittor.dataset import Dataset | from jittor.dataset import Dataset | ||||
else: | else: | ||||
from fastNLP.core.dataset import DataSet as Dataset | from fastNLP.core.dataset import DataSet as Dataset | ||||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | ||||
from fastNLP.core.collators import AutoCollator | |||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
@@ -48,7 +49,7 @@ class JittorDataLoader: | |||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | ||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Callable = None) -> None: | |||||
collate_fn: Union[None, str, Callable] = "auto") -> None: | |||||
""" | """ | ||||
:param dataset: 实现__getitem__和__len__的dataset | :param dataset: 实现__getitem__和__len__的dataset | ||||
@@ -66,11 +67,20 @@ class JittorDataLoader: | |||||
# TODO 支持fastnlp dataset | # TODO 支持fastnlp dataset | ||||
# TODO 验证支持replacesampler (以后完成) | # TODO 验证支持replacesampler (以后完成) | ||||
# 是否为 jittor 类型的 dataset | # 是否为 jittor 类型的 dataset | ||||
if isinstance(dataset, FDataSet): | |||||
collator = dataset.get_collator().set_as_numpy(as_numpy=True) | |||||
if isinstance(collate_fn, str): | |||||
if collate_fn == "auto": | |||||
if isinstance(dataset, FDataSet): | |||||
self._collate_fn = dataset.collator | |||||
self._collate_fn.set_backend(backend="jittor") | |||||
else: | |||||
self._collate_fn = Collator(backend="jittor") | |||||
else: | |||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||||
elif isinstance(collate_fn, Callable): | |||||
if collate_fn is not collate_batch: | |||||
self._collate_fn = collate_fn | |||||
else: | else: | ||||
collator = None | |||||
self._collate_fn = collate_batch | |||||
self.dataset = _JittorDataset(dataset) | self.dataset = _JittorDataset(dataset) | ||||
@@ -80,17 +90,13 @@ class JittorDataLoader: | |||||
if isinstance(self.dataset.dataset, Dataset): | if isinstance(self.dataset.dataset, Dataset): | ||||
self.dataset.dataset.set_attrs(batch_size=1) | self.dataset.dataset.set_attrs(batch_size=1) | ||||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | ||||
self.collate_fn = collate_fn | |||||
if self.collate_fn is None: | |||||
self.collate_fn = collate_batch | |||||
self.auto_collator = collator | |||||
self.cur_batch_indices = None | |||||
# self._collate_fn = _collate_fn | |||||
def __iter__(self): | def __iter__(self): | ||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | ||||
self.collate_fn = self._collate_fn | |||||
if self.cur_batch_indices is None: | if self.cur_batch_indices is None: | ||||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn, | |||||
self.auto_collator))) | |||||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) | |||||
for indices, data in self.dataset.__iter__(): | for indices, data in self.dataset.__iter__(): | ||||
self.cur_batch_indices = indices | self.cur_batch_indices = indices | ||||
yield data | yield data | ||||
@@ -100,39 +106,56 @@ class JittorDataLoader: | |||||
return len(self.dataset) // self.dataset.batch_size | return len(self.dataset) // self.dataset.batch_size | ||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | ||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
""" | """ | ||||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | |||||
:param field_names: | |||||
:param val: padding值,默认为0 | |||||
:return: | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
if self.auto_collator is None: | |||||
self.auto_collator = AutoCollator(as_numpy=True) | |||||
self.auto_collator.set_pad_val(*field_names, val=val) | |||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
def set_input(self, *field_names) -> None: | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | """ | ||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
if self.auto_collator is None: | |||||
self.auto_collator = AutoCollator(as_numpy=True) | |||||
self.auto_collator.set_input(*field_names) | |||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_ignore(*field_names) | |||||
return self._collate_fn | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前数据的idx | |||||
获取当前 batch 的 idx | |||||
:return: | :return: | ||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_jittor_dataloader(): | def prepare_jittor_dataloader(): | ||||
... | ... |
@@ -6,6 +6,7 @@ __all__ = [ | |||||
from typing import Callable, List, Optional, Union, Dict, Sequence | from typing import Callable, List, Optional, Union, Dict, Sequence | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.io import DataLoader, Dataset | from paddle.io import DataLoader, Dataset | ||||
from paddle.fluid.dataloader.collate import default_collate_fn | from paddle.fluid.dataloader.collate import default_collate_fn | ||||
@@ -13,9 +14,10 @@ else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | ||||
from fastNLP.core.collators.collator import _MultiCollator | |||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.collators.collator import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | |||||
class _PaddleDataset(Dataset): | class _PaddleDataset(Dataset): | ||||
@@ -45,7 +47,7 @@ class PaddleDataLoader(DataLoader): | |||||
def __init__(self, dataset, feed_list=None, places=None, | def __init__(self, dataset, feed_list=None, places=None, | ||||
return_list: bool = True, batch_sampler=None, | return_list: bool = True, batch_sampler=None, | ||||
batch_size: int = 1, shuffle: bool = False, | batch_size: int = 1, shuffle: bool = False, | ||||
drop_last: bool = False, collate_fn: Callable = None, | |||||
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | worker_init_fn: Callable = None, persistent_workers=False) -> None: | ||||
@@ -53,6 +55,10 @@ class PaddleDataLoader(DataLoader): | |||||
if not isinstance(dataset, _PaddleDataset): | if not isinstance(dataset, _PaddleDataset): | ||||
dataset = _PaddleDataset(dataset) | dataset = _PaddleDataset(dataset) | ||||
if batch_sampler is None: | |||||
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | ||||
return_list=return_list, batch_sampler=batch_sampler, | return_list=return_list, batch_sampler=batch_sampler, | ||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | ||||
@@ -60,13 +66,21 @@ class PaddleDataLoader(DataLoader): | |||||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
persistent_workers=persistent_workers) | persistent_workers=persistent_workers) | ||||
if isinstance(dataset.dataset, FDataSet): | |||||
self._collate_fn = dataset.dataset.get_collator() | |||||
self._collate_fn.set_as_numpy(as_numpy=True) | |||||
if collate_fn is not None: | |||||
self._collate_fn.add_collator(collate_fn) | |||||
if isinstance(collate_fn, str): | |||||
if collate_fn == 'auto': | |||||
if isinstance(dataset.dataset, FDataSet): | |||||
self._collate_fn = dataset.dataset.collator | |||||
self._collate_fn.set_backend(backend="paddle") | |||||
else: | |||||
self._collate_fn = Collator(backend="paddle") | |||||
else: | |||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||||
elif isinstance(collate_fn, Callable): | |||||
if collate_fn is not default_collate_fn: | |||||
self._collate_fn = collate_fn | |||||
else: | else: | ||||
self._collate_fn = _MultiCollator(collate_fn) | |||||
self._collate_fn = default_collate_fn | |||||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | ||||
# if collate_fn is not None: | # if collate_fn is not None: | ||||
# _collate_fn.add_collator(collate_fn) | # _collate_fn.add_collator(collate_fn) | ||||
@@ -75,68 +89,60 @@ class PaddleDataLoader(DataLoader): | |||||
def __iter__(self): | def __iter__(self): | ||||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | ||||
if len(self._collate_fn.get_collators()) == 0: | |||||
self._collate_fn.add_collator(default_collate_fn) | |||||
# self._collate_fn = default_collate_fn | |||||
# if len(self._collate_fn.get_collators()) == 0: | |||||
# self._collate_fn.add_collator(default_collate_fn) | |||||
# self._collate_fn = default_collate_fn | |||||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | self.collate_fn = indice_collate_wrapper(self._collate_fn) | ||||
for indices, data in super().__iter__(): | for indices, data in super().__iter__(): | ||||
self.cur_batch_indices = indices | self.cur_batch_indices = indices | ||||
yield data | yield data | ||||
def __getattr__(self, item): | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
""" | """ | ||||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||||
:param item: | |||||
:return: | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
try: | |||||
return self.dataset.__getattr__(item) | |||||
except AttributeError as e: | |||||
raise e | |||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||||
""" | |||||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | |||||
:param field_names: | |||||
:param val: padding值,默认为0 | |||||
:return: | |||||
""" | |||||
for field_name in field_names: | |||||
self._collate_fn.set_pad_val(field_name, val=val) | |||||
def set_input(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
self._collate_fn.set_input(*field_names) | |||||
def set_collator(self, collator: Callable) -> None: | |||||
""" | |||||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||||
:param collator: 用户自定义的Callable函数 | |||||
:return: | |||||
""" | |||||
self._collate_fn = _MultiCollator(collator) | |||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
def add_collator(self, collator) -> None: | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | """ | ||||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||||
:param collator: | |||||
:return: | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
self._collate_fn.add_collator(collator) | |||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_ignore(*field_names) | |||||
return self._collate_fn | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前数据的idx | |||||
获取当前 batch 的 idx | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -144,20 +150,22 @@ class PaddleDataLoader(DataLoader): | |||||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | ||||
return_list: bool = True, batch_sampler=None, | |||||
train_batch_size: int = 1, shuffle: bool = False, | |||||
drop_last: bool = False, collate_fn: Callable = None, | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | |||||
use_shared_memory: bool = True, timeout: int = 0, | |||||
worker_init_fn: Callable = None, persistent_workers=False, | |||||
non_train_batch_size: int = 16, | |||||
input_fields: Union[List[str], str] = None)\ | |||||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||||
if isinstance(input_fields, str): | |||||
input_fields = [input_fields] | |||||
return_list: bool = True, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
train_batch_size: int = 1, shuffle: bool = False, | |||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | |||||
use_shared_memory: bool = True, timeout: int = 0, | |||||
worker_init_fn: Callable = None, persistent_workers=False, | |||||
non_train_batch_size: int = 16) \ | |||||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||||
if isinstance(ds_or_db, Dataset): | if isinstance(ds_or_db, Dataset): | ||||
... | |||||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
return dl | |||||
elif isinstance(ds_or_db, Sequence): | elif isinstance(ds_or_db, Sequence): | ||||
ds_seq = [] | ds_seq = [] | ||||
for ds in ds_or_db: | for ds in ds_or_db: | ||||
@@ -166,7 +174,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | ||||
dl.set_input(*input_fields) | |||||
ds_seq.append(dl) | ds_seq.append(dl) | ||||
return ds_seq | return ds_seq | ||||
@@ -178,14 +185,15 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | ||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
else: | else: | ||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, | batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, | ||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | ||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | ||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
dl.set_input(*input_fields) | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
ds_dict[name] = dl | ds_dict[name] = dl | ||||
return ds_dict | return ds_dict | ||||
else: | else: | ||||
@@ -3,15 +3,14 @@ __all__ = [ | |||||
'prepare_torch_dataloader' | 'prepare_torch_dataloader' | ||||
] | ] | ||||
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping | |||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.collators import AutoCollator | |||||
from fastNLP.core.collators.collator import _MultiCollator | |||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
@@ -51,11 +50,11 @@ class TorchDataLoader(DataLoader): | |||||
def __init__(self, dataset, batch_size: int = 1, | def __init__(self, dataset, batch_size: int = 1, | ||||
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: | |||||
persistent_workers: bool = False, **kwargs) -> None: | |||||
""" | """ | ||||
:param dataset: 实现了__getitem__和__len__的数据容器 | :param dataset: 实现了__getitem__和__len__的数据容器 | ||||
@@ -64,7 +63,7 @@ class TorchDataLoader(DataLoader): | |||||
:param sampler: sampler实例化对象 | :param sampler: sampler实例化对象 | ||||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | ||||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | :param num_workers: 进程的数量,当num_worker=0时不开启多进程 | ||||
:param collate_fn: 对取得到的数据进行打包的callable函数。[None, auto, callable] | |||||
:param collate_fn: [None, 'auto', callable] 对取得到的数据进行打包的callable函数 | |||||
:param pin_memory: | :param pin_memory: | ||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | :param drop_last: 是否去掉最后一个不符合batch_size的数据 | ||||
:param timeout: | :param timeout: | ||||
@@ -73,133 +72,99 @@ class TorchDataLoader(DataLoader): | |||||
:param generator: | :param generator: | ||||
:param prefetch_factor: | :param prefetch_factor: | ||||
:param persistent_workers: | :param persistent_workers: | ||||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||||
""" | """ | ||||
if not isinstance(dataset, _FDataSet): | if not isinstance(dataset, _FDataSet): | ||||
dataset = _FDataSet(dataset) | dataset = _FDataSet(dataset) | ||||
if sampler is None and batch_sampler is None: | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | ||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | ||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, | prefetch_factor=prefetch_factor, | ||||
persistent_workers=persistent_workers) | persistent_workers=persistent_workers) | ||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||||
self._collate_fn = dataset.dataset.get_collator() | |||||
self._collate_fn.set_as_numpy(as_numpy) | |||||
if collate_fn is not None and collate_fn is not default_collate: | |||||
# 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | |||||
self._collate_fn.add_collator(collate_fn) | |||||
if isinstance(collate_fn, str): | |||||
if collate_fn == 'auto': | |||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||||
self._collate_fn = dataset.dataset.collator | |||||
self._collate_fn.set_backend(backend="torch") | |||||
else: | |||||
self._collate_fn = Collator(backend="torch") | |||||
else: | |||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||||
elif isinstance(collate_fn, Callable): | |||||
if collate_fn is not default_collate: | |||||
self._collate_fn = collate_fn | |||||
else: | else: | ||||
self._collate_fn = _MultiCollator(collate_fn) | |||||
self._collate_fn = default_collate | |||||
self.cur_indices_batch = None | self.cur_indices_batch = None | ||||
self.as_numpy = as_numpy | |||||
def __getattr__(self, item): | |||||
""" | |||||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||||
:param item: | |||||
:return: | |||||
""" | |||||
try: | |||||
return self.dataset.__getattr__(item) | |||||
except AttributeError as e: | |||||
raise e | |||||
def __iter__(self): | def __iter__(self): | ||||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | ||||
if len(self._collate_fn.get_collators()) == 0: | |||||
self._collate_fn.add_collator(self.collate_fn) | |||||
# if len(self._collate_fn.get_collators()) == 0: | |||||
# self._collate_fn.add_collator(self.collate_fn) | |||||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | self.collate_fn = indice_collate_wrapper(self._collate_fn) | ||||
for indices, data in super().__iter__(): | for indices, data in super().__iter__(): | ||||
self.cur_batch_indices = indices | self.cur_batch_indices = indices | ||||
yield data | yield data | ||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||||
""" | |||||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | |||||
:param field_names: | |||||
:param val: padding值,默认为0 | |||||
:return: | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
""" | """ | ||||
flag = False | |||||
for collator in self._collate_fn.get_collators(): | |||||
if isinstance(collator, AutoCollator): | |||||
flag = True | |||||
break | |||||
if flag is False: | |||||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||||
for field_name in field_names: | |||||
self._collate_fn.set_pad_val(field_name, val=val) | |||||
def set_input(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
flag = False | |||||
for collator in self._collate_fn.get_collators(): | |||||
if isinstance(collator, AutoCollator): | |||||
flag = True | |||||
break | |||||
if flag is False: | |||||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||||
self._collate_fn.set_input(*field_names) | |||||
def set_collator(self, collator: Callable) -> None: | |||||
""" | |||||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||||
:param collator: 用户自定义的Callable函数 | |||||
:return: | |||||
""" | |||||
self._collate_fn = _MultiCollator(collator) | |||||
def add_collator(self, collator) -> None: | |||||
""" | |||||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
:param collator: | |||||
:return: | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | """ | ||||
self._collate_fn.add_collator(collator) | |||||
def get_batch_indices(self) -> List[int]: | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
获取当前数据的idx | |||||
:return: | |||||
""" | |||||
return self.cur_batch_indices | |||||
def set_pad(self): | |||||
pass | |||||
def set_ignore(self): | |||||
pass | |||||
def set_backend(self): | |||||
pass | |||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_ignore(*field_names) | |||||
return self._collate_fn | |||||
else: | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 1, | batch_size: int = 1, | ||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||||
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | ||||
non_train_batch_size: int = 16, as_numpy: bool = False, | |||||
input_fields: Union[List, str, None] = None) \ | |||||
non_train_batch_size: int = 16) \ | |||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | ||||
""" | """ | ||||
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 | 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 | ||||
@@ -211,7 +176,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
:param sampler: sampler实例化对象 | :param sampler: sampler实例化对象 | ||||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | ||||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | :param num_workers: 进程的数量,当num_worker=0时不开启多进程 | ||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||||
:param collate_fn: ['auto', None, callable]对取得到的数据进行打包的callable函数 | |||||
:param pin_memory: | :param pin_memory: | ||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | :param drop_last: 是否去掉最后一个不符合batch_size的数据 | ||||
:param timeout: | :param timeout: | ||||
@@ -222,11 +187,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
:param persistent_workers: | :param persistent_workers: | ||||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | ||||
:param non_train_batch_size: | :param non_train_batch_size: | ||||
:param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。 | |||||
""" | """ | ||||
# TODO dict, sequence情况下需要提供 | |||||
if isinstance(input_fields, str): | |||||
input_fields = [input_fields] | |||||
if isinstance(ds_or_db, DataSet): | if isinstance(ds_or_db, DataSet): | ||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | ||||
@@ -235,9 +196,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | ||||
as_numpy=as_numpy) | |||||
if input_fields: | |||||
dl.set_input(*input_fields) | |||||
) | |||||
return dl | return dl | ||||
elif isinstance(ds_or_db, DataBundle): | elif isinstance(ds_or_db, DataBundle): | ||||
@@ -251,7 +210,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, | prefetch_factor=prefetch_factor, | ||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
as_numpy=as_numpy) | |||||
) | |||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | ||||
shuffle=shuffle, sampler=non_train_sampler, | shuffle=shuffle, sampler=non_train_sampler, | ||||
@@ -261,9 +220,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, | prefetch_factor=prefetch_factor, | ||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
as_numpy=as_numpy) | |||||
if input_fields: | |||||
dl_bundle[name].set_input(*input_fields) | |||||
) | |||||
return dl_bundle | return dl_bundle | ||||
elif isinstance(ds_or_db, Sequence): | elif isinstance(ds_or_db, Sequence): | ||||
@@ -277,7 +234,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | ||||
as_numpy=as_numpy) | |||||
) | |||||
) | ) | ||||
else: | else: | ||||
dl_bundle.append( | dl_bundle.append( | ||||
@@ -287,11 +244,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | ||||
as_numpy=as_numpy) | |||||
) | |||||
) | ) | ||||
if input_fields: | |||||
for dl in dl_bundle: | |||||
dl.set_input(*input_fields) | |||||
return dl_bundle | return dl_bundle | ||||
elif isinstance(ds_or_db, Mapping): | elif isinstance(ds_or_db, Mapping): | ||||
@@ -305,7 +259,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, | prefetch_factor=prefetch_factor, | ||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
as_numpy=as_numpy) | |||||
) | |||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | ||||
shuffle=shuffle, sampler=non_train_sampler, | shuffle=shuffle, sampler=non_train_sampler, | ||||
@@ -315,10 +269,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | multiprocessing_context=multiprocessing_context, generator=generator, | ||||
prefetch_factor=prefetch_factor, | prefetch_factor=prefetch_factor, | ||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
as_numpy=as_numpy) | |||||
if input_fields: | |||||
dl_bundle[name].set_input(*input_fields) | |||||
) | |||||
return dl_bundle | return dl_bundle | ||||
else: | else: | ||||
@@ -0,0 +1,16 @@ | |||||
def indice_collate_wrapper(func): | |||||
""" | |||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||||
:param func: 需要修饰的函数 | |||||
:return: | |||||
""" | |||||
def wrapper(tuple_data): | |||||
indice, ins_list = [], [] | |||||
for idx, ins in tuple_data: | |||||
indice.append(idx) | |||||
ins_list.append(ins) | |||||
return indice, func(ins_list) | |||||
return wrapper |
@@ -23,9 +23,8 @@ except: | |||||
from .field import FieldArray | from .field import FieldArray | ||||
from .instance import Instance | from .instance import Instance | ||||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | from fastNLP.core.utils.utils import pretty_table_printer, deprecated | ||||
from fastNLP.core.collators import AutoCollator | |||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.core.utils.rich_progress import f_rich_progress | from fastNLP.core.utils.rich_progress import f_rich_progress | ||||
from fastNLP.core.collators.collator import _MultiCollator | |||||
class ApplyResultException(Exception): | class ApplyResultException(Exception): | ||||
@@ -114,7 +113,7 @@ class DataSet: | |||||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | 每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | ||||
""" | """ | ||||
self.field_arrays = {} | self.field_arrays = {} | ||||
self.collate_fns: _MultiCollator = _MultiCollator(AutoCollator(as_numpy=False)) | |||||
self._collator = Collator(backend="numpy") | |||||
if data is not None: | if data is not None: | ||||
if isinstance(data, Dict): | if isinstance(data, Dict): | ||||
length_set = set() | length_set = set() | ||||
@@ -181,7 +180,7 @@ class DataSet: | |||||
dataset = DataSet() | dataset = DataSet() | ||||
for field_name, field in self.field_arrays.items(): | for field_name, field in self.field_arrays.items(): | ||||
dataset.add_field(field_name=field_name, fields=field.content[idx]) | dataset.add_field(field_name=field_name, fields=field.content[idx]) | ||||
dataset.collate_fns = deepcopy(self.collate_fns) | |||||
dataset._collator = deepcopy(self.collator) | |||||
return dataset | return dataset | ||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
@@ -193,7 +192,7 @@ class DataSet: | |||||
assert isinstance(i, int), "Only int index allowed." | assert isinstance(i, int), "Only int index allowed." | ||||
instance = self[i] | instance = self[i] | ||||
dataset.append(instance) | dataset.append(instance) | ||||
dataset.collate_fns = deepcopy(self.collate_fns) | |||||
dataset._collator = deepcopy(self.collator) | |||||
return dataset | return dataset | ||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
@@ -676,8 +675,8 @@ class DataSet: | |||||
dev_set.append(self[idx]) | dev_set.append(self[idx]) | ||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
dev_set.collate_fns = deepcopy(self.collate_fns) | |||||
train_set.collate_fns = deepcopy(self.collate_fns) | |||||
dev_set._collator = deepcopy(self.collator) | |||||
train_set._collator = deepcopy(self.collator) | |||||
return dev_set, train_set | return dev_set, train_set | ||||
@@ -771,67 +770,17 @@ class DataSet: | |||||
df = self.to_pandas() | df = self.to_pandas() | ||||
return df.to_csv(path, encoding="utf-8") | return df.to_csv(path, encoding="utf-8") | ||||
def add_collate_fn(self, collate_fn: Callable) -> None: | |||||
""" | |||||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||||
:param collate_fn: Callable的函数 | |||||
:return: | |||||
""" | |||||
self.collate_fns.add_collator(collate_fn) | |||||
def set_collate_fn(self, collate_fn: Callable) -> None: | |||||
""" | |||||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||||
:param collate_fn: | |||||
:return: | |||||
""" | |||||
self.collate_fns = _MultiCollator(collate_fn) | |||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||||
""" | |||||
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | |||||
:param field_names: dataset存在的field_name | |||||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||||
:return: | |||||
""" | |||||
# TODO 不能为空 | |||||
for field_name in field_names: | |||||
self.collate_fns.set_pad_val(field_name, val=val) | |||||
def set_input(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
# | |||||
self.collate_fns.set_input(*field_names) | |||||
def get_collator(self) -> _MultiCollator: | |||||
""" | |||||
获取dataset绑定的collate_fn,其中包括auto_collate | |||||
:return: | |||||
""" | |||||
return self.collate_fns | |||||
@deprecated() | |||||
def set_target(self, *field_names) -> None: | |||||
def set_ignore(self, *field_names) -> None: | |||||
""" | """ | ||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | ||||
:param field_names: | :param field_names: | ||||
:return: | :return: | ||||
""" | """ | ||||
self.collate_fns.set_input(*field_names) | |||||
self.collator.set_ignore(*field_names) | |||||
@property | @property | ||||
def collator(self): | |||||
def collator(self) -> Collator: | |||||
if self._collator is None: | if self._collator is None: | ||||
self._collator = Collator() | self._collator = Collator() | ||||
return self._collator | return self._collator |
@@ -22,7 +22,7 @@ from fastNLP.core.utils import ( | |||||
rank_zero_rm | rank_zero_rm | ||||
) | ) | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomBatchSampler, | |||||
ReproduceBatchSampler, | |||||
ReproducibleSampler, | ReproducibleSampler, | ||||
ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
RandomSampler, | RandomSampler, | ||||
@@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return self.model, model.forward | return self.model, model.forward | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]], | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
r""" | r""" | ||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | ||||
@@ -22,7 +22,7 @@ from fastNLP.core.log import logger | |||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
ReproducibleSampler, | ReproducibleSampler, | ||||
RandomBatchSampler, | |||||
ReproduceBatchSampler, | |||||
RandomSampler, | RandomSampler, | ||||
) | ) | ||||
@@ -345,7 +345,7 @@ class PaddleDriver(Driver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = RandomBatchSampler( | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -476,7 +476,7 @@ class PaddleDriver(Driver): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
# RandomBatchSampler 的情况 | |||||
# ReproduceBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | elif hasattr(dataloader.batch_sampler, "batch_sampler"): | ||||
batch_sampler = dataloader.batch_sampler.batch_sampler | batch_sampler = dataloader.batch_sampler.batch_sampler | ||||
res.sampler = batch_sampler.sampler | res.sampler = batch_sampler.sampler | ||||
@@ -14,7 +14,7 @@ from fastNLP.core.utils import ( | |||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
RandomBatchSampler, | |||||
ReproduceBatchSampler, | |||||
ReproducibleSampler, | ReproducibleSampler, | ||||
RandomSampler, | RandomSampler, | ||||
re_instantiate_sampler, | re_instantiate_sampler, | ||||
@@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | else: | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -15,7 +15,7 @@ from .torch_driver import TorchDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | from fastNLP.core.samplers import RandomSampler | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver): | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | else: | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -293,7 +293,7 @@ class TorchDriver(Driver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = RandomBatchSampler( | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -407,7 +407,7 @@ class TorchDriver(Driver): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
# RandomBatchSampler 的情况 | |||||
# ReproduceBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | elif hasattr(dataloader.batch_sampler, "batch_sampler"): | ||||
batch_sampler = dataloader.batch_sampler.batch_sampler | batch_sampler = dataloader.batch_sampler.batch_sampler | ||||
res.sampler = batch_sampler.sampler | res.sampler = batch_sampler.sampler | ||||
@@ -14,9 +14,10 @@ __all__ = [ | |||||
"UnrepeatedSortedSampler", | "UnrepeatedSortedSampler", | ||||
"UnrepeatedSequentialSampler", | "UnrepeatedSequentialSampler", | ||||
"RandomBatchSampler", | |||||
"ReproduceBatchSampler", | |||||
"BucketedBatchSampler", | "BucketedBatchSampler", | ||||
"ReproducibleBatchSampler", | "ReproducibleBatchSampler", | ||||
"RandomBatchSampler", | |||||
"re_instantiate_sampler" | "re_instantiate_sampler" | ||||
] | ] | ||||
@@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling | |||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | ||||
from .utils import re_instantiate_sampler | from .utils import re_instantiate_sampler | ||||
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | ||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||||
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler | |||||
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'BucketedBatchSampler', | 'BucketedBatchSampler', | ||||
"ReproduceBatchSampler", | |||||
"RandomBatchSampler" | "RandomBatchSampler" | ||||
] | ] | ||||
@@ -54,13 +55,13 @@ class ReproducibleBatchSampler: | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | ||||
class RandomBatchSampler(ReproducibleBatchSampler): | |||||
class ReproduceBatchSampler(ReproducibleBatchSampler): | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | ||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | ||||
""" | """ | ||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | ||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | ||||
:param batch_size: 每个 batch 的大小是多少。 | :param batch_size: 每个 batch 的大小是多少。 | ||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | ||||
@@ -143,7 +144,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.need_reinitialize = False | self.need_reinitialize = False | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") | |||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | ||||
@@ -158,6 +159,211 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | ||||
class RandomBatchSampler(ReproducibleBatchSampler): | |||||
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | |||||
drop_last: bool = False, seed: int = 0, **kwargs): | |||||
""" | |||||
随机分 batch 的 batch_sampler 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param batch_size: 每个 batch 的大小 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__() | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
self.shuffle = shuffle | |||||
self.drop_last = drop_last | |||||
self.seed = seed | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||||
# 多卡的相关的参数 | |||||
self.num_replicas = kwargs.get("num_replicas", 1) | |||||
self.rank = kwargs.get("rank", 0) | |||||
self.epoch = kwargs.get("epoch", -1) | |||||
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | |||||
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | |||||
self.during_iter = kwargs.get("during_iter", False) | |||||
# 以下变量为内部使用恢复状态的变量。 | |||||
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||||
"during an unfinished iteration." | |||||
assert num_replicas > 0 and isinstance(num_replicas, int) | |||||
assert isinstance(rank, int) and 0 <= rank < num_replicas | |||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||||
self.num_replicas = num_replicas | |||||
self.rank = rank | |||||
self.pad = pad | |||||
return self | |||||
def __iter__(self): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self.during_iter = True | |||||
indices = list(range(len(self.dataset))) | |||||
if self.shuffle: | |||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||||
_batches = [] | |||||
for _i in range(self.old_num_replicas): | |||||
_indices = indices[_i:len(indices):self.old_num_replicas] | |||||
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch) | |||||
_batches.append(__batches) | |||||
batches = list(chain(*[_ for _ in zip(*_batches)])) | |||||
indices = list(chain(*batches)) | |||||
indices = indices[self.num_consumed_samples:] | |||||
# 取出这个 rank , | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch) | |||||
batches = list(map(list, batches)) | |||||
else: | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
_num_batches = len(indices) // self.batch_size | |||||
if _num_batches == 0: | |||||
batches = [indices] | |||||
else: | |||||
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches))) | |||||
if len(indices)%self.batch_size!=0: | |||||
batches.append(indices[_num_batches*self.batch_size:]) | |||||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||||
if len(batches) > 0: | |||||
if len(batches[-1])<self.batch_size: | |||||
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。 | |||||
else: | |||||
batches.append([batches[-1][0]]) | |||||
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank: | |||||
if len(batches): | |||||
batches[-1].pop(-1) | |||||
if len(batches[-1])==0: | |||||
batches.pop(-1) | |||||
assert sum(map(len, batches)) == self.num_left_samples | |||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||||
batches = batches[:-1] | |||||
for batch in batches: | |||||
self.num_consumed_samples += self.num_replicas * len(batch) | |||||
yield list(map(int, batch)) | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
self.old_batch_size = self.batch_size | |||||
self.old_num_replicas = self.num_replicas | |||||
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 | |||||
self.epoch -= 1 | |||||
def batchify(self, indices, batch_size, seed): | |||||
""" | |||||
将 indices 分为 batches | |||||
:param sorted_indices: List[int] | |||||
:param batch_size: int | |||||
:param seed: int | |||||
:return: List[List[int]] | |||||
""" | |||||
# 实际的 bucket 大小 | |||||
rng = np.random.default_rng(abs(seed)) | |||||
rng.shuffle(indices) | |||||
num_samples = 0 | |||||
batches = [] | |||||
while num_samples<len(indices): | |||||
batches.append(indices[num_samples:num_samples+batch_size]) | |||||
num_samples += batch_size | |||||
return batches | |||||
def set_epoch(self, epoch): | |||||
self.epoch = epoch | |||||
@property | |||||
def batch_idx_in_epoch(self): | |||||
if self.drop_last: | |||||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | |||||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | |||||
@property | |||||
def total_size(self): | |||||
""" | |||||
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、 | |||||
大于或者小于len(dataset) | |||||
:return: | |||||
""" | |||||
return self.num_consumed_samples + self.num_replicas*self.num_left_samples | |||||
@property | |||||
def num_left_samples(self): | |||||
""" | |||||
返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。 | |||||
:return: | |||||
""" | |||||
num_consumed_samples = self.num_consumed_samples | |||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
def __len__(self)->int: | |||||
""" | |||||
返回当前 sampler 还会返回多少个 batch 的数据 | |||||
:return: | |||||
""" | |||||
num_sampler_per_rank = self.total_size//self.num_replicas | |||||
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \ | |||||
(num_sampler_per_rank+self.batch_size-1)//self.batch_size | |||||
return num_batches | |||||
def state_dict(self) -> Dict: | |||||
if self.old_batch_size != self.batch_size: | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||||
" consumed. ") | |||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | |||||
'num_replicas': self.num_replicas} | |||||
return states | |||||
def load_state_dict(self, states: Dict): | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
"during an unfinished iteration." | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
length = states['length'] | |||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | |||||
self.seed = states['seed'] | |||||
self.epoch = states['epoch'] | |||||
self.num_consumed_samples = states['num_consumed_samples'] | |||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | |||||
if self.shuffle != states['shuffle']: | |||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||||
f"we use shuffle={states['shuffle']}") | |||||
self.shuffle = states["shuffle"] | |||||
self.old_batch_size = states['batch_size'] | |||||
self.old_num_replicas = states['num_replicas'] | |||||
class BucketedBatchSampler(ReproducibleBatchSampler): | class BucketedBatchSampler(ReproducibleBatchSampler): | ||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | ||||
@@ -54,13 +54,12 @@ class RandomSampler(ReproducibleSampler): | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
:param dataset: 实现了 __len__ 方法的数据容器 | :param dataset: 实现了 __len__ 方法的数据容器 | ||||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 | :param shuffle: 是否在每次 iterate 的时候打乱顺序。 | ||||
:param seed: 随机数种子。 | :param seed: 随机数种子。 | ||||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | :param kwargs: 用户不需要使用,fastNLP 内部使用 | ||||
""" | """ | ||||
super(RandomSampler, self).__init__() | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.seed = seed | self.seed = seed | ||||
@@ -21,7 +21,6 @@ __all__ = [ | |||||
'nullcontext', | 'nullcontext', | ||||
'pretty_table_printer', | 'pretty_table_printer', | ||||
'Option', | 'Option', | ||||
'indice_collate_wrapper', | |||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'rank_zero_rm', | 'rank_zero_rm', | ||||
@@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device | |||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | ||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
from ..dataloaders.utils import indice_collate_wrapper | |||||
@@ -7,13 +7,13 @@ from collections.abc import Mapping, Callable | |||||
from functools import wraps | from functools import wraps | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor as jt | import jittor as jt | ||||
from fastNLP.core.dataset import Instance | from fastNLP.core.dataset import Instance | ||||
def is_jittor_dataset(dataset) -> bool: | def is_jittor_dataset(dataset) -> bool: | ||||
try: | try: | ||||
if isinstance(dataset, jt.dataset.Dataset): | if isinstance(dataset, jt.dataset.Dataset): | ||||
@@ -32,6 +32,7 @@ def jittor_collate_wraps(func, auto_collator: Callable): | |||||
:param auto_collator: | :param auto_collator: | ||||
:return: | :return: | ||||
""" | """ | ||||
@wraps(func) | @wraps(func) | ||||
def wrapper(batch): | def wrapper(batch): | ||||
if isinstance(batch[0], Instance): | if isinstance(batch[0], Instance): | ||||
@@ -6,7 +6,7 @@ import warnings | |||||
from dataclasses import is_dataclass | from dataclasses import is_dataclass | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from collections import defaultdict, OrderedDict | from collections import defaultdict, OrderedDict | ||||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional | |||||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence | |||||
from typing import Tuple, Optional | from typing import Tuple, Optional | ||||
from time import sleep | from time import sleep | ||||
@@ -35,7 +35,6 @@ __all__ = [ | |||||
'nullcontext', | 'nullcontext', | ||||
'pretty_table_printer', | 'pretty_table_printer', | ||||
'Option', | 'Option', | ||||
'indice_collate_wrapper', | |||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'rank_zero_rm', | 'rank_zero_rm', | ||||
@@ -513,24 +512,6 @@ class Option(dict): | |||||
self.update(state) | self.update(state) | ||||
def indice_collate_wrapper(func): | |||||
""" | |||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||||
:param func: 需要修饰的函数 | |||||
:return: | |||||
""" | |||||
def wrapper(tuple_data): | |||||
indice, ins_list = [], [] | |||||
for idx, ins in tuple_data: | |||||
indice.append(idx) | |||||
ins_list.append(ins) | |||||
return indice, func(ins_list) | |||||
return wrapper | |||||
_emitted_deprecation_warnings = set() | _emitted_deprecation_warnings = set() | ||||
@@ -0,0 +1,106 @@ | |||||
import numpy as np | |||||
import pytest | |||||
from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder | |||||
from fastNLP.core.collators.padders.exceptions import DtypeError | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | |||||
import paddle | |||||
@pytest.mark.paddle | |||||
class TestpaddleNumberPadder: | |||||
def test_run(self): | |||||
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [1, 2, 3] | |||||
t_a = padder(a) | |||||
assert isinstance(t_a, paddle.Tensor) | |||||
assert (t_a == paddle.to_tensor(a, dtype='int64')).sum() == 3 | |||||
@pytest.mark.paddle | |||||
class TestpaddleSequencePadder: | |||||
def test_run(self): | |||||
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [[1, 2, 3], [3]] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, paddle.Tensor) | |||||
assert tuple(shape) == (2, 3) | |||||
b = paddle.to_tensor([[1, 2, 3], [3, -1, -1]], dtype='int64') | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
def test_dtype_check(self): | |||||
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) | |||||
a = padder([[1], [2, 322]]) | |||||
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | |||||
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
@pytest.mark.paddle | |||||
class TestpaddleTensorPadder: | |||||
def test_run(self): | |||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) | |||||
a = [paddle.zeros((3,)), paddle.zeros((2,))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, paddle.Tensor) | |||||
assert tuple(shape) == (2, 3) | |||||
b = paddle.to_tensor([[0, 0, 0], [0, 0, -1]], dtype='int64') | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | |||||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 2))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, paddle.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[0, 0], [-1, -1], [-1, -1]]], dtype='int64') | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 1))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, paddle.Tensor) | |||||
assert tuple(shape) == (3, 3, 2) | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) | |||||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, paddle.Tensor) | |||||
assert tuple(shape) == (2, 3, 2) | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | |||||
]) | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) | |||||
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] | |||||
a = padder(a) | |||||
shape = a.shape | |||||
assert isinstance(a, paddle.Tensor) | |||||
assert tuple(shape) == (2, 3, 2) | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]]], dtype='float32') | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||||
def test_dtype_check(self): | |||||
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | |||||
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) | |||||
def test_v1(self): | |||||
print(paddle.zeros((3, )).dtype) |
@@ -40,8 +40,8 @@ class TestJittor: | |||||
""" | """ | ||||
dataset = MyDataset() | dataset = MyDataset() | ||||
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | ||||
jtl.set_pad_val('x', 'y') | |||||
jtl.set_input('x') | |||||
# jtl.set_pad_val('x', 'y') | |||||
# jtl.set_input('x') | |||||
for batch in jtl: | for batch in jtl: | ||||
print(batch) | print(batch) | ||||
print(jtl.get_batch_indices()) | print(jtl.get_batch_indices()) | ||||
@@ -54,15 +54,17 @@ class TestJittor: | |||||
""" | """ | ||||
dataset = Fdataset({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | dataset = Fdataset({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | ||||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | ||||
jtl.set_pad_val('x', val=-1) | |||||
jtl.set_input('x', 'y') | |||||
jtl.set_pad("x", -1) | |||||
jtl.set_ignore("y") | |||||
# jtl.set_pad_val('x', val=-1) | |||||
# jtl.set_input('x', 'y') | |||||
for batch in jtl: | for batch in jtl: | ||||
assert batch['x'].size() == (16, 4) | assert batch['x'].size() == (16, 4) | ||||
def test_v3(self): | def test_v3(self): | ||||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | ||||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | ||||
jtl.set_input('x', 'y') | |||||
# jtl.set_input('x', 'y') | |||||
for batch in jtl: | for batch in jtl: | ||||
print(batch) | print(batch) | ||||
@@ -3,6 +3,8 @@ import numpy as np | |||||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.io import Dataset, DataLoader | from paddle.io import Dataset, DataLoader | ||||
@@ -11,11 +13,12 @@ else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
class RandomDataset(Dataset): | class RandomDataset(Dataset): | ||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
image = np.random.random((10, 5)).astype('float32') | image = np.random.random((10, 5)).astype('float32') | ||||
return {'image': paddle.Tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]} | |||||
return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]} | |||||
def __len__(self): | def __len__(self): | ||||
return 10 | return 10 | ||||
@@ -36,23 +39,30 @@ class TestPaddle: | |||||
def test_fdl_batch_indices(self): | def test_fdl_batch_indices(self): | ||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ||||
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | ||||
fdl.set_input("x", "y") | |||||
for batch in fdl: | for batch in fdl: | ||||
assert len(fdl.get_batch_indices()) == 4 | assert len(fdl.get_batch_indices()) == 4 | ||||
print(batch) | print(batch) | ||||
print(fdl.get_batch_indices()) | print(fdl.get_batch_indices()) | ||||
def test_set_inputs_and_set_pad_val(self): | def test_set_inputs_and_set_pad_val(self): | ||||
logger.setLevel("DEBUG") | |||||
ds = RandomDataset() | ds = RandomDataset() | ||||
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | ||||
fdl.set_input('image', 'label') | |||||
fdl.set_pad_val('label', val=-1) | |||||
fdl.set_pad('label', -1) | |||||
for batch in fdl: | for batch in fdl: | ||||
print(batch['image']) | |||||
assert batch['image'].shape == [2, 10, 5] | assert batch['image'].shape == [2, 10, 5] | ||||
print(batch) | print(batch) | ||||
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | ||||
fdl1.set_input('image', 'label') | |||||
fdl1.set_pad_val('image', val=None) | |||||
fdl1.set_ignore('label') | |||||
for batch in fdl1: | for batch in fdl1: | ||||
assert batch['image'].shape == [4, 10, 5] | assert batch['image'].shape == [4, 10, 5] | ||||
print(batch) | print(batch) | ||||
def test_v2(self): | |||||
from fastNLP.core.collators import Collator | |||||
logger.setLevel("DEBUG") | |||||
data = [paddle.Tensor(np.random.random((10, 5)).astype('float32')), paddle.Tensor(np.random.random((10, 5)).astype('float32'))] | |||||
col = Collator(backend="jittor") | |||||
res = col(data) | |||||
print(res) |
@@ -13,42 +13,23 @@ class TestFdl: | |||||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | ||||
# for batch in fdl: | # for batch in fdl: | ||||
# print(batch) | # print(batch) | ||||
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) | |||||
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||||
# for batch in fdl1: | # for batch in fdl1: | ||||
# print(batch) | # print(batch) | ||||
def test_set_padding(self): | def test_set_padding(self): | ||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
ds.set_pad_val("x", val=-1) | |||||
fdl = TorchDataLoader(ds, batch_size=3) | fdl = TorchDataLoader(ds, batch_size=3) | ||||
fdl.set_input("x", "y") | |||||
fdl.set_pad_val("x", val=None) | |||||
fdl.set_pad("x", -1) | |||||
for batch in fdl: | for batch in fdl: | ||||
print(batch) | print(batch) | ||||
# fdl.set_pad_val("x", val=-2) | # fdl.set_pad_val("x", val=-2) | ||||
# for batch in fdl: | # for batch in fdl: | ||||
# print(batch) | # print(batch) | ||||
def test_add_collator(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
def collate_fn(ins_list): | |||||
_dict = {"Y": []} | |||||
for ins in ins_list: | |||||
_dict["Y"].append(ins['y']) | |||||
return _dict | |||||
fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) | |||||
fdl.set_input("x", "y") | |||||
# fdl.set_pad_val("x", val=None) | |||||
fdl.add_collator(collate_fn) | |||||
for batch in fdl: | |||||
print(batch) | |||||
def test_get_batch_indices(self): | def test_get_batch_indices(self): | ||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | ||||
fdl.set_input("y", "x") | |||||
for batch in fdl: | for batch in fdl: | ||||
print(fdl.get_batch_indices()) | print(fdl.get_batch_indices()) | ||||
@@ -2,7 +2,7 @@ import pytest | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -278,7 +278,7 @@ class TestPaddleDriverFunctions: | |||||
dataset = PaddleNormalDataset() | dataset = PaddleNormalDataset() | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset, | dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), | BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), | ||||
batch_size, | batch_size, | ||||
drop_last, | drop_last, | ||||
@@ -287,7 +287,7 @@ class TestPaddleDriverFunctions: | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | res = PaddleSingleDriver.get_dataloader_args(dataloader) | ||||
assert isinstance(res.dataset, PaddleNormalDataset) | assert isinstance(res.dataset, PaddleNormalDataset) | ||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(res.batch_sampler, ReproduceBatchSampler) | |||||
if shuffle: | if shuffle: | ||||
assert isinstance(res.sampler, paddle.io.RandomSampler) | assert isinstance(res.sampler, paddle.io.RandomSampler) | ||||
else: | else: | ||||
@@ -387,7 +387,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | ||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler | |||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | ||||
@@ -400,7 +400,7 @@ class TestSetDistReproDataloader: | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | ||||
else: | else: | ||||
# 此时会替换 batch_sampler | # 此时会替换 batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -414,11 +414,11 @@ class TestSetDistReproDataloader: | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | ||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) | ||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) | |||||
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | assert replaced_loader.batch_sampler is dist | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | ||||
@@ -450,7 +450,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=self.dataset, | dataset=self.dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), | BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), | ||||
batch_size=4, | batch_size=4, | ||||
drop_last=False, | drop_last=False, | ||||
@@ -459,7 +459,7 @@ class TestSetDistReproDataloader: | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -500,20 +500,20 @@ class TestSetDistReproDataloader: | |||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | already_seen_idx.update(batch) | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
else: | else: | ||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | ||||
left_idxes = set() | left_idxes = set() | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
# 重新改造 dataloader | # 重新改造 dataloader | ||||
new_loader = DataLoader( | new_loader = DataLoader( | ||||
dataset=replaced_loader.dataset, | dataset=replaced_loader.dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), | BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), | ||||
batch_size=batch_size, | batch_size=batch_size, | ||||
drop_last=False, | drop_last=False, | ||||
@@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
dataset = PaddleRandomMaxDataset(40, 10) | dataset = PaddleRandomMaxDataset(40, 10) | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||||
) | ) | ||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | ||||
@@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | |||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | |||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | ||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
@@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | # 2. 检查 batch_sampler 是否被正确地加载和替换 | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | assert replaced_loader.batch_sampler is dataloader.batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | ||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | ||||
@@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import ( | |||||
replace_batch_sampler, | replace_batch_sampler, | ||||
replace_sampler, | replace_sampler, | ||||
) | ) | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, | |||||
def test_replace_batch_sampler(): | def test_replace_batch_sampler(): | ||||
dataset = PaddleNormalDataset(10) | dataset = PaddleNormalDataset(10) | ||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, PaddleNormalDataset) | assert isinstance(replaced_loader.dataset, PaddleNormalDataset) | ||||
assert len(replaced_loader.dataset) == len(dataset) | assert len(replaced_loader.dataset) == len(dataset) | ||||
assert replaced_loader.batch_sampler.batch_size == 16 | assert replaced_loader.batch_sampler.batch_size == 16 | ||||
@@ -2,7 +2,7 @@ import pytest | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset | ||||
@@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE: | |||||
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | ||||
""" | """ | ||||
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader | |||||
建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader | |||||
""" | """ | ||||
if shuffle: | if shuffle: | ||||
sampler = torch.utils.data.RandomSampler(dataset) | sampler = torch.utils.data.RandomSampler(dataset) | ||||
@@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | |||||
sampler = torch.utils.data.SequentialSampler(dataset) | sampler = torch.utils.data.SequentialSampler(dataset) | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler( | BatchSampler( | ||||
sampler, batch_size=batch_size, drop_last=drop_last | sampler, batch_size=batch_size, drop_last=drop_last | ||||
), | ), | ||||
@@ -306,7 +306,7 @@ class TestTorchDriverFunctions: | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | res = TorchSingleDriver.get_dataloader_args(dataloader) | ||||
assert isinstance(res.dataset, TorchNormalDataset) | assert isinstance(res.dataset, TorchNormalDataset) | ||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(res.batch_sampler, ReproduceBatchSampler) | |||||
if shuffle: | if shuffle: | ||||
assert isinstance(res.sampler, torch.utils.data.RandomSampler) | assert isinstance(res.sampler, torch.utils.data.RandomSampler) | ||||
else: | else: | ||||
@@ -401,7 +401,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | ||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler | |||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | ||||
@@ -414,7 +414,7 @@ class TestSetDistReproDataloader: | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | ||||
else: | else: | ||||
# 此时会替换 batch_sampler | # 此时会替换 batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -428,11 +428,11 @@ class TestSetDistReproDataloader: | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | ||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) | |||||
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | assert replaced_loader.batch_sampler is dist | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | ||||
@@ -466,7 +466,7 @@ class TestSetDistReproDataloader: | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -502,14 +502,14 @@ class TestSetDistReproDataloader: | |||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | already_seen_idx.update(batch) | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
else: | else: | ||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | ||||
left_idxes = set() | left_idxes = set() | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
# 重新改造 dataloader | # 重新改造 dataloader | ||||
@@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | # 2. 检查 batch_sampler 是否被正确地加载和替换 | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | assert replaced_loader.batch_sampler is dataloader.batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | ||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | ||||
@@ -30,7 +30,7 @@ class SequenceDataSet: | |||||
def check_replace_sampler(driver): | def check_replace_sampler(driver): | ||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler | |||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler | |||||
# reproducible 是 True 和 False | # reproducible 是 True 和 False | ||||
# 需要 check 返回的 sampler 和 dataloader 都不同了 | # 需要 check 返回的 sampler 和 dataloader 都不同了 | ||||
@@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
replace_batch_sampler, | replace_batch_sampler, | ||||
replace_sampler, | replace_sampler, | ||||
) | ) | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from torch.utils.data import DataLoader, BatchSampler | from torch.utils.data import DataLoader, BatchSampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
def test_replace_batch_sampler(): | def test_replace_batch_sampler(): | ||||
dataset = TorchNormalDataset(10) | dataset = TorchNormalDataset(10) | ||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, TorchNormalDataset) | assert isinstance(replaced_loader.dataset, TorchNormalDataset) | ||||
assert len(replaced_loader.dataset) == len(dataset) | assert len(replaced_loader.dataset) == len(dataset) | ||||
assert replaced_loader.batch_sampler.batch_size == 16 | assert replaced_loader.batch_sampler.batch_size == 16 | ||||
@@ -5,7 +5,7 @@ import pytest | |||||
from itertools import chain | from itertools import chain | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -19,7 +19,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
# before_batch_size = 7 | # before_batch_size = 7 | ||||
# dataset = TorchNormalDataset(num_of_data=100) | # dataset = TorchNormalDataset(num_of_data=100) | ||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | # dataloader = DataLoader(dataset, batch_size=before_batch_size) | ||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
# | # | ||||
# forward_steps = 3 | # forward_steps = 3 | ||||
@@ -29,15 +29,15 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
# | # | ||||
# # 1. 保存状态 | # # 1. 保存状态 | ||||
# _get_re_batchsampler = dataloader.batch_sampler | # _get_re_batchsampler = dataloader.batch_sampler | ||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | # state = _get_re_batchsampler.state_dict() | ||||
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | # assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | ||||
# "sampler_type": "RandomBatchSampler"} | |||||
# "sampler_type": "ReproduceBatchSampler"} | |||||
# | # | ||||
# # 2. 断点重训,重新生成一个 dataloader; | # # 2. 断点重训,重新生成一个 dataloader; | ||||
# # 不改变 batch_size; | # # 不改变 batch_size; | ||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | # dataloader = DataLoader(dataset, batch_size=before_batch_size) | ||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | # re_batchsampler.load_state_dict(state) | ||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
# | # | ||||
@@ -54,7 +54,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
# # 改变 batch_size; | # # 改变 batch_size; | ||||
# after_batch_size = 3 | # after_batch_size = 3 | ||||
# dataloader = DataLoader(dataset, batch_size=after_batch_size) | # dataloader = DataLoader(dataset, batch_size=after_batch_size) | ||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | # re_batchsampler.load_state_dict(state) | ||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
# | # | ||||
@@ -100,7 +100,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
# dataset = TorchNormalDataset(num_of_data=100) | # dataset = TorchNormalDataset(num_of_data=100) | ||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | ||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | ||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
# | # | ||||
# # 将一轮的所有数据保存下来,看是否恢复的是正确的; | # # 将一轮的所有数据保存下来,看是否恢复的是正确的; | ||||
@@ -112,13 +112,13 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
# | # | ||||
# # 1. 保存状态 | # # 1. 保存状态 | ||||
# _get_re_batchsampler = dataloader.batch_sampler | # _get_re_batchsampler = dataloader.batch_sampler | ||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | # state = _get_re_batchsampler.state_dict() | ||||
# | # | ||||
# # 2. 断点重训,重新生成一个 dataloader; | # # 2. 断点重训,重新生成一个 dataloader; | ||||
# # 不改变 batch_size; | # # 不改变 batch_size; | ||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | ||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | # re_batchsampler.load_state_dict(state) | ||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
# | # | ||||
@@ -511,3 +511,313 @@ class TestBucketedBatchSampler: | |||||
already_seen_set.update(batch) | already_seen_set.update(batch) | ||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) | assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) | ||||
class TestRandomBatchSampler: | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) | |||||
def test_single_num_batch(self, shuffle, drop_last, num): | |||||
# 数量不够不报错 | |||||
for num in [2, 7, 14, 15, 70, 71]: | |||||
dataset = DatasetWithVaryLength(num_of_data=num) | |||||
before_batch_size = 7 | |||||
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
count = len(list(iter(re_batchsampler))) | |||||
if drop_last: | |||||
assert count==num//before_batch_size, num | |||||
else: | |||||
assert count==(num+before_batch_size-1)//before_batch_size, num | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
def test_single(self, shuffle, drop_last): | |||||
before_batch_size = 7 | |||||
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler.set_epoch(0) | |||||
forward_steps = 10 | |||||
iterator = iter(re_batchsampler) | |||||
already_generate_indices = set() | |||||
for _ in range(forward_steps): | |||||
batch = next(iterator) | |||||
already_generate_indices.update(batch) | |||||
# 1. 保存状态 | |||||
state = re_batchsampler.state_dict() | |||||
# 2. 断点重训,继续训练 | |||||
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler2.load_state_dict(state) | |||||
re_batchsampler2.set_epoch(0) | |||||
new_already_generate_indices = set() | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-before_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in re_batchsampler2: | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
new_already_generate_indices.update(batch) | |||||
if drop_last is False: | |||||
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler3.load_state_dict(state) | |||||
re_batchsampler3.set_epoch(0) | |||||
count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
for batch in re_batchsampler3: | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
# 再 save ,不允许再上个epoch没结束继续sample | |||||
after_batch_size = 5 | |||||
with pytest.raises(RuntimeError): | |||||
state = re_batchsampler3.state_dict() | |||||
for batch in re_batchsampler3: # consume all, 这样才能save | |||||
pass | |||||
already_generate_indices = set() | |||||
count = 0 | |||||
for batch in re_batchsampler3: # 重新开始 | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
state = re_batchsampler3.state_dict() | |||||
# 这里的 drop_last 为 False,需要最终是所有 sample | |||||
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
drop_last=False, | |||||
shuffle=shuffle) | |||||
re_batchsampler4.load_state_dict(state) | |||||
re_batchsampler4.set_epoch(0) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
for batch in re_batchsampler4: | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
assert len(already_generate_indices) == len(dataset) | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
def test_multi(self, shuffle, drop_last, pad): | |||||
# def test_multi(self, shuffle=True, drop_last=False, pad=False): | |||||
# no shuffle | |||||
num_replica = 2 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
batch_size = 5 | |||||
num_batch_per_bucket = 10 | |||||
lengths = [] | |||||
rank0_already_seen_indexes = None | |||||
max_diff = num_batch_per_bucket * batch_size * num_replica | |||||
for rank in range(num_replica): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
already_seen_indexes = set() | |||||
repeat_count = 0 | |||||
for batch in sampler: | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if rank0_already_seen_indexes: # 不能交叉出现 | |||||
assert b not in rank0_already_seen_indexes | |||||
already_seen_indexes.update(batch) | |||||
if rank0_already_seen_indexes is None: | |||||
rank0_already_seen_indexes = already_seen_indexes | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count<=1 | |||||
else: | |||||
assert repeat_count==0 | |||||
assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 | |||||
# 多进程的保存 | |||||
already_seen_indexes = set() | |||||
for rank in range(num_replica): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
count = 0 | |||||
for batch in sampler: | |||||
already_seen_indexes.update(batch) | |||||
if count>5: | |||||
break | |||||
count += 1 | |||||
state = sampler.state_dict() | |||||
# 切换成单机 | |||||
new_batch_size = 6 | |||||
num_batch_per_bucket = 3 | |||||
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.load_state_dict(state) | |||||
repeat_count = 0 | |||||
new_already_seen_indexes = set(list(already_seen_indexes)) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
for batch in new_sampler: | |||||
for b in batch: | |||||
repeat_count += int(b in new_already_seen_indexes) | |||||
new_already_seen_indexes.update(batch) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
if drop_last is False: # 如果没有drop应该相等 | |||||
assert len(new_already_seen_indexes)==len(dataset) | |||||
# 测试替换卡的数量。 | |||||
num_replica = 3 | |||||
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.load_state_dict(state) | |||||
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) | |||||
repeat_count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
for batch in new_sampler: | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
batch_size = 6 | |||||
if num_replicas*batch_size > num_samples: | |||||
return | |||||
num_batch_per_bucket = 10 | |||||
samplers = [] | |||||
lengths = [] | |||||
for i in range(num_replicas): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||||
sampler.set_epoch(0) | |||||
samplers.append(sampler) | |||||
lengths.append(len(list(iter(sampler)))) | |||||
assert len(set(lengths))==1 | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
""" | |||||
测试是否能够正确地恢复使用过的(forward)数据 | |||||
:return: | |||||
""" | |||||
batch_size = 6 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
samplers = [] | |||||
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas)) | |||||
for i in range(num_replicas): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
count = 0 | |||||
already_seen_sets = [set()] | |||||
already_seen_set = set() | |||||
for batchs in zip(*samplers): | |||||
batch = chain(*batchs) | |||||
already_seen_set.update(batch) | |||||
already_seen_sets.append(deepcopy(already_seen_set)) | |||||
count += 1 | |||||
if count > 3: | |||||
break | |||||
states = samplers[0].state_dict() | |||||
for i in range(len(already_seen_sets)): | |||||
states['num_consumed_samples'] = num_consumed_samples_array[i] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
already_seen_set = deepcopy(already_seen_sets[i]) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||||
dataset) | |||||
# 测试保存之后再次保存 | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
states['num_consumed_samples'] = num_consumed_samples_array[2] | |||||
if len(already_seen_sets)<3: | |||||
return | |||||
already_seen_set = already_seen_sets[2] | |||||
count = 0 | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
count += 1 | |||||
if count > 6: | |||||
break | |||||
states = sampler.state_dict() | |||||
num_consumed_samples_array = list(range(len(dataset))) | |||||
states['num_consumed_samples'] = num_consumed_samples_array[count] | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.load_state_dict(states) | |||||
sampler.set_epoch(0) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) |