@@ -1,5 +1,4 @@ | |||
__all__ = [ | |||
'AutoCollator', | |||
'Collator' | |||
] | |||
from .collator import AutoCollator, Collator | |||
from .collator import Collator |
@@ -1,386 +1,573 @@ | |||
__all__ = [ | |||
'AutoCollator', | |||
'Collator', | |||
] | |||
from typing import List, Union, Dict, Callable, Sequence, Mapping | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, Dict, List, Callable, Union, Tuple | |||
from numbers import Number | |||
import warnings | |||
import numpy as np | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
class ApplyResultException(Exception): | |||
def __init__(self, msg, index=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
from fastNLP.core.log import logger | |||
from .padders.get_padder import get_padder | |||
import re | |||
class SetInputOrTargetException(Exception): | |||
def __init__(self, msg, index=None, field_name=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
self.field_name = field_name # 标示当前 field 的名称 | |||
from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||
pack_batch_sequence | |||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] | |||
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||
r""" | |||
识别cell的类别与dimension的数量 | |||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||
:param cell: | |||
:param dim: | |||
:return: | |||
""" | |||
if isinstance(cell, (str, Number, np.bool_)): | |||
if hasattr(cell, 'dtype'): | |||
return cell.dtype.type, dim | |||
return type(cell), dim | |||
elif isinstance(cell, list): | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i, j in res]) | |||
dims = set([j for i, j in res]) | |||
if len(types) > 1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types) == 0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims) > 1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
elif isinstance(cell, torch.Tensor): | |||
return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 | |||
elif isinstance(cell, paddle.Tensor): | |||
return cell.dtype, cell.dim() + dim | |||
elif isinstance(cell, np.ndarray): | |||
if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 | |||
return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 | |||
# 否则需要继续往下 iterate | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i, j in res]) | |||
dims = set([j for i, j in res]) | |||
if len(types) > 1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types) == 0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims) > 1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
else: # 包含 tuple, set, dict 以及其它的类型 | |||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||
def _get_ds_type_dim(ds: dict): | |||
# 获取数据集第一行的 field 内部函数的类型和维度 | |||
field_dtype, field_dim = {}, {} | |||
for field_name, field_content in ds.items(): | |||
type_0, dim_0 = _get_ele_type_and_dim(field_content) | |||
field_dtype[field_name], field_dim[field_name] = type_0, dim_0 | |||
return field_dtype, field_dim | |||
class Collator(metaclass=ABCMeta): | |||
r""" | |||
辅助DataLoader管理collate_fn的类 | |||
""" | |||
def __init__(self): | |||
super(Collator, self).__init__() | |||
self.collate_fn = [] | |||
@abstractmethod | |||
def __call__(self, ins_lst: List) -> Any: | |||
raise NotImplementedError | |||
@abstractmethod | |||
def set_pad_val(self, *field_names: str, value=0): | |||
raise NotImplementedError | |||
class _MultiCollator: | |||
""" | |||
管理所有collator的容器, | |||
遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 | |||
""" | |||
def __init__(self, collate_fns: Union[Callable, List[Callable], None]): | |||
if collate_fns is None: | |||
collate_fns = [] | |||
if isinstance(collate_fns, Callable): | |||
collate_fns = [collate_fns] | |||
self._collators: list = collate_fns | |||
def __call__(self, ins_lst) -> Dict: | |||
out, list_out = {}, [] | |||
for idx, _collate_fn in enumerate(self._collators): | |||
res = _collate_fn(ins_lst) | |||
if isinstance(res, Dict): | |||
out.update(res) | |||
else: | |||
list_out.append(res) | |||
# else: | |||
# raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") | |||
if len(out) > 0 and len(list_out) > 0: | |||
raise ValueError("the return of collate_fns is not the same, must be dict or list") | |||
if len(list_out) == 1: | |||
list_out = list_out[-1] | |||
# print(list_out) | |||
return out if len(out) > 0 else list_out | |||
def get_collators(self): | |||
return self._collators | |||
def add_collator(self, collator: Callable): | |||
self._collators.append(collator) | |||
def set_as_numpy(self, as_numpy: bool): | |||
""" | |||
存在AutoCollator时,as_numpy控制其返回值的类型 | |||
:param as_numpy: | |||
:return: | |||
class Collator: | |||
def __init__(self, backend='torch'): | |||
""" | |||
for collator in self._collators: | |||
if isinstance(collator, AutoCollator): | |||
collator.set_as_numpy(as_numpy) | |||
return self | |||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||
def set_pad_val(self, *field_names, val=0): | |||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 | |||
若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 | |||
""" | |||
self.unpack_batch_func = None | |||
self.pack_batch_func = None | |||
self.ignore_fields = set() | |||
self.padders = {} | |||
self.input_fields = {} | |||
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 | |||
self.set_backend(backend) | |||
def __call__(self, batch)->Union[List, Dict]: | |||
""" | |||
存在AutoCollator时,设置field_name的padding值 | |||
batch可能存在三种可能性 | |||
List[Dict], List[List], List[Sample] | |||
第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 | |||
第二步:使用每个 field 各自的 padder 进行 pad 。 | |||
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 | |||
第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 | |||
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample | |||
的类别。 | |||
第一次调用会根据当前 field 决定对应的 Padder 。 | |||
:param field_names: 数据集的field名 | |||
:param val: padding的值 | |||
:return: | |||
""" | |||
flag = True | |||
for collator in self._collators: | |||
if isinstance(collator, AutoCollator): | |||
collator.set_pad_val(*field_names, val=val) | |||
flag = False | |||
if flag: | |||
warnings.warn("AutoCollator is remove, set_padding is unavailable!!") | |||
if self.unpack_batch_func is None: | |||
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 | |||
if self.batch_data_type is None: | |||
if isinstance(batch[0], Mapping): | |||
self.batch_data_type = 'd' | |||
elif isinstance(batch[0], Sequence): # 这里存在误判的风险 | |||
self.batch_data_type = 'l' | |||
else: | |||
self.batch_data_type = 's' | |||
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||
f"is `{self.batch_data_type}`.") | |||
if self.batch_data_type == 's': | |||
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 | |||
self.pack_batch_func = lambda x: x['_single'] | |||
elif self.batch_data_type == 'l': | |||
self.unpack_batch_func = unpack_batch_sequence | |||
self.pack_batch_func = pack_batch_sequence | |||
elif self.batch_data_type == 'd': | |||
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} | |||
self.unpack_batch_func = unpack_batch_nested_mapping | |||
self.pack_batch_func = pack_batch_nested_mapping | |||
else: | |||
self.unpack_batch_func = unpack_batch_mapping | |||
self.pack_batch_func = lambda x:x | |||
# 在这里用ignore_field过滤掉 | |||
if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | |||
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | |||
else: | |||
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 | |||
pad_batch = {} | |||
if len(self.padders)==0: # 第一次运行,准备 padder | |||
for key in unpack_batch.keys(): | |||
if key not in self.input_fields and key not in self.ignore_fields: | |||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | |||
for field_name, setting in self.input_fields.items(): | |||
pad_fn = setting.get('pad_fn', None) | |||
if callable(pad_fn): | |||
padder = pad_fn | |||
else: | |||
batch_field = unpack_batch.get(field_name) | |||
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | |||
dtype=setting['dtype'], backend=setting['backend'], | |||
field_name=field_name) | |||
self.padders[field_name] = padder | |||
if self.batch_data_type == 'l': | |||
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | |||
for key, padder in self.padders.items(): | |||
batch = unpack_batch.get(key) | |||
pad_batch[key] = padder(batch) | |||
return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
pad_fn:Callable=None) -> "Collator": | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
self.padders.clear() # 重新生成 | |||
if self.batch_data_type is not None: | |||
if self.batch_data_type == 's': | |||
logger.debug("Set as single field mode.") | |||
self.input_fields.clear() | |||
elif self.batch_data_type == 'd': | |||
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ | |||
f"index, but other field is set as dict mode." | |||
elif self.batch_data_type == 'l': | |||
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||
f"field name is {field_name}." | |||
if field_name == '_single': | |||
self.batch_data_type = 's' | |||
elif isinstance(field_name, str) and sequence_idx_str.match(field_name): | |||
self.batch_data_type = 'l' | |||
else: | |||
self.batch_data_type = 'd' | |||
if field_name in self.ignore_fields: | |||
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") | |||
if backend is None: | |||
backend = self.backend | |||
else: | |||
assert backend in SUPPORTED_BACKENDS | |||
self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} | |||
return self | |||
def set_input(self, *field_names): | |||
def set_backend(self, backend:str): | |||
""" | |||
设置AutoCollator需要的field_names,未被设置默认过滤掉 | |||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | |||
:param field_names: | |||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||
若为 None ,则不进行 padding 。 | |||
:return: | |||
""" | |||
flag = True | |||
for collator in self._collators: | |||
if isinstance(collator, AutoCollator): | |||
collator.set_input(*field_names) | |||
flag = False | |||
if flag: | |||
warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||
assert backend in SUPPORTED_BACKENDS | |||
self.padders.clear() | |||
self.backend = backend | |||
def set_ignore(self, *field_names) -> "Collator": | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Ex:: | |||
collator.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
for field_name in field_names: | |||
if field_name in self.input_fields: | |||
self.input_fields.pop(field_name) | |||
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | |||
self.padders.pop(field_name, None) # 如果有的话,将它的 padder 扔掉。 | |||
self.ignore_fields.add(field_name) | |||
return self | |||
class AutoCollator(Collator): | |||
def __init__(self, as_numpy: bool): | |||
super(AutoCollator, self).__init__() | |||
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||
self.need_inputs = set() # 需要的 field name | |||
self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||
self.field_dims = None # 每列数据单元维度 | |||
self.as_numpy = as_numpy | |||
def __call__(self, ins_lst: List[Dict]) -> dict: | |||
if len(self.need_inputs) == 0: | |||
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||
# 第一种情况,设置了 set_input 的值 | |||
# 第二种情况, 根据数据的类型的判断是否 padding | |||
if self.field_dtypes is None and self.field_dims is None: | |||
field_dtypes, field_dims = {}, {} | |||
for key, value in ins_lst[0].items(): | |||
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||
self.field_dtypes = field_dtypes | |||
self.field_dims = field_dims | |||
pack_ins_lst, pad_ins_lst = {field_name: [] | |||
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||
# 将 list 列表内数据按列名打包 | |||
for per_ins in ins_lst: | |||
for field_name, _field_content in per_ins.items(): | |||
if field_name in self.need_inputs: | |||
pack_ins_lst[field_name].append(_field_content) | |||
pad_field_kv = {field_name: 0 for field_name in self.need_inputs} | |||
pad_field_kv.update(self.pad_field_value) | |||
self.pad_field_value = pad_field_kv | |||
if len(self.pad_field_value.keys()) > 0: | |||
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||
non_pad_field_names = [] | |||
for k, v in self.pad_field_value.items(): | |||
if v is None: | |||
non_pad_field_names.append(k) | |||
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||
for field_name in non_pad_field_names: | |||
field_array = pack_ins_lst.pop(field_name) | |||
pad_ins_lst[field_name] = np.array(field_array) | |||
for field_name, field_array in pack_ins_lst.items(): | |||
content = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
self.field_dims[field_name], | |||
self.pad_field_value[field_name], | |||
as_numpy=self.as_numpy) | |||
pad_ins_lst[field_name] = content | |||
# else: | |||
# # 取出每列的数据,根据类型判断是否能 pad | |||
# for field_name, field_array in pack_ins_lst.items(): | |||
# pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
# self.field_dims[field_name], | |||
# pad_val=0, as_numpy=self.as_numpy) | |||
# pad_ins_lst[field_name] = pad_field_array | |||
return pad_ins_lst | |||
def set_pad_val(self, *field_names, val=0): | |||
for field_name in field_names: | |||
self.pad_field_value[field_name] = val | |||
def set_as_numpy(self, as_numpy: bool): | |||
self.as_numpy = as_numpy | |||
def set_input(self, *field_names): | |||
for field_name in field_names: | |||
self.need_inputs.add(field_name) | |||
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||
if field_type: | |||
# 不处理, 返回 np.array 类型 | |||
if field_dim > 3: | |||
return np.array(content) | |||
# 元素类型为数值类型 np.int64, np.float64, int, float 等 | |||
if isinstance(field_type, type) and \ | |||
(issubclass(field_type, np.number) or issubclass(field_type, Number)): | |||
if field_dim == 0: | |||
array = np.array(content, dtype=field_type) | |||
elif field_dim == 1: | |||
max_len = max(map(len, content)) | |||
array = np.full((len(content), max_len), pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
array[i, :len(content_i)] = content_i | |||
elif field_dim == 2: | |||
max_len = max(map(len, content)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in content]) | |||
array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
for j, content_ii in enumerate(content_i): | |||
array[i, j, :len(content_ii)] = content_ii | |||
else: | |||
shape = np.shape(content) | |||
if len(shape) == 4: # 说明各 dimension 是相同的大小 | |||
array = np.array(content, dtype=field_type) | |||
else: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
if as_numpy is False: | |||
array = torch.tensor(array) | |||
return array | |||
# 元素类型为数值类型 torch.float 等 | |||
elif str(field_type).startswith('torch'): | |||
if field_dim == 0: | |||
tensor = torch.tensor(content).to(field_type) | |||
elif field_dim == 1: | |||
max_len = max(map(len, content)) | |||
tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||
elif field_dim == 2: | |||
max_len = max(map(len, content)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in content]) | |||
tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
for j, content_ii in enumerate(content_i): | |||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
else: | |||
shapes = set([np.shape(content_i) for content_i in content]) | |||
if len(shapes) > 1: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
shape = shapes.pop() | |||
if len(shape) == 3: | |||
tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i] = content_i.clone().detach().to(field_type) | |||
else: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return tensor | |||
# TODO 增加jittor/paddle? | |||
elif str(field_type).startswith('paddle'): | |||
if field_dim == 0: | |||
tensor = paddle.Tensor(content).to(field_type) | |||
elif field_dim == 1: | |||
max_len = max(map(len, content)) | |||
tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||
elif field_dim == 2: | |||
max_len = max(map(len, content)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in content]) | |||
tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
for j, content_ii in enumerate(content_i): | |||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
else: | |||
shapes = set([np.shape(content_i) for content_i in content]) | |||
if len(shapes) > 1: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
shape = shapes.pop() | |||
if len(shape) == 3: | |||
tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i] = content_i.clone().detach().to(field_type) | |||
else: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return tensor | |||
else: | |||
return np.array(content) # 不进行任何操作 | |||
else: | |||
return np.array(content) | |||
# | |||
# from abc import ABCMeta, abstractmethod | |||
# from typing import Any, Dict, List, Callable, Union, Tuple | |||
# from numbers import Number | |||
# import warnings | |||
# | |||
# import numpy as np | |||
# | |||
# from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
# | |||
# if _NEED_IMPORT_PADDLE: | |||
# import paddle | |||
# | |||
# if _NEED_IMPORT_TORCH: | |||
# import torch | |||
# | |||
# | |||
# class ApplyResultException(Exception): | |||
# def __init__(self, msg, index=None): | |||
# super().__init__(msg) | |||
# self.msg = msg | |||
# self.index = index # 标示在哪个数据遭遇到问题了 | |||
# | |||
# | |||
# class SetInputOrTargetException(Exception): | |||
# def __init__(self, msg, index=None, field_name=None): | |||
# super().__init__(msg) | |||
# self.msg = msg | |||
# self.index = index # 标示在哪个数据遭遇到问题了 | |||
# self.field_name = field_name # 标示当前 field 的名称 | |||
# | |||
# | |||
# def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||
# r""" | |||
# 识别cell的类别与dimension的数量 | |||
# | |||
# numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||
# :param cell: | |||
# :param dim: | |||
# :return: | |||
# """ | |||
# if isinstance(cell, (str, Number, np.bool_)): | |||
# if hasattr(cell, 'dtype'): | |||
# return cell.dtype.type, dim | |||
# return type(cell), dim | |||
# | |||
# elif isinstance(cell, list): | |||
# dim += 1 | |||
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
# types = set([i for i, j in res]) | |||
# dims = set([j for i, j in res]) | |||
# if len(types) > 1: | |||
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
# elif len(types) == 0: | |||
# raise SetInputOrTargetException("Empty value encountered.") | |||
# if len(dims) > 1: | |||
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
# return types.pop(), dims.pop() | |||
# | |||
# elif isinstance(cell, torch.Tensor): | |||
# return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 | |||
# | |||
# elif isinstance(cell, paddle.Tensor): | |||
# return cell.dtype, cell.dim() + dim | |||
# | |||
# elif isinstance(cell, np.ndarray): | |||
# if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 | |||
# return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 | |||
# # 否则需要继续往下 iterate | |||
# dim += 1 | |||
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
# types = set([i for i, j in res]) | |||
# dims = set([j for i, j in res]) | |||
# if len(types) > 1: | |||
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
# elif len(types) == 0: | |||
# raise SetInputOrTargetException("Empty value encountered.") | |||
# if len(dims) > 1: | |||
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
# return types.pop(), dims.pop() | |||
# | |||
# else: # 包含 tuple, set, dict 以及其它的类型 | |||
# raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||
# | |||
# | |||
# def _get_ds_type_dim(ds: dict): | |||
# # 获取数据集第一行的 field 内部函数的类型和维度 | |||
# field_dtype, field_dim = {}, {} | |||
# for field_name, field_content in ds.items(): | |||
# type_0, dim_0 = _get_ele_type_and_dim(field_content) | |||
# field_dtype[field_name], field_dim[field_name] = type_0, dim_0 | |||
# return field_dtype, field_dim | |||
# | |||
# | |||
# class Collator(metaclass=ABCMeta): | |||
# r""" | |||
# 辅助DataLoader管理collate_fn的类 | |||
# | |||
# """ | |||
# | |||
# def __init__(self): | |||
# super(Collator, self).__init__() | |||
# self.collate_fn = [] | |||
# | |||
# @abstractmethod | |||
# def __call__(self, ins_lst: List) -> Any: | |||
# raise NotImplementedError | |||
# | |||
# @abstractmethod | |||
# def set_pad_val(self, *field_names: str, value=0): | |||
# raise NotImplementedError | |||
# | |||
# | |||
# class _MultiCollator: | |||
# """ | |||
# 管理所有collator的容器, | |||
# 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 | |||
# """ | |||
# | |||
# def __init__(self, collate_fns: Union[Callable, List[Callable], None]): | |||
# | |||
# if collate_fns is None: | |||
# collate_fns = [] | |||
# | |||
# if isinstance(collate_fns, Callable): | |||
# collate_fns = [collate_fns] | |||
# | |||
# self._collators: list = collate_fns | |||
# | |||
# def __call__(self, ins_lst) -> Dict: | |||
# out, list_out = {}, [] | |||
# for idx, _collate_fn in enumerate(self._collators): | |||
# res = _collate_fn(ins_lst) | |||
# if isinstance(res, Dict): | |||
# out.update(res) | |||
# else: | |||
# list_out.append(res) | |||
# # else: | |||
# # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") | |||
# if len(out) > 0 and len(list_out) > 0: | |||
# raise ValueError("the return of collate_fns is not the same, must be dict or list") | |||
# if len(list_out) == 1: | |||
# list_out = list_out[-1] | |||
# # print(list_out) | |||
# return out if len(out) > 0 else list_out | |||
# | |||
# def get_collators(self): | |||
# return self._collators | |||
# | |||
# def add_collator(self, collator: Callable): | |||
# self._collators.append(collator) | |||
# | |||
# def set_as_numpy(self, as_numpy: bool): | |||
# """ | |||
# 存在AutoCollator时,as_numpy控制其返回值的类型 | |||
# | |||
# :param as_numpy: | |||
# :return: | |||
# """ | |||
# for collator in self._collators: | |||
# if isinstance(collator, AutoCollator): | |||
# collator.set_as_numpy(as_numpy) | |||
# return self | |||
# | |||
# def set_pad_val(self, *field_names, val=0): | |||
# """ | |||
# 存在AutoCollator时,设置field_name的padding值 | |||
# | |||
# :param field_names: 数据集的field名 | |||
# :param val: padding的值 | |||
# :return: | |||
# """ | |||
# flag = True | |||
# for collator in self._collators: | |||
# if isinstance(collator, AutoCollator): | |||
# collator.set_pad_val(*field_names, val=val) | |||
# flag = False | |||
# if flag: | |||
# warnings.warn("AutoCollator is remove, set_padding is unavailable!!") | |||
# return self | |||
# | |||
# def set_input(self, *field_names): | |||
# """ | |||
# 设置AutoCollator需要的field_names,未被设置默认过滤掉 | |||
# | |||
# :param field_names: | |||
# :return: | |||
# """ | |||
# flag = True | |||
# for collator in self._collators: | |||
# if isinstance(collator, AutoCollator): | |||
# collator.set_input(*field_names) | |||
# flag = False | |||
# if flag: | |||
# warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||
# return self | |||
# | |||
# | |||
# class AutoCollator(Collator): | |||
# | |||
# def __init__(self, as_numpy: bool): | |||
# super(AutoCollator, self).__init__() | |||
# self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||
# self.need_inputs = set() # 需要的 field name | |||
# self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||
# self.field_dims = None # 每列数据单元维度 | |||
# self.as_numpy = as_numpy | |||
# | |||
# def __call__(self, ins_lst: List[Dict]) -> dict: | |||
# if len(self.need_inputs) == 0: | |||
# raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||
# # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||
# | |||
# # 第一种情况,设置了 set_input 的值 | |||
# # 第二种情况, 根据数据的类型的判断是否 padding | |||
# if self.field_dtypes is None and self.field_dims is None: | |||
# field_dtypes, field_dims = {}, {} | |||
# for key, value in ins_lst[0].items(): | |||
# if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||
# field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||
# self.field_dtypes = field_dtypes | |||
# self.field_dims = field_dims | |||
# | |||
# pack_ins_lst, pad_ins_lst = {field_name: [] | |||
# for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||
# # 将 list 列表内数据按列名打包 | |||
# for per_ins in ins_lst: | |||
# for field_name, _field_content in per_ins.items(): | |||
# if field_name in self.need_inputs: | |||
# pack_ins_lst[field_name].append(_field_content) | |||
# | |||
# pad_field_kv = {field_name: 0 for field_name in self.need_inputs} | |||
# pad_field_kv.update(self.pad_field_value) | |||
# self.pad_field_value = pad_field_kv | |||
# | |||
# if len(self.pad_field_value.keys()) > 0: | |||
# # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||
# non_pad_field_names = [] | |||
# for k, v in self.pad_field_value.items(): | |||
# if v is None: | |||
# non_pad_field_names.append(k) | |||
# | |||
# # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||
# for field_name in non_pad_field_names: | |||
# field_array = pack_ins_lst.pop(field_name) | |||
# pad_ins_lst[field_name] = np.array(field_array) | |||
# | |||
# for field_name, field_array in pack_ins_lst.items(): | |||
# content = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
# self.field_dims[field_name], | |||
# self.pad_field_value[field_name], | |||
# as_numpy=self.as_numpy) | |||
# pad_ins_lst[field_name] = content | |||
# | |||
# # else: | |||
# # # 取出每列的数据,根据类型判断是否能 pad | |||
# # for field_name, field_array in pack_ins_lst.items(): | |||
# # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
# # self.field_dims[field_name], | |||
# # pad_val=0, as_numpy=self.as_numpy) | |||
# # pad_ins_lst[field_name] = pad_field_array | |||
# | |||
# return pad_ins_lst | |||
# | |||
# def set_pad_val(self, *field_names, val=0): | |||
# for field_name in field_names: | |||
# self.pad_field_value[field_name] = val | |||
# | |||
# def set_as_numpy(self, as_numpy: bool): | |||
# self.as_numpy = as_numpy | |||
# | |||
# def set_input(self, *field_names): | |||
# for field_name in field_names: | |||
# self.need_inputs.add(field_name) | |||
# | |||
# | |||
# def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||
# | |||
# if field_type: | |||
# # 不处理, 返回 np.array 类型 | |||
# if field_dim > 3: | |||
# return np.array(content) | |||
# # 元素类型为数值类型 np.int64, np.float64, int, float 等 | |||
# if isinstance(field_type, type) and \ | |||
# (issubclass(field_type, np.number) or issubclass(field_type, Number)): | |||
# if field_dim == 0: | |||
# array = np.array(content, dtype=field_type) | |||
# elif field_dim == 1: | |||
# max_len = max(map(len, content)) | |||
# array = np.full((len(content), max_len), pad_val, dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# array[i, :len(content_i)] = content_i | |||
# elif field_dim == 2: | |||
# max_len = max(map(len, content)) | |||
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
# content_i in content]) | |||
# array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# for j, content_ii in enumerate(content_i): | |||
# array[i, j, :len(content_ii)] = content_ii | |||
# else: | |||
# shape = np.shape(content) | |||
# if len(shape) == 4: # 说明各 dimension 是相同的大小 | |||
# array = np.array(content, dtype=field_type) | |||
# else: | |||
# raise RuntimeError( | |||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
# if as_numpy is False: | |||
# array = torch.tensor(array) | |||
# return array | |||
# # 元素类型为数值类型 torch.float 等 | |||
# elif str(field_type).startswith('torch'): | |||
# if field_dim == 0: | |||
# tensor = torch.tensor(content).to(field_type) | |||
# elif field_dim == 1: | |||
# max_len = max(map(len, content)) | |||
# tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# tensor[i, :len(content_i)] = content_i.clone().detach() | |||
# elif field_dim == 2: | |||
# max_len = max(map(len, content)) | |||
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
# content_i in content]) | |||
# tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
# dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# for j, content_ii in enumerate(content_i): | |||
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
# else: | |||
# shapes = set([np.shape(content_i) for content_i in content]) | |||
# if len(shapes) > 1: | |||
# raise RuntimeError( | |||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
# shape = shapes.pop() | |||
# if len(shape) == 3: | |||
# tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, | |||
# dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# tensor[i] = content_i.clone().detach().to(field_type) | |||
# else: | |||
# raise RuntimeError( | |||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
# return tensor | |||
# # TODO 增加jittor/paddle? | |||
# elif str(field_type).startswith('paddle'): | |||
# if field_dim == 0: | |||
# tensor = paddle.Tensor(content).to(field_type) | |||
# elif field_dim == 1: | |||
# max_len = max(map(len, content)) | |||
# tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# tensor[i, :len(content_i)] = content_i.clone().detach() | |||
# elif field_dim == 2: | |||
# max_len = max(map(len, content)) | |||
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
# content_i in content]) | |||
# tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
# dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# for j, content_ii in enumerate(content_i): | |||
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
# else: | |||
# shapes = set([np.shape(content_i) for content_i in content]) | |||
# if len(shapes) > 1: | |||
# raise RuntimeError( | |||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
# shape = shapes.pop() | |||
# if len(shape) == 3: | |||
# tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, | |||
# dtype=field_type) | |||
# for i, content_i in enumerate(content): | |||
# tensor[i] = content_i.clone().detach().to(field_type) | |||
# else: | |||
# raise RuntimeError( | |||
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
# return tensor | |||
# | |||
# else: | |||
# return np.array(content) # 不进行任何操作 | |||
# else: | |||
# return np.array(content) |
@@ -27,7 +27,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
:param field_name: 方便报错的。 | |||
:return: | |||
""" | |||
logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field)) | |||
logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) | |||
if pad_val is None: | |||
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | |||
return NullPadder() | |||
@@ -0,0 +1,174 @@ | |||
from inspect import isclass | |||
import numpy as np | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
numpy_to_paddle_dtype_dict = { | |||
np.bool_: 'bool', | |||
np.uint8: 'uint8', | |||
np.int8: "int8", | |||
np.int16: "int16", | |||
np.int32: "int32", | |||
np.int64: "int64", | |||
np.float16: "float16", | |||
np.float32: 'float32', | |||
np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||
np.complex64: 'complex64', | |||
np.complex128: "complex128" | |||
} | |||
number_to_paddle_dtype_dict = { | |||
float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64 | |||
int: 'int64', | |||
bool: 'bool' | |||
} | |||
from .padder import Padder | |||
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||
from .exceptions import * | |||
def is_paddle_tensor(dtype): | |||
if not isclass(dtype) and isinstance(dtype, paddle.dtype): | |||
return True | |||
return False | |||
def is_paddle_dtype_str(dtype): | |||
try: | |||
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | |||
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | |||
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8', | |||
u'int16', u'int32', u'int64', u'uint8', u'complex64', | |||
u'complex128'}: | |||
return True | |||
except: | |||
pass | |||
return False | |||
def _get_dtype(ele_dtype, dtype, class_name): | |||
if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | |||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") | |||
if dtype is not None: | |||
if not (is_paddle_tensor(dtype) or is_number(dtype) or is_paddle_dtype_str(dtype)): | |||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||
f"or paddle.dtype but get `{dtype}`.") | |||
dtype = number_to_paddle_dtype_dict.get(dtype, dtype) | |||
else: | |||
if (is_number(ele_dtype) or is_paddle_tensor(ele_dtype)): | |||
ele_dtype = number_to_paddle_dtype_dict.get(ele_dtype, ele_dtype) | |||
dtype = ele_dtype | |||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype.type) | |||
elif is_numpy_generic_class(ele_dtype): | |||
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) | |||
else: | |||
dtype == ele_dtype | |||
return dtype | |||
class paddleNumberPadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@staticmethod | |||
def pad(batch_field, pad_val, dtype): | |||
return paddle.to_tensor(batch_field, dtype=dtype) | |||
class paddleSequencePadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@staticmethod | |||
def pad(batch_field, pad_val, dtype): | |||
tensor = get_padded_paddle_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||
return tensor | |||
class paddleTensorPadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
""" | |||
目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的 | |||
:param ele_dtype: | |||
:param pad_val: | |||
:param dtype: | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@staticmethod | |||
def pad(batch_field, pad_val, dtype): | |||
shapes = [field.shape for field in batch_field] | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if isinstance(dtype, np.dtype): | |||
print(dtype) | |||
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
for i, field in enumerate(batch_field): | |||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
if isinstance(field, np.ndarray): | |||
field = paddle.to_tensor(field) | |||
tensor[slices] = field | |||
return tensor | |||
def fill_tensor(batch_field, padded_batch, dtype): | |||
""" | |||
将 batch_field 中的值填入到 tensor 中。 | |||
:param batch_field: 需要填充进入 array 中的内容 | |||
:param padded_batch: 待填充的 tensor | |||
:param dtype: 数据的类别 | |||
:return: | |||
""" | |||
if padded_batch.ndim == 2: | |||
for i, content_i in enumerate(batch_field): | |||
padded_batch[i, :len(content_i)] = paddle.Tensor(content_i, dtype=dtype) | |||
elif padded_batch.ndim == 3: | |||
for i, content_i in enumerate(batch_field): | |||
for j, content_ii in enumerate(content_i): | |||
padded_batch[i, j, :len(content_ii)] = paddle.Tensor(content_ii, dtype=dtype) | |||
elif padded_batch.ndim == 4: | |||
try: # 应该是图像,所以直接应该就 ok 了。 | |||
padded_batch = np.array(batch_field) | |||
except: | |||
for i, content_i in enumerate(batch_field): | |||
for j, content_ii in enumerate(content_i): | |||
for k, content_iii in enumerate(content_ii): | |||
padded_batch[i, j, k, :len(content_iii)] = paddle.Tensor(content_iii, dtype=dtype) | |||
elif padded_batch.ndim == 1: | |||
padded_batch[:] = paddle.Tensor(batch_field, dtype=dtype) | |||
else: | |||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||
"report.") | |||
return padded_batch | |||
def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): | |||
""" | |||
例如: | |||
[[1,2], [3]] -> paddle.LongTensor([[1, 2], [3, 0]]) | |||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||
/4d(多为图片)。 | |||
:param dtype: 目标类别是什么 | |||
:param pad_val: pad 的 value | |||
:return: | |||
""" | |||
shapes = get_shape(batch_field) | |||
tensor = paddle.full(shapes, dtype=dtype, fill_value=pad_val) | |||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||
return tensor |
@@ -3,16 +3,17 @@ __all__ = [ | |||
'prepare_jittor_dataloader' | |||
] | |||
from typing import Callable, Optional, List | |||
from typing import Callable, Optional, List, Union | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
from jittor.dataset.utils import collate_batch | |||
from jittor.dataset import Dataset | |||
else: | |||
from fastNLP.core.dataset import DataSet as Dataset | |||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
@@ -48,7 +49,7 @@ class JittorDataLoader: | |||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | |||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||
collate_fn: Callable = None) -> None: | |||
collate_fn: Union[None, str, Callable] = "auto") -> None: | |||
""" | |||
:param dataset: 实现__getitem__和__len__的dataset | |||
@@ -66,11 +67,20 @@ class JittorDataLoader: | |||
# TODO 支持fastnlp dataset | |||
# TODO 验证支持replacesampler (以后完成) | |||
# 是否为 jittor 类型的 dataset | |||
if isinstance(dataset, FDataSet): | |||
collator = dataset.get_collator().set_as_numpy(as_numpy=True) | |||
if isinstance(collate_fn, str): | |||
if collate_fn == "auto": | |||
if isinstance(dataset, FDataSet): | |||
self._collate_fn = dataset.collator | |||
self._collate_fn.set_backend(backend="jittor") | |||
else: | |||
self._collate_fn = Collator(backend="jittor") | |||
else: | |||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
elif isinstance(collate_fn, Callable): | |||
if collate_fn is not collate_batch: | |||
self._collate_fn = collate_fn | |||
else: | |||
collator = None | |||
self._collate_fn = collate_batch | |||
self.dataset = _JittorDataset(dataset) | |||
@@ -80,17 +90,13 @@ class JittorDataLoader: | |||
if isinstance(self.dataset.dataset, Dataset): | |||
self.dataset.dataset.set_attrs(batch_size=1) | |||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | |||
self.collate_fn = collate_fn | |||
if self.collate_fn is None: | |||
self.collate_fn = collate_batch | |||
self.auto_collator = collator | |||
self.cur_batch_indices = None | |||
# self._collate_fn = _collate_fn | |||
def __iter__(self): | |||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | |||
self.collate_fn = self._collate_fn | |||
if self.cur_batch_indices is None: | |||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn, | |||
self.auto_collator))) | |||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) | |||
for indices, data in self.dataset.__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
@@ -100,30 +106,48 @@ class JittorDataLoader: | |||
return len(self.dataset) // self.dataset.batch_size | |||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||
pad_fn: Callable = None) -> "JittorDataLoader": | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
if self.auto_collator is None: | |||
self.auto_collator = AutoCollator(as_numpy=True) | |||
self.auto_collator.set_pad_val(*field_names, val=val) | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, | |||
backend=backend) | |||
return self | |||
else: | |||
raise ValueError(f"collate_fn is not fastnlp collator") | |||
def set_input(self, *field_names) -> None: | |||
def set_ignore(self, *field_names) -> "JittorDataLoader": | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Ex:: | |||
collator.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
if self.auto_collator is None: | |||
self.auto_collator = AutoCollator(as_numpy=True) | |||
self.auto_collator.set_input(*field_names) | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_ignore(*field_names) | |||
return self | |||
else: | |||
raise ValueError(f"collate_fn is not fastnlp collator") | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
@@ -6,6 +6,7 @@ __all__ = [ | |||
from typing import Callable, List, Optional, Union, Dict, Sequence | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
from paddle.io import DataLoader, Dataset | |||
from paddle.fluid.dataloader.collate import default_collate_fn | |||
@@ -13,7 +14,7 @@ else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.collators.collator import Collator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
@@ -45,7 +46,7 @@ class PaddleDataLoader(DataLoader): | |||
def __init__(self, dataset, feed_list=None, places=None, | |||
return_list: bool = True, batch_sampler=None, | |||
batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Callable = None, | |||
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | |||
@@ -60,13 +61,23 @@ class PaddleDataLoader(DataLoader): | |||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
if isinstance(dataset.dataset, FDataSet): | |||
self._collate_fn = dataset.dataset.get_collator() | |||
self._collate_fn.set_as_numpy(as_numpy=True) | |||
if collate_fn is not None: | |||
self._collate_fn.add_collator(collate_fn) | |||
if isinstance(collate_fn, str): | |||
if collate_fn == 'auto': | |||
if isinstance(dataset.dataset, FDataSet): | |||
self._collate_fn = dataset.dataset.collator | |||
self._collate_fn.set_backend(backend="paddle") | |||
# if collate_fn is not None: | |||
# self._collate_fn.add_collator(collate_fn) | |||
else: | |||
self._collate_fn = Collator(backend="paddle") | |||
else: | |||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
elif isinstance(collate_fn, Callable): | |||
if collate_fn is not default_collate_fn: | |||
self._collate_fn = collate_fn | |||
else: | |||
self._collate_fn = _MultiCollator(collate_fn) | |||
self._collate_fn = default_collate_fn | |||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | |||
# if collate_fn is not None: | |||
# _collate_fn.add_collator(collate_fn) | |||
@@ -75,64 +86,56 @@ class PaddleDataLoader(DataLoader): | |||
def __iter__(self): | |||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
if len(self._collate_fn.get_collators()) == 0: | |||
self._collate_fn.add_collator(default_collate_fn) | |||
# self._collate_fn = default_collate_fn | |||
# if len(self._collate_fn.get_collators()) == 0: | |||
# self._collate_fn.add_collator(default_collate_fn) | |||
# self._collate_fn = default_collate_fn | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
return self.dataset.__getattr__(item) | |||
except AttributeError as e: | |||
raise e | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
for field_name in field_names: | |||
self._collate_fn.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
self._collate_fn.set_input(*field_names) | |||
def set_collator(self, collator: Callable) -> None: | |||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||
pad_fn: Callable = None) -> "PaddleDataLoader": | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collator: 用户自定义的Callable函数 | |||
:return: | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
self._collate_fn = _MultiCollator(collator) | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, | |||
backend=backend) | |||
return self | |||
else: | |||
raise ValueError(f"collate_fn is not fastnlp collator") | |||
def add_collator(self, collator) -> None: | |||
def set_ignore(self, *field_names) -> "PaddleDataLoader": | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collator: | |||
:return: | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Ex:: | |||
collator.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
self._collate_fn.add_collator(collator) | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_ignore(*field_names) | |||
return self | |||
else: | |||
raise ValueError(f"collate_fn is not fastnlp collator") | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
@@ -144,20 +147,21 @@ class PaddleDataLoader(DataLoader): | |||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
return_list: bool = True, batch_sampler=None, | |||
train_batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Callable = None, | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False, | |||
non_train_batch_size: int = 16, | |||
input_fields: Union[List[str], str] = None)\ | |||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
if isinstance(input_fields, str): | |||
input_fields = [input_fields] | |||
return_list: bool = True, batch_sampler=None, | |||
train_batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False, | |||
non_train_batch_size: int = 16) \ | |||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
if isinstance(ds_or_db, Dataset): | |||
... | |||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
return dl | |||
elif isinstance(ds_or_db, Sequence): | |||
ds_seq = [] | |||
for ds in ds_or_db: | |||
@@ -166,7 +170,6 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
dl.set_input(*input_fields) | |||
ds_seq.append(dl) | |||
return ds_seq | |||
@@ -178,14 +181,15 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
else: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
dl.set_input(*input_fields) | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
ds_dict[name] = dl | |||
return ds_dict | |||
else: | |||
@@ -6,8 +6,7 @@ __all__ = [ | |||
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
@@ -51,11 +50,11 @@ class TorchDataLoader(DataLoader): | |||
def __init__(self, dataset, batch_size: int = 1, | |||
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: | |||
persistent_workers: bool = False, **kwargs) -> None: | |||
""" | |||
:param dataset: 实现了__getitem__和__len__的数据容器 | |||
@@ -64,7 +63,7 @@ class TorchDataLoader(DataLoader): | |||
:param sampler: sampler实例化对象 | |||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param collate_fn: [None, 'auto', callable] 对取得到的数据进行打包的callable函数 | |||
:param pin_memory: | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param timeout: | |||
@@ -73,7 +72,6 @@ class TorchDataLoader(DataLoader): | |||
:param generator: | |||
:param prefetch_factor: | |||
:param persistent_workers: | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||
""" | |||
if not isinstance(dataset, _FDataSet): | |||
dataset = _FDataSet(dataset) | |||
@@ -84,91 +82,76 @@ class TorchDataLoader(DataLoader): | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers) | |||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
self._collate_fn = dataset.dataset.get_collator() | |||
self._collate_fn.set_as_numpy(as_numpy) | |||
if collate_fn is not None and collate_fn is not default_collate: | |||
# 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | |||
self._collate_fn.add_collator(collate_fn) | |||
if isinstance(collate_fn, str): | |||
if collate_fn == 'auto': | |||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
self._collate_fn = dataset.dataset.collator | |||
self._collate_fn.set_backend(backend="torch") | |||
# if collate_fn is not None and collate_fn is not default_collate: | |||
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | |||
# self._collate_fn.add_collator(collate_fn) | |||
else: | |||
self._collate_fn = Collator(backend='torch') | |||
else: | |||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
elif isinstance(collate_fn, Callable): | |||
if collate_fn is not default_collate: | |||
self._collate_fn = collate_fn | |||
else: | |||
self._collate_fn = _MultiCollator(collate_fn) | |||
self._collate_fn = default_collate | |||
self.cur_indices_batch = None | |||
self.as_numpy = as_numpy | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
return self.dataset.__getattr__(item) | |||
except AttributeError as e: | |||
raise e | |||
def __iter__(self): | |||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
if len(self._collate_fn.get_collators()) == 0: | |||
self._collate_fn.add_collator(self.collate_fn) | |||
# if len(self._collate_fn.get_collators()) == 0: | |||
# self._collate_fn.add_collator(self.collate_fn) | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
pad_fn:Callable=None) -> "TorchDataLoader": | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
flag = False | |||
for collator in self._collate_fn.get_collators(): | |||
if isinstance(collator, AutoCollator): | |||
flag = True | |||
break | |||
if flag is False: | |||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||
for field_name in field_names: | |||
self._collate_fn.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
flag = False | |||
for collator in self._collate_fn.get_collators(): | |||
if isinstance(collator, AutoCollator): | |||
flag = True | |||
break | |||
if flag is False: | |||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||
self._collate_fn.set_input(*field_names) | |||
def set_collator(self, collator: Callable) -> None: | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collator: 用户自定义的Callable函数 | |||
:return: | |||
""" | |||
self._collate_fn = _MultiCollator(collator) | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
return self | |||
else: | |||
raise ValueError(f"collate_fn is not fastnlp collator") | |||
def add_collator(self, collator) -> None: | |||
def set_ignore(self, *field_names) -> "TorchDataLoader": | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collator: | |||
:return: | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
Ex:: | |||
collator.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
self._collate_fn.add_collator(collator) | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_ignore(*field_names) | |||
return self | |||
else: | |||
raise ValueError(f"collate_fn is not fastnlp collator") | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
@@ -183,13 +166,12 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
batch_size: int = 1, | |||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | |||
non_train_batch_size: int = 16, as_numpy: bool = False, | |||
input_fields: Union[List, str, None] = None) \ | |||
non_train_batch_size: int = 16) \ | |||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||
""" | |||
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 | |||
@@ -201,7 +183,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
:param sampler: sampler实例化对象 | |||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param collate_fn: ['auto', None, callable]对取得到的数据进行打包的callable函数 | |||
:param pin_memory: | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param timeout: | |||
@@ -212,11 +194,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
:param persistent_workers: | |||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||
:param non_train_batch_size: | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。 | |||
""" | |||
# TODO dict, sequence情况下需要提供 | |||
if isinstance(input_fields, str): | |||
input_fields = [input_fields] | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
@@ -225,9 +203,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
if input_fields: | |||
dl.set_input(*input_fields) | |||
) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
@@ -241,7 +217,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler, | |||
@@ -251,9 +227,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
if input_fields: | |||
dl_bundle[name].set_input(*input_fields) | |||
) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Sequence): | |||
@@ -267,7 +241,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
) | |||
else: | |||
dl_bundle.append( | |||
@@ -277,11 +251,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
) | |||
if input_fields: | |||
for dl in dl_bundle: | |||
dl.set_input(*input_fields) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Mapping): | |||
@@ -295,7 +266,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler, | |||
@@ -305,10 +276,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
if input_fields: | |||
dl_bundle[name].set_input(*input_fields) | |||
) | |||
return dl_bundle | |||
else: | |||
@@ -23,9 +23,8 @@ except: | |||
from .field import FieldArray | |||
from .instance import Instance | |||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.rich_progress import f_rich_progress | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
class ApplyResultException(Exception): | |||
@@ -114,7 +113,7 @@ class DataSet: | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
""" | |||
self.field_arrays = {} | |||
self.collate_fns: _MultiCollator = _MultiCollator(AutoCollator(as_numpy=False)) | |||
self._collator = Collator(backend="numpy") | |||
if data is not None: | |||
if isinstance(data, Dict): | |||
length_set = set() | |||
@@ -181,7 +180,7 @@ class DataSet: | |||
dataset = DataSet() | |||
for field_name, field in self.field_arrays.items(): | |||
dataset.add_field(field_name=field_name, fields=field.content[idx]) | |||
dataset.collate_fns = deepcopy(self.collate_fns) | |||
dataset._collator = deepcopy(self.collator) | |||
return dataset | |||
elif isinstance(idx, str): | |||
if idx not in self: | |||
@@ -193,7 +192,7 @@ class DataSet: | |||
assert isinstance(i, int), "Only int index allowed." | |||
instance = self[i] | |||
dataset.append(instance) | |||
dataset.collate_fns = deepcopy(self.collate_fns) | |||
dataset._collator = deepcopy(self.collator) | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
@@ -676,8 +675,8 @@ class DataSet: | |||
dev_set.append(self[idx]) | |||
for idx in train_indices: | |||
train_set.append(self[idx]) | |||
dev_set.collate_fns = deepcopy(self.collate_fns) | |||
train_set.collate_fns = deepcopy(self.collate_fns) | |||
dev_set._collator = deepcopy(self.collator) | |||
train_set._collator = deepcopy(self.collator) | |||
return dev_set, train_set | |||
@@ -772,63 +771,17 @@ class DataSet: | |||
df = self.to_pandas() | |||
df.to_csv(path, encoding="utf-8") | |||
def add_collate_fn(self, collate_fn: Callable) -> None: | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collate_fn: Callable的函数 | |||
:return: | |||
""" | |||
self.collate_fns.add_collator(collate_fn) | |||
def set_collate_fn(self, collate_fn: Callable) -> None: | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collate_fn: | |||
:return: | |||
""" | |||
self.collate_fns = _MultiCollator(collate_fn) | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: dataset存在的field_name | |||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||
:return: | |||
""" | |||
# TODO 不能为空 | |||
for field_name in field_names: | |||
self.collate_fns.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
# | |||
self.collate_fns.set_input(*field_names) | |||
def get_collator(self) -> _MultiCollator: | |||
""" | |||
获取dataset绑定的collate_fn,其中包括auto_collate | |||
:return: | |||
""" | |||
return self.collate_fns | |||
@deprecated() | |||
def set_target(self, *field_names) -> None: | |||
def set_ignore(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
self.collate_fns.set_input(*field_names) | |||
self.collator.set_ignore(*field_names) | |||
@property | |||
def collator(self): | |||
if self._collator is None: | |||
self._collator = Collator() | |||
return self._collator |
@@ -7,13 +7,13 @@ from collections.abc import Mapping, Callable | |||
from functools import wraps | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from fastNLP.core.dataset import Instance | |||
def is_jittor_dataset(dataset) -> bool: | |||
try: | |||
if isinstance(dataset, jt.dataset.Dataset): | |||
@@ -32,6 +32,7 @@ def jittor_collate_wraps(func, auto_collator: Callable): | |||
:param auto_collator: | |||
:return: | |||
""" | |||
@wraps(func) | |||
def wrapper(batch): | |||
if isinstance(batch[0], Instance): | |||
@@ -0,0 +1,107 @@ | |||
import numpy as np | |||
import pytest | |||
from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder | |||
from fastNLP.core.collators.padders.exceptions import DtypeError | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
@pytest.mark.paddle | |||
class TestpaddleNumberPadder: | |||
def test_run(self): | |||
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
a = [1, 2, 3] | |||
t_a = padder(a) | |||
assert isinstance(t_a, paddle.Tensor) | |||
assert (t_a == paddle.to_tensor(a, dtype='int64')).sum() == 3 | |||
@pytest.mark.paddle | |||
class TestpaddleSequencePadder: | |||
def test_run(self): | |||
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
a = [[1, 2, 3], [3]] | |||
a = padder(a) | |||
shape = a.shape | |||
assert isinstance(a, paddle.Tensor) | |||
assert tuple(shape) == (2, 3) | |||
b = paddle.to_tensor([[1, 2, 3], [3, -1, -1]], dtype='int64') | |||
assert (a == b).sum().item() == shape[0]*shape[1] | |||
def test_dtype_check(self): | |||
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
with pytest.raises(DtypeError): | |||
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||
padder = paddleSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) | |||
a = padder([[1], [2, 322]]) | |||
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | |||
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||
@pytest.mark.paddle | |||
class TestpaddleTensorPadder: | |||
def test_run(self): | |||
padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=int, pad_val=-1) | |||
a = [paddle.zeros(3), paddle.zeros(2), paddle.zeros(0)] | |||
a = padder(a) | |||
shape = a.shape | |||
assert isinstance(a, paddle.Tensor) | |||
assert tuple(shape) == (3, 3) | |||
b = paddle.to_tensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]], dtype='int64') | |||
assert (a == b).sum().item() == shape[0]*shape[1] | |||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 2))] | |||
a = padder(a) | |||
shape = a.shape | |||
assert isinstance(a, paddle.Tensor) | |||
assert tuple(shape) == (3, 3, 2) | |||
b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||
[[0, 0], [0, 0], [-1, -1]], | |||
[[0, 0], [-1, -1], [-1, -1]]]) | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 1))] | |||
a = padder(a) | |||
shape = a.shape | |||
assert isinstance(a, paddle.Tensor) | |||
assert tuple(shape) == (3, 3, 2) | |||
b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||
[[0, 0], [0, 0], [-1, -1]], | |||
[[0, -1], [-1, -1], [-1, -1]]]) | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=int, pad_val=-1) | |||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 0))] | |||
a = padder(a) | |||
shape = a.shape | |||
assert isinstance(a, paddle.Tensor) | |||
assert tuple(shape) == (3, 3, 2) | |||
b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||
[[0, 0], [0, 0], [-1, -1]], | |||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
padder = paddleTensorPadder(ele_dtype=paddle.zeros(3).dtype, dtype=None, pad_val=-1) | |||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | |||
a = padder(a) | |||
shape = a.shape | |||
assert isinstance(a, paddle.Tensor) | |||
assert tuple(shape) == (3, 3, 2) | |||
b = paddle.FloatTensor([[[0, 0], [0, 0], [0, 0]], | |||
[[0, 0], [0, 0], [-1, -1]], | |||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
def test_dtype_check(self): | |||
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
with pytest.raises(DtypeError): | |||
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
padder = paddleTensorPadder(ele_dtype=paddle.long, dtype=int, pad_val=-1) | |||
padder = paddleTensorPadder(ele_dtype=int, dtype=paddle.long, pad_val=-1) | |||
@@ -36,8 +36,8 @@ class TestJittor: | |||
""" | |||
dataset = MyDataset() | |||
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | |||
jtl.set_pad_val('x', 'y') | |||
jtl.set_input('x') | |||
# jtl.set_pad_val('x', 'y') | |||
# jtl.set_input('x') | |||
for batch in jtl: | |||
print(batch) | |||
print(jtl.get_batch_indices()) | |||
@@ -50,15 +50,17 @@ class TestJittor: | |||
""" | |||
dataset = Fdataset({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | |||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | |||
jtl.set_pad_val('x', val=-1) | |||
jtl.set_input('x', 'y') | |||
jtl.set_pad("x", -1) | |||
jtl.set_ignore("y") | |||
# jtl.set_pad_val('x', val=-1) | |||
# jtl.set_input('x', 'y') | |||
for batch in jtl: | |||
assert batch['x'].size() == (16, 4) | |||
def test_v3(self): | |||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | |||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | |||
jtl.set_input('x', 'y') | |||
# jtl.set_input('x', 'y') | |||
for batch in jtl: | |||
print(batch) | |||
@@ -2,6 +2,7 @@ import pytest | |||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.log import logger | |||
from paddle.io import Dataset, DataLoader | |||
import numpy as np | |||
import paddle | |||
@@ -11,7 +12,7 @@ class RandomDataset(Dataset): | |||
def __getitem__(self, idx): | |||
image = np.random.random((10, 5)).astype('float32') | |||
return {'image': paddle.Tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]} | |||
return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]} | |||
def __len__(self): | |||
return 10 | |||
@@ -32,23 +33,30 @@ class TestPaddle: | |||
def test_fdl_batch_indices(self): | |||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | |||
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | |||
fdl.set_input("x", "y") | |||
for batch in fdl: | |||
assert len(fdl.get_batch_indices()) == 4 | |||
print(batch) | |||
print(fdl.get_batch_indices()) | |||
def test_set_inputs_and_set_pad_val(self): | |||
logger.setLevel("DEBUG") | |||
ds = RandomDataset() | |||
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | |||
fdl.set_input('image', 'label') | |||
fdl.set_pad_val('label', val=-1) | |||
fdl.set_pad('label', -1) | |||
for batch in fdl: | |||
print(batch['image']) | |||
assert batch['image'].shape == [2, 10, 5] | |||
print(batch) | |||
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | |||
fdl1.set_input('image', 'label') | |||
fdl1.set_pad_val('image', val=None) | |||
fdl1.set_ignore('image') | |||
for batch in fdl1: | |||
assert batch['image'].shape == [4, 10, 5] | |||
print(batch) | |||
def test_v2(self): | |||
from fastNLP.core.collators import Collator | |||
logger.setLevel("DEBUG") | |||
data = [paddle.Tensor(np.random.random((10, 5)).astype('float32')), paddle.Tensor(np.random.random((10, 5)).astype('float32'))] | |||
col = Collator(backend="jittor") | |||
res = col(data) | |||
print(res) |
@@ -13,42 +13,23 @@ class TestFdl: | |||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||
# for batch in fdl: | |||
# print(batch) | |||
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) | |||
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||
# for batch in fdl1: | |||
# print(batch) | |||
def test_set_padding(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
ds.set_pad_val("x", val=-1) | |||
fdl = TorchDataLoader(ds, batch_size=3) | |||
fdl.set_input("x", "y") | |||
fdl.set_pad_val("x", val=None) | |||
fdl.set_pad("x", -1) | |||
for batch in fdl: | |||
print(batch) | |||
# fdl.set_pad_val("x", val=-2) | |||
# for batch in fdl: | |||
# print(batch) | |||
def test_add_collator(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
def collate_fn(ins_list): | |||
_dict = {"Y": []} | |||
for ins in ins_list: | |||
_dict["Y"].append(ins['y']) | |||
return _dict | |||
fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) | |||
fdl.set_input("x", "y") | |||
# fdl.set_pad_val("x", val=None) | |||
fdl.add_collator(collate_fn) | |||
for batch in fdl: | |||
print(batch) | |||
def test_get_batch_indices(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | |||
fdl.set_input("y", "x") | |||
for batch in fdl: | |||
print(fdl.get_batch_indices()) | |||