|
@@ -0,0 +1,379 @@ |
|
|
|
|
|
__all__ = [ |
|
|
|
|
|
'AutoCollator', |
|
|
|
|
|
'Collator', |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABCMeta, abstractmethod |
|
|
|
|
|
from typing import Any, Dict, List, Callable, Union |
|
|
|
|
|
from numbers import Number |
|
|
|
|
|
import warnings |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH |
|
|
|
|
|
|
|
|
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
|
|
|
import paddle |
|
|
|
|
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ApplyResultException(Exception): |
|
|
|
|
|
def __init__(self, msg, index=None): |
|
|
|
|
|
super().__init__(msg) |
|
|
|
|
|
self.msg = msg |
|
|
|
|
|
self.index = index # 标示在哪个数据遭遇到问题了 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SetInputOrTargetException(Exception): |
|
|
|
|
|
def __init__(self, msg, index=None, field_name=None): |
|
|
|
|
|
super().__init__(msg) |
|
|
|
|
|
self.msg = msg |
|
|
|
|
|
self.index = index # 标示在哪个数据遭遇到问题了 |
|
|
|
|
|
self.field_name = field_name # 标示当前 field 的名称 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_ele_type_and_dim(cell: Any, dim=0): |
|
|
|
|
|
r""" |
|
|
|
|
|
识别cell的类别与dimension的数量 |
|
|
|
|
|
|
|
|
|
|
|
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html |
|
|
|
|
|
:param cell: |
|
|
|
|
|
:param dim: |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
if isinstance(cell, (str, Number, np.bool_)): |
|
|
|
|
|
if hasattr(cell, 'dtype'): |
|
|
|
|
|
return cell.dtype.type, dim |
|
|
|
|
|
return type(cell), dim |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(cell, list): |
|
|
|
|
|
dim += 1 |
|
|
|
|
|
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] |
|
|
|
|
|
types = set([i for i, j in res]) |
|
|
|
|
|
dims = set([j for i, j in res]) |
|
|
|
|
|
if len(types) > 1: |
|
|
|
|
|
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) |
|
|
|
|
|
elif len(types) == 0: |
|
|
|
|
|
raise SetInputOrTargetException("Empty value encountered.") |
|
|
|
|
|
if len(dims) > 1: |
|
|
|
|
|
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) |
|
|
|
|
|
return types.pop(), dims.pop() |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(cell, torch.Tensor): |
|
|
|
|
|
return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(cell, paddle.Tensor): |
|
|
|
|
|
return cell.dtype, cell.dim() + dim |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(cell, np.ndarray): |
|
|
|
|
|
if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 |
|
|
|
|
|
return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 |
|
|
|
|
|
# 否则需要继续往下 iterate |
|
|
|
|
|
dim += 1 |
|
|
|
|
|
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] |
|
|
|
|
|
types = set([i for i, j in res]) |
|
|
|
|
|
dims = set([j for i, j in res]) |
|
|
|
|
|
if len(types) > 1: |
|
|
|
|
|
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) |
|
|
|
|
|
elif len(types) == 0: |
|
|
|
|
|
raise SetInputOrTargetException("Empty value encountered.") |
|
|
|
|
|
if len(dims) > 1: |
|
|
|
|
|
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) |
|
|
|
|
|
return types.pop(), dims.pop() |
|
|
|
|
|
|
|
|
|
|
|
else: # 包含 tuple, set, dict 以及其它的类型 |
|
|
|
|
|
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_ds_type_dim(ds: dict): |
|
|
|
|
|
# 获取数据集第一行的 field 内部函数的类型和维度 |
|
|
|
|
|
field_dtype, field_dim = {}, {} |
|
|
|
|
|
for field_name, field_content in ds.items(): |
|
|
|
|
|
type_0, dim_0 = _get_ele_type_and_dim(field_content) |
|
|
|
|
|
field_dtype[field_name], field_dim[field_name] = type_0, dim_0 |
|
|
|
|
|
return field_dtype, field_dim |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Collator(metaclass=ABCMeta): |
|
|
|
|
|
r""" |
|
|
|
|
|
辅助DataLoader管理collate_fn的类 |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
super(Collator, self).__init__() |
|
|
|
|
|
self.collate_fn = [] |
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
|
|
def __call__(self, ins_lst: List) -> Any: |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
|
|
def set_pad_val(self, *field_names: str, value=0): |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MultiCollator: |
|
|
|
|
|
""" |
|
|
|
|
|
管理所有collator的容器, |
|
|
|
|
|
遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, collate_fns: Union[Callable, List[Callable], None]): |
|
|
|
|
|
|
|
|
|
|
|
if collate_fns is None: |
|
|
|
|
|
collate_fns = [] |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(collate_fns, Callable): |
|
|
|
|
|
collate_fns = [collate_fns] |
|
|
|
|
|
|
|
|
|
|
|
self._collators: list = collate_fns |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, ins_lst) -> Dict: |
|
|
|
|
|
out, list_out = {}, [] |
|
|
|
|
|
for idx, _collate_fn in enumerate(self._collators): |
|
|
|
|
|
res = _collate_fn(ins_lst) |
|
|
|
|
|
if isinstance(res, Dict): |
|
|
|
|
|
out.update(res) |
|
|
|
|
|
else: |
|
|
|
|
|
list_out.append(res) |
|
|
|
|
|
# else: |
|
|
|
|
|
# raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") |
|
|
|
|
|
if len(out) > 0 and len(list_out) > 0: |
|
|
|
|
|
raise ValueError("the return of collate_fns is not the same, must be dict or list") |
|
|
|
|
|
if len(list_out) == 1: |
|
|
|
|
|
list_out = list_out[-1] |
|
|
|
|
|
# print(list_out) |
|
|
|
|
|
return out if len(out) > 0 else list_out |
|
|
|
|
|
|
|
|
|
|
|
def get_collators(self): |
|
|
|
|
|
return self._collators |
|
|
|
|
|
|
|
|
|
|
|
def add_collator(self, collator: Callable): |
|
|
|
|
|
self._collators.append(collator) |
|
|
|
|
|
|
|
|
|
|
|
def set_as_numpy(self, as_numpy: bool): |
|
|
|
|
|
""" |
|
|
|
|
|
存在AutoCollator时,as_numpy控制其返回值的类型 |
|
|
|
|
|
|
|
|
|
|
|
:param as_numpy: |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
for collator in self._collators: |
|
|
|
|
|
if isinstance(collator, AutoCollator): |
|
|
|
|
|
collator.set_as_numpy(as_numpy) |
|
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
def set_pad_val(self, *field_names, val=0): |
|
|
|
|
|
""" |
|
|
|
|
|
存在AutoCollator时,设置field_name的padding值 |
|
|
|
|
|
|
|
|
|
|
|
:param field_names: 数据集的field名 |
|
|
|
|
|
:param val: padding的值 |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
flag = True |
|
|
|
|
|
for collator in self._collators: |
|
|
|
|
|
if isinstance(collator, AutoCollator): |
|
|
|
|
|
collator.set_pad_val(*field_names, val=val) |
|
|
|
|
|
flag = False |
|
|
|
|
|
if flag: |
|
|
|
|
|
warnings.warn("AutoCollator is remove, set_padding is unavailable!!") |
|
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
def set_input(self, *field_names): |
|
|
|
|
|
""" |
|
|
|
|
|
设置AutoCollator需要的field_names,未被设置默认过滤掉 |
|
|
|
|
|
|
|
|
|
|
|
:param field_names: |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
flag = True |
|
|
|
|
|
for collator in self._collators: |
|
|
|
|
|
if isinstance(collator, AutoCollator): |
|
|
|
|
|
collator.set_input(*field_names) |
|
|
|
|
|
flag = False |
|
|
|
|
|
if flag: |
|
|
|
|
|
warnings.warn("AutoCollator is remove, set_input is unavailable!!") |
|
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AutoCollator(Collator): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, as_numpy: bool): |
|
|
|
|
|
super(AutoCollator, self).__init__() |
|
|
|
|
|
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 |
|
|
|
|
|
self.need_inputs = [] # 需要的 field name |
|
|
|
|
|
self.field_dtypes = None # 每列数据单元的 dtype 类型 |
|
|
|
|
|
self.field_dims = None # 每列数据单元维度 |
|
|
|
|
|
self.as_numpy = as_numpy |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, ins_lst: List[Dict]) -> dict: |
|
|
|
|
|
if len(self.need_inputs) == 0: |
|
|
|
|
|
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) |
|
|
|
|
|
# 第一种情况,设置了 set_input 的值 |
|
|
|
|
|
# 第二种情况, 根据数据的类型的判断是否 padding |
|
|
|
|
|
if self.field_dtypes is None and self.field_dims is None: |
|
|
|
|
|
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) |
|
|
|
|
|
|
|
|
|
|
|
pack_ins_lst, pad_ins_lst = {field_name: [] |
|
|
|
|
|
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} |
|
|
|
|
|
# 将 list 列表内数据按列名打包 |
|
|
|
|
|
for per_ins in ins_lst: |
|
|
|
|
|
for field_name, _field_content in per_ins.items(): |
|
|
|
|
|
if field_name in self.need_inputs: |
|
|
|
|
|
pack_ins_lst[field_name].append(_field_content) |
|
|
|
|
|
|
|
|
|
|
|
pad_field_kv = {field_name: 0 for field_name in self.need_inputs} |
|
|
|
|
|
pad_field_kv.update(self.pad_field_value) |
|
|
|
|
|
self.pad_field_value = pad_field_kv |
|
|
|
|
|
|
|
|
|
|
|
if len(self.pad_field_value.keys()) > 0: |
|
|
|
|
|
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 |
|
|
|
|
|
drop_field_names = [] |
|
|
|
|
|
for k, v in self.pad_field_value.items(): |
|
|
|
|
|
if v is None: |
|
|
|
|
|
drop_field_names.append(k) |
|
|
|
|
|
|
|
|
|
|
|
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) |
|
|
|
|
|
for field_name in drop_field_names: |
|
|
|
|
|
field_array = pack_ins_lst.pop(field_name) |
|
|
|
|
|
pad_ins_lst[field_name] = np.array(field_array) |
|
|
|
|
|
|
|
|
|
|
|
for field_name, field_array in pack_ins_lst.items(): |
|
|
|
|
|
content = pad_content(field_array, field_name, self.field_dtypes[field_name], |
|
|
|
|
|
self.field_dims[field_name], |
|
|
|
|
|
self.pad_field_value[field_name], |
|
|
|
|
|
as_numpy=self.as_numpy) |
|
|
|
|
|
pad_ins_lst[field_name] = content |
|
|
|
|
|
|
|
|
|
|
|
# else: |
|
|
|
|
|
# # 取出每列的数据,根据类型判断是否能 pad |
|
|
|
|
|
# for field_name, field_array in pack_ins_lst.items(): |
|
|
|
|
|
# pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], |
|
|
|
|
|
# self.field_dims[field_name], |
|
|
|
|
|
# pad_val=0, as_numpy=self.as_numpy) |
|
|
|
|
|
# pad_ins_lst[field_name] = pad_field_array |
|
|
|
|
|
|
|
|
|
|
|
return pad_ins_lst |
|
|
|
|
|
|
|
|
|
|
|
def set_pad_val(self, *field_names, val=0): |
|
|
|
|
|
for field_name in field_names: |
|
|
|
|
|
self.pad_field_value[field_name] = val |
|
|
|
|
|
|
|
|
|
|
|
def set_as_numpy(self, as_numpy: bool): |
|
|
|
|
|
self.as_numpy = as_numpy |
|
|
|
|
|
|
|
|
|
|
|
def set_input(self, *field_names): |
|
|
|
|
|
for field_name in field_names: |
|
|
|
|
|
self.need_inputs.append(field_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): |
|
|
|
|
|
|
|
|
|
|
|
if field_type: |
|
|
|
|
|
# 不处理, 返回 np.array 类型 |
|
|
|
|
|
if field_dim > 3: |
|
|
|
|
|
return np.array(content) |
|
|
|
|
|
# 元素类型为数值类型 np.int64, np.float64, int, float 等 |
|
|
|
|
|
if isinstance(field_type, type) and \ |
|
|
|
|
|
(issubclass(field_type, np.number) or issubclass(field_type, Number)): |
|
|
|
|
|
if field_dim == 0: |
|
|
|
|
|
array = np.array(content, dtype=field_type) |
|
|
|
|
|
elif field_dim == 1: |
|
|
|
|
|
max_len = max(map(len, content)) |
|
|
|
|
|
array = np.full((len(content), max_len), pad_val, dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
array[i, :len(content_i)] = content_i |
|
|
|
|
|
elif field_dim == 2: |
|
|
|
|
|
max_len = max(map(len, content)) |
|
|
|
|
|
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for |
|
|
|
|
|
content_i in content]) |
|
|
|
|
|
array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
for j, content_ii in enumerate(content_i): |
|
|
|
|
|
array[i, j, :len(content_ii)] = content_ii |
|
|
|
|
|
else: |
|
|
|
|
|
shape = np.shape(content) |
|
|
|
|
|
if len(shape) == 4: # 说明各 dimension 是相同的大小 |
|
|
|
|
|
array = np.array(content, dtype=field_type) |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
|
|
if as_numpy is False: |
|
|
|
|
|
array = torch.tensor(array) |
|
|
|
|
|
return array |
|
|
|
|
|
# 元素类型为数值类型 torch.float 等 |
|
|
|
|
|
elif str(field_type).startswith('torch'): |
|
|
|
|
|
if field_dim == 0: |
|
|
|
|
|
tensor = torch.tensor(content).to(field_type) |
|
|
|
|
|
elif field_dim == 1: |
|
|
|
|
|
max_len = max(map(len, content)) |
|
|
|
|
|
tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
tensor[i, :len(content_i)] = content_i.clone().detach() |
|
|
|
|
|
elif field_dim == 2: |
|
|
|
|
|
max_len = max(map(len, content)) |
|
|
|
|
|
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for |
|
|
|
|
|
content_i in content]) |
|
|
|
|
|
tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, |
|
|
|
|
|
dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
for j, content_ii in enumerate(content_i): |
|
|
|
|
|
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() |
|
|
|
|
|
else: |
|
|
|
|
|
shapes = set([np.shape(content_i) for content_i in content]) |
|
|
|
|
|
if len(shapes) > 1: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
|
|
shape = shapes.pop() |
|
|
|
|
|
if len(shape) == 3: |
|
|
|
|
|
tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, |
|
|
|
|
|
dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
tensor[i] = content_i.clone().detach().to(field_type) |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
|
|
return tensor |
|
|
|
|
|
# TODO 增加jittor/paddle? |
|
|
|
|
|
elif str(field_type).startswith('paddle'): |
|
|
|
|
|
if field_dim == 0: |
|
|
|
|
|
tensor = paddle.Tensor(content).to(field_type) |
|
|
|
|
|
elif field_dim == 1: |
|
|
|
|
|
max_len = max(map(len, content)) |
|
|
|
|
|
tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
tensor[i, :len(content_i)] = content_i.clone().detach() |
|
|
|
|
|
elif field_dim == 2: |
|
|
|
|
|
max_len = max(map(len, content)) |
|
|
|
|
|
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for |
|
|
|
|
|
content_i in content]) |
|
|
|
|
|
tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, |
|
|
|
|
|
dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
for j, content_ii in enumerate(content_i): |
|
|
|
|
|
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() |
|
|
|
|
|
else: |
|
|
|
|
|
shapes = set([np.shape(content_i) for content_i in content]) |
|
|
|
|
|
if len(shapes) > 1: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
|
|
shape = shapes.pop() |
|
|
|
|
|
if len(shape) == 3: |
|
|
|
|
|
tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, |
|
|
|
|
|
dtype=field_type) |
|
|
|
|
|
for i, content_i in enumerate(content): |
|
|
|
|
|
tensor[i] = content_i.clone().detach().to(field_type) |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
|
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
return np.array(content) # 不进行任何操作 |
|
|
|
|
|
else: |
|
|
|
|
|
return np.array(content) |