Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
9375872b7c
32 changed files with 1755 additions and 1016 deletions
  1. +1
    -2
      fastNLP/core/collators/__init__.py
  2. +611
    -349
      fastNLP/core/collators/collator.py
  3. +0
    -253
      fastNLP/core/collators/new_collator.py
  4. +9
    -1
      fastNLP/core/collators/padders/get_padder.py
  5. +178
    -0
      fastNLP/core/collators/padders/paddle_padder.py
  6. +1
    -0
      fastNLP/core/controllers/trainer.py
  7. +59
    -36
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  8. +83
    -75
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  9. +74
    -123
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  10. +16
    -0
      fastNLP/core/dataloaders/utils.py
  11. +9
    -60
      fastNLP/core/dataset/dataset.py
  12. +2
    -2
      fastNLP/core/drivers/paddle_driver/fleet.py
  13. +3
    -3
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  14. +2
    -2
      fastNLP/core/drivers/paddle_driver/single_device.py
  15. +2
    -2
      fastNLP/core/drivers/torch_driver/single_device.py
  16. +3
    -3
      fastNLP/core/drivers/torch_driver/torch_driver.py
  17. +3
    -2
      fastNLP/core/samplers/__init__.py
  18. +209
    -3
      fastNLP/core/samplers/reproducible_batch_sampler.py
  19. +1
    -2
      fastNLP/core/samplers/reproducible_sampler.py
  20. +2
    -2
      fastNLP/core/utils/__init__.py
  21. +2
    -1
      fastNLP/core/utils/jittor_utils.py
  22. +1
    -20
      fastNLP/core/utils/utils.py
  23. +106
    -0
      tests/core/collators/padders/test_paddle_padder.py
  24. +7
    -5
      tests/core/dataloaders/jittor_dataloader/test_fdl.py
  25. +16
    -6
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  26. +2
    -21
      tests/core/dataloaders/torch_dataloader/test_fdl.py
  27. +15
    -15
      tests/core/drivers/paddle_driver/test_single_device.py
  28. +3
    -3
      tests/core/drivers/paddle_driver/test_utils.py
  29. +12
    -12
      tests/core/drivers/torch_driver/test_single_device.py
  30. +1
    -1
      tests/core/drivers/torch_driver/test_torch_replace_sampler.py
  31. +3
    -3
      tests/core/drivers/torch_driver/test_utils.py
  32. +319
    -9
      tests/core/samplers/test_reproducible_batch_sampler.py

+ 1
- 2
fastNLP/core/collators/__init__.py View File

@@ -1,5 +1,4 @@
__all__ = [
'AutoCollator',
'Collator'
]
from .collator import AutoCollator, Collator
from .collator import Collator

+ 611
- 349
fastNLP/core/collators/collator.py View File

@@ -1,386 +1,648 @@
__all__ = [
'AutoCollator',
'Collator',
'Collator'
]

from typing import List, Union, Dict, Callable, Sequence, Mapping
import os
import sys
import inspect

from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Callable, Union, Tuple
from numbers import Number
import warnings
from fastNLP.core.log import logger
from .padders.get_padder import get_padder

import numpy as np
import re

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \
pack_batch_sequence

if _NEED_IMPORT_PADDLE:
import paddle
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None]
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend

if _NEED_IMPORT_TORCH:
import torch


class ApplyResultException(Exception):
def __init__(self, msg, index=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了


class SetInputOrTargetException(Exception):
def __init__(self, msg, index=None, field_name=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
self.field_name = field_name # 标示当前 field 的名称


def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]:
r"""
识别cell的类别与dimension的数量

numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
:param cell:
:param dim:
:return:
def _get_backend() -> str:
"""
if isinstance(cell, (str, Number, np.bool_)):
if hasattr(cell, 'dtype'):
return cell.dtype.type, dim
return type(cell), dim

elif isinstance(cell, list):
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i, j in res])
dims = set([j for i, j in res])
if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.")
if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()

elif isinstance(cell, torch.Tensor):
return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0

elif isinstance(cell, paddle.Tensor):
return cell.dtype, cell.dim() + dim

elif isinstance(cell, np.ndarray):
if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了
return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等
# 否则需要继续往下 iterate
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i, j in res])
dims = set([j for i, j in res])
if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.")
if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()

else: # 包含 tuple, set, dict 以及其它的类型
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")


def _get_ds_type_dim(ds: dict):
# 获取数据集第一行的 field 内部函数的类型和维度
field_dtype, field_dim = {}, {}
for field_name, field_content in ds.items():
type_0, dim_0 = _get_ele_type_and_dim(field_content)
field_dtype[field_name], field_dim[field_name] = type_0, dim_0
return field_dtype, field_dim


class Collator(metaclass=ABCMeta):
r"""
辅助DataLoader管理collate_fn的类
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个:
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是
某个 backend 的 dataloader 。
(2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。

如果都没有找到则返回 numpy 。
:return:
"""
def _check_module(module):
"""
检查该 module 是否含有 某个 backend 的特征

def __init__(self):
super(Collator, self).__init__()
self.collate_fn = []

@abstractmethod
def __call__(self, ins_lst: List) -> Any:
raise NotImplementedError

@abstractmethod
def set_pad_val(self, *field_names: str, value=0):
raise NotImplementedError

:param module: module 对象
:return:
"""
catch_backend = []
try:
file = module.__file__
for backend in CHECK_BACKEND:
if f'{os.sep}site-packages{os.sep}{backend}' in file:
catch_backend = [backend, file]
except:
pass
return catch_backend

currentframe = inspect.currentframe()
# 方式(1)
catch_backend = []
for i in range(100):
currentframe = currentframe.f_back
if currentframe is not None:
module = inspect.getmodule(currentframe)
if module is not None:
catch_backend = _check_module(module)
if len(catch_backend): # 主要捕获到一个就结束吧
break
else:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.")
return catch_backend[0]

# 方式 (2)
for backend in CHECK_BACKEND:
if backend in sys.modules:
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.")
return backend
for key, module in sys.modules.items():
catch_backend = _check_module(module)
if catch_backend:
break
if len(catch_backend):
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
return catch_backend[0]

return 'numpy'


class Collator:
def __init__(self, backend='auto'):
"""
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。

class _MultiCollator:
"""
管理所有collator的容器,
遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。
"""
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad
的数据返回一定是 list 。
"""
self.unpack_batch_func = None
self.pack_batch_func = None
self.ignore_fields = set()
self.padders = {}
self.input_fields = {}
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。
self.set_backend(backend)

def __call__(self, batch)->Union[List, Dict]:
"""
batch可能存在三种可能性
List[Dict], List[List], List[Sample]

def __init__(self, collate_fns: Union[Callable, List[Callable], None]):
第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。
第二步:使用每个 field 各自的 padder 进行 pad 。
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。

if collate_fns is None:
collate_fns = []
第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample
的类别。
第一次调用会根据当前 field 决定对应的 Padder 。

if isinstance(collate_fns, Callable):
collate_fns = [collate_fns]
"""
if self.unpack_batch_func is None:
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型
if self.batch_data_type is None:
if isinstance(batch[0], Mapping):
self.batch_data_type = 'd'
elif isinstance(batch[0], Sequence): # 这里存在误判的风险
self.batch_data_type = 'l'
else:
self.batch_data_type = 's'
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
f"is `{self.batch_data_type}`.")
if self.batch_data_type == 's':
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整
self.pack_batch_func = lambda x: x['_single']
elif self.batch_data_type == 'l':
self.unpack_batch_func = unpack_batch_sequence
self.pack_batch_func = pack_batch_sequence
elif self.batch_data_type == 'd':
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value}
self.unpack_batch_func = unpack_batch_nested_mapping
self.pack_batch_func = pack_batch_nested_mapping
else:
self.unpack_batch_func = unpack_batch_mapping
self.pack_batch_func = lambda x:x

self._collators: list = collate_fns
if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys()))
else:
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。

pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。
self.backend = _get_backend()

for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto':
self.input_fields[key]['backend'] = self.backend

for field_name, setting in self.input_fields.items():
pad_fn = setting.get('pad_fn', None)
if callable(pad_fn):
padder = pad_fn
else:
backend = self.backend if setting['backend'] == 'auto' else setting['backend']
batch_field = unpack_batch.get(field_name)
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
dtype=setting['dtype'], backend=backend,
field_name=field_name)
self.padders[field_name] = padder
if self.batch_data_type == 'l':
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序

for key, padder in self.padders.items():
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)

return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型

def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto',
pad_fn:Callable=None) -> "Collator":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
self.padders.clear() # 重新生成

if self.batch_data_type is not None:
if self.batch_data_type == 's':
logger.debug("Set as single field mode.")
self.input_fields.clear()
elif self.batch_data_type == 'd':
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \
f"index, but other field is set as dict mode."
elif self.batch_data_type == 'l':
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
f"field name is {field_name}."

if field_name == '_single':
self.batch_data_type = 's'
elif isinstance(field_name, str) and sequence_idx_str.match(field_name):
self.batch_data_type = 'l'
else:
self.batch_data_type = 'd'

def __call__(self, ins_lst) -> Dict:
out, list_out = {}, []
for idx, _collate_fn in enumerate(self._collators):
res = _collate_fn(ins_lst)
if isinstance(res, Dict):
out.update(res)
else:
list_out.append(res)
# else:
# raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict")
if len(out) > 0 and len(list_out) > 0:
raise ValueError("the return of collate_fns is not the same, must be dict or list")
if len(list_out) == 1:
list_out = list_out[-1]
# print(list_out)
return out if len(out) > 0 else list_out
if field_name in self.ignore_fields:
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.")
if backend is None:
backend = self.backend
else:
assert backend in SUPPORTED_BACKENDS

def get_collators(self):
return self._collators
self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn}

def add_collator(self, collator: Callable):
self._collators.append(collator)
return self

def set_as_numpy(self, as_numpy: bool):
def set_backend(self, backend:str):
"""
存在AutoCollator时,as_numpy控制其返回值的类型
设置可以 pad 的 field 默认 pad 为什么类型的 tensor

:param as_numpy:
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None],
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。
:return:
"""
for collator in self._collators:
if isinstance(collator, AutoCollator):
collator.set_as_numpy(as_numpy)
return self
assert backend in SUPPORTED_BACKENDS
self.padders.clear()
self.backend = backend

def set_pad_val(self, *field_names, val=0):
def set_ignore(self, *field_names) -> "Collator":
"""
存在AutoCollator时,设置field_name的padding值

:param field_names: 数据集的field名
:param val: padding的值
:return:
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
flag = True
for collator in self._collators:
if isinstance(collator, AutoCollator):
collator.set_pad_val(*field_names, val=val)
flag = False
if flag:
warnings.warn("AutoCollator is remove, set_padding is unavailable!!")
for field_name in field_names:
if field_name in self.input_fields:
self.input_fields.pop(field_name)
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.")
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。
self.ignore_fields.add(field_name)
return self

def set_input(self, *field_names):
"""
设置AutoCollator需要的field_names,未被设置默认过滤掉

:param field_names:
:return:
"""
flag = True
for collator in self._collators:
if isinstance(collator, AutoCollator):
collator.set_input(*field_names)
flag = False
if flag:
warnings.warn("AutoCollator is removed, set_input is unavailable!!")
return self


class AutoCollator(Collator):

def __init__(self, as_numpy: bool):
super(AutoCollator, self).__init__()
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0
self.need_inputs = set() # 需要的 field name
self.field_dtypes = None # 每列数据单元的 dtype 类型
self.field_dims = None # 每列数据单元维度
self.as_numpy = as_numpy

def __call__(self, ins_lst: List[Dict]) -> dict:
if len(self.need_inputs) == 0:
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"})
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的

# 第一种情况,设置了 set_input 的值
# 第二种情况, 根据数据的类型的判断是否 padding
if self.field_dtypes is None and self.field_dims is None:
field_dtypes, field_dims = {}, {}
for key, value in ins_lst[0].items():
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None:
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value)
self.field_dtypes = field_dtypes
self.field_dims = field_dims

pack_ins_lst, pad_ins_lst = {field_name: []
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {}
# 将 list 列表内数据按列名打包
for per_ins in ins_lst:
for field_name, _field_content in per_ins.items():
if field_name in self.need_inputs:
pack_ins_lst[field_name].append(_field_content)

pad_field_kv = {field_name: 0 for field_name in self.need_inputs}
pad_field_kv.update(self.pad_field_value)
self.pad_field_value = pad_field_kv

if len(self.pad_field_value.keys()) > 0:
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略
non_pad_field_names = []
for k, v in self.pad_field_value.items():
if v is None:
non_pad_field_names.append(k)

# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields))
for field_name in non_pad_field_names:
field_array = pack_ins_lst.pop(field_name)
pad_ins_lst[field_name] = np.array(field_array)

for field_name, field_array in pack_ins_lst.items():
content = pad_content(field_array, field_name, self.field_dtypes[field_name],
self.field_dims[field_name],
self.pad_field_value[field_name],
as_numpy=self.as_numpy)
pad_ins_lst[field_name] = content

# else:
# # 取出每列的数据,根据类型判断是否能 pad
# for field_name, field_array in pack_ins_lst.items():
# pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name],
# self.field_dims[field_name],
# pad_val=0, as_numpy=self.as_numpy)
# pad_ins_lst[field_name] = pad_field_array

return pad_ins_lst

def set_pad_val(self, *field_names, val=0):
for field_name in field_names:
self.pad_field_value[field_name] = val

def set_as_numpy(self, as_numpy: bool):
self.as_numpy = as_numpy

def set_input(self, *field_names):
for field_name in field_names:
self.need_inputs.add(field_name)


def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool):

if field_type:
# 不处理, 返回 np.array 类型
if field_dim > 3:
return np.array(content)
# 元素类型为数值类型 np.int64, np.float64, int, float 等
if isinstance(field_type, type) and \
(issubclass(field_type, np.number) or issubclass(field_type, Number)):
if field_dim == 0:
array = np.array(content, dtype=field_type)
elif field_dim == 1:
max_len = max(map(len, content))
array = np.full((len(content), max_len), pad_val, dtype=field_type)
for i, content_i in enumerate(content):
array[i, :len(content_i)] = content_i
elif field_dim == 2:
max_len = max(map(len, content))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in content])
array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type)
for i, content_i in enumerate(content):
for j, content_ii in enumerate(content_i):
array[i, j, :len(content_ii)] = content_ii
else:
shape = np.shape(content)
if len(shape) == 4: # 说明各 dimension 是相同的大小
array = np.array(content, dtype=field_type)
else:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
if as_numpy is False:
array = torch.tensor(array)
return array
# 元素类型为数值类型 torch.float 等
elif str(field_type).startswith('torch'):
if field_dim == 0:
tensor = torch.tensor(content).to(field_type)
elif field_dim == 1:
max_len = max(map(len, content))
tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
for i, content_i in enumerate(content):
tensor[i, :len(content_i)] = content_i.clone().detach()
elif field_dim == 2:
max_len = max(map(len, content))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in content])
tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
for j, content_ii in enumerate(content_i):
tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
else:
shapes = set([np.shape(content_i) for content_i in content])
if len(shapes) > 1:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
shape = shapes.pop()
if len(shape) == 3:
tensor = torch.full([len(content)] + list(shape), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
tensor[i] = content_i.clone().detach().to(field_type)
else:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return tensor
# TODO 增加jittor/paddle?
elif str(field_type).startswith('paddle'):
if field_dim == 0:
tensor = paddle.Tensor(content).to(field_type)
elif field_dim == 1:
max_len = max(map(len, content))
tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
for i, content_i in enumerate(content):
tensor[i, :len(content_i)] = content_i.clone().detach()
elif field_dim == 2:
max_len = max(map(len, content))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in content])
tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
for j, content_ii in enumerate(content_i):
tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
else:
shapes = set([np.shape(content_i) for content_i in content])
if len(shapes) > 1:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
shape = shapes.pop()
if len(shape) == 3:
tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
tensor[i] = content_i.clone().detach().to(field_type)
else:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return tensor

else:
return np.array(content) # 不进行任何操作
else:
return np.array(content)

#
# from abc import ABCMeta, abstractmethod
# from typing import Any, Dict, List, Callable, Union, Tuple
# from numbers import Number
# import warnings
#
# import numpy as np
#
# from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
#
# if _NEED_IMPORT_PADDLE:
# import paddle
#
# if _NEED_IMPORT_TORCH:
# import torch
#
#
# class ApplyResultException(Exception):
# def __init__(self, msg, index=None):
# super().__init__(msg)
# self.msg = msg
# self.index = index # 标示在哪个数据遭遇到问题了
#
#
# class SetInputOrTargetException(Exception):
# def __init__(self, msg, index=None, field_name=None):
# super().__init__(msg)
# self.msg = msg
# self.index = index # 标示在哪个数据遭遇到问题了
# self.field_name = field_name # 标示当前 field 的名称
#
#
# def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]:
# r"""
# 识别cell的类别与dimension的数量
#
# numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
# :param cell:
# :param dim:
# :return:
# """
# if isinstance(cell, (str, Number, np.bool_)):
# if hasattr(cell, 'dtype'):
# return cell.dtype.type, dim
# return type(cell), dim
#
# elif isinstance(cell, list):
# dim += 1
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
# types = set([i for i, j in res])
# dims = set([j for i, j in res])
# if len(types) > 1:
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
# elif len(types) == 0:
# raise SetInputOrTargetException("Empty value encountered.")
# if len(dims) > 1:
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
# return types.pop(), dims.pop()
#
# elif isinstance(cell, torch.Tensor):
# return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0
#
# elif isinstance(cell, paddle.Tensor):
# return cell.dtype, cell.dim() + dim
#
# elif isinstance(cell, np.ndarray):
# if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了
# return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等
# # 否则需要继续往下 iterate
# dim += 1
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
# types = set([i for i, j in res])
# dims = set([j for i, j in res])
# if len(types) > 1:
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
# elif len(types) == 0:
# raise SetInputOrTargetException("Empty value encountered.")
# if len(dims) > 1:
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
# return types.pop(), dims.pop()
#
# else: # 包含 tuple, set, dict 以及其它的类型
# raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")
#
#
# def _get_ds_type_dim(ds: dict):
# # 获取数据集第一行的 field 内部函数的类型和维度
# field_dtype, field_dim = {}, {}
# for field_name, field_content in ds.items():
# type_0, dim_0 = _get_ele_type_and_dim(field_content)
# field_dtype[field_name], field_dim[field_name] = type_0, dim_0
# return field_dtype, field_dim
#
#
# class Collator(metaclass=ABCMeta):
# r"""
# 辅助DataLoader管理collate_fn的类
#
# """
#
# def __init__(self):
# super(Collator, self).__init__()
# self.collate_fn = []
#
# @abstractmethod
# def __call__(self, ins_lst: List) -> Any:
# raise NotImplementedError
#
# @abstractmethod
# def set_pad_val(self, *field_names: str, value=0):
# raise NotImplementedError
#
#
# class _MultiCollator:
# """
# 管理所有collator的容器,
# 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。
# """
#
# def __init__(self, collate_fns: Union[Callable, List[Callable], None]):
#
# if collate_fns is None:
# collate_fns = []
#
# if isinstance(collate_fns, Callable):
# collate_fns = [collate_fns]
#
# self._collators: list = collate_fns
#
# def __call__(self, ins_lst) -> Dict:
# out, list_out = {}, []
# for idx, _collate_fn in enumerate(self._collators):
# res = _collate_fn(ins_lst)
# if isinstance(res, Dict):
# out.update(res)
# else:
# list_out.append(res)
# # else:
# # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict")
# if len(out) > 0 and len(list_out) > 0:
# raise ValueError("the return of collate_fns is not the same, must be dict or list")
# if len(list_out) == 1:
# list_out = list_out[-1]
# # print(list_out)
# return out if len(out) > 0 else list_out
#
# def get_collators(self):
# return self._collators
#
# def add_collator(self, collator: Callable):
# self._collators.append(collator)
#
# def set_as_numpy(self, as_numpy: bool):
# """
# 存在AutoCollator时,as_numpy控制其返回值的类型
#
# :param as_numpy:
# :return:
# """
# for collator in self._collators:
# if isinstance(collator, AutoCollator):
# collator.set_as_numpy(as_numpy)
# return self
#
# def set_pad_val(self, *field_names, val=0):
# """
# 存在AutoCollator时,设置field_name的padding值
#
# :param field_names: 数据集的field名
# :param val: padding的值
# :return:
# """
# flag = True
# for collator in self._collators:
# if isinstance(collator, AutoCollator):
# collator.set_pad_val(*field_names, val=val)
# flag = False
# if flag:
# warnings.warn("AutoCollator is remove, set_padding is unavailable!!")
# return self
#
# def set_input(self, *field_names):
# """
# 设置AutoCollator需要的field_names,未被设置默认过滤掉
#
# :param field_names:
# :return:
# """
# flag = True
# for collator in self._collators:
# if isinstance(collator, AutoCollator):
# collator.set_input(*field_names)
# flag = False
# if flag:
# warnings.warn("AutoCollator is removed, set_input is unavailable!!")
# return self
#
#
# class AutoCollator(Collator):
#
# def __init__(self, as_numpy: bool):
# super(AutoCollator, self).__init__()
# self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0
# self.need_inputs = set() # 需要的 field name
# self.field_dtypes = None # 每列数据单元的 dtype 类型
# self.field_dims = None # 每列数据单元维度
# self.as_numpy = as_numpy
#
# def __call__(self, ins_lst: List[Dict]) -> dict:
# if len(self.need_inputs) == 0:
# raise ValueError({"set_inputs is None, you should use set_inputs method first!!"})
# # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的
#
# # 第一种情况,设置了 set_input 的值
# # 第二种情况, 根据数据的类型的判断是否 padding
# if self.field_dtypes is None and self.field_dims is None:
# field_dtypes, field_dims = {}, {}
# for key, value in ins_lst[0].items():
# if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None:
# field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value)
# self.field_dtypes = field_dtypes
# self.field_dims = field_dims
#
# pack_ins_lst, pad_ins_lst = {field_name: []
# for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {}
# # 将 list 列表内数据按列名打包
# for per_ins in ins_lst:
# for field_name, _field_content in per_ins.items():
# if field_name in self.need_inputs:
# pack_ins_lst[field_name].append(_field_content)
#
# pad_field_kv = {field_name: 0 for field_name in self.need_inputs}
# pad_field_kv.update(self.pad_field_value)
# self.pad_field_value = pad_field_kv
#
# if len(self.pad_field_value.keys()) > 0:
# # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略
# non_pad_field_names = []
# for k, v in self.pad_field_value.items():
# if v is None:
# non_pad_field_names.append(k)
#
# # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields))
# for field_name in non_pad_field_names:
# field_array = pack_ins_lst.pop(field_name)
# pad_ins_lst[field_name] = np.array(field_array)
#
# for field_name, field_array in pack_ins_lst.items():
# content = pad_content(field_array, field_name, self.field_dtypes[field_name],
# self.field_dims[field_name],
# self.pad_field_value[field_name],
# as_numpy=self.as_numpy)
# pad_ins_lst[field_name] = content
#
# # else:
# # # 取出每列的数据,根据类型判断是否能 pad
# # for field_name, field_array in pack_ins_lst.items():
# # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name],
# # self.field_dims[field_name],
# # pad_val=0, as_numpy=self.as_numpy)
# # pad_ins_lst[field_name] = pad_field_array
#
# return pad_ins_lst
#
# def set_pad_val(self, *field_names, val=0):
# for field_name in field_names:
# self.pad_field_value[field_name] = val
#
# def set_as_numpy(self, as_numpy: bool):
# self.as_numpy = as_numpy
#
# def set_input(self, *field_names):
# for field_name in field_names:
# self.need_inputs.add(field_name)
#
#
# def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool):
#
# if field_type:
# # 不处理, 返回 np.array 类型
# if field_dim > 3:
# return np.array(content)
# # 元素类型为数值类型 np.int64, np.float64, int, float 等
# if isinstance(field_type, type) and \
# (issubclass(field_type, np.number) or issubclass(field_type, Number)):
# if field_dim == 0:
# array = np.array(content, dtype=field_type)
# elif field_dim == 1:
# max_len = max(map(len, content))
# array = np.full((len(content), max_len), pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# array[i, :len(content_i)] = content_i
# elif field_dim == 2:
# max_len = max(map(len, content))
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
# content_i in content])
# array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# for j, content_ii in enumerate(content_i):
# array[i, j, :len(content_ii)] = content_ii
# else:
# shape = np.shape(content)
# if len(shape) == 4: # 说明各 dimension 是相同的大小
# array = np.array(content, dtype=field_type)
# else:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# if as_numpy is False:
# array = torch.tensor(array)
# return array
# # 元素类型为数值类型 torch.float 等
# elif str(field_type).startswith('torch'):
# if field_dim == 0:
# tensor = torch.tensor(content).to(field_type)
# elif field_dim == 1:
# max_len = max(map(len, content))
# tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i, :len(content_i)] = content_i.clone().detach()
# elif field_dim == 2:
# max_len = max(map(len, content))
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
# content_i in content])
# tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# for j, content_ii in enumerate(content_i):
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
# else:
# shapes = set([np.shape(content_i) for content_i in content])
# if len(shapes) > 1:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# shape = shapes.pop()
# if len(shape) == 3:
# tensor = torch.full([len(content)] + list(shape), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i] = content_i.clone().detach().to(field_type)
# else:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# return tensor
# # TODO 增加jittor/paddle?
# elif str(field_type).startswith('paddle'):
# if field_dim == 0:
# tensor = paddle.Tensor(content).to(field_type)
# elif field_dim == 1:
# max_len = max(map(len, content))
# tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i, :len(content_i)] = content_i.clone().detach()
# elif field_dim == 2:
# max_len = max(map(len, content))
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
# content_i in content])
# tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# for j, content_ii in enumerate(content_i):
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
# else:
# shapes = set([np.shape(content_i) for content_i in content])
# if len(shapes) > 1:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# shape = shapes.pop()
# if len(shape) == 3:
# tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i] = content_i.clone().detach().to(field_type)
# else:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# return tensor
#
# else:
# return np.array(content) # 不进行任何操作
# else:
# return np.array(content)

+ 0
- 253
fastNLP/core/collators/new_collator.py View File

@@ -1,253 +0,0 @@
from typing import List, Union, Dict, Callable, Sequence, Mapping
import os
import sys
import inspect

from fastNLP.core.log import logger
from .padders.get_padder import get_padder

import re

from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \
pack_batch_sequence

sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None]
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend


def _get_backend() -> str:
"""
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个:
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是
某个 backend 的 dataloader 。
(2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。

如果都没有找到则返回 numpy 。
:return:
"""
def _check_module(module):
"""
检查该 module 是否含有 某个 backend 的特征

:param module: module 对象
:return:
"""
catch_backend = []
try:
file = module.__file__
for backend in CHECK_BACKEND:
if f'{os.sep}site-packages{os.sep}{backend}' in file:
catch_backend = [backend, file]
except:
pass
return catch_backend

currentframe = inspect.currentframe()
# 方式(1)
catch_backend = []
for i in range(100):
currentframe = currentframe.f_back
if currentframe is not None:
module = inspect.getmodule(currentframe)
if module is not None:
catch_backend = _check_module(module)
if len(catch_backend): # 主要捕获到一个就结束吧
break
else:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.")
return catch_backend[0]

# 方式 (2)
for key, module in sys.modules.items():
catch_backend = _check_module(module)
if catch_backend:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
return catch_backend[0]

return 'numpy'


class Collator:
def __init__(self, backend='auto'):
"""
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad
的数据返回一定是 list 。
"""
self.unpack_batch_func = None
self.pack_batch_func = None
self.ignore_fields = set()
self.padders = {}
self.input_fields = {}
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。
self.set_backend(backend)

def __call__(self, batch)->Union[List, Dict]:
"""
batch可能存在三种可能性
List[Dict], List[List], List[Sample]

第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。
第二步:使用每个 field 各自的 padder 进行 pad 。
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。

第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample
的类别。
第一次调用会根据当前 field 决定对应的 Padder 。

"""
if self.unpack_batch_func is None:
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型
if self.batch_data_type is None:
if isinstance(batch[0], Mapping):
self.batch_data_type = 'd'
elif isinstance(batch[0], Sequence): # 这里存在误判的风险
self.batch_data_type = 'l'
else:
self.batch_data_type = 's'
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
f"is `{self.batch_data_type}`.")
if self.batch_data_type == 's':
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整
self.pack_batch_func = lambda x: x['_single']
elif self.batch_data_type == 'l':
self.unpack_batch_func = unpack_batch_sequence
self.pack_batch_func = pack_batch_sequence
elif self.batch_data_type == 'd':
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value}
self.unpack_batch_func = unpack_batch_nested_mapping
self.pack_batch_func = pack_batch_nested_mapping
else:
self.unpack_batch_func = unpack_batch_mapping
self.pack_batch_func = lambda x:x

if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys()))
else:
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。

pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。
self.backend = _get_backend()

for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto':
self.input_fields[key]['backend'] = self.backend

for field_name, setting in self.input_fields.items():
pad_fn = setting.get('pad_fn', None)
if callable(pad_fn):
padder = pad_fn
else:
backend = self.backend if setting['backend'] == 'auto' else setting['backend']
batch_field = unpack_batch.get(field_name)
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
dtype=setting['dtype'], backend=backend,
field_name=field_name)
self.padders[field_name] = padder
if self.batch_data_type == 'l':
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序

for key, padder in self.padders.items():
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)

return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型

def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto',
pad_fn:Callable=None) -> "Collator":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
self.padders.clear() # 重新生成

if self.batch_data_type is not None:
if self.batch_data_type == 's':
logger.debug("Set as single field mode.")
self.input_fields.clear()
elif self.batch_data_type == 'd':
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \
f"index, but other field is set as dict mode."
elif self.batch_data_type == 'l':
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
f"field name is {field_name}."

if field_name == '_single':
self.batch_data_type = 's'
elif isinstance(field_name, str) and sequence_idx_str.match(field_name):
self.batch_data_type = 'l'
else:
self.batch_data_type = 'd'

if field_name in self.ignore_fields:
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.")
if backend is None:
backend = self.backend
else:
assert backend in SUPPORTED_BACKENDS

self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn}

return self

def set_backend(self, backend:str):
"""
设置可以 pad 的 field 默认 pad 为什么类型的 tensor

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None],
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。
:return:
"""
assert backend in SUPPORTED_BACKENDS
self.padders.clear()
self.backend = backend

def set_ignore(self, *field_names) -> "Collator":
"""
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
for field_name in field_names:
if field_name in self.input_fields:
self.input_fields.pop(field_name)
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.")
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。
self.ignore_fields.add(field_name)

return self



+ 9
- 1
fastNLP/core/collators/padders/get_padder.py View File

@@ -13,6 +13,7 @@ from .padder import Padder, NullPadder
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from .exceptions import *


@@ -27,7 +28,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
:param field_name: 方便报错的。
:return:
"""
logger.debug(f"The content in the field:`{field_name}` is:\n"+str(batch_field))

logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field))
if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
return NullPadder()
@@ -89,6 +91,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'torch':
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)

if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw':
@@ -97,12 +101,16 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'torch':
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)

if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'torch':
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)

if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report."


+ 178
- 0
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -0,0 +1,178 @@
__all__ = [
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder"
]
from inspect import isclass
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
import paddle
numpy_to_paddle_dtype_dict = {
np.bool_: 'bool',
np.uint8: 'uint8',
np.int8: "int8",
np.int16: "int16",
np.int32: "int32",
np.int64: "int64",
np.float16: "float16",
np.float32: 'float32',
np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
np.complex64: 'complex64',
np.complex128: "complex128"
}
number_to_paddle_dtype_dict = {
float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64
int: 'int64',
bool: 'bool'
}

from .padder import Padder
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
from .exceptions import *


def is_paddle_tensor(dtype):
if not isclass(dtype) and isinstance(dtype, paddle.dtype):
return True

return False


def is_paddle_dtype_str(dtype):
try:
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8',
u'int16', u'int32', u'int64', u'uint8', u'complex64',
u'complex128'}:
return True
except:
pass
return False



def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.")

if dtype is not None:
if not (is_paddle_tensor(dtype) or is_number(dtype) or is_paddle_dtype_str(dtype)):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or paddle.dtype but get `{dtype}`.")
dtype = number_to_paddle_dtype_dict.get(dtype, dtype)
else:
if (is_number(ele_dtype) or is_paddle_tensor(ele_dtype)):
ele_dtype = number_to_paddle_dtype_dict.get(ele_dtype, ele_dtype)
dtype = ele_dtype
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype.type)
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype)
else:
dtype = ele_dtype

return dtype


class PaddleNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
return paddle.to_tensor(batch_field, dtype=dtype)


class PaddleSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val)
return tensor


class PaddleTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
"""
目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的

:param ele_dtype:
:param pad_val:
:param dtype:
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
print(dtype)
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
if isinstance(field, np.ndarray):
field = paddle.to_tensor(field)
tensor[slices] = field
return tensor


def fill_tensor(batch_field, padded_batch, dtype):
"""
将 batch_field 中的值填入到 tensor 中。

:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 tensor
:param dtype: 数据的类别

:return:
"""
if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = paddle.to_tensor(content_i, dtype=dtype)
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = paddle.to_tensor(content_ii, dtype=dtype)
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
except:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = paddle.to_tensor(content_iii, dtype=dtype)
elif padded_batch.ndim == 1:
padded_batch[:] = paddle.to_tensor(batch_field, dtype=dtype)
else:
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
"report.")
return padded_batch


def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0):
"""
例如:
[[1,2], [3]] -> paddle.LongTensor([[1, 2], [3, 0]])

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
/4d(多为图片)。
:param dtype: 目标类别是什么
:param pad_val: pad 的 value
:return:
"""
shapes = get_shape(batch_field)
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

+ 1
- 0
fastNLP/core/controllers/trainer.py View File

@@ -440,6 +440,7 @@ class Trainer(TrainerEventTrigger):
"""
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
_own_callbacks.extend(self._custom_callbacks[None])
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().")
self._custom_callbacks[None] = []
if self.marker is not None:
if len(self._custom_callbacks[self.marker]) == 0:


+ 59
- 36
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -3,17 +3,18 @@ __all__ = [
'prepare_jittor_dataloader'
]

from typing import Callable, Optional, List
from typing import Callable, Optional, List, Union

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR

if _NEED_IMPORT_JITTOR:
from jittor.dataset.utils import collate_batch
from jittor.dataset import Dataset
else:
from fastNLP.core.dataset import DataSet as Dataset
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
from fastNLP.core.collators import AutoCollator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet


@@ -48,7 +49,7 @@ class JittorDataLoader:
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Callable = None) -> None:
collate_fn: Union[None, str, Callable] = "auto") -> None:
"""

:param dataset: 实现__getitem__和__len__的dataset
@@ -66,11 +67,20 @@ class JittorDataLoader:
# TODO 支持fastnlp dataset
# TODO 验证支持replacesampler (以后完成)
# 是否为 jittor 类型的 dataset

if isinstance(dataset, FDataSet):
collator = dataset.get_collator().set_as_numpy(as_numpy=True)
if isinstance(collate_fn, str):
if collate_fn == "auto":
if isinstance(dataset, FDataSet):
self._collate_fn = dataset.collator
self._collate_fn.set_backend(backend="jittor")
else:
self._collate_fn = Collator(backend="jittor")
else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable):
if collate_fn is not collate_batch:
self._collate_fn = collate_fn
else:
collator = None
self._collate_fn = collate_batch

self.dataset = _JittorDataset(dataset)

@@ -80,17 +90,13 @@ class JittorDataLoader:
if isinstance(self.dataset.dataset, Dataset):
self.dataset.dataset.set_attrs(batch_size=1)
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数
self.collate_fn = collate_fn
if self.collate_fn is None:
self.collate_fn = collate_batch
self.auto_collator = collator
self.cur_batch_indices = None
# self._collate_fn = _collate_fn

def __iter__(self):
# TODO 第一次迭代后不能设置collate_fn,设置是无效的
self.collate_fn = self._collate_fn
if self.cur_batch_indices is None:
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn,
self.auto_collator)))
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn))
for indices, data in self.dataset.__iter__():
self.cur_batch_indices = indices
yield data
@@ -100,39 +106,56 @@ class JittorDataLoader:
return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
"""
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names:
:param val: padding值,默认为0
:return:
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
if self.auto_collator is None:
self.auto_collator = AutoCollator(as_numpy=True)
self.auto_collator.set_pad_val(*field_names, val=val)
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")

def set_input(self, *field_names) -> None:
def set_ignore(self, *field_names) -> Collator:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
if self.auto_collator is None:
self.auto_collator = AutoCollator(as_numpy=True)

self.auto_collator.set_input(*field_names)
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self._collate_fn
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")

def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
获取当前 batch 的 idx

:return:
"""
return self.cur_batch_indices


def prepare_jittor_dataloader():
...

+ 83
- 75
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -6,6 +6,7 @@ __all__ = [
from typing import Callable, List, Optional, Union, Dict, Sequence

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
from paddle.io import DataLoader, Dataset
from paddle.fluid.dataloader.collate import default_collate_fn
@@ -13,9 +14,10 @@ else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader

from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.collators.collator import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler


class _PaddleDataset(Dataset):
@@ -45,7 +47,7 @@ class PaddleDataLoader(DataLoader):
def __init__(self, dataset, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Callable = None,
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto',
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False) -> None:
@@ -53,6 +55,10 @@ class PaddleDataLoader(DataLoader):
if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset)

if batch_sampler is None:
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last)

super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
@@ -60,13 +66,21 @@ class PaddleDataLoader(DataLoader):
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.get_collator()
self._collate_fn.set_as_numpy(as_numpy=True)
if collate_fn is not None:
self._collate_fn.add_collator(collate_fn)
if isinstance(collate_fn, str):
if collate_fn == 'auto':
if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="paddle")
else:
self._collate_fn = Collator(backend="paddle")

else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable):
if collate_fn is not default_collate_fn:
self._collate_fn = collate_fn
else:
self._collate_fn = _MultiCollator(collate_fn)
self._collate_fn = default_collate_fn
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True))
# if collate_fn is not None:
# _collate_fn.add_collator(collate_fn)
@@ -75,68 +89,60 @@ class PaddleDataLoader(DataLoader):

def __iter__(self):
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。
if len(self._collate_fn.get_collators()) == 0:
self._collate_fn.add_collator(default_collate_fn)
# self._collate_fn = default_collate_fn
# if len(self._collate_fn.get_collators()) == 0:
# self._collate_fn.add_collator(default_collate_fn)
# self._collate_fn = default_collate_fn
self.collate_fn = indice_collate_wrapper(self._collate_fn)
for indices, data in super().__iter__():
self.cur_batch_indices = indices
yield data

def __getattr__(self, item):
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
"""
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法

:param item:
:return:
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
try:
return self.dataset.__getattr__(item)
except AttributeError as e:
raise e

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names:
:param val: padding值,默认为0
:return:
"""
for field_name in field_names:
self._collate_fn.set_pad_val(field_name, val=val)

def set_input(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
self._collate_fn.set_input(*field_names)

def set_collator(self, collator: Callable) -> None:
"""
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate

:param collator: 用户自定义的Callable函数
:return:
"""
self._collate_fn = _MultiCollator(collator)
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")

def add_collator(self, collator) -> None:
def set_ignore(self, *field_names) -> Collator:
"""
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面

:param collator:
:return:
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
self._collate_fn.add_collator(collator)
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self._collate_fn
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")

def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
获取当前 batch 的 idx

:return:
"""
@@ -144,20 +150,22 @@ class PaddleDataLoader(DataLoader):


def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
train_batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Callable = None,
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False,
non_train_batch_size: int = 16,
input_fields: Union[List[str], str] = None)\
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]:
if isinstance(input_fields, str):
input_fields = [input_fields]

return_list: bool = True,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
train_batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None,
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False,
non_train_batch_size: int = 16) \
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]:
if isinstance(ds_or_db, Dataset):
...
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
return dl
elif isinstance(ds_or_db, Sequence):
ds_seq = []
for ds in ds_or_db:
@@ -166,7 +174,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
dl.set_input(*input_fields)
ds_seq.append(dl)
return ds_seq

@@ -178,14 +185,15 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
else:
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
dl.set_input(*input_fields)
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
ds_dict[name] = dl
return ds_dict
else:


+ 74
- 123
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,15 +3,14 @@ __all__ = [
'prepare_torch_dataloader'
]

from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping

from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
@@ -51,11 +50,11 @@ class TorchDataLoader(DataLoader):
def __init__(self, dataset, batch_size: int = 1,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto',
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None:
persistent_workers: bool = False, **kwargs) -> None:
"""

:param dataset: 实现了__getitem__和__len__的数据容器
@@ -64,7 +63,7 @@ class TorchDataLoader(DataLoader):
:param sampler: sampler实例化对象
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据
:param num_workers: 进程的数量,当num_worker=0时不开启多进程
:param collate_fn: 对取得到的数据进行打包的callable函数。[None, auto, callable]
:param collate_fn: [None, 'auto', callable] 对取得到的数据进行打包的callable函数
:param pin_memory:
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param timeout:
@@ -73,133 +72,99 @@ class TorchDataLoader(DataLoader):
:param generator:
:param prefetch_factor:
:param persistent_workers:
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型
"""
if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset)

if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.get_collator()
self._collate_fn.set_as_numpy(as_numpy)
if collate_fn is not None and collate_fn is not default_collate:
# 防止ddp重新初始化时候将torch dataloader的默认collate加进来
self._collate_fn.add_collator(collate_fn)
if isinstance(collate_fn, str):
if collate_fn == 'auto':
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="torch")
else:
self._collate_fn = Collator(backend="torch")
else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable):
if collate_fn is not default_collate:
self._collate_fn = collate_fn
else:
self._collate_fn = _MultiCollator(collate_fn)
self._collate_fn = default_collate

self.cur_indices_batch = None
self.as_numpy = as_numpy

def __getattr__(self, item):
"""
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法

:param item:
:return:
"""
try:
return self.dataset.__getattr__(item)
except AttributeError as e:
raise e

def __iter__(self):
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。
if len(self._collate_fn.get_collators()) == 0:
self._collate_fn.add_collator(self.collate_fn)
# if len(self._collate_fn.get_collators()) == 0:
# self._collate_fn.add_collator(self.collate_fn)
self.collate_fn = indice_collate_wrapper(self._collate_fn)
for indices, data in super().__iter__():
self.cur_batch_indices = indices
yield data

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names:
:param val: padding值,默认为0
:return:
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
"""
flag = False
for collator in self._collate_fn.get_collators():
if isinstance(collator, AutoCollator):
flag = True
break
if flag is False:
self._collate_fn.add_collator(AutoCollator(self.as_numpy))
for field_name in field_names:
self._collate_fn.set_pad_val(field_name, val=val)
def set_input(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉
:param field_names:
:return:
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
flag = False
for collator in self._collate_fn.get_collators():
if isinstance(collator, AutoCollator):
flag = True
break
if flag is False:
self._collate_fn.add_collator(AutoCollator(self.as_numpy))
self._collate_fn.set_input(*field_names)

def set_collator(self, collator: Callable) -> None:
"""
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate

:param collator: 用户自定义的Callable函数
:return:
"""
self._collate_fn = _MultiCollator(collator)

def add_collator(self, collator) -> None:
"""
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")

:param collator:
:return:
def set_ignore(self, *field_names) -> Collator:
"""
self._collate_fn.add_collator(collator)

def get_batch_indices(self) -> List[int]:
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
获取当前数据的idx

:return:
"""
return self.cur_batch_indices

def set_pad(self):
pass

def set_ignore(self):
pass

def set_backend(self):
pass

if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self._collate_fn
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str, None] = None) \
non_train_batch_size: int = 16) \
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
"""
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象
@@ -211,7 +176,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
:param sampler: sampler实例化对象
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据
:param num_workers: 进程的数量,当num_worker=0时不开启多进程
:param collate_fn: 对取得到的数据进行打包的callable函数
:param collate_fn: ['auto', None, callable]对取得到的数据进行打包的callable函数
:param pin_memory:
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param timeout:
@@ -222,11 +187,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
:param persistent_workers:
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler
:param non_train_batch_size:
:param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。
"""
# TODO dict, sequence情况下需要提供
if isinstance(input_fields, str):
input_fields = [input_fields]

if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
@@ -235,9 +196,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
if input_fields:
dl.set_input(*input_fields)
)
return dl

elif isinstance(ds_or_db, DataBundle):
@@ -251,7 +210,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler,
@@ -261,9 +220,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
if input_fields:
dl_bundle[name].set_input(*input_fields)
)
return dl_bundle

elif isinstance(ds_or_db, Sequence):
@@ -277,7 +234,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
)
else:
dl_bundle.append(
@@ -287,11 +244,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
)
if input_fields:
for dl in dl_bundle:
dl.set_input(*input_fields)
return dl_bundle

elif isinstance(ds_or_db, Mapping):
@@ -305,7 +259,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler,
@@ -315,10 +269,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)

if input_fields:
dl_bundle[name].set_input(*input_fields)
)

return dl_bundle
else:


+ 16
- 0
fastNLP/core/dataloaders/utils.py View File

@@ -0,0 +1,16 @@
def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。

:param func: 需要修饰的函数
:return:
"""

def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)

return wrapper

+ 9
- 60
fastNLP/core/dataset/dataset.py View File

@@ -23,9 +23,8 @@ except:
from .field import FieldArray
from .instance import Instance
from fastNLP.core.utils.utils import pretty_table_printer, deprecated
from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators import Collator
from fastNLP.core.utils.rich_progress import f_rich_progress
from fastNLP.core.collators.collator import _MultiCollator


class ApplyResultException(Exception):
@@ -114,7 +113,7 @@ class DataSet:
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。
"""
self.field_arrays = {}
self.collate_fns: _MultiCollator = _MultiCollator(AutoCollator(as_numpy=False))
self._collator = Collator(backend="numpy")
if data is not None:
if isinstance(data, Dict):
length_set = set()
@@ -181,7 +180,7 @@ class DataSet:
dataset = DataSet()
for field_name, field in self.field_arrays.items():
dataset.add_field(field_name=field_name, fields=field.content[idx])
dataset.collate_fns = deepcopy(self.collate_fns)
dataset._collator = deepcopy(self.collator)
return dataset
elif isinstance(idx, str):
if idx not in self:
@@ -193,7 +192,7 @@ class DataSet:
assert isinstance(i, int), "Only int index allowed."
instance = self[i]
dataset.append(instance)
dataset.collate_fns = deepcopy(self.collate_fns)
dataset._collator = deepcopy(self.collator)
return dataset
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
@@ -676,8 +675,8 @@ class DataSet:
dev_set.append(self[idx])
for idx in train_indices:
train_set.append(self[idx])
dev_set.collate_fns = deepcopy(self.collate_fns)
train_set.collate_fns = deepcopy(self.collate_fns)
dev_set._collator = deepcopy(self.collator)
train_set._collator = deepcopy(self.collator)

return dev_set, train_set

@@ -771,67 +770,17 @@ class DataSet:
df = self.to_pandas()
return df.to_csv(path, encoding="utf-8")

def add_collate_fn(self, collate_fn: Callable) -> None:
"""
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面

:param collate_fn: Callable的函数
:return:
"""
self.collate_fns.add_collator(collate_fn)

def set_collate_fn(self, collate_fn: Callable) -> None:
"""
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate

:param collate_fn:
:return:
"""
self.collate_fns = _MultiCollator(collate_fn)

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names: dataset存在的field_name
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。
:return:
"""
# TODO 不能为空
for field_name in field_names:
self.collate_fns.set_pad_val(field_name, val=val)

def set_input(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
#
self.collate_fns.set_input(*field_names)

def get_collator(self) -> _MultiCollator:
"""
获取dataset绑定的collate_fn,其中包括auto_collate

:return:
"""
return self.collate_fns

@deprecated()
def set_target(self, *field_names) -> None:
def set_ignore(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
self.collate_fns.set_input(*field_names)
self.collator.set_ignore(*field_names)

@property
def collator(self):
def collator(self) -> Collator:
if self._collator is None:
self._collator = Collator()
return self._collator

+ 2
- 2
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -22,7 +22,7 @@ from fastNLP.core.utils import (
rank_zero_rm
)
from fastNLP.core.samplers import (
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler,
ReproducibleBatchSampler,
RandomSampler,
@@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver):

return self.model, model.forward

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
reproducible: bool = False):
r"""
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。


+ 3
- 3
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -22,7 +22,7 @@ from fastNLP.core.log import logger
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
ReproducibleSampler,
RandomBatchSampler,
ReproduceBatchSampler,
RandomSampler,
)

@@ -345,7 +345,7 @@ class PaddleDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@@ -476,7 +476,7 @@ class PaddleDriver(Driver):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler


+ 2
- 2
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.utils import (
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler,
RandomSampler,
re_instantiate_sampler,
@@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver):
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last


+ 2
- 2
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -15,7 +15,7 @@ from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger

@@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver):
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last


+ 3
- 3
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler


class TorchDriver(Driver):
@@ -293,7 +293,7 @@ class TorchDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@@ -407,7 +407,7 @@ class TorchDriver(Driver):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler


+ 3
- 2
fastNLP/core/samplers/__init__.py View File

@@ -14,9 +14,10 @@ __all__ = [
"UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler",

"RandomBatchSampler",
"ReproduceBatchSampler",
"BucketedBatchSampler",
"ReproducibleBatchSampler",
"RandomBatchSampler",

"re_instantiate_sampler"
]
@@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler


+ 209
- 3
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -1,5 +1,6 @@
__all__ = [
'BucketedBatchSampler',
"ReproduceBatchSampler",
"RandomBatchSampler"
]

@@ -54,13 +55,13 @@ class ReproducibleBatchSampler:
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")


class RandomBatchSampler(ReproducibleBatchSampler):
class ReproduceBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
可以使得 batch_sampler 对象状态恢复的 wrapper 。

:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
:param batch_size: 每个 batch 的大小是多少。
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
@@ -143,7 +144,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.need_reinitialize = False

def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")

def set_epoch(self, epoch):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
@@ -158,6 +159,211 @@ class RandomBatchSampler(ReproducibleBatchSampler):
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size


class RandomBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = 0, **kwargs):
"""
随机分 batch 的 batch_sampler 。

:param dataset: 实现了 __len__ 方法的数据容器。
:param batch_size: 每个 batch 的大小
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__()

self.dataset = dataset

self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = seed

self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量

# 多卡的相关的参数
self.num_replicas = kwargs.get("num_replicas", 1)
self.rank = kwargs.get("rank", 0)
self.epoch = kwargs.get("epoch", -1)
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;

# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
self.during_iter = kwargs.get("during_iter", False)

# 以下变量为内部使用恢复状态的变量。
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)

def set_distributed(self, num_replicas, rank, pad=True):
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
"during an unfinished iteration."
assert num_replicas > 0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0 <= rank < num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank
self.pad = pad

return self

def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True

indices = list(range(len(self.dataset)))

if self.shuffle:
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
_batches = []
for _i in range(self.old_num_replicas):
_indices = indices[_i:len(indices):self.old_num_replicas]
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch)
_batches.append(__batches)
batches = list(chain(*[_ for _ in zip(*_batches)]))
indices = list(chain(*batches))
indices = indices[self.num_consumed_samples:]
# 取出这个 rank ,
indices = indices[self.rank:len(indices):self.num_replicas]
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch)
batches = list(map(list, batches))
else:
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
_num_batches = len(indices) // self.batch_size
if _num_batches == 0:
batches = [indices]
else:
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches)))
if len(indices)%self.batch_size!=0:
batches.append(indices[_num_batches*self.batch_size:])

need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。
else:
batches.append([batches[-1][0]])
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank:
if len(batches):
batches[-1].pop(-1)
if len(batches[-1])==0:
batches.pop(-1)

assert sum(map(len, batches)) == self.num_left_samples

if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
self.old_batch_size = self.batch_size
self.old_num_replicas = self.num_replicas
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了
self.epoch -= 1

def batchify(self, indices, batch_size, seed):
"""
将 indices 分为 batches

:param sorted_indices: List[int]
:param batch_size: int
:param seed: int
:return: List[List[int]]
"""
# 实际的 bucket 大小
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
num_samples = 0
batches = []
while num_samples<len(indices):
batches.append(indices[num_samples:num_samples+batch_size])
num_samples += batch_size
return batches

def set_epoch(self, epoch):
self.epoch = epoch

@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

@property
def total_size(self):
"""
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、
大于或者小于len(dataset)

:return:
"""
return self.num_consumed_samples + self.num_replicas*self.num_left_samples

@property
def num_left_samples(self):
"""
返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。

:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))

def __len__(self)->int:
"""
返回当前 sampler 还会返回多少个 batch 的数据

:return:
"""
num_sampler_per_rank = self.total_size//self.num_replicas
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
(num_sampler_per_rank+self.batch_size-1)//self.batch_size
return num_batches

def state_dict(self) -> Dict:
if self.old_batch_size != self.batch_size:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_replicas': self.num_replicas}

return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
f"we use shuffle={states['shuffle']}")
self.shuffle = states["shuffle"]
self.old_batch_size = states['batch_size']
self.old_num_replicas = states['num_replicas']


class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):


+ 1
- 2
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -54,13 +54,12 @@ class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
"""


:param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序。
:param seed: 随机数种子。
:param kwargs: 用户不需要使用,fastNLP 内部使用
"""
super(RandomSampler, self).__init__()
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed


+ 2
- 2
fastNLP/core/utils/__init__.py View File

@@ -21,7 +21,6 @@ __all__ = [
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'rank_zero_rm',
@@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
from ..dataloaders.utils import indice_collate_wrapper



+ 2
- 1
fastNLP/core/utils/jittor_utils.py View File

@@ -7,13 +7,13 @@ from collections.abc import Mapping, Callable
from functools import wraps

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR

if _NEED_IMPORT_JITTOR:
import jittor as jt

from fastNLP.core.dataset import Instance



def is_jittor_dataset(dataset) -> bool:
try:
if isinstance(dataset, jt.dataset.Dataset):
@@ -32,6 +32,7 @@ def jittor_collate_wraps(func, auto_collator: Callable):
:param auto_collator:
:return:
"""

@wraps(func)
def wrapper(batch):
if isinstance(batch[0], Instance):


+ 1
- 20
fastNLP/core/utils/utils.py View File

@@ -6,7 +6,7 @@ import warnings
from dataclasses import is_dataclass
from copy import deepcopy
from collections import defaultdict, OrderedDict
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
from typing import Tuple, Optional
from time import sleep

@@ -35,7 +35,6 @@ __all__ = [
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'rank_zero_rm',
@@ -513,24 +512,6 @@ class Option(dict):
self.update(state)


def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。

:param func: 需要修饰的函数
:return:
"""

def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)

return wrapper


_emitted_deprecation_warnings = set()




+ 106
- 0
tests/core/collators/padders/test_paddle_padder.py View File

@@ -0,0 +1,106 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
import paddle


@pytest.mark.paddle
class TestpaddleNumberPadder:
def test_run(self):
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
t_a = padder(a)
assert isinstance(t_a, paddle.Tensor)
assert (t_a == paddle.to_tensor(a, dtype='int64')).sum() == 3


@pytest.mark.paddle
class TestpaddleSequencePadder:
def test_run(self):
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = a.shape
assert isinstance(a, paddle.Tensor)
assert tuple(shape) == (2, 3)
b = paddle.to_tensor([[1, 2, 3], [3, -1, -1]], dtype='int64')
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]])
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)


@pytest.mark.paddle
class TestpaddleTensorPadder:
def test_run(self):
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
a = [paddle.zeros((3,)), paddle.zeros((2,))]
a = padder(a)
shape = a.shape
assert isinstance(a, paddle.Tensor)
assert tuple(shape) == (2, 3)
b = paddle.to_tensor([[0, 0, 0], [0, 0, -1]], dtype='int64')
assert (a == b).sum().item() == shape[0]*shape[1]

a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 2))]
a = padder(a)
shape = a.shape
assert isinstance(a, paddle.Tensor)
assert tuple(shape) == (3, 3, 2)
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, 0], [-1, -1], [-1, -1]]], dtype='int64')
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 1))]
a = padder(a)
shape = a.shape
assert isinstance(a, paddle.Tensor)
assert tuple(shape) == (3, 3, 2)
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))]
a = padder(a)
shape = a.shape
assert isinstance(a, paddle.Tensor)
assert tuple(shape) == (2, 3, 2)
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)]
a = padder(a)
shape = a.shape
assert isinstance(a, paddle.Tensor)
assert tuple(shape) == (2, 3, 2)
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]]], dtype='float32')
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)

def test_v1(self):
print(paddle.zeros((3, )).dtype)

+ 7
- 5
tests/core/dataloaders/jittor_dataloader/test_fdl.py View File

@@ -40,8 +40,8 @@ class TestJittor:
"""
dataset = MyDataset()
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4)
jtl.set_pad_val('x', 'y')
jtl.set_input('x')
# jtl.set_pad_val('x', 'y')
# jtl.set_input('x')
for batch in jtl:
print(batch)
print(jtl.get_batch_indices())
@@ -54,15 +54,17 @@ class TestJittor:
"""
dataset = Fdataset({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100})
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True)
jtl.set_pad_val('x', val=-1)
jtl.set_input('x', 'y')
jtl.set_pad("x", -1)
jtl.set_ignore("y")
# jtl.set_pad_val('x', val=-1)
# jtl.set_input('x', 'y')
for batch in jtl:
assert batch['x'].size() == (16, 4)

def test_v3(self):
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100})
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True)
jtl.set_input('x', 'y')
# jtl.set_input('x', 'y')
for batch in jtl:
print(batch)



+ 16
- 6
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -3,6 +3,8 @@ import numpy as np

from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
from paddle.io import Dataset, DataLoader
@@ -11,11 +13,12 @@ else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset



class RandomDataset(Dataset):

def __getitem__(self, idx):
image = np.random.random((10, 5)).astype('float32')
return {'image': paddle.Tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]}
return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]}

def __len__(self):
return 10
@@ -36,23 +39,30 @@ class TestPaddle:
def test_fdl_batch_indices(self):
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True)
fdl.set_input("x", "y")
for batch in fdl:
assert len(fdl.get_batch_indices()) == 4
print(batch)
print(fdl.get_batch_indices())

def test_set_inputs_and_set_pad_val(self):
logger.setLevel("DEBUG")
ds = RandomDataset()
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True)
fdl.set_input('image', 'label')
fdl.set_pad_val('label', val=-1)
fdl.set_pad('label', -1)
for batch in fdl:
print(batch['image'])
assert batch['image'].shape == [2, 10, 5]
print(batch)
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True)
fdl1.set_input('image', 'label')
fdl1.set_pad_val('image', val=None)
fdl1.set_ignore('label')
for batch in fdl1:
assert batch['image'].shape == [4, 10, 5]
print(batch)

def test_v2(self):
from fastNLP.core.collators import Collator
logger.setLevel("DEBUG")
data = [paddle.Tensor(np.random.random((10, 5)).astype('float32')), paddle.Tensor(np.random.random((10, 5)).astype('float32'))]
col = Collator(backend="jittor")
res = col(data)
print(res)

+ 2
- 21
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -13,42 +13,23 @@ class TestFdl:
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True)
# for batch in fdl:
# print(batch)
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True)
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True)
# for batch in fdl1:
# print(batch)

def test_set_padding(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
ds.set_pad_val("x", val=-1)
fdl = TorchDataLoader(ds, batch_size=3)
fdl.set_input("x", "y")
fdl.set_pad_val("x", val=None)
fdl.set_pad("x", -1)
for batch in fdl:
print(batch)
# fdl.set_pad_val("x", val=-2)
# for batch in fdl:
# print(batch)

def test_add_collator(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})

def collate_fn(ins_list):
_dict = {"Y": []}
for ins in ins_list:
_dict["Y"].append(ins['y'])
return _dict

fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True)
fdl.set_input("x", "y")
# fdl.set_pad_val("x", val=None)
fdl.add_collator(collate_fn)
for batch in fdl:
print(batch)

def test_get_batch_indices(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True)
fdl.set_input("y", "x")
for batch in fdl:
print(fdl.get_batch_indices())



+ 15
- 15
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -2,7 +2,7 @@ import pytest
from pathlib import Path

from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -278,7 +278,7 @@ class TestPaddleDriverFunctions:
dataset = PaddleNormalDataset()
dataloader = DataLoader(
dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size,
drop_last,
@@ -287,7 +287,7 @@ class TestPaddleDriverFunctions:
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
@@ -387,7 +387,7 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@@ -400,7 +400,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@@ -414,11 +414,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist

self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@@ -450,7 +450,7 @@ class TestSetDistReproDataloader:
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
@@ -459,7 +459,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@@ -500,20 +500,20 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
@@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")

@@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
@@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4



+ 3
- 3
tests/core/drivers/paddle_driver/test_utils.py View File

@@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import (
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
import paddle
@@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices,
def test_replace_batch_sampler():
dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)

replaced_loader = replace_batch_sampler(dataloader, batch_sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16


+ 12
- 12
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -2,7 +2,7 @@ import pytest
from pathlib import Path

from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
@@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE:

def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
"""
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader
建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader
"""
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
@@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
sampler = torch.utils.data.SequentialSampler(dataset)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last
),
@@ -306,7 +306,7 @@ class TestTorchDriverFunctions:
res = TorchSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle:
assert isinstance(res.sampler, torch.utils.data.RandomSampler)
else:
@@ -401,7 +401,7 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@@ -414,7 +414,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@@ -428,11 +428,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist

self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@@ -466,7 +466,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@@ -502,14 +502,14 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()

# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
@@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4



+ 1
- 1
tests/core/drivers/torch_driver/test_torch_replace_sampler.py View File

@@ -30,7 +30,7 @@ class SequenceDataSet:


def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler
# reproducible 是 True 和 False

# 需要 check 返回的 sampler 和 dataloader 都不同了


+ 3
- 3
tests/core/drivers/torch_driver/test_utils.py View File

@@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from torch.utils.data import DataLoader, BatchSampler

from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
def test_replace_batch_sampler():
dataset = TorchNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)

replaced_loader = replace_batch_sampler(dataloader, batch_sampler)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, TorchNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16


+ 319
- 9
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -5,7 +5,7 @@ import pytest
from itertools import chain
from copy import deepcopy

from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset

@@ -19,7 +19,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# forward_steps = 3
@@ -29,15 +29,15 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "RandomBatchSampler"}
# "sampler_type": "ReproduceBatchSampler"}
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
@@ -54,7 +54,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# # 改变 batch_size;
# after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
@@ -100,7 +100,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
@@ -112,13 +112,13 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
@@ -511,3 +511,313 @@ class TestBucketedBatchSampler:
already_seen_set.update(batch)

assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)


class TestRandomBatchSampler:
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71])
def test_single_num_batch(self, shuffle, drop_last, num):
# 数量不够不报错
for num in [2, 7, 14, 15, 70, 71]:
dataset = DatasetWithVaryLength(num_of_data=num)
before_batch_size = 7
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
count = len(list(iter(re_batchsampler)))
if drop_last:
assert count==num//before_batch_size, num
else:
assert count==(num+before_batch_size-1)//before_batch_size, num

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
def test_single(self, shuffle, drop_last):

before_batch_size = 7
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4

dataset = DatasetWithVaryLength(num_of_data=1000)
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler.set_epoch(0)
forward_steps = 10
iterator = iter(re_batchsampler)
already_generate_indices = set()
for _ in range(forward_steps):
batch = next(iterator)
already_generate_indices.update(batch)

# 1. 保存状态
state = re_batchsampler.state_dict()

# 2. 断点重训,继续训练
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler2.load_state_dict(state)
re_batchsampler2.set_epoch(0)
new_already_generate_indices = set()
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-before_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i])
for batch in re_batchsampler2:
for b in batch:
assert b not in already_generate_indices
new_already_generate_indices.update(batch)
if drop_last is False:
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset)

# 改变 batch_size;
after_batch_size = 3
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler3.load_state_dict(state)
re_batchsampler3.set_epoch(0)
count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]

for batch in re_batchsampler3:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

# 再 save ,不允许再上个epoch没结束继续sample
after_batch_size = 5
with pytest.raises(RuntimeError):
state = re_batchsampler3.state_dict()

for batch in re_batchsampler3: # consume all, 这样才能save
pass

already_generate_indices = set()
count = 0
for batch in re_batchsampler3: # 重新开始
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

state = re_batchsampler3.state_dict()
# 这里的 drop_last 为 False,需要最终是所有 sample
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=False,
shuffle=shuffle)
re_batchsampler4.load_state_dict(state)
re_batchsampler4.set_epoch(0)

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
for batch in re_batchsampler4:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)

assert len(already_generate_indices) == len(dataset)

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
def test_multi(self, shuffle, drop_last, pad):
# def test_multi(self, shuffle=True, drop_last=False, pad=False):

# no shuffle
num_replica = 2
dataset = DatasetWithVaryLength(num_of_data=1000)
batch_size = 5
num_batch_per_bucket = 10
lengths = []
rank0_already_seen_indexes = None
max_diff = num_batch_per_bucket * batch_size * num_replica
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
already_seen_indexes = set()
repeat_count = 0
for batch in sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if rank0_already_seen_indexes: # 不能交叉出现
assert b not in rank0_already_seen_indexes
already_seen_indexes.update(batch)
if rank0_already_seen_indexes is None:
rank0_already_seen_indexes = already_seen_indexes
if pad: # 应该允许重复一次
assert repeat_count<=1
else:
assert repeat_count==0

assert len(set(lengths))==1, lengths # 每个进程的batch数量一致

# 多进程的保存
already_seen_indexes = set()
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
count = 0
for batch in sampler:
already_seen_indexes.update(batch)
if count>5:
break
count += 1
state = sampler.state_dict()

# 切换成单机
new_batch_size = 6
num_batch_per_bucket = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.load_state_dict(state)
repeat_count = 0
new_already_seen_indexes = set(list(already_seen_indexes))

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]

for batch in new_sampler:
for b in batch:
repeat_count += int(b in new_already_seen_indexes)
new_already_seen_indexes.update(batch)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
if drop_last is False: # 如果没有drop应该相等
assert len(new_already_seen_indexes)==len(dataset)

# 测试替换卡的数量。
num_replica = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.set_epoch(0)
new_sampler.load_state_dict(state)
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad)
repeat_count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]

for batch in new_sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replicas*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的(forward)数据

:return:
"""
batch_size = 6
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)

sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for batchs in zip(*samplers):
batch = chain(*batchs)
already_seen_set.update(batch)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
shuffle=shuffle, drop_last=drop_last)
sampler.set_epoch(0)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
dataset)

# 测试保存之后再次保存
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
states['num_consumed_samples'] = num_consumed_samples_array[2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for batch in sampler:
already_seen_set.update(batch)
count += 1
if count > 6:
break

states = sampler.state_dict()
num_consumed_samples_array = list(range(len(dataset)))
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
shuffle=shuffle,
drop_last=drop_last)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for batch in sampler:
already_seen_set.update(batch)

assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)

Loading…
Cancel
Save