Browse Source

添加了collators

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
ce1b837d13
2 changed files with 384 additions and 0 deletions
  1. +5
    -0
      fastNLP/core/collators/__init__.py
  2. +379
    -0
      fastNLP/core/collators/collator.py

+ 5
- 0
fastNLP/core/collators/__init__.py View File

@@ -0,0 +1,5 @@
__all__ = [
'AutoCollator',
'Collator'
]
from .collator import AutoCollator, Collator

+ 379
- 0
fastNLP/core/collators/collator.py View File

@@ -0,0 +1,379 @@
__all__ = [
'AutoCollator',
'Collator',
]


from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Callable, Union
from numbers import Number
import warnings

import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH

if _NEED_IMPORT_PADDLE:
import paddle

if _NEED_IMPORT_TORCH:
import torch


class ApplyResultException(Exception):
def __init__(self, msg, index=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了


class SetInputOrTargetException(Exception):
def __init__(self, msg, index=None, field_name=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
self.field_name = field_name # 标示当前 field 的名称


def _get_ele_type_and_dim(cell: Any, dim=0):
r"""
识别cell的类别与dimension的数量

numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
:param cell:
:param dim:
:return:
"""
if isinstance(cell, (str, Number, np.bool_)):
if hasattr(cell, 'dtype'):
return cell.dtype.type, dim
return type(cell), dim

elif isinstance(cell, list):
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i, j in res])
dims = set([j for i, j in res])
if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.")
if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()

elif isinstance(cell, torch.Tensor):
return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0

elif isinstance(cell, paddle.Tensor):
return cell.dtype, cell.dim() + dim

elif isinstance(cell, np.ndarray):
if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了
return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等
# 否则需要继续往下 iterate
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i, j in res])
dims = set([j for i, j in res])
if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.")
if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()

else: # 包含 tuple, set, dict 以及其它的类型
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")


def _get_ds_type_dim(ds: dict):
# 获取数据集第一行的 field 内部函数的类型和维度
field_dtype, field_dim = {}, {}
for field_name, field_content in ds.items():
type_0, dim_0 = _get_ele_type_and_dim(field_content)
field_dtype[field_name], field_dim[field_name] = type_0, dim_0
return field_dtype, field_dim


class Collator(metaclass=ABCMeta):
r"""
辅助DataLoader管理collate_fn的类

"""

def __init__(self):
super(Collator, self).__init__()
self.collate_fn = []

@abstractmethod
def __call__(self, ins_lst: List) -> Any:
raise NotImplementedError

@abstractmethod
def set_pad_val(self, *field_names: str, value=0):
raise NotImplementedError


class _MultiCollator:
"""
管理所有collator的容器,
遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。
"""

def __init__(self, collate_fns: Union[Callable, List[Callable], None]):

if collate_fns is None:
collate_fns = []

if isinstance(collate_fns, Callable):
collate_fns = [collate_fns]

self._collators: list = collate_fns

def __call__(self, ins_lst) -> Dict:
out, list_out = {}, []
for idx, _collate_fn in enumerate(self._collators):
res = _collate_fn(ins_lst)
if isinstance(res, Dict):
out.update(res)
else:
list_out.append(res)
# else:
# raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict")
if len(out) > 0 and len(list_out) > 0:
raise ValueError("the return of collate_fns is not the same, must be dict or list")
if len(list_out) == 1:
list_out = list_out[-1]
# print(list_out)
return out if len(out) > 0 else list_out

def get_collators(self):
return self._collators

def add_collator(self, collator: Callable):
self._collators.append(collator)

def set_as_numpy(self, as_numpy: bool):
"""
存在AutoCollator时,as_numpy控制其返回值的类型

:param as_numpy:
:return:
"""
for collator in self._collators:
if isinstance(collator, AutoCollator):
collator.set_as_numpy(as_numpy)
return self

def set_pad_val(self, *field_names, val=0):
"""
存在AutoCollator时,设置field_name的padding值

:param field_names: 数据集的field名
:param val: padding的值
:return:
"""
flag = True
for collator in self._collators:
if isinstance(collator, AutoCollator):
collator.set_pad_val(*field_names, val=val)
flag = False
if flag:
warnings.warn("AutoCollator is remove, set_padding is unavailable!!")
return self

def set_input(self, *field_names):
"""
设置AutoCollator需要的field_names,未被设置默认过滤掉

:param field_names:
:return:
"""
flag = True
for collator in self._collators:
if isinstance(collator, AutoCollator):
collator.set_input(*field_names)
flag = False
if flag:
warnings.warn("AutoCollator is remove, set_input is unavailable!!")
return self


class AutoCollator(Collator):

def __init__(self, as_numpy: bool):
super(AutoCollator, self).__init__()
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0
self.need_inputs = [] # 需要的 field name
self.field_dtypes = None # 每列数据单元的 dtype 类型
self.field_dims = None # 每列数据单元维度
self.as_numpy = as_numpy

def __call__(self, ins_lst: List[Dict]) -> dict:
if len(self.need_inputs) == 0:
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"})
# 第一种情况,设置了 set_input 的值
# 第二种情况, 根据数据的类型的判断是否 padding
if self.field_dtypes is None and self.field_dims is None:
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0])

pack_ins_lst, pad_ins_lst = {field_name: []
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {}
# 将 list 列表内数据按列名打包
for per_ins in ins_lst:
for field_name, _field_content in per_ins.items():
if field_name in self.need_inputs:
pack_ins_lst[field_name].append(_field_content)

pad_field_kv = {field_name: 0 for field_name in self.need_inputs}
pad_field_kv.update(self.pad_field_value)
self.pad_field_value = pad_field_kv

if len(self.pad_field_value.keys()) > 0:
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略
drop_field_names = []
for k, v in self.pad_field_value.items():
if v is None:
drop_field_names.append(k)

# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields))
for field_name in drop_field_names:
field_array = pack_ins_lst.pop(field_name)
pad_ins_lst[field_name] = np.array(field_array)

for field_name, field_array in pack_ins_lst.items():
content = pad_content(field_array, field_name, self.field_dtypes[field_name],
self.field_dims[field_name],
self.pad_field_value[field_name],
as_numpy=self.as_numpy)
pad_ins_lst[field_name] = content

# else:
# # 取出每列的数据,根据类型判断是否能 pad
# for field_name, field_array in pack_ins_lst.items():
# pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name],
# self.field_dims[field_name],
# pad_val=0, as_numpy=self.as_numpy)
# pad_ins_lst[field_name] = pad_field_array

return pad_ins_lst

def set_pad_val(self, *field_names, val=0):
for field_name in field_names:
self.pad_field_value[field_name] = val

def set_as_numpy(self, as_numpy: bool):
self.as_numpy = as_numpy

def set_input(self, *field_names):
for field_name in field_names:
self.need_inputs.append(field_name)


def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool):

if field_type:
# 不处理, 返回 np.array 类型
if field_dim > 3:
return np.array(content)
# 元素类型为数值类型 np.int64, np.float64, int, float 等
if isinstance(field_type, type) and \
(issubclass(field_type, np.number) or issubclass(field_type, Number)):
if field_dim == 0:
array = np.array(content, dtype=field_type)
elif field_dim == 1:
max_len = max(map(len, content))
array = np.full((len(content), max_len), pad_val, dtype=field_type)
for i, content_i in enumerate(content):
array[i, :len(content_i)] = content_i
elif field_dim == 2:
max_len = max(map(len, content))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in content])
array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type)
for i, content_i in enumerate(content):
for j, content_ii in enumerate(content_i):
array[i, j, :len(content_ii)] = content_ii
else:
shape = np.shape(content)
if len(shape) == 4: # 说明各 dimension 是相同的大小
array = np.array(content, dtype=field_type)
else:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
if as_numpy is False:
array = torch.tensor(array)
return array
# 元素类型为数值类型 torch.float 等
elif str(field_type).startswith('torch'):
if field_dim == 0:
tensor = torch.tensor(content).to(field_type)
elif field_dim == 1:
max_len = max(map(len, content))
tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
for i, content_i in enumerate(content):
tensor[i, :len(content_i)] = content_i.clone().detach()
elif field_dim == 2:
max_len = max(map(len, content))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in content])
tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
for j, content_ii in enumerate(content_i):
tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
else:
shapes = set([np.shape(content_i) for content_i in content])
if len(shapes) > 1:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
shape = shapes.pop()
if len(shape) == 3:
tensor = torch.full([len(content)] + list(shape), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
tensor[i] = content_i.clone().detach().to(field_type)
else:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return tensor
# TODO 增加jittor/paddle?
elif str(field_type).startswith('paddle'):
if field_dim == 0:
tensor = paddle.Tensor(content).to(field_type)
elif field_dim == 1:
max_len = max(map(len, content))
tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
for i, content_i in enumerate(content):
tensor[i, :len(content_i)] = content_i.clone().detach()
elif field_dim == 2:
max_len = max(map(len, content))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in content])
tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
for j, content_ii in enumerate(content_i):
tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
else:
shapes = set([np.shape(content_i) for content_i in content])
if len(shapes) > 1:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
shape = shapes.pop()
if len(shape) == 3:
tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val,
dtype=field_type)
for i, content_i in enumerate(content):
tensor[i] = content_i.clone().detach().to(field_type)
else:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return tensor

else:
return np.array(content) # 不进行任何操作
else:
return np.array(content)

Loading…
Cancel
Save