From b8026f786fa38498908414f8a5181cc8be389be2 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Mon, 2 May 2022 17:12:45 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9collator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collators/__init__.py | 3 +- fastNLP/core/collators/collator.py | 905 +++++++++++------- fastNLP/core/collators/padders/get_padder.py | 2 +- .../core/collators/padders/paddle_padder.py | 174 ++++ .../core/dataloaders/jittor_dataloader/fdl.py | 90 +- .../core/dataloaders/paddle_dataloader/fdl.py | 150 +-- .../core/dataloaders/torch_dataloader/fdl.py | 168 ++-- fastNLP/core/dataloaders/utils/__init__.py | 0 fastNLP/core/dataset/dataset.py | 73 +- fastNLP/core/utils/jittor_utils.py | 3 +- .../collators/padders/test_paddle_padder.py | 107 +++ .../dataloaders/jittor_dataloader/test_fdl.py | 12 +- .../dataloaders/paddle_dataloader/test_fdl.py | 20 +- .../dataloaders/torch_dataloader/test_fdl.py | 23 +- 14 files changed, 1069 insertions(+), 661 deletions(-) create mode 100644 fastNLP/core/collators/padders/paddle_padder.py create mode 100644 fastNLP/core/dataloaders/utils/__init__.py create mode 100644 tests/core/collators/padders/test_paddle_padder.py diff --git a/fastNLP/core/collators/__init__.py b/fastNLP/core/collators/__init__.py index c896d08d..17cbb6ae 100644 --- a/fastNLP/core/collators/__init__.py +++ b/fastNLP/core/collators/__init__.py @@ -1,5 +1,4 @@ __all__ = [ - 'AutoCollator', 'Collator' ] -from .collator import AutoCollator, Collator +from .collator import Collator diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index b6b6de14..3bbc6141 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -1,386 +1,573 @@ __all__ = [ - 'AutoCollator', 'Collator', ] +from typing import List, Union, Dict, Callable, Sequence, Mapping -from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Callable, Union, Tuple -from numbers import Number -import warnings - -import numpy as np - -from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH - -if _NEED_IMPORT_PADDLE: - import paddle - -if _NEED_IMPORT_TORCH: - import torch - - -class ApplyResultException(Exception): - def __init__(self, msg, index=None): - super().__init__(msg) - self.msg = msg - self.index = index # 标示在哪个数据遭遇到问题了 +from fastNLP.core.log import logger +from .padders.get_padder import get_padder +import re -class SetInputOrTargetException(Exception): - def __init__(self, msg, index=None, field_name=None): - super().__init__(msg) - self.msg = msg - self.index = index # 标示在哪个数据遭遇到问题了 - self.field_name = field_name # 标示当前 field 的名称 +from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ + pack_batch_sequence +sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 +SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] -def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: - r""" - 识别cell的类别与dimension的数量 - - numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html - :param cell: - :param dim: - :return: - """ - if isinstance(cell, (str, Number, np.bool_)): - if hasattr(cell, 'dtype'): - return cell.dtype.type, dim - return type(cell), dim - - elif isinstance(cell, list): - dim += 1 - res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] - types = set([i for i, j in res]) - dims = set([j for i, j in res]) - if len(types) > 1: - raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) - elif len(types) == 0: - raise SetInputOrTargetException("Empty value encountered.") - if len(dims) > 1: - raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) - return types.pop(), dims.pop() - - elif isinstance(cell, torch.Tensor): - return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 - - elif isinstance(cell, paddle.Tensor): - return cell.dtype, cell.dim() + dim - - elif isinstance(cell, np.ndarray): - if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 - return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 - # 否则需要继续往下 iterate - dim += 1 - res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] - types = set([i for i, j in res]) - dims = set([j for i, j in res]) - if len(types) > 1: - raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) - elif len(types) == 0: - raise SetInputOrTargetException("Empty value encountered.") - if len(dims) > 1: - raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) - return types.pop(), dims.pop() - - else: # 包含 tuple, set, dict 以及其它的类型 - raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") - - -def _get_ds_type_dim(ds: dict): - # 获取数据集第一行的 field 内部函数的类型和维度 - field_dtype, field_dim = {}, {} - for field_name, field_content in ds.items(): - type_0, dim_0 = _get_ele_type_and_dim(field_content) - field_dtype[field_name], field_dim[field_name] = type_0, dim_0 - return field_dtype, field_dim - - -class Collator(metaclass=ABCMeta): - r""" - 辅助DataLoader管理collate_fn的类 - - """ - - def __init__(self): - super(Collator, self).__init__() - self.collate_fn = [] - - @abstractmethod - def __call__(self, ins_lst: List) -> Any: - raise NotImplementedError - - @abstractmethod - def set_pad_val(self, *field_names: str, value=0): - raise NotImplementedError - - -class _MultiCollator: - """ - 管理所有collator的容器, - 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 - """ - - def __init__(self, collate_fns: Union[Callable, List[Callable], None]): - - if collate_fns is None: - collate_fns = [] - - if isinstance(collate_fns, Callable): - collate_fns = [collate_fns] - - self._collators: list = collate_fns - - def __call__(self, ins_lst) -> Dict: - out, list_out = {}, [] - for idx, _collate_fn in enumerate(self._collators): - res = _collate_fn(ins_lst) - if isinstance(res, Dict): - out.update(res) - else: - list_out.append(res) - # else: - # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") - if len(out) > 0 and len(list_out) > 0: - raise ValueError("the return of collate_fns is not the same, must be dict or list") - if len(list_out) == 1: - list_out = list_out[-1] - # print(list_out) - return out if len(out) > 0 else list_out - - def get_collators(self): - return self._collators - - def add_collator(self, collator: Callable): - self._collators.append(collator) - - def set_as_numpy(self, as_numpy: bool): - """ - 存在AutoCollator时,as_numpy控制其返回值的类型 - :param as_numpy: - :return: +class Collator: + def __init__(self, backend='torch'): """ - for collator in self._collators: - if isinstance(collator, AutoCollator): - collator.set_as_numpy(as_numpy) - return self + 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 + 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 + 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 - def set_pad_val(self, *field_names, val=0): + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 + 若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 + """ + self.unpack_batch_func = None + self.pack_batch_func = None + self.ignore_fields = set() + self.padders = {} + self.input_fields = {} + self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 + self.set_backend(backend) + + def __call__(self, batch)->Union[List, Dict]: """ - 存在AutoCollator时,设置field_name的padding值 + batch可能存在三种可能性 + List[Dict], List[List], List[Sample] + + 第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 + 第二步:使用每个 field 各自的 padder 进行 pad 。 + 第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 + + 第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 + list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample + 的类别。 + 第一次调用会根据当前 field 决定对应的 Padder 。 - :param field_names: 数据集的field名 - :param val: padding的值 - :return: """ - flag = True - for collator in self._collators: - if isinstance(collator, AutoCollator): - collator.set_pad_val(*field_names, val=val) - flag = False - if flag: - warnings.warn("AutoCollator is remove, set_padding is unavailable!!") + if self.unpack_batch_func is None: + # 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 + if self.batch_data_type is None: + if isinstance(batch[0], Mapping): + self.batch_data_type = 'd' + elif isinstance(batch[0], Sequence): # 这里存在误判的风险 + self.batch_data_type = 'l' + else: + self.batch_data_type = 's' + logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " + f"is `{self.batch_data_type}`.") + if self.batch_data_type == 's': + self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 + self.pack_batch_func = lambda x: x['_single'] + elif self.batch_data_type == 'l': + self.unpack_batch_func = unpack_batch_sequence + self.pack_batch_func = pack_batch_sequence + elif self.batch_data_type == 'd': + if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} + self.unpack_batch_func = unpack_batch_nested_mapping + self.pack_batch_func = pack_batch_nested_mapping + else: + self.unpack_batch_func = unpack_batch_mapping + self.pack_batch_func = lambda x:x + # 在这里用ignore_field过滤掉 + if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 + unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) + else: + unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 + + pad_batch = {} + if len(self.padders)==0: # 第一次运行,准备 padder + for key in unpack_batch.keys(): + if key not in self.input_fields and key not in self.ignore_fields: + self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} + + for field_name, setting in self.input_fields.items(): + pad_fn = setting.get('pad_fn', None) + if callable(pad_fn): + padder = pad_fn + else: + batch_field = unpack_batch.get(field_name) + padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], + dtype=setting['dtype'], backend=setting['backend'], + field_name=field_name) + self.padders[field_name] = padder + if self.batch_data_type == 'l': + self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 + + for key, padder in self.padders.items(): + batch = unpack_batch.get(key) + pad_batch[key] = padder(batch) + + return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 + + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> "Collator": + """ + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, + paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 + """ + self.padders.clear() # 重新生成 + + if self.batch_data_type is not None: + if self.batch_data_type == 's': + logger.debug("Set as single field mode.") + self.input_fields.clear() + elif self.batch_data_type == 'd': + assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ + f"index, but other field is set as dict mode." + elif self.batch_data_type == 'l': + assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ + f"field name is {field_name}." + + if field_name == '_single': + self.batch_data_type = 's' + elif isinstance(field_name, str) and sequence_idx_str.match(field_name): + self.batch_data_type = 'l' + else: + self.batch_data_type = 'd' + + if field_name in self.ignore_fields: + logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") + if backend is None: + backend = self.backend + else: + assert backend in SUPPORTED_BACKENDS + + self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} + return self - def set_input(self, *field_names): + def set_backend(self, backend:str): """ - 设置AutoCollator需要的field_names,未被设置默认过滤掉 + 设置可以 pad 的 field 默认 pad 为什么类型的 tensor - :param field_names: + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], + 若为 None ,则不进行 padding 。 :return: """ - flag = True - for collator in self._collators: - if isinstance(collator, AutoCollator): - collator.set_input(*field_names) - flag = False - if flag: - warnings.warn("AutoCollator is removed, set_input is unavailable!!") + assert backend in SUPPORTED_BACKENDS + self.padders.clear() + self.backend = backend + + def set_ignore(self, *field_names) -> "Collator": + """ + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: 返回 Collator 自身 + """ + for field_name in field_names: + if field_name in self.input_fields: + self.input_fields.pop(field_name) + logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") + self.padders.pop(field_name, None) # 如果有的话,将它的 padder 扔掉。 + self.ignore_fields.add(field_name) + return self -class AutoCollator(Collator): - - def __init__(self, as_numpy: bool): - super(AutoCollator, self).__init__() - self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 - self.need_inputs = set() # 需要的 field name - self.field_dtypes = None # 每列数据单元的 dtype 类型 - self.field_dims = None # 每列数据单元维度 - self.as_numpy = as_numpy - - def __call__(self, ins_lst: List[Dict]) -> dict: - if len(self.need_inputs) == 0: - raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) - # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 - - # 第一种情况,设置了 set_input 的值 - # 第二种情况, 根据数据的类型的判断是否 padding - if self.field_dtypes is None and self.field_dims is None: - field_dtypes, field_dims = {}, {} - for key, value in ins_lst[0].items(): - if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: - field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) - self.field_dtypes = field_dtypes - self.field_dims = field_dims - - pack_ins_lst, pad_ins_lst = {field_name: [] - for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} - # 将 list 列表内数据按列名打包 - for per_ins in ins_lst: - for field_name, _field_content in per_ins.items(): - if field_name in self.need_inputs: - pack_ins_lst[field_name].append(_field_content) - - pad_field_kv = {field_name: 0 for field_name in self.need_inputs} - pad_field_kv.update(self.pad_field_value) - self.pad_field_value = pad_field_kv - - if len(self.pad_field_value.keys()) > 0: - # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 - non_pad_field_names = [] - for k, v in self.pad_field_value.items(): - if v is None: - non_pad_field_names.append(k) - - # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) - for field_name in non_pad_field_names: - field_array = pack_ins_lst.pop(field_name) - pad_ins_lst[field_name] = np.array(field_array) - - for field_name, field_array in pack_ins_lst.items(): - content = pad_content(field_array, field_name, self.field_dtypes[field_name], - self.field_dims[field_name], - self.pad_field_value[field_name], - as_numpy=self.as_numpy) - pad_ins_lst[field_name] = content - - # else: - # # 取出每列的数据,根据类型判断是否能 pad - # for field_name, field_array in pack_ins_lst.items(): - # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], - # self.field_dims[field_name], - # pad_val=0, as_numpy=self.as_numpy) - # pad_ins_lst[field_name] = pad_field_array - - return pad_ins_lst - - def set_pad_val(self, *field_names, val=0): - for field_name in field_names: - self.pad_field_value[field_name] = val - def set_as_numpy(self, as_numpy: bool): - self.as_numpy = as_numpy - def set_input(self, *field_names): - for field_name in field_names: - self.need_inputs.add(field_name) - - -def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): - - if field_type: - # 不处理, 返回 np.array 类型 - if field_dim > 3: - return np.array(content) - # 元素类型为数值类型 np.int64, np.float64, int, float 等 - if isinstance(field_type, type) and \ - (issubclass(field_type, np.number) or issubclass(field_type, Number)): - if field_dim == 0: - array = np.array(content, dtype=field_type) - elif field_dim == 1: - max_len = max(map(len, content)) - array = np.full((len(content), max_len), pad_val, dtype=field_type) - for i, content_i in enumerate(content): - array[i, :len(content_i)] = content_i - elif field_dim == 2: - max_len = max(map(len, content)) - max_word_len = max([max([len(content_ii) for content_ii in content_i]) for - content_i in content]) - array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) - for i, content_i in enumerate(content): - for j, content_ii in enumerate(content_i): - array[i, j, :len(content_ii)] = content_ii - else: - shape = np.shape(content) - if len(shape) == 4: # 说明各 dimension 是相同的大小 - array = np.array(content, dtype=field_type) - else: - raise RuntimeError( - f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") - if as_numpy is False: - array = torch.tensor(array) - return array - # 元素类型为数值类型 torch.float 等 - elif str(field_type).startswith('torch'): - if field_dim == 0: - tensor = torch.tensor(content).to(field_type) - elif field_dim == 1: - max_len = max(map(len, content)) - tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) - for i, content_i in enumerate(content): - tensor[i, :len(content_i)] = content_i.clone().detach() - elif field_dim == 2: - max_len = max(map(len, content)) - max_word_len = max([max([len(content_ii) for content_ii in content_i]) for - content_i in content]) - tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, - dtype=field_type) - for i, content_i in enumerate(content): - for j, content_ii in enumerate(content_i): - tensor[i, j, :len(content_ii)] = content_ii.clone().detach() - else: - shapes = set([np.shape(content_i) for content_i in content]) - if len(shapes) > 1: - raise RuntimeError( - f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") - shape = shapes.pop() - if len(shape) == 3: - tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, - dtype=field_type) - for i, content_i in enumerate(content): - tensor[i] = content_i.clone().detach().to(field_type) - else: - raise RuntimeError( - f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") - return tensor - # TODO 增加jittor/paddle? - elif str(field_type).startswith('paddle'): - if field_dim == 0: - tensor = paddle.Tensor(content).to(field_type) - elif field_dim == 1: - max_len = max(map(len, content)) - tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) - for i, content_i in enumerate(content): - tensor[i, :len(content_i)] = content_i.clone().detach() - elif field_dim == 2: - max_len = max(map(len, content)) - max_word_len = max([max([len(content_ii) for content_ii in content_i]) for - content_i in content]) - tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, - dtype=field_type) - for i, content_i in enumerate(content): - for j, content_ii in enumerate(content_i): - tensor[i, j, :len(content_ii)] = content_ii.clone().detach() - else: - shapes = set([np.shape(content_i) for content_i in content]) - if len(shapes) > 1: - raise RuntimeError( - f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") - shape = shapes.pop() - if len(shape) == 3: - tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, - dtype=field_type) - for i, content_i in enumerate(content): - tensor[i] = content_i.clone().detach().to(field_type) - else: - raise RuntimeError( - f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") - return tensor - else: - return np.array(content) # 不进行任何操作 - else: - return np.array(content) +# +# from abc import ABCMeta, abstractmethod +# from typing import Any, Dict, List, Callable, Union, Tuple +# from numbers import Number +# import warnings +# +# import numpy as np +# +# from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +# +# if _NEED_IMPORT_PADDLE: +# import paddle +# +# if _NEED_IMPORT_TORCH: +# import torch +# +# +# class ApplyResultException(Exception): +# def __init__(self, msg, index=None): +# super().__init__(msg) +# self.msg = msg +# self.index = index # 标示在哪个数据遭遇到问题了 +# +# +# class SetInputOrTargetException(Exception): +# def __init__(self, msg, index=None, field_name=None): +# super().__init__(msg) +# self.msg = msg +# self.index = index # 标示在哪个数据遭遇到问题了 +# self.field_name = field_name # 标示当前 field 的名称 +# +# +# def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: +# r""" +# 识别cell的类别与dimension的数量 +# +# numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html +# :param cell: +# :param dim: +# :return: +# """ +# if isinstance(cell, (str, Number, np.bool_)): +# if hasattr(cell, 'dtype'): +# return cell.dtype.type, dim +# return type(cell), dim +# +# elif isinstance(cell, list): +# dim += 1 +# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] +# types = set([i for i, j in res]) +# dims = set([j for i, j in res]) +# if len(types) > 1: +# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) +# elif len(types) == 0: +# raise SetInputOrTargetException("Empty value encountered.") +# if len(dims) > 1: +# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) +# return types.pop(), dims.pop() +# +# elif isinstance(cell, torch.Tensor): +# return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 +# +# elif isinstance(cell, paddle.Tensor): +# return cell.dtype, cell.dim() + dim +# +# elif isinstance(cell, np.ndarray): +# if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 +# return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 +# # 否则需要继续往下 iterate +# dim += 1 +# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] +# types = set([i for i, j in res]) +# dims = set([j for i, j in res]) +# if len(types) > 1: +# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) +# elif len(types) == 0: +# raise SetInputOrTargetException("Empty value encountered.") +# if len(dims) > 1: +# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) +# return types.pop(), dims.pop() +# +# else: # 包含 tuple, set, dict 以及其它的类型 +# raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") +# +# +# def _get_ds_type_dim(ds: dict): +# # 获取数据集第一行的 field 内部函数的类型和维度 +# field_dtype, field_dim = {}, {} +# for field_name, field_content in ds.items(): +# type_0, dim_0 = _get_ele_type_and_dim(field_content) +# field_dtype[field_name], field_dim[field_name] = type_0, dim_0 +# return field_dtype, field_dim +# +# +# class Collator(metaclass=ABCMeta): +# r""" +# 辅助DataLoader管理collate_fn的类 +# +# """ +# +# def __init__(self): +# super(Collator, self).__init__() +# self.collate_fn = [] +# +# @abstractmethod +# def __call__(self, ins_lst: List) -> Any: +# raise NotImplementedError +# +# @abstractmethod +# def set_pad_val(self, *field_names: str, value=0): +# raise NotImplementedError +# +# +# class _MultiCollator: +# """ +# 管理所有collator的容器, +# 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 +# """ +# +# def __init__(self, collate_fns: Union[Callable, List[Callable], None]): +# +# if collate_fns is None: +# collate_fns = [] +# +# if isinstance(collate_fns, Callable): +# collate_fns = [collate_fns] +# +# self._collators: list = collate_fns +# +# def __call__(self, ins_lst) -> Dict: +# out, list_out = {}, [] +# for idx, _collate_fn in enumerate(self._collators): +# res = _collate_fn(ins_lst) +# if isinstance(res, Dict): +# out.update(res) +# else: +# list_out.append(res) +# # else: +# # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") +# if len(out) > 0 and len(list_out) > 0: +# raise ValueError("the return of collate_fns is not the same, must be dict or list") +# if len(list_out) == 1: +# list_out = list_out[-1] +# # print(list_out) +# return out if len(out) > 0 else list_out +# +# def get_collators(self): +# return self._collators +# +# def add_collator(self, collator: Callable): +# self._collators.append(collator) +# +# def set_as_numpy(self, as_numpy: bool): +# """ +# 存在AutoCollator时,as_numpy控制其返回值的类型 +# +# :param as_numpy: +# :return: +# """ +# for collator in self._collators: +# if isinstance(collator, AutoCollator): +# collator.set_as_numpy(as_numpy) +# return self +# +# def set_pad_val(self, *field_names, val=0): +# """ +# 存在AutoCollator时,设置field_name的padding值 +# +# :param field_names: 数据集的field名 +# :param val: padding的值 +# :return: +# """ +# flag = True +# for collator in self._collators: +# if isinstance(collator, AutoCollator): +# collator.set_pad_val(*field_names, val=val) +# flag = False +# if flag: +# warnings.warn("AutoCollator is remove, set_padding is unavailable!!") +# return self +# +# def set_input(self, *field_names): +# """ +# 设置AutoCollator需要的field_names,未被设置默认过滤掉 +# +# :param field_names: +# :return: +# """ +# flag = True +# for collator in self._collators: +# if isinstance(collator, AutoCollator): +# collator.set_input(*field_names) +# flag = False +# if flag: +# warnings.warn("AutoCollator is removed, set_input is unavailable!!") +# return self +# +# +# class AutoCollator(Collator): +# +# def __init__(self, as_numpy: bool): +# super(AutoCollator, self).__init__() +# self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 +# self.need_inputs = set() # 需要的 field name +# self.field_dtypes = None # 每列数据单元的 dtype 类型 +# self.field_dims = None # 每列数据单元维度 +# self.as_numpy = as_numpy +# +# def __call__(self, ins_lst: List[Dict]) -> dict: +# if len(self.need_inputs) == 0: +# raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) +# # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 +# +# # 第一种情况,设置了 set_input 的值 +# # 第二种情况, 根据数据的类型的判断是否 padding +# if self.field_dtypes is None and self.field_dims is None: +# field_dtypes, field_dims = {}, {} +# for key, value in ins_lst[0].items(): +# if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: +# field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) +# self.field_dtypes = field_dtypes +# self.field_dims = field_dims +# +# pack_ins_lst, pad_ins_lst = {field_name: [] +# for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} +# # 将 list 列表内数据按列名打包 +# for per_ins in ins_lst: +# for field_name, _field_content in per_ins.items(): +# if field_name in self.need_inputs: +# pack_ins_lst[field_name].append(_field_content) +# +# pad_field_kv = {field_name: 0 for field_name in self.need_inputs} +# pad_field_kv.update(self.pad_field_value) +# self.pad_field_value = pad_field_kv +# +# if len(self.pad_field_value.keys()) > 0: +# # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 +# non_pad_field_names = [] +# for k, v in self.pad_field_value.items(): +# if v is None: +# non_pad_field_names.append(k) +# +# # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) +# for field_name in non_pad_field_names: +# field_array = pack_ins_lst.pop(field_name) +# pad_ins_lst[field_name] = np.array(field_array) +# +# for field_name, field_array in pack_ins_lst.items(): +# content = pad_content(field_array, field_name, self.field_dtypes[field_name], +# self.field_dims[field_name], +# self.pad_field_value[field_name], +# as_numpy=self.as_numpy) +# pad_ins_lst[field_name] = content +# +# # else: +# # # 取出每列的数据,根据类型判断是否能 pad +# # for field_name, field_array in pack_ins_lst.items(): +# # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], +# # self.field_dims[field_name], +# # pad_val=0, as_numpy=self.as_numpy) +# # pad_ins_lst[field_name] = pad_field_array +# +# return pad_ins_lst +# +# def set_pad_val(self, *field_names, val=0): +# for field_name in field_names: +# self.pad_field_value[field_name] = val +# +# def set_as_numpy(self, as_numpy: bool): +# self.as_numpy = as_numpy +# +# def set_input(self, *field_names): +# for field_name in field_names: +# self.need_inputs.add(field_name) +# +# +# def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): +# +# if field_type: +# # 不处理, 返回 np.array 类型 +# if field_dim > 3: +# return np.array(content) +# # 元素类型为数值类型 np.int64, np.float64, int, float 等 +# if isinstance(field_type, type) and \ +# (issubclass(field_type, np.number) or issubclass(field_type, Number)): +# if field_dim == 0: +# array = np.array(content, dtype=field_type) +# elif field_dim == 1: +# max_len = max(map(len, content)) +# array = np.full((len(content), max_len), pad_val, dtype=field_type) +# for i, content_i in enumerate(content): +# array[i, :len(content_i)] = content_i +# elif field_dim == 2: +# max_len = max(map(len, content)) +# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for +# content_i in content]) +# array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) +# for i, content_i in enumerate(content): +# for j, content_ii in enumerate(content_i): +# array[i, j, :len(content_ii)] = content_ii +# else: +# shape = np.shape(content) +# if len(shape) == 4: # 说明各 dimension 是相同的大小 +# array = np.array(content, dtype=field_type) +# else: +# raise RuntimeError( +# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") +# if as_numpy is False: +# array = torch.tensor(array) +# return array +# # 元素类型为数值类型 torch.float 等 +# elif str(field_type).startswith('torch'): +# if field_dim == 0: +# tensor = torch.tensor(content).to(field_type) +# elif field_dim == 1: +# max_len = max(map(len, content)) +# tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) +# for i, content_i in enumerate(content): +# tensor[i, :len(content_i)] = content_i.clone().detach() +# elif field_dim == 2: +# max_len = max(map(len, content)) +# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for +# content_i in content]) +# tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, +# dtype=field_type) +# for i, content_i in enumerate(content): +# for j, content_ii in enumerate(content_i): +# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() +# else: +# shapes = set([np.shape(content_i) for content_i in content]) +# if len(shapes) > 1: +# raise RuntimeError( +# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") +# shape = shapes.pop() +# if len(shape) == 3: +# tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, +# dtype=field_type) +# for i, content_i in enumerate(content): +# tensor[i] = content_i.clone().detach().to(field_type) +# else: +# raise RuntimeError( +# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") +# return tensor +# # TODO 增加jittor/paddle? +# elif str(field_type).startswith('paddle'): +# if field_dim == 0: +# tensor = paddle.Tensor(content).to(field_type) +# elif field_dim == 1: +# max_len = max(map(len, content)) +# tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) +# for i, content_i in enumerate(content): +# tensor[i, :len(content_i)] = content_i.clone().detach() +# elif field_dim == 2: +# max_len = max(map(len, content)) +# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for +# content_i in content]) +# tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, +# dtype=field_type) +# for i, content_i in enumerate(content): +# for j, content_ii in enumerate(content_i): +# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() +# else: +# shapes = set([np.shape(content_i) for content_i in content]) +# if len(shapes) > 1: +# raise RuntimeError( +# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") +# shape = shapes.pop() +# if len(shape) == 3: +# tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, +# dtype=field_type) +# for i, content_i in enumerate(content): +# tensor[i] = content_i.clone().detach().to(field_type) +# else: +# raise RuntimeError( +# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") +# return tensor +# +# else: +# return np.array(content) # 不进行任何操作 +# else: +# return np.array(content) diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 051a0ffc..d6c7f40c 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -27,7 +27,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> :param field_name: 方便报错的。 :return: """ - logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field)) + logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) if pad_val is None: logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") return NullPadder() diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py new file mode 100644 index 00000000..83784cfe --- /dev/null +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -0,0 +1,174 @@ + +from inspect import isclass +import numpy as np + +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE + +if _NEED_IMPORT_PADDLE: + import paddle + numpy_to_paddle_dtype_dict = { + np.bool_: 'bool', + np.uint8: 'uint8', + np.int8: "int8", + np.int16: "int16", + np.int32: "int32", + np.int64: "int64", + np.float16: "float16", + np.float32: 'float32', + np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 + np.complex64: 'complex64', + np.complex128: "complex128" + } + number_to_paddle_dtype_dict = { + float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64 + int: 'int64', + bool: 'bool' + } + +from .padder import Padder +from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class +from .exceptions import * + + +def is_paddle_tensor(dtype): + if not isclass(dtype) and isinstance(dtype, paddle.dtype): + return True + + return False + + +def is_paddle_dtype_str(dtype): + try: + if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', + 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', + u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8', + u'int16', u'int32', u'int64', u'uint8', u'complex64', + u'complex128'}: + return True + except: + pass + return False + + + +def _get_dtype(ele_dtype, dtype, class_name): + if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") + + if dtype is not None: + if not (is_paddle_tensor(dtype) or is_number(dtype) or is_paddle_dtype_str(dtype)): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " + f"or paddle.dtype but get `{dtype}`.") + dtype = number_to_paddle_dtype_dict.get(dtype, dtype) + else: + if (is_number(ele_dtype) or is_paddle_tensor(ele_dtype)): + ele_dtype = number_to_paddle_dtype_dict.get(ele_dtype, ele_dtype) + dtype = ele_dtype + elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 + dtype = numpy_to_paddle_dtype_dict.get(ele_dtype.type) + elif is_numpy_generic_class(ele_dtype): + dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) + else: + dtype == ele_dtype + + return dtype + + +class paddleNumberPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + return paddle.to_tensor(batch_field, dtype=dtype) + + +class paddleSequencePadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) + return tensor + + +class paddleTensorPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + """ + 目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的 + + :param ele_dtype: + :param pad_val: + :param dtype: + """ + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + shapes = [field.shape for field in batch_field] + max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] + if isinstance(dtype, np.dtype): + print(dtype) + tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) + for i, field in enumerate(batch_field): + slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) + if isinstance(field, np.ndarray): + field = paddle.to_tensor(field) + tensor[slices] = field + return tensor + + +def fill_tensor(batch_field, padded_batch, dtype): + """ + 将 batch_field 中的值填入到 tensor 中。 + + :param batch_field: 需要填充进入 array 中的内容 + :param padded_batch: 待填充的 tensor + :param dtype: 数据的类别 + + :return: + """ + if padded_batch.ndim == 2: + for i, content_i in enumerate(batch_field): + padded_batch[i, :len(content_i)] = paddle.Tensor(content_i, dtype=dtype) + elif padded_batch.ndim == 3: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + padded_batch[i, j, :len(content_ii)] = paddle.Tensor(content_ii, dtype=dtype) + elif padded_batch.ndim == 4: + try: # 应该是图像,所以直接应该就 ok 了。 + padded_batch = np.array(batch_field) + except: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + for k, content_iii in enumerate(content_ii): + padded_batch[i, j, k, :len(content_iii)] = paddle.Tensor(content_iii, dtype=dtype) + elif padded_batch.ndim == 1: + padded_batch[:] = paddle.Tensor(batch_field, dtype=dtype) + else: + raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " + "report.") + return padded_batch + + +def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): + """ + 例如: + [[1,2], [3]] -> paddle.LongTensor([[1, 2], [3, 0]]) + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param dtype: 目标类别是什么 + :param pad_val: pad 的 value + :return: + """ + shapes = get_shape(batch_field) + tensor = paddle.full(shapes, dtype=dtype, fill_value=pad_val) + tensor = fill_tensor(batch_field, tensor, dtype=dtype) + return tensor diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 2cf85fd8..3e9cf17a 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -3,16 +3,17 @@ __all__ = [ 'prepare_jittor_dataloader' ] -from typing import Callable, Optional, List +from typing import Callable, Optional, List, Union from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + if _NEED_IMPORT_JITTOR: from jittor.dataset.utils import collate_batch from jittor.dataset import Dataset else: from fastNLP.core.dataset import DataSet as Dataset from fastNLP.core.utils.jittor_utils import jittor_collate_wraps -from fastNLP.core.collators import AutoCollator +from fastNLP.core.collators import Collator from fastNLP.core.utils.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet @@ -48,7 +49,7 @@ class JittorDataLoader: def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, - collate_fn: Callable = None) -> None: + collate_fn: Union[None, str, Callable] = "auto") -> None: """ :param dataset: 实现__getitem__和__len__的dataset @@ -66,11 +67,20 @@ class JittorDataLoader: # TODO 支持fastnlp dataset # TODO 验证支持replacesampler (以后完成) # 是否为 jittor 类型的 dataset - - if isinstance(dataset, FDataSet): - collator = dataset.get_collator().set_as_numpy(as_numpy=True) + if isinstance(collate_fn, str): + if collate_fn == "auto": + if isinstance(dataset, FDataSet): + self._collate_fn = dataset.collator + self._collate_fn.set_backend(backend="jittor") + else: + self._collate_fn = Collator(backend="jittor") + else: + raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") + elif isinstance(collate_fn, Callable): + if collate_fn is not collate_batch: + self._collate_fn = collate_fn else: - collator = None + self._collate_fn = collate_batch self.dataset = _JittorDataset(dataset) @@ -80,17 +90,13 @@ class JittorDataLoader: if isinstance(self.dataset.dataset, Dataset): self.dataset.dataset.set_attrs(batch_size=1) # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 - self.collate_fn = collate_fn - if self.collate_fn is None: - self.collate_fn = collate_batch - self.auto_collator = collator - self.cur_batch_indices = None + # self._collate_fn = _collate_fn def __iter__(self): # TODO 第一次迭代后不能设置collate_fn,设置是无效的 + self.collate_fn = self._collate_fn if self.cur_batch_indices is None: - self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn, - self.auto_collator))) + self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) for indices, data in self.dataset.__iter__(): self.cur_batch_indices = indices yield data @@ -100,30 +106,48 @@ class JittorDataLoader: return len(self.dataset) // self.dataset.batch_size return (len(self.dataset) - 1) // self.dataset.batch_size + 1 - def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> "JittorDataLoader": """ - 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 - 当val=None时,意味着给定的field_names都不需要尝试padding - - :param field_names: - :param val: padding值,默认为0 - :return: + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, + paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ - if self.auto_collator is None: - self.auto_collator = AutoCollator(as_numpy=True) - self.auto_collator.set_pad_val(*field_names, val=val) + if isinstance(self._collate_fn, Collator): + self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, + backend=backend) + return self + else: + raise ValueError(f"collate_fn is not fastnlp collator") - def set_input(self, *field_names) -> None: + def set_ignore(self, *field_names) -> "JittorDataLoader": """ - 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 - - :param field_names: - :return: + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: 返回 Collator 自身 """ - if self.auto_collator is None: - self.auto_collator = AutoCollator(as_numpy=True) - - self.auto_collator.set_input(*field_names) + if isinstance(self._collate_fn, Collator): + self._collate_fn.set_ignore(*field_names) + return self + else: + raise ValueError(f"collate_fn is not fastnlp collator") def get_batch_indices(self) -> List[int]: """ diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index b54b9cff..b4b675c4 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -6,6 +6,7 @@ __all__ = [ from typing import Callable, List, Optional, Union, Dict, Sequence from fastNLP.envs.imports import _NEED_IMPORT_PADDLE + if _NEED_IMPORT_PADDLE: from paddle.io import DataLoader, Dataset from paddle.fluid.dataloader.collate import default_collate_fn @@ -13,7 +14,7 @@ else: from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as DataLoader -from fastNLP.core.collators.collator import _MultiCollator +from fastNLP.core.collators.collator import Collator from fastNLP.core.utils.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet @@ -45,7 +46,7 @@ class PaddleDataLoader(DataLoader): def __init__(self, dataset, feed_list=None, places=None, return_list: bool = True, batch_sampler=None, batch_size: int = 1, shuffle: bool = False, - drop_last: bool = False, collate_fn: Callable = None, + drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, worker_init_fn: Callable = None, persistent_workers=False) -> None: @@ -60,13 +61,23 @@ class PaddleDataLoader(DataLoader): use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) - if isinstance(dataset.dataset, FDataSet): - self._collate_fn = dataset.dataset.get_collator() - self._collate_fn.set_as_numpy(as_numpy=True) - if collate_fn is not None: - self._collate_fn.add_collator(collate_fn) + if isinstance(collate_fn, str): + if collate_fn == 'auto': + if isinstance(dataset.dataset, FDataSet): + self._collate_fn = dataset.dataset.collator + self._collate_fn.set_backend(backend="paddle") + # if collate_fn is not None: + # self._collate_fn.add_collator(collate_fn) + else: + self._collate_fn = Collator(backend="paddle") + + else: + raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") + elif isinstance(collate_fn, Callable): + if collate_fn is not default_collate_fn: + self._collate_fn = collate_fn else: - self._collate_fn = _MultiCollator(collate_fn) + self._collate_fn = default_collate_fn # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) # if collate_fn is not None: # _collate_fn.add_collator(collate_fn) @@ -75,64 +86,56 @@ class PaddleDataLoader(DataLoader): def __iter__(self): # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 - if len(self._collate_fn.get_collators()) == 0: - self._collate_fn.add_collator(default_collate_fn) - # self._collate_fn = default_collate_fn + # if len(self._collate_fn.get_collators()) == 0: + # self._collate_fn.add_collator(default_collate_fn) + # self._collate_fn = default_collate_fn self.collate_fn = indice_collate_wrapper(self._collate_fn) for indices, data in super().__iter__(): self.cur_batch_indices = indices yield data - def __getattr__(self, item): - """ - 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 - - :param item: - :return: - """ - try: - return self.dataset.__getattr__(item) - except AttributeError as e: - raise e - - def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: - """ - 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 - 当val=None时,意味着给定的field_names都不需要尝试padding - - :param field_names: - :param val: padding值,默认为0 - :return: - """ - for field_name in field_names: - self._collate_fn.set_pad_val(field_name, val=val) - - def set_input(self, *field_names) -> None: - """ - 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 - - :param field_names: - :return: - """ - self._collate_fn.set_input(*field_names) - - def set_collator(self, collator: Callable) -> None: + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> "PaddleDataLoader": """ - 设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate - - :param collator: 用户自定义的Callable函数 - :return: + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, + paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ - self._collate_fn = _MultiCollator(collator) + if isinstance(self._collate_fn, Collator): + self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, + backend=backend) + return self + else: + raise ValueError(f"collate_fn is not fastnlp collator") - def add_collator(self, collator) -> None: + def set_ignore(self, *field_names) -> "PaddleDataLoader": """ - 添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 - - :param collator: - :return: + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: 返回 Collator 自身 """ - self._collate_fn.add_collator(collator) + if isinstance(self._collate_fn, Collator): + self._collate_fn.set_ignore(*field_names) + return self + else: + raise ValueError(f"collate_fn is not fastnlp collator") def get_batch_indices(self) -> List[int]: """ @@ -144,20 +147,21 @@ class PaddleDataLoader(DataLoader): def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, - return_list: bool = True, batch_sampler=None, - train_batch_size: int = 1, shuffle: bool = False, - drop_last: bool = False, collate_fn: Callable = None, - num_workers: int = 0, use_buffer_reader: bool = True, - use_shared_memory: bool = True, timeout: int = 0, - worker_init_fn: Callable = None, persistent_workers=False, - non_train_batch_size: int = 16, - input_fields: Union[List[str], str] = None)\ - -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: - if isinstance(input_fields, str): - input_fields = [input_fields] - + return_list: bool = True, batch_sampler=None, + train_batch_size: int = 1, shuffle: bool = False, + drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, + num_workers: int = 0, use_buffer_reader: bool = True, + use_shared_memory: bool = True, timeout: int = 0, + worker_init_fn: Callable = None, persistent_workers=False, + non_train_batch_size: int = 16) \ + -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: if isinstance(ds_or_db, Dataset): - ... + dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, + batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) + return dl elif isinstance(ds_or_db, Sequence): ds_seq = [] for ds in ds_or_db: @@ -166,7 +170,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) - dl.set_input(*input_fields) ds_seq.append(dl) return ds_seq @@ -178,14 +181,15 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, - timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) else: dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, - timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) - dl.set_input(*input_fields) + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) ds_dict[name] = dl return ds_dict else: diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 02721aaf..689d24b1 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -6,8 +6,7 @@ __all__ = [ from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping from fastNLP.core.dataset import DataSet -from fastNLP.core.collators import AutoCollator -from fastNLP.core.collators.collator import _MultiCollator +from fastNLP.core.collators import Collator from fastNLP.core.utils.utils import indice_collate_wrapper from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH @@ -51,11 +50,11 @@ class TorchDataLoader(DataLoader): def __init__(self, dataset, batch_size: int = 1, shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, - num_workers: int = 0, collate_fn: Optional[Callable] = None, + num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, - persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: + persistent_workers: bool = False, **kwargs) -> None: """ :param dataset: 实现了__getitem__和__len__的数据容器 @@ -64,7 +63,7 @@ class TorchDataLoader(DataLoader): :param sampler: sampler实例化对象 :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 :param num_workers: 进程的数量,当num_worker=0时不开启多进程 - :param collate_fn: 对取得到的数据进行打包的callable函数 + :param collate_fn: [None, 'auto', callable] 对取得到的数据进行打包的callable函数 :param pin_memory: :param drop_last: 是否去掉最后一个不符合batch_size的数据 :param timeout: @@ -73,7 +72,6 @@ class TorchDataLoader(DataLoader): :param generator: :param prefetch_factor: :param persistent_workers: - :param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 """ if not isinstance(dataset, _FDataSet): dataset = _FDataSet(dataset) @@ -84,91 +82,76 @@ class TorchDataLoader(DataLoader): multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers) - if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset - self._collate_fn = dataset.dataset.get_collator() - self._collate_fn.set_as_numpy(as_numpy) - if collate_fn is not None and collate_fn is not default_collate: - # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 - self._collate_fn.add_collator(collate_fn) + if isinstance(collate_fn, str): + if collate_fn == 'auto': + if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset + self._collate_fn = dataset.dataset.collator + self._collate_fn.set_backend(backend="torch") + # if collate_fn is not None and collate_fn is not default_collate: + # # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 + # self._collate_fn.add_collator(collate_fn) + else: + self._collate_fn = Collator(backend='torch') + else: + raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") + elif isinstance(collate_fn, Callable): + if collate_fn is not default_collate: + self._collate_fn = collate_fn else: - self._collate_fn = _MultiCollator(collate_fn) + self._collate_fn = default_collate self.cur_indices_batch = None - self.as_numpy = as_numpy - - def __getattr__(self, item): - """ - 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 - - :param item: - :return: - """ - try: - return self.dataset.__getattr__(item) - except AttributeError as e: - raise e def __iter__(self): # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 - if len(self._collate_fn.get_collators()) == 0: - self._collate_fn.add_collator(self.collate_fn) + # if len(self._collate_fn.get_collators()) == 0: + # self._collate_fn.add_collator(self.collate_fn) self.collate_fn = indice_collate_wrapper(self._collate_fn) for indices, data in super().__iter__(): self.cur_batch_indices = indices yield data - def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> "TorchDataLoader": """ - 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 - 当val=None时,意味着给定的field_names都不需要尝试padding - - :param field_names: - :param val: padding值,默认为0 - :return: - """ - flag = False - for collator in self._collate_fn.get_collators(): - if isinstance(collator, AutoCollator): - flag = True - break - if flag is False: - self._collate_fn.add_collator(AutoCollator(self.as_numpy)) - for field_name in field_names: - self._collate_fn.set_pad_val(field_name, val=val) - - def set_input(self, *field_names) -> None: + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, + paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ - 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 - - :param field_names: - :return: - """ - flag = False - for collator in self._collate_fn.get_collators(): - if isinstance(collator, AutoCollator): - flag = True - break - if flag is False: - self._collate_fn.add_collator(AutoCollator(self.as_numpy)) - self._collate_fn.set_input(*field_names) - - def set_collator(self, collator: Callable) -> None: - """ - 设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate - - :param collator: 用户自定义的Callable函数 - :return: - """ - self._collate_fn = _MultiCollator(collator) + if isinstance(self._collate_fn, Collator): + self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return self + else: + raise ValueError(f"collate_fn is not fastnlp collator") - def add_collator(self, collator) -> None: + def set_ignore(self, *field_names) -> "TorchDataLoader": """ - 添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 - - :param collator: - :return: + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: 返回 Collator 自身 """ - self._collate_fn.add_collator(collator) + if isinstance(self._collate_fn, Collator): + self._collate_fn.set_ignore(*field_names) + return self + else: + raise ValueError(f"collate_fn is not fastnlp collator") def get_batch_indices(self) -> List[int]: """ @@ -183,13 +166,12 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS batch_size: int = 1, shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, batch_sampler: Optional["Sampler[Sequence[int]]"] = None, - num_workers: int = 0, collate_fn: Optional[Callable] = None, + num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, - non_train_batch_size: int = 16, as_numpy: bool = False, - input_fields: Union[List, str, None] = None) \ + non_train_batch_size: int = 16) \ -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: """ 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 @@ -201,7 +183,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS :param sampler: sampler实例化对象 :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 :param num_workers: 进程的数量,当num_worker=0时不开启多进程 - :param collate_fn: 对取得到的数据进行打包的callable函数 + :param collate_fn: ['auto', None, callable]对取得到的数据进行打包的callable函数 :param pin_memory: :param drop_last: 是否去掉最后一个不符合batch_size的数据 :param timeout: @@ -212,11 +194,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS :param persistent_workers: :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler :param non_train_batch_size: - :param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。 """ - # TODO dict, sequence情况下需要提供 - if isinstance(input_fields, str): - input_fields = [input_fields] if isinstance(ds_or_db, DataSet): dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, @@ -225,9 +203,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - as_numpy=as_numpy) - if input_fields: - dl.set_input(*input_fields) + ) return dl elif isinstance(ds_or_db, DataBundle): @@ -241,7 +217,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - as_numpy=as_numpy) + ) else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, shuffle=shuffle, sampler=non_train_sampler, @@ -251,9 +227,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - as_numpy=as_numpy) - if input_fields: - dl_bundle[name].set_input(*input_fields) + ) return dl_bundle elif isinstance(ds_or_db, Sequence): @@ -267,7 +241,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - as_numpy=as_numpy) + ) ) else: dl_bundle.append( @@ -277,11 +251,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - as_numpy=as_numpy) + ) ) - if input_fields: - for dl in dl_bundle: - dl.set_input(*input_fields) return dl_bundle elif isinstance(ds_or_db, Mapping): @@ -295,7 +266,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - as_numpy=as_numpy) + ) else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, shuffle=shuffle, sampler=non_train_sampler, @@ -305,10 +276,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - as_numpy=as_numpy) - - if input_fields: - dl_bundle[name].set_input(*input_fields) + ) return dl_bundle else: diff --git a/fastNLP/core/dataloaders/utils/__init__.py b/fastNLP/core/dataloaders/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index cd887253..9e65ea95 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -23,9 +23,8 @@ except: from .field import FieldArray from .instance import Instance from fastNLP.core.utils.utils import pretty_table_printer, deprecated -from fastNLP.core.collators import AutoCollator +from fastNLP.core.collators import Collator from fastNLP.core.utils.rich_progress import f_rich_progress -from fastNLP.core.collators.collator import _MultiCollator class ApplyResultException(Exception): @@ -114,7 +113,7 @@ class DataSet: 每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 """ self.field_arrays = {} - self.collate_fns: _MultiCollator = _MultiCollator(AutoCollator(as_numpy=False)) + self._collator = Collator(backend="numpy") if data is not None: if isinstance(data, Dict): length_set = set() @@ -181,7 +180,7 @@ class DataSet: dataset = DataSet() for field_name, field in self.field_arrays.items(): dataset.add_field(field_name=field_name, fields=field.content[idx]) - dataset.collate_fns = deepcopy(self.collate_fns) + dataset._collator = deepcopy(self.collator) return dataset elif isinstance(idx, str): if idx not in self: @@ -193,7 +192,7 @@ class DataSet: assert isinstance(i, int), "Only int index allowed." instance = self[i] dataset.append(instance) - dataset.collate_fns = deepcopy(self.collate_fns) + dataset._collator = deepcopy(self.collator) return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) @@ -676,8 +675,8 @@ class DataSet: dev_set.append(self[idx]) for idx in train_indices: train_set.append(self[idx]) - dev_set.collate_fns = deepcopy(self.collate_fns) - train_set.collate_fns = deepcopy(self.collate_fns) + dev_set._collator = deepcopy(self.collator) + train_set._collator = deepcopy(self.collator) return dev_set, train_set @@ -772,63 +771,17 @@ class DataSet: df = self.to_pandas() df.to_csv(path, encoding="utf-8") - def add_collate_fn(self, collate_fn: Callable) -> None: - """ - 添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 - - :param collate_fn: Callable的函数 - :return: - """ - self.collate_fns.add_collator(collate_fn) - - def set_collate_fn(self, collate_fn: Callable) -> None: - """ - 设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate - - :param collate_fn: - :return: - """ - self.collate_fns = _MultiCollator(collate_fn) - - def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: - """ - 设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 - 当val=None时,意味着给定的field_names都不需要尝试padding - - :param field_names: dataset存在的field_name - :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 - :return: - """ - # TODO 不能为空 - for field_name in field_names: - self.collate_fns.set_pad_val(field_name, val=val) - - def set_input(self, *field_names) -> None: - """ - 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 - - :param field_names: - :return: - """ - # - self.collate_fns.set_input(*field_names) - - def get_collator(self) -> _MultiCollator: - """ - 获取dataset绑定的collate_fn,其中包括auto_collate - - :return: - """ - return self.collate_fns - - @deprecated() - def set_target(self, *field_names) -> None: + def set_ignore(self, *field_names) -> None: """ 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 :param field_names: :return: """ - self.collate_fns.set_input(*field_names) - + self.collator.set_ignore(*field_names) + @property + def collator(self): + if self._collator is None: + self._collator = Collator() + return self._collator diff --git a/fastNLP/core/utils/jittor_utils.py b/fastNLP/core/utils/jittor_utils.py index 3784f991..89686cff 100644 --- a/fastNLP/core/utils/jittor_utils.py +++ b/fastNLP/core/utils/jittor_utils.py @@ -7,13 +7,13 @@ from collections.abc import Mapping, Callable from functools import wraps from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + if _NEED_IMPORT_JITTOR: import jittor as jt from fastNLP.core.dataset import Instance - def is_jittor_dataset(dataset) -> bool: try: if isinstance(dataset, jt.dataset.Dataset): @@ -32,6 +32,7 @@ def jittor_collate_wraps(func, auto_collator: Callable): :param auto_collator: :return: """ + @wraps(func) def wrapper(batch): if isinstance(batch[0], Instance): diff --git a/tests/core/collators/padders/test_paddle_padder.py b/tests/core/collators/padders/test_paddle_padder.py new file mode 100644 index 00000000..3674cd48 --- /dev/null +++ b/tests/core/collators/padders/test_paddle_padder.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder +from fastNLP.core.collators.padders.exceptions import DtypeError +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE + +if _NEED_IMPORT_PADDLE: + import paddle + + +@pytest.mark.paddle +class TestpaddleNumberPadder: + def test_run(self): + padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [1, 2, 3] + t_a = padder(a) + assert isinstance(t_a, paddle.Tensor) + assert (t_a == paddle.to_tensor(a, dtype='int64')).sum() == 3 + + +@pytest.mark.paddle +class TestpaddleSequencePadder: + def test_run(self): + padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = a.shape + assert isinstance(a, paddle.Tensor) + assert tuple(shape) == (2, 3) + b = paddle.to_tensor([[1, 2, 3], [3, -1, -1]], dtype='int64') + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) + padder = paddleSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) + a = padder([[1], [2, 322]]) + assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 + padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) + + +@pytest.mark.paddle +class TestpaddleTensorPadder: + def test_run(self): + padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=int, pad_val=-1) + a = [paddle.zeros(3), paddle.zeros(2), paddle.zeros(0)] + a = padder(a) + shape = a.shape + assert isinstance(a, paddle.Tensor) + assert tuple(shape) == (3, 3) + b = paddle.to_tensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]], dtype='int64') + assert (a == b).sum().item() == shape[0]*shape[1] + + a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 2))] + a = padder(a) + shape = a.shape + assert isinstance(a, paddle.Tensor) + assert tuple(shape) == (3, 3, 2) + b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, 0], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 1))] + a = padder(a) + shape = a.shape + assert isinstance(a, paddle.Tensor) + assert tuple(shape) == (3, 3, 2) + b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=int, pad_val=-1) + a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, paddle.Tensor) + assert tuple(shape) == (3, 3, 2) + b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=None, pad_val=-1) + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, paddle.Tensor) + assert tuple(shape) == (3, 3, 2) + b = paddle.FloatTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + def test_dtype_check(self): + padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = paddleTensorPadder(ele_dtype=paddle.long, dtype=int, pad_val=-1) + padder = paddleTensorPadder(ele_dtype=int, dtype=paddle.long, pad_val=-1) + + + diff --git a/tests/core/dataloaders/jittor_dataloader/test_fdl.py b/tests/core/dataloaders/jittor_dataloader/test_fdl.py index f2021923..90eae486 100644 --- a/tests/core/dataloaders/jittor_dataloader/test_fdl.py +++ b/tests/core/dataloaders/jittor_dataloader/test_fdl.py @@ -36,8 +36,8 @@ class TestJittor: """ dataset = MyDataset() jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) - jtl.set_pad_val('x', 'y') - jtl.set_input('x') + # jtl.set_pad_val('x', 'y') + # jtl.set_input('x') for batch in jtl: print(batch) print(jtl.get_batch_indices()) @@ -50,15 +50,17 @@ class TestJittor: """ dataset = Fdataset({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) - jtl.set_pad_val('x', val=-1) - jtl.set_input('x', 'y') + jtl.set_pad("x", -1) + jtl.set_ignore("y") + # jtl.set_pad_val('x', val=-1) + # jtl.set_input('x', 'y') for batch in jtl: assert batch['x'].size() == (16, 4) def test_v3(self): dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) - jtl.set_input('x', 'y') + # jtl.set_input('x', 'y') for batch in jtl: print(batch) diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 83e40610..8a603c51 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -2,6 +2,7 @@ import pytest from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataset import DataSet +from fastNLP.core.log import logger from paddle.io import Dataset, DataLoader import numpy as np import paddle @@ -11,7 +12,7 @@ class RandomDataset(Dataset): def __getitem__(self, idx): image = np.random.random((10, 5)).astype('float32') - return {'image': paddle.Tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]} + return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]} def __len__(self): return 10 @@ -32,23 +33,30 @@ class TestPaddle: def test_fdl_batch_indices(self): ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) - fdl.set_input("x", "y") for batch in fdl: assert len(fdl.get_batch_indices()) == 4 print(batch) print(fdl.get_batch_indices()) def test_set_inputs_and_set_pad_val(self): + logger.setLevel("DEBUG") ds = RandomDataset() fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) - fdl.set_input('image', 'label') - fdl.set_pad_val('label', val=-1) + fdl.set_pad('label', -1) for batch in fdl: + print(batch['image']) assert batch['image'].shape == [2, 10, 5] print(batch) fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) - fdl1.set_input('image', 'label') - fdl1.set_pad_val('image', val=None) + fdl1.set_ignore('image') for batch in fdl1: assert batch['image'].shape == [4, 10, 5] print(batch) + + def test_v2(self): + from fastNLP.core.collators import Collator + logger.setLevel("DEBUG") + data = [paddle.Tensor(np.random.random((10, 5)).astype('float32')), paddle.Tensor(np.random.random((10, 5)).astype('float32'))] + col = Collator(backend="jittor") + res = col(data) + print(res) \ No newline at end of file diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 1b521ca9..52fe48ff 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -13,42 +13,23 @@ class TestFdl: fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) # for batch in fdl: # print(batch) - fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) + fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) # for batch in fdl1: # print(batch) def test_set_padding(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - ds.set_pad_val("x", val=-1) fdl = TorchDataLoader(ds, batch_size=3) - fdl.set_input("x", "y") - fdl.set_pad_val("x", val=None) + fdl.set_pad("x", -1) for batch in fdl: print(batch) # fdl.set_pad_val("x", val=-2) # for batch in fdl: # print(batch) - def test_add_collator(self): - ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - - def collate_fn(ins_list): - _dict = {"Y": []} - for ins in ins_list: - _dict["Y"].append(ins['y']) - return _dict - - fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) - fdl.set_input("x", "y") - # fdl.set_pad_val("x", val=None) - fdl.add_collator(collate_fn) - for batch in fdl: - print(batch) - def test_get_batch_indices(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) - fdl.set_input("y", "x") for batch in fdl: print(fdl.get_batch_indices()) From 755912520e9b0137a55c12aafa5edaecb757f448 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Mon, 2 May 2022 19:26:00 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0paddle=20padder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collators/padders/get_padder.py | 7 ++++ .../core/collators/padders/paddle_padder.py | 23 +++++++----- .../core/dataloaders/torch_dataloader/fdl.py | 4 +- .../collators/padders/test_paddle_padder.py | 37 +++++++++---------- .../dataloaders/paddle_dataloader/test_fdl.py | 2 +- 5 files changed, 42 insertions(+), 31 deletions(-) diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index b5fb1e39..3e136d7d 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -13,6 +13,7 @@ from .padder import Padder, NullPadder from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder from .raw_padder import RawNumberPadder, RawSequencePadder +from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from .exceptions import * @@ -90,6 +91,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'torch': return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + elif backend == 'paddle': + return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 if backend == 'raw': @@ -98,12 +101,16 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'torch': return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + elif backend == 'paddle': + return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) if depth == 1 and shape_len != 0: if backend == 'numpy': return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'torch': return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + elif backend == 'paddle': + return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) if shape_len != 0 and depth>1: msg = "Does not support pad tensor under nested list. If you need this, please report." diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 83784cfe..7a569003 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -1,4 +1,8 @@ - +__all__ = [ + "PaddleNumberPadder", + "PaddleTensorPadder", + "PaddleSequencePadder" +] from inspect import isclass import numpy as np @@ -75,7 +79,7 @@ def _get_dtype(ele_dtype, dtype, class_name): return dtype -class paddleNumberPadder(Padder): +class PaddleNumberPadder(Padder): def __init__(self, ele_dtype, pad_val=0, dtype=None): # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) @@ -86,7 +90,7 @@ class paddleNumberPadder(Padder): return paddle.to_tensor(batch_field, dtype=dtype) -class paddleSequencePadder(Padder): +class PaddleSequencePadder(Padder): def __init__(self, ele_dtype, pad_val=0, dtype=None): dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -97,7 +101,7 @@ class paddleSequencePadder(Padder): return tensor -class paddleTensorPadder(Padder): +class PaddleTensorPadder(Padder): def __init__(self, ele_dtype, pad_val=0, dtype=None): """ 目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的 @@ -136,11 +140,11 @@ def fill_tensor(batch_field, padded_batch, dtype): """ if padded_batch.ndim == 2: for i, content_i in enumerate(batch_field): - padded_batch[i, :len(content_i)] = paddle.Tensor(content_i, dtype=dtype) + padded_batch[i, :len(content_i)] = paddle.to_tensor(content_i, dtype=dtype) elif padded_batch.ndim == 3: for i, content_i in enumerate(batch_field): for j, content_ii in enumerate(content_i): - padded_batch[i, j, :len(content_ii)] = paddle.Tensor(content_ii, dtype=dtype) + padded_batch[i, j, :len(content_ii)] = paddle.to_tensor(content_ii, dtype=dtype) elif padded_batch.ndim == 4: try: # 应该是图像,所以直接应该就 ok 了。 padded_batch = np.array(batch_field) @@ -148,9 +152,9 @@ def fill_tensor(batch_field, padded_batch, dtype): for i, content_i in enumerate(batch_field): for j, content_ii in enumerate(content_i): for k, content_iii in enumerate(content_ii): - padded_batch[i, j, k, :len(content_iii)] = paddle.Tensor(content_iii, dtype=dtype) + padded_batch[i, j, k, :len(content_iii)] = paddle.to_tensor(content_iii, dtype=dtype) elif padded_batch.ndim == 1: - padded_batch[:] = paddle.Tensor(batch_field, dtype=dtype) + padded_batch[:] = paddle.to_tensor(batch_field, dtype=dtype) else: raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " "report.") @@ -169,6 +173,7 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): :return: """ shapes = get_shape(batch_field) - tensor = paddle.full(shapes, dtype=dtype, fill_value=pad_val) + tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) + # tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val) tensor = fill_tensor(batch_field, tensor, dtype=dtype) return tensor diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index e41bd4d2..3ee838c4 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -86,12 +86,12 @@ class TorchDataLoader(DataLoader): if collate_fn == 'auto': if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset self._collate_fn = dataset.dataset.collator - self._collate_fn.set_backend() + self._collate_fn.set_backend(backend="torch") # if collate_fn is not None and collate_fn is not default_collate: # # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 # self._collate_fn.add_collator(collate_fn) else: - self._collate_fn = Collator() + self._collate_fn = Collator(backend="torch") else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") elif isinstance(collate_fn, Callable): diff --git a/tests/core/collators/padders/test_paddle_padder.py b/tests/core/collators/padders/test_paddle_padder.py index f7ef4a07..80abf30a 100644 --- a/tests/core/collators/padders/test_paddle_padder.py +++ b/tests/core/collators/padders/test_paddle_padder.py @@ -32,26 +32,26 @@ class TestpaddleSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) with pytest.raises(DtypeError): padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) - padder = paddleSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) + padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) a = padder([[1], [2, 322]]) - assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 + # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) @pytest.mark.paddle class TestpaddleTensorPadder: def test_run(self): - padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=int, pad_val=-1) - a = [paddle.zeros(3), paddle.zeros(2), paddle.zeros(0)] + padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) + a = [paddle.zeros((3,)), paddle.zeros((2,))] a = padder(a) shape = a.shape assert isinstance(a, paddle.Tensor) - assert tuple(shape) == (3, 3) - b = paddle.to_tensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]], dtype='int64') + assert tuple(shape) == (2, 3) + b = paddle.to_tensor([[0, 0, 0], [0, 0, -1]], dtype='int64') assert (a == b).sum().item() == shape[0]*shape[1] a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 2))] @@ -61,7 +61,7 @@ class TestpaddleTensorPadder: assert tuple(shape) == (3, 3, 2) b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [-1, -1]], - [[0, 0], [-1, -1], [-1, -1]]], dtype='in') + [[0, 0], [-1, -1], [-1, -1]]], dtype='int64') assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 1))] @@ -74,26 +74,25 @@ class TestpaddleTensorPadder: [[0, -1], [-1, -1], [-1, -1]]]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=int, pad_val=-1) - a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 0))] + padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) + a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] a = padder(a) shape = a.shape assert isinstance(a, paddle.Tensor) - assert tuple(shape) == (3, 3, 2) + assert tuple(shape) == (2, 3, 2) b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [-1, -1]], - [[-1, -1], [-1, -1], [-1, -1]]]) + ]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=None, pad_val=-1) - a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] + padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) + a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] a = padder(a) shape = a.shape assert isinstance(a, paddle.Tensor) - assert tuple(shape) == (3, 3, 2) + assert tuple(shape) == (2, 3, 2) b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], - [[0, 0], [0, 0], [-1, -1]], - [[-1, -1], [-1, -1], [-1, -1]]], dtype='float32') + [[0, 0], [0, 0], [-1, -1]]], dtype='float32') assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] def test_dtype_check(self): @@ -103,5 +102,5 @@ class TestpaddleTensorPadder: padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) - - + def test_v1(self): + print(paddle.zeros((3, )).dtype) diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 8a603c51..c2281ffd 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -48,7 +48,7 @@ class TestPaddle: assert batch['image'].shape == [2, 10, 5] print(batch) fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) - fdl1.set_ignore('image') + fdl1.set_ignore('label') for batch in fdl1: assert batch['image'].shape == [4, 10, 5] print(batch) From 4e1b74c4cb9907a9b66b80df514b9bf19d9c3e09 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Mon, 2 May 2022 19:28:52 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9collator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collators/collator.py | 13 +- fastNLP/core/collators/new_collator.py | 253 ------------------------- 2 files changed, 8 insertions(+), 258 deletions(-) delete mode 100644 fastNLP/core/collators/new_collator.py diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 54570e92..ceb50a29 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -20,7 +20,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend -def _get_backend(): +def _get_backend() -> str: """ 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 @@ -61,7 +61,7 @@ def _get_backend(): else: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") + logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") return catch_backend[0] # 方式 (2) @@ -70,7 +70,7 @@ def _get_backend(): if catch_backend: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") + logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") return catch_backend[0] return 'numpy' @@ -84,7 +84,7 @@ class Collator: 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 - 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad + 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 """ self.unpack_batch_func = None @@ -148,15 +148,18 @@ class Collator: for key in unpack_batch.keys(): if key not in self.input_fields and key not in self.ignore_fields: self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} + elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': + self.input_fields[key]['backend'] = self.backend for field_name, setting in self.input_fields.items(): pad_fn = setting.get('pad_fn', None) if callable(pad_fn): padder = pad_fn else: + backend = self.backend if setting['backend'] == 'auto' else setting['backend'] batch_field = unpack_batch.get(field_name) padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], - dtype=setting['dtype'], backend=setting['backend'], + dtype=setting['dtype'], backend=backend, field_name=field_name) self.padders[field_name] = padder if self.batch_data_type == 'l': diff --git a/fastNLP/core/collators/new_collator.py b/fastNLP/core/collators/new_collator.py deleted file mode 100644 index cee713f2..00000000 --- a/fastNLP/core/collators/new_collator.py +++ /dev/null @@ -1,253 +0,0 @@ -from typing import List, Union, Dict, Callable, Sequence, Mapping -import os -import sys -import inspect - -from fastNLP.core.log import logger -from .padders.get_padder import get_padder - -import re - -from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ - pack_batch_sequence - -sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 -SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] -CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend - - -def _get_backend() -> str: - """ - 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: - (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 - 某个 backend 的 dataloader 。 - (2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。 - - 如果都没有找到则返回 numpy 。 - :return: - """ - def _check_module(module): - """ - 检查该 module 是否含有 某个 backend 的特征 - - :param module: module 对象 - :return: - """ - catch_backend = [] - try: - file = module.__file__ - for backend in CHECK_BACKEND: - if f'{os.sep}site-packages{os.sep}{backend}' in file: - catch_backend = [backend, file] - except: - pass - return catch_backend - - currentframe = inspect.currentframe() - # 方式(1) - catch_backend = [] - for i in range(100): - currentframe = currentframe.f_back - if currentframe is not None: - module = inspect.getmodule(currentframe) - if module is not None: - catch_backend = _check_module(module) - if len(catch_backend): # 主要捕获到一个就结束吧 - break - else: - break - if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") - return catch_backend[0] - - # 方式 (2) - for key, module in sys.modules.items(): - catch_backend = _check_module(module) - if catch_backend: - break - if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") - return catch_backend[0] - - return 'numpy' - - -class Collator: - def __init__(self, backend='auto'): - """ - 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 - 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 - 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 - - :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 - 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad - 的数据返回一定是 list 。 - """ - self.unpack_batch_func = None - self.pack_batch_func = None - self.ignore_fields = set() - self.padders = {} - self.input_fields = {} - self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 - self.set_backend(backend) - - def __call__(self, batch)->Union[List, Dict]: - """ - batch可能存在三种可能性 - List[Dict], List[List], List[Sample] - - 第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 - 第二步:使用每个 field 各自的 padder 进行 pad 。 - 第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 - - 第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 - list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample - 的类别。 - 第一次调用会根据当前 field 决定对应的 Padder 。 - - """ - if self.unpack_batch_func is None: - # 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 - if self.batch_data_type is None: - if isinstance(batch[0], Mapping): - self.batch_data_type = 'd' - elif isinstance(batch[0], Sequence): # 这里存在误判的风险 - self.batch_data_type = 'l' - else: - self.batch_data_type = 's' - logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " - f"is `{self.batch_data_type}`.") - if self.batch_data_type == 's': - self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 - self.pack_batch_func = lambda x: x['_single'] - elif self.batch_data_type == 'l': - self.unpack_batch_func = unpack_batch_sequence - self.pack_batch_func = pack_batch_sequence - elif self.batch_data_type == 'd': - if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} - self.unpack_batch_func = unpack_batch_nested_mapping - self.pack_batch_func = pack_batch_nested_mapping - else: - self.unpack_batch_func = unpack_batch_mapping - self.pack_batch_func = lambda x:x - - if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 - unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) - else: - unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 - - pad_batch = {} - if len(self.padders)==0: # 第一次运行,准备 padder - if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 - self.backend = _get_backend() - - for key in unpack_batch.keys(): - if key not in self.input_fields and key not in self.ignore_fields: - self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} - elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': - self.input_fields[key]['backend'] = self.backend - - for field_name, setting in self.input_fields.items(): - pad_fn = setting.get('pad_fn', None) - if callable(pad_fn): - padder = pad_fn - else: - backend = self.backend if setting['backend'] == 'auto' else setting['backend'] - batch_field = unpack_batch.get(field_name) - padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], - dtype=setting['dtype'], backend=backend, - field_name=field_name) - self.padders[field_name] = padder - if self.batch_data_type == 'l': - self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 - - for key, padder in self.padders.items(): - batch = unpack_batch.get(key) - pad_batch[key] = padder(batch) - - return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 - - def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', - pad_fn:Callable=None) -> "Collator": - """ - 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 - 无意义。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, - torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 - """ - self.padders.clear() # 重新生成 - - if self.batch_data_type is not None: - if self.batch_data_type == 's': - logger.debug("Set as single field mode.") - self.input_fields.clear() - elif self.batch_data_type == 'd': - assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ - f"index, but other field is set as dict mode." - elif self.batch_data_type == 'l': - assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ - f"field name is {field_name}." - - if field_name == '_single': - self.batch_data_type = 's' - elif isinstance(field_name, str) and sequence_idx_str.match(field_name): - self.batch_data_type = 'l' - else: - self.batch_data_type = 'd' - - if field_name in self.ignore_fields: - logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") - if backend is None: - backend = self.backend - else: - assert backend in SUPPORTED_BACKENDS - - self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} - - return self - - def set_backend(self, backend:str): - """ - 设置可以 pad 的 field 默认 pad 为什么类型的 tensor - - :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], - 若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 - :return: - """ - assert backend in SUPPORTED_BACKENDS - self.padders.clear() - self.backend = backend - - def set_ignore(self, *field_names) -> "Collator": - """ - 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 - Ex:: - collator.set_ignore('field1', 'field2') - - :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 - __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 - :return: 返回 Collator 自身 - """ - for field_name in field_names: - if field_name in self.input_fields: - self.input_fields.pop(field_name) - logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") - self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 - self.ignore_fields.add(field_name) - - return self - - From 7d5ce620f47651b3c9c144ab585a6741cd9d6756 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 2 May 2022 23:08:50 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0RandomBatchSampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collators/collator.py | 8 +- .../core/collators/padders/paddle_padder.py | 3 +- fastNLP/core/controllers/trainer.py | 2 +- .../core/dataloaders/jittor_dataloader/fdl.py | 53 ++- .../core/dataloaders/paddle_dataloader/fdl.py | 62 ++-- .../core/dataloaders/torch_dataloader/fdl.py | 66 ++-- fastNLP/core/dataloaders/utils.py | 16 + fastNLP/core/dataloaders/utils/__init__.py | 0 fastNLP/core/dataset/dataset.py | 2 +- fastNLP/core/drivers/paddle_driver/fleet.py | 4 +- .../drivers/paddle_driver/paddle_driver.py | 6 +- .../drivers/paddle_driver/single_device.py | 4 +- .../drivers/torch_driver/single_device.py | 4 +- .../core/drivers/torch_driver/torch_driver.py | 6 +- fastNLP/core/samplers/__init__.py | 5 +- .../samplers/reproducible_batch_sampler.py | 212 ++++++++++- fastNLP/core/samplers/reproducible_sampler.py | 3 +- fastNLP/core/utils/__init__.py | 4 +- fastNLP/core/utils/utils.py | 21 +- .../paddle_driver/test_single_device.py | 30 +- .../core/drivers/paddle_driver/test_utils.py | 6 +- .../torch_driver/test_single_device.py | 24 +- .../test_torch_replace_sampler.py | 2 +- tests/core/drivers/torch_driver/test_utils.py | 6 +- .../test_reproducible_batch_sampler.py | 328 +++++++++++++++++- 25 files changed, 694 insertions(+), 183 deletions(-) create mode 100644 fastNLP/core/dataloaders/utils.py delete mode 100644 fastNLP/core/dataloaders/utils/__init__.py diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index ceb50a29..5c5abda4 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -65,12 +65,16 @@ def _get_backend() -> str: return catch_backend[0] # 方式 (2) + for backend in CHECK_BACKEND: + if backend in sys.modules: + logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") + return backend for key, module in sys.modules.items(): catch_backend = _check_module(module) if catch_backend: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") + logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") return catch_backend[0] return 'numpy' @@ -227,7 +231,7 @@ class Collator: 设置可以 pad 的 field 默认 pad 为什么类型的 tensor :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], - 若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 + 若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 :return: """ assert backend in SUPPORTED_BACKENDS diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 7a569003..13eda4a9 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -74,7 +74,7 @@ def _get_dtype(ele_dtype, dtype, class_name): elif is_numpy_generic_class(ele_dtype): dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) else: - dtype == ele_dtype + dtype = ele_dtype return dtype @@ -174,6 +174,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): """ shapes = get_shape(batch_field) tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) - # tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val) tensor = fill_tensor(batch_field, tensor, dtype=dtype) return tensor diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 5223c9d8..6fed9dc1 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger): raise e finally: self.on_train_end() - self.driver.barrier() def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: @@ -441,6 +440,7 @@ class Trainer(TrainerEventTrigger): """ _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) _own_callbacks.extend(self._custom_callbacks[None]) + logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().") self._custom_callbacks[None] = [] if self.marker is not None: if len(self._custom_callbacks[self.marker]) == 0: diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 3e9cf17a..9b67629e 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -14,7 +14,7 @@ else: from fastNLP.core.dataset import DataSet as Dataset from fastNLP.core.utils.jittor_utils import jittor_collate_wraps from fastNLP.core.collators import Collator -from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet @@ -106,33 +106,33 @@ class JittorDataLoader: return len(self.dataset) // self.dataset.batch_size return (len(self.dataset) - 1) // self.dataset.batch_size + 1 - def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, - pad_fn: Callable = None) -> "JittorDataLoader": + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> Collator: """ - 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, - paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ if isinstance(self._collate_fn, Collator): - self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, - backend=backend) - return self + self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") - def set_ignore(self, *field_names) -> "JittorDataLoader": + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Ex:: @@ -145,18 +145,17 @@ class JittorDataLoader: """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_ignore(*field_names) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") def get_batch_indices(self) -> List[int]: """ - 获取当前数据的idx + 获取当前 batch 的 idx :return: """ return self.cur_batch_indices - def prepare_jittor_dataloader(): ... diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index b4b675c4..fa99be22 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -15,8 +15,9 @@ else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.collators.collator import Collator -from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet +from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler class _PaddleDataset(Dataset): @@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader): if not isinstance(dataset, _PaddleDataset): dataset = _PaddleDataset(dataset) + if batch_sampler is None: + batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, + drop_last=drop_last) + super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, @@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader): if isinstance(dataset.dataset, FDataSet): self._collate_fn = dataset.dataset.collator self._collate_fn.set_backend(backend="paddle") - # if collate_fn is not None: - # self._collate_fn.add_collator(collate_fn) else: self._collate_fn = Collator(backend="paddle") @@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader): self.cur_batch_indices = indices yield data - def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, - pad_fn: Callable = None) -> "PaddleDataLoader": + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> Collator: """ - 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, - paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ if isinstance(self._collate_fn, Collator): - self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, - backend=backend) - return self + self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") - def set_ignore(self, *field_names) -> "PaddleDataLoader": + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Ex:: @@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader): """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_ignore(*field_names) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") def get_batch_indices(self) -> List[int]: """ - 获取当前数据的idx + 获取当前 batch 的 idx :return: """ @@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader): def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, - return_list: bool = True, batch_sampler=None, + return_list: bool = True, + batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, train_batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, num_workers: int = 0, use_buffer_reader: bool = True, diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 3ee838c4..12356074 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -3,14 +3,14 @@ __all__ = [ 'prepare_torch_dataloader' ] -from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator -from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler @@ -76,6 +76,9 @@ class TorchDataLoader(DataLoader): if not isinstance(dataset, _FDataSet): dataset = _FDataSet(dataset) + if sampler is None and batch_sampler is None: + sampler = RandomSampler(dataset, shuffle=shuffle) + super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -87,9 +90,6 @@ class TorchDataLoader(DataLoader): if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset self._collate_fn = dataset.dataset.collator self._collate_fn.set_backend(backend="torch") - # if collate_fn is not None and collate_fn is not default_collate: - # # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 - # self._collate_fn.add_collator(collate_fn) else: self._collate_fn = Collator(backend="torch") else: @@ -112,31 +112,32 @@ class TorchDataLoader(DataLoader): yield data def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, - pad_fn:Callable=None) -> "TorchDataLoader": + pad_fn:Callable=None) -> Collator: """ - 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, - paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") - def set_ignore(self, *field_names) -> "TorchDataLoader": + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Ex:: @@ -149,24 +150,15 @@ class TorchDataLoader(DataLoader): """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_ignore(*field_names) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") - - def get_batch_indices(self) -> List[int]: - """ - 获取当前数据的idx - - :return: - """ - return self.cur_batch_indices - + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1, - shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, - batch_sampler: Optional["Sampler[Sequence[int]]"] = None, + shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py new file mode 100644 index 00000000..a71dc50c --- /dev/null +++ b/fastNLP/core/dataloaders/utils.py @@ -0,0 +1,16 @@ +def indice_collate_wrapper(func): + """ + 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 + + :param func: 需要修饰的函数 + :return: + """ + + def wrapper(tuple_data): + indice, ins_list = [], [] + for idx, ins in tuple_data: + indice.append(idx) + ins_list.append(ins) + return indice, func(ins_list) + + return wrapper \ No newline at end of file diff --git a/fastNLP/core/dataloaders/utils/__init__.py b/fastNLP/core/dataloaders/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 0c79bc92..11a2536c 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -780,7 +780,7 @@ class DataSet: self.collator.set_ignore(*field_names) @property - def collator(self): + def collator(self) -> Collator: if self._collator is None: self._collator = Collator() return self._collator diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index a1275bed..73342748 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -22,7 +22,7 @@ from fastNLP.core.utils import ( rank_zero_rm ) from fastNLP.core.samplers import ( - RandomBatchSampler, + ReproduceBatchSampler, ReproducibleSampler, ReproducibleBatchSampler, RandomSampler, @@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver): return self.model, model.forward - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]], reproducible: bool = False): r""" 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index ed1aad73..f65efd3d 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -22,7 +22,7 @@ from fastNLP.core.log import logger from fastNLP.core.samplers import ( ReproducibleBatchSampler, ReproducibleSampler, - RandomBatchSampler, + ReproduceBatchSampler, RandomSampler, ) @@ -345,7 +345,7 @@ class PaddleDriver(Driver): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " "`ReproducibleSampler`.") else: - sampler = RandomBatchSampler( + sampler = ReproduceBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last @@ -476,7 +476,7 @@ class PaddleDriver(Driver): res.shuffle = True else: res.shuffle = False - # RandomBatchSampler 的情况 + # ReproduceBatchSampler 的情况 elif hasattr(dataloader.batch_sampler, "batch_sampler"): batch_sampler = dataloader.batch_sampler.batch_sampler res.sampler = batch_sampler.sampler diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index f140ad69..52805a97 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -14,7 +14,7 @@ from fastNLP.core.utils import ( from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.samplers import ( ReproducibleBatchSampler, - RandomBatchSampler, + ReproduceBatchSampler, ReproducibleSampler, RandomSampler, re_instantiate_sampler, @@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver): logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) else: - batch_sampler = RandomBatchSampler( + batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 99ba754e..6c125a73 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -15,7 +15,7 @@ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call from fastNLP.core.utils.utils import _get_fun_msg -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler from fastNLP.core.samplers import RandomSampler from fastNLP.core.log import logger @@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver): logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) else: - batch_sampler = RandomBatchSampler( + batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 172a3cf0..8c332251 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler class TorchDriver(Driver): @@ -293,7 +293,7 @@ class TorchDriver(Driver): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " "`ReproducibleSampler`.") else: - sampler = RandomBatchSampler( + sampler = ReproduceBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last @@ -407,7 +407,7 @@ class TorchDriver(Driver): res.shuffle = True else: res.shuffle = False - # RandomBatchSampler 的情况 + # ReproduceBatchSampler 的情况 elif hasattr(dataloader.batch_sampler, "batch_sampler"): batch_sampler = dataloader.batch_sampler.batch_sampler res.sampler = batch_sampler.sampler diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index edc1f891..53c29689 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -14,9 +14,10 @@ __all__ = [ "UnrepeatedSortedSampler", "UnrepeatedSequentialSampler", - "RandomBatchSampler", + "ReproduceBatchSampler", "BucketedBatchSampler", "ReproducibleBatchSampler", + "RandomBatchSampler", "re_instantiate_sampler" ] @@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler from .utils import re_instantiate_sampler from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler -from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler +from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 2bbf409f..958cf5b4 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -1,5 +1,6 @@ __all__ = [ 'BucketedBatchSampler', + "ReproduceBatchSampler", "RandomBatchSampler" ] @@ -54,13 +55,13 @@ class ReproducibleBatchSampler: raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") -class RandomBatchSampler(ReproducibleBatchSampler): +class ReproduceBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): """ 可以使得 batch_sampler 对象状态恢复的 wrapper 。 - :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 + :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 :param batch_size: 每个 batch 的大小是多少。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 @@ -143,7 +144,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.need_reinitialize = False def set_distributed(self, num_replicas, rank, pad=True): - raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") + raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") def set_epoch(self, epoch): if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): @@ -158,6 +159,211 @@ class RandomBatchSampler(ReproducibleBatchSampler): (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size +class RandomBatchSampler(ReproducibleBatchSampler): + def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, + drop_last: bool = False, seed: int = 0, **kwargs): + """ + 随机分 batch 的 batch_sampler 。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param batch_size: 每个 batch 的大小 + :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 + :param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 + :param seed: 设置的随机数种子 + :param kwargs: fastNLP 保留使用 + """ + super().__init__() + + self.dataset = dataset + + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.seed = seed + + self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 + + # 多卡的相关的参数 + self.num_replicas = kwargs.get("num_replicas", 1) + self.rank = kwargs.get("rank", 0) + self.epoch = kwargs.get("epoch", -1) + self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; + + # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() + self.during_iter = kwargs.get("during_iter", False) + + # 以下变量为内部使用恢复状态的变量。 + self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) + + def set_distributed(self, num_replicas, rank, pad=True): + assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ + "during an unfinished iteration." + assert num_replicas > 0 and isinstance(num_replicas, int) + assert isinstance(rank, int) and 0 <= rank < num_replicas + # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; + self.num_replicas = num_replicas + self.rank = rank + self.pad = pad + + return self + + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + + indices = list(range(len(self.dataset))) + + if self.shuffle: + if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 + _batches = [] + for _i in range(self.old_num_replicas): + _indices = indices[_i:len(indices):self.old_num_replicas] + __batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch) + _batches.append(__batches) + batches = list(chain(*[_ for _ in zip(*_batches)])) + indices = list(chain(*batches)) + indices = indices[self.num_consumed_samples:] + # 取出这个 rank , + indices = indices[self.rank:len(indices):self.num_replicas] + batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch) + batches = list(map(list, batches)) + else: + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + _num_batches = len(indices) // self.batch_size + if _num_batches == 0: + batches = [indices] + else: + batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches))) + if len(indices)%self.batch_size!=0: + batches.append(indices[_num_batches*self.batch_size:]) + + need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas + if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: + if len(batches) > 0: + if len(batches[-1])self.rank: + if len(batches): + batches[-1].pop(-1) + if len(batches[-1])==0: + batches.pop(-1) + + assert sum(map(len, batches)) == self.num_left_samples + + if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: + batches = batches[:-1] + + for batch in batches: + self.num_consumed_samples += self.num_replicas * len(batch) + yield list(map(int, batch)) + self.during_iter = False + self.num_consumed_samples = 0 + self.old_batch_size = self.batch_size + self.old_num_replicas = self.num_replicas + if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 + self.epoch -= 1 + + def batchify(self, indices, batch_size, seed): + """ + 将 indices 分为 batches + + :param sorted_indices: List[int] + :param batch_size: int + :param seed: int + :return: List[List[int]] + """ + # 实际的 bucket 大小 + rng = np.random.default_rng(abs(seed)) + rng.shuffle(indices) + num_samples = 0 + batches = [] + while num_samplesint: + """ + 返回当前 sampler 还会返回多少个 batch 的数据 + + :return: + """ + num_sampler_per_rank = self.total_size//self.num_replicas + num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \ + (num_sampler_per_rank+self.batch_size-1)//self.batch_size + return num_batches + + def state_dict(self) -> Dict: + if self.old_batch_size != self.batch_size: + raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" + " consumed. ") + states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, + 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, + 'batch_size': self.batch_size, + 'num_replicas': self.num_replicas} + + return states + + def load_state_dict(self, states: Dict): + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; + assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ + "during an unfinished iteration." + + assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ + f"we cannot use {self.__class__.__name__} to load it." + + length = states['length'] + assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ + "and current dataset." + self.seed = states['seed'] + self.epoch = states['epoch'] + self.num_consumed_samples = states['num_consumed_samples'] + if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 + self.num_consumed_samples = 0 + if self.shuffle != states['shuffle']: + logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " + f"we use shuffle={states['shuffle']}") + self.shuffle = states["shuffle"] + self.old_batch_size = states['batch_size'] + self.old_num_replicas = states['num_replicas'] + + class BucketedBatchSampler(ReproducibleBatchSampler): def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index c8425dc7..7edb607a 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -54,13 +54,12 @@ class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): """ - :param dataset: 实现了 __len__ 方法的数据容器 :param shuffle: 是否在每次 iterate 的时候打乱顺序。 :param seed: 随机数种子。 :param kwargs: 用户不需要使用,fastNLP 内部使用 """ - + super(RandomSampler, self).__init__() self.dataset = dataset self.shuffle = shuffle self.seed = seed diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 910a2df0..9fb538a9 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -21,7 +21,6 @@ __all__ = [ 'nullcontext', 'pretty_table_printer', 'Option', - 'indice_collate_wrapper', 'deprecated', 'seq_len_to_mask', 'rank_zero_rm', @@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ - indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir + deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir +from ..dataloaders.utils import indice_collate_wrapper diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index c3f57bcf..91b3c8f6 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -6,7 +6,7 @@ import warnings from dataclasses import is_dataclass from copy import deepcopy from collections import defaultdict, OrderedDict -from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional +from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence from typing import Tuple, Optional from time import sleep @@ -35,7 +35,6 @@ __all__ = [ 'nullcontext', 'pretty_table_printer', 'Option', - 'indice_collate_wrapper', 'deprecated', 'seq_len_to_mask', 'rank_zero_rm', @@ -513,24 +512,6 @@ class Option(dict): self.update(state) -def indice_collate_wrapper(func): - """ - 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 - - :param func: 需要修饰的函数 - :return: - """ - - def wrapper(tuple_data): - indice, ins_list = [], [] - for idx, ins in tuple_data: - indice.append(idx) - ins_list.append(ins) - return indice, func(ins_list) - - return wrapper - - _emitted_deprecation_warnings = set() diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index a00a41f5..b8ccd802 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -278,7 +278,7 @@ class TestPaddleDriverFunctions: dataset = PaddleNormalDataset() dataloader = DataLoader( dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), batch_size, drop_last, @@ -287,7 +287,7 @@ class TestPaddleDriverFunctions: res = PaddleSingleDriver.get_dataloader_args(dataloader) assert isinstance(res.dataset, PaddleNormalDataset) - assert isinstance(res.batch_sampler, RandomBatchSampler) + assert isinstance(res.batch_sampler, ReproduceBatchSampler) if shuffle: assert isinstance(res.sampler, paddle.io.RandomSampler) else: @@ -387,7 +387,7 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), - 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) @@ -400,7 +400,7 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) else: # 此时会替换 batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -414,11 +414,11 @@ class TestSetDistReproDataloader: 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) - dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) + dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler is dist self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @@ -450,7 +450,7 @@ class TestSetDistReproDataloader: """ dataloader = DataLoader( dataset=self.dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), batch_size=4, drop_last=False, @@ -459,7 +459,7 @@ class TestSetDistReproDataloader: replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -500,20 +500,20 @@ class TestSetDistReproDataloader: if idx >= num_consumed_batches: break already_seen_idx.update(batch) - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新改造 dataloader new_loader = DataLoader( dataset=replaced_loader.dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), batch_size=batch_size, drop_last=False, @@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): dataset = PaddleRandomMaxDataset(40, 10) dataloader = DataLoader( dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) + batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) ) driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") @@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 更改 batch_size dataloader = DataLoader( dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) + batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) ) load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") @@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 2. 检查 batch_sampler 是否被正确地加载和替换 assert not (replaced_loader is dataloader) assert replaced_loader.batch_sampler is dataloader.batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index 4b683c1e..3b0fb9e0 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import ( replace_batch_sampler, replace_sampler, ) -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle @@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, def test_replace_batch_sampler(): dataset = PaddleNormalDataset(10) dataloader = DataLoader(dataset, batch_size=32) - batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) replaced_loader = replace_batch_sampler(dataloader, batch_sampler) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.dataset, PaddleNormalDataset) assert len(replaced_loader.dataset) == len(dataset) assert replaced_loader.batch_sampler.batch_size == 16 diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 8c761a95..ef60e2b6 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset @@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE: def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): """ - 建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader + 建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader """ if shuffle: sampler = torch.utils.data.RandomSampler(dataset) @@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): sampler = torch.utils.data.SequentialSampler(dataset) dataloader = DataLoader( dataset=dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler( sampler, batch_size=batch_size, drop_last=drop_last ), @@ -306,7 +306,7 @@ class TestTorchDriverFunctions: res = TorchSingleDriver.get_dataloader_args(dataloader) assert isinstance(res.dataset, TorchNormalDataset) - assert isinstance(res.batch_sampler, RandomBatchSampler) + assert isinstance(res.batch_sampler, ReproduceBatchSampler) if shuffle: assert isinstance(res.sampler, torch.utils.data.RandomSampler) else: @@ -401,7 +401,7 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), - 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) @@ -414,7 +414,7 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) else: # 此时会替换 batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -428,11 +428,11 @@ class TestSetDistReproDataloader: 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) - dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) + dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler is dist self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @@ -466,7 +466,7 @@ class TestSetDistReproDataloader: replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -502,14 +502,14 @@ class TestSetDistReproDataloader: if idx >= num_consumed_batches: break already_seen_idx.update(batch) - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range left_idxes = set() - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新改造 dataloader @@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 2. 检查 batch_sampler 是否被正确地加载和替换 assert not (replaced_loader is dataloader) assert replaced_loader.batch_sampler is dataloader.batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index 161bbfe8..56de18fe 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -30,7 +30,7 @@ class SequenceDataSet: def check_replace_sampler(driver): - # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler + # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler # reproducible 是 True 和 False # 需要 check 返回的 sampler 和 dataloader 都不同了 diff --git a/tests/core/drivers/torch_driver/test_utils.py b/tests/core/drivers/torch_driver/test_utils.py index 97037b71..8d5d3267 100644 --- a/tests/core/drivers/torch_driver/test_utils.py +++ b/tests/core/drivers/torch_driver/test_utils.py @@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( replace_batch_sampler, replace_sampler, ) -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from torch.utils.data import DataLoader, BatchSampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset def test_replace_batch_sampler(): dataset = TorchNormalDataset(10) dataloader = DataLoader(dataset, batch_size=32) - batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) replaced_loader = replace_batch_sampler(dataloader, batch_sampler) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.dataset, TorchNormalDataset) assert len(replaced_loader.dataset) == len(dataset) assert replaced_loader.batch_sampler.batch_size == 16 diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index 6cf4b7d4..cac595ba 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -5,7 +5,7 @@ import pytest from itertools import chain from copy import deepcopy -from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler +from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -19,7 +19,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset # before_batch_size = 7 # dataset = TorchNormalDataset(num_of_data=100) # dataloader = DataLoader(dataset, batch_size=before_batch_size) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) # dataloader = replace_batch_sampler(dataloader, re_batchsampler) # # forward_steps = 3 @@ -29,15 +29,15 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset # # # 1. 保存状态 # _get_re_batchsampler = dataloader.batch_sampler -# assert isinstance(_get_re_batchsampler, RandomBatchSampler) +# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) # state = _get_re_batchsampler.state_dict() # assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, -# "sampler_type": "RandomBatchSampler"} +# "sampler_type": "ReproduceBatchSampler"} # # # 2. 断点重训,重新生成一个 dataloader; # # 不改变 batch_size; # dataloader = DataLoader(dataset, batch_size=before_batch_size) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) # re_batchsampler.load_state_dict(state) # dataloader = replace_batch_sampler(dataloader, re_batchsampler) # @@ -54,7 +54,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset # # 改变 batch_size; # after_batch_size = 3 # dataloader = DataLoader(dataset, batch_size=after_batch_size) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) # re_batchsampler.load_state_dict(state) # dataloader = replace_batch_sampler(dataloader, re_batchsampler) # @@ -100,7 +100,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset # dataset = TorchNormalDataset(num_of_data=100) # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) # dataloader = replace_batch_sampler(dataloader, re_batchsampler) # # # 将一轮的所有数据保存下来,看是否恢复的是正确的; @@ -112,13 +112,13 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset # # # 1. 保存状态 # _get_re_batchsampler = dataloader.batch_sampler -# assert isinstance(_get_re_batchsampler, RandomBatchSampler) +# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) # state = _get_re_batchsampler.state_dict() # # # 2. 断点重训,重新生成一个 dataloader; # # 不改变 batch_size; # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) +# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) # re_batchsampler.load_state_dict(state) # dataloader = replace_batch_sampler(dataloader, re_batchsampler) # @@ -511,3 +511,313 @@ class TestBucketedBatchSampler: already_seen_set.update(batch) assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) + + +class TestRandomBatchSampler: + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) + def test_single_num_batch(self, shuffle, drop_last, num): + # 数量不够不报错 + for num in [2, 7, 14, 15, 70, 71]: + dataset = DatasetWithVaryLength(num_of_data=num) + before_batch_size = 7 + re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + drop_last=drop_last, + shuffle=shuffle) + count = len(list(iter(re_batchsampler))) + if drop_last: + assert count==num//before_batch_size, num + else: + assert count==(num+before_batch_size-1)//before_batch_size, num + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + def test_single(self, shuffle, drop_last): + + before_batch_size = 7 + num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 + + dataset = DatasetWithVaryLength(num_of_data=1000) + re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + drop_last=drop_last, + shuffle=shuffle) + re_batchsampler.set_epoch(0) + forward_steps = 10 + iterator = iter(re_batchsampler) + already_generate_indices = set() + for _ in range(forward_steps): + batch = next(iterator) + already_generate_indices.update(batch) + + # 1. 保存状态 + state = re_batchsampler.state_dict() + + # 2. 断点重训,继续训练 + re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + drop_last=drop_last, + shuffle=shuffle) + re_batchsampler2.load_state_dict(state) + re_batchsampler2.set_epoch(0) + new_already_generate_indices = set() + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + indices = np.arange(len(dataset))[mask] + max_diff = -1 + for i in range(len(indices)-before_batch_size * num_batch_per_bucket): + max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) + for batch in re_batchsampler2: + for b in batch: + assert b not in already_generate_indices + new_already_generate_indices.update(batch) + if drop_last is False: + assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) + + # 改变 batch_size; + after_batch_size = 3 + re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, + drop_last=drop_last, + shuffle=shuffle) + re_batchsampler3.load_state_dict(state) + re_batchsampler3.set_epoch(0) + count = 0 + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + indices = np.arange(len(dataset))[mask] + + for batch in re_batchsampler3: + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + count += 1 + if count > 5: + break + + # 再 save ,不允许再上个epoch没结束继续sample + after_batch_size = 5 + with pytest.raises(RuntimeError): + state = re_batchsampler3.state_dict() + + for batch in re_batchsampler3: # consume all, 这样才能save + pass + + already_generate_indices = set() + count = 0 + for batch in re_batchsampler3: # 重新开始 + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + count += 1 + if count > 5: + break + + state = re_batchsampler3.state_dict() + # 这里的 drop_last 为 False,需要最终是所有 sample + re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, + drop_last=False, + shuffle=shuffle) + re_batchsampler4.load_state_dict(state) + re_batchsampler4.set_epoch(0) + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + for batch in re_batchsampler4: + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + + assert len(already_generate_indices) == len(dataset) + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + def test_multi(self, shuffle, drop_last, pad): + # def test_multi(self, shuffle=True, drop_last=False, pad=False): + + # no shuffle + num_replica = 2 + dataset = DatasetWithVaryLength(num_of_data=1000) + batch_size = 5 + num_batch_per_bucket = 10 + lengths = [] + rank0_already_seen_indexes = None + max_diff = num_batch_per_bucket * batch_size * num_replica + for rank in range(num_replica): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, + shuffle = shuffle, drop_last=drop_last) + sampler.set_epoch(0) + sampler.set_distributed(num_replica, rank=rank, pad=pad) + lengths.append(len(sampler)) + already_seen_indexes = set() + repeat_count = 0 + for batch in sampler: + for b in batch: + repeat_count += int(b in already_seen_indexes) + if rank0_already_seen_indexes: # 不能交叉出现 + assert b not in rank0_already_seen_indexes + already_seen_indexes.update(batch) + if rank0_already_seen_indexes is None: + rank0_already_seen_indexes = already_seen_indexes + if pad: # 应该允许重复一次 + assert repeat_count<=1 + else: + assert repeat_count==0 + + assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 + + # 多进程的保存 + already_seen_indexes = set() + for rank in range(num_replica): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, + shuffle = shuffle, drop_last=drop_last) + sampler.set_epoch(0) + sampler.set_distributed(num_replica, rank=rank, pad=pad) + lengths.append(len(sampler)) + count = 0 + for batch in sampler: + already_seen_indexes.update(batch) + if count>5: + break + count += 1 + state = sampler.state_dict() + + # 切换成单机 + new_batch_size = 6 + num_batch_per_bucket = 3 + new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, + shuffle=shuffle, drop_last=drop_last) + new_sampler.load_state_dict(state) + repeat_count = 0 + new_already_seen_indexes = set(list(already_seen_indexes)) + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_seen_indexes)] = 0 + indices = np.arange(len(dataset))[mask] + + for batch in new_sampler: + for b in batch: + repeat_count += int(b in new_already_seen_indexes) + new_already_seen_indexes.update(batch) + if pad: # 应该允许重复一次 + assert repeat_count <= 1 + else: + assert repeat_count == 0 + if drop_last is False: # 如果没有drop应该相等 + assert len(new_already_seen_indexes)==len(dataset) + + # 测试替换卡的数量。 + num_replica = 3 + new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, + shuffle=shuffle, drop_last=drop_last) + new_sampler.set_epoch(0) + new_sampler.load_state_dict(state) + new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) + repeat_count = 0 + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_seen_indexes)] = 0 + indices = np.arange(len(dataset))[mask] + + for batch in new_sampler: + for b in batch: + repeat_count += int(b in already_seen_indexes) + if pad: # 应该允许重复一次 + assert repeat_count <= 1 + else: + assert repeat_count == 0 + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [2, 3]) + def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): + # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): + dataset = DatasetWithVaryLength(num_of_data=num_samples) + batch_size = 6 + if num_replicas*batch_size > num_samples: + return + num_batch_per_bucket = 10 + samplers = [] + lengths = [] + for i in range(num_replicas): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, + shuffle=shuffle, drop_last=drop_last) + sampler.set_distributed(num_replicas, rank=i, pad=pad) + sampler.set_epoch(0) + samplers.append(sampler) + lengths.append(len(list(iter(sampler)))) + assert len(set(lengths))==1 + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [1, 2, 3]) + def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): + """ + 测试是否能够正确地恢复使用过的(forward)数据 + + :return: + """ + batch_size = 6 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + samplers = [] + num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas)) + for i in range(num_replicas): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, + shuffle=shuffle, drop_last=drop_last) + + sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) + samplers.append(sampler) + count = 0 + already_seen_sets = [set()] + already_seen_set = set() + for batchs in zip(*samplers): + batch = chain(*batchs) + already_seen_set.update(batch) + already_seen_sets.append(deepcopy(already_seen_set)) + count += 1 + if count > 3: + break + states = samplers[0].state_dict() + for i in range(len(already_seen_sets)): + states['num_consumed_samples'] = num_consumed_samples_array[i] + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, + shuffle=shuffle, drop_last=drop_last) + sampler.set_epoch(0) + already_seen_set = deepcopy(already_seen_sets[i]) + for batch in sampler: + already_seen_set.update(batch) + assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( + dataset) + + # 测试保存之后再次保存 + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, + shuffle=shuffle, + drop_last=drop_last) + sampler.set_epoch(0) + states['num_consumed_samples'] = num_consumed_samples_array[2] + if len(already_seen_sets)<3: + return + already_seen_set = already_seen_sets[2] + count = 0 + for batch in sampler: + already_seen_set.update(batch) + count += 1 + if count > 6: + break + + states = sampler.state_dict() + num_consumed_samples_array = list(range(len(dataset))) + states['num_consumed_samples'] = num_consumed_samples_array[count] + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, + shuffle=shuffle, + drop_last=drop_last) + sampler.load_state_dict(states) + sampler.set_epoch(0) + for batch in sampler: + already_seen_set.update(batch) + + assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)