|
|
@@ -1,251 +1,162 @@ |
|
|
|
""" |
|
|
|
field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fastNLP.DataSet` 中一列的存储方式, |
|
|
|
原理部分请参考 :doc:`fastNLP.core.dataset` |
|
|
|
|
|
|
|
""" |
|
|
|
__all__ = [ |
|
|
|
"FieldArray", |
|
|
|
"Padder", |
|
|
|
"AutoPadder", |
|
|
|
"EngChar2DPadder" |
|
|
|
] |
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
from numbers import Number |
|
|
|
import torch |
|
|
|
import numpy as np |
|
|
|
from typing import Any |
|
|
|
from abc import abstractmethod |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
|
|
|
class FieldArray(object): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.FieldArray` :class:`fastNLP.core.field.FieldArray` |
|
|
|
|
|
|
|
FieldArray 是用于保存 :class:`~fastNLP.DataSet` 中一个field的类型。 |
|
|
|
|
|
|
|
:param str name: FieldArray的名称 |
|
|
|
:param list,numpy.ndarray content: 列表的元素可以为list,int,float, |
|
|
|
:param bool is_target: 这个field是否是一个target field。 |
|
|
|
:param bool is_input: 这个field是否是一个input field。 |
|
|
|
:param padder: :class:`~fastNLP.Padder` 类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 |
|
|
|
fieldarray.set_pad_val()。默认为None,即使用 :class:`~fastNLP.AutoPadder` 。 |
|
|
|
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, |
|
|
|
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): |
|
|
|
class SetInputOrTargetException(Exception): |
|
|
|
def __init__(self, msg, index=None, field_name=None): |
|
|
|
super().__init__(msg) |
|
|
|
self.msg = msg |
|
|
|
self.index = index # 标示在哪个数据遭遇到问题了 |
|
|
|
self.field_name = field_name # 标示当前field的名称 |
|
|
|
|
|
|
|
class AppendToTargetOrInputException(Exception): |
|
|
|
def __init__(self, msg, index=None, field_name=None): |
|
|
|
super().__init__(msg) |
|
|
|
self.msg = msg |
|
|
|
self.index = index # 标示在哪个数据遭遇到问题了 |
|
|
|
self.field_name = field_name # 标示当前field的名称 |
|
|
|
|
|
|
|
class FieldArray: |
|
|
|
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): |
|
|
|
if len(content)==0: |
|
|
|
raise RuntimeError("Empty fieldarray is not allowed.") |
|
|
|
_content = content |
|
|
|
try: |
|
|
|
_content = list(_content) |
|
|
|
except BaseException as e: |
|
|
|
print(f"Cannot convert content(of type:{type(content)}) into list.") |
|
|
|
raise e |
|
|
|
self.name = name |
|
|
|
if isinstance(content, list): |
|
|
|
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list |
|
|
|
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] |
|
|
|
for idx, item in enumerate(content): |
|
|
|
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) |
|
|
|
# 将[np.array] 转化为 list of list |
|
|
|
# 也可以支持[array, array, array]的情况 |
|
|
|
if isinstance(item, np.ndarray): |
|
|
|
content[idx] = content[idx].tolist() |
|
|
|
elif isinstance(content, np.ndarray): |
|
|
|
content = content.tolist() # convert np.ndarray into 2-D list |
|
|
|
else: |
|
|
|
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) |
|
|
|
if len(content) == 0: |
|
|
|
raise RuntimeError("Cannot initialize FieldArray with empty list.") |
|
|
|
|
|
|
|
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 |
|
|
|
self.content_dim = None # 表示content是多少维的list |
|
|
|
self.content = _content |
|
|
|
self._ignore_type = ignore_type |
|
|
|
# 根据input的情况设置input,target等 |
|
|
|
self._cell_ndim = None # 多少维度 |
|
|
|
self.dtype = None # 最内层的element都是什么类型的 |
|
|
|
self._is_input = False |
|
|
|
self._is_target = False |
|
|
|
|
|
|
|
if is_input: |
|
|
|
self.is_input = is_input |
|
|
|
if is_target: |
|
|
|
self.is_target = is_target |
|
|
|
|
|
|
|
if padder is None: |
|
|
|
padder = AutoPadder(pad_val=0) |
|
|
|
else: |
|
|
|
assert isinstance(padder, Padder), "padder must be of type Padder." |
|
|
|
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." |
|
|
|
padder = deepcopy(padder) |
|
|
|
self.set_padder(padder) |
|
|
|
self.ignore_type = ignore_type |
|
|
|
|
|
|
|
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array |
|
|
|
|
|
|
|
self.pytype = None |
|
|
|
self.dtype = None |
|
|
|
self._is_input = None |
|
|
|
self._is_target = None |
|
|
|
|
|
|
|
if is_input is not None or is_target is not None: |
|
|
|
self.is_input = is_input |
|
|
|
self.is_target = is_target |
|
|
|
|
|
|
|
def _set_dtype(self): |
|
|
|
if self.ignore_type is False: |
|
|
|
self.pytype = self._type_detection(self.content) |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def ignore_type(self): |
|
|
|
return self._ignore_type |
|
|
|
|
|
|
|
@ignore_type.setter |
|
|
|
def ignore_type(self, value): |
|
|
|
if value: |
|
|
|
self._cell_ndim = None |
|
|
|
self.dtype = None |
|
|
|
|
|
|
|
@property |
|
|
|
def is_input(self): |
|
|
|
return self._is_input |
|
|
|
|
|
|
|
|
|
|
|
@is_input.setter |
|
|
|
def is_input(self, value): |
|
|
|
""" |
|
|
|
当 field_array.is_input = True / False 时被调用 |
|
|
|
""" |
|
|
|
if value is True: |
|
|
|
self._set_dtype() |
|
|
|
# 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) |
|
|
|
if value is True and \ |
|
|
|
self._is_target is False and \ |
|
|
|
self._ignore_type is False: |
|
|
|
self._check_dtype_and_ndim() |
|
|
|
if value is False and self._is_target is False: |
|
|
|
self.dtype = None |
|
|
|
self._cell_ndim = None |
|
|
|
self._is_input = value |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def is_target(self): |
|
|
|
return self._is_target |
|
|
|
|
|
|
|
|
|
|
|
@is_target.setter |
|
|
|
def is_target(self, value): |
|
|
|
""" |
|
|
|
当 field_array.is_target = True / False 时被调用 |
|
|
|
""" |
|
|
|
if value is True: |
|
|
|
self._set_dtype() |
|
|
|
if value is True and \ |
|
|
|
self._is_input is False and \ |
|
|
|
self._ignore_type is False: |
|
|
|
self._check_dtype_and_ndim() |
|
|
|
if value is False and self._is_input is False: |
|
|
|
self.dtype = None |
|
|
|
self._cell_ndim = None |
|
|
|
self._is_target = value |
|
|
|
|
|
|
|
def _type_detection(self, content): |
|
|
|
""" |
|
|
|
当该field被设置为is_input或者is_target时被调用 |
|
|
|
|
|
|
|
def _check_dtype_and_ndim(self): |
|
|
|
""" |
|
|
|
if len(content) == 0: |
|
|
|
raise RuntimeError("Empty list in Field {}.".format(self.name)) |
|
|
|
|
|
|
|
type_set = set([type(item) for item in content]) |
|
|
|
|
|
|
|
if list in type_set: |
|
|
|
if len(type_set) > 1: |
|
|
|
# list 跟 非list 混在一起 |
|
|
|
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) |
|
|
|
# >1维list |
|
|
|
inner_type_set = set() |
|
|
|
for l in content: |
|
|
|
[inner_type_set.add(type(obj)) for obj in l] |
|
|
|
if list not in inner_type_set: |
|
|
|
# 二维list |
|
|
|
self.content_dim = 2 |
|
|
|
return self._basic_type_detection(inner_type_set) |
|
|
|
else: |
|
|
|
if len(inner_type_set) == 1: |
|
|
|
# >2维list |
|
|
|
inner_inner_type_set = set() |
|
|
|
for _2d_list in content: |
|
|
|
for _1d_list in _2d_list: |
|
|
|
[inner_inner_type_set.add(type(obj)) for obj in _1d_list] |
|
|
|
if list in inner_inner_type_set: |
|
|
|
raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") |
|
|
|
# 3维list |
|
|
|
self.content_dim = 3 |
|
|
|
return self._basic_type_detection(inner_inner_type_set) |
|
|
|
else: |
|
|
|
# list 跟 非list 混在一起 |
|
|
|
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) |
|
|
|
else: |
|
|
|
# 一维list |
|
|
|
for content_type in type_set: |
|
|
|
if content_type not in self.BASIC_TYPES: |
|
|
|
raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( |
|
|
|
self.name, self.BASIC_TYPES, content_type)) |
|
|
|
self.content_dim = 1 |
|
|
|
return self._basic_type_detection(type_set) |
|
|
|
|
|
|
|
def _basic_type_detection(self, type_set): |
|
|
|
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 |
|
|
|
通过将直接报错. |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
:param type_set: a set of Python types |
|
|
|
:return: one of self.BASIC_TYPES |
|
|
|
cell_0 = self.content[0] |
|
|
|
index = 0 |
|
|
|
try: |
|
|
|
type_0, dim_0 = _get_ele_type_and_dim(cell_0) |
|
|
|
for cell in self.content[1:]: |
|
|
|
index += 1 |
|
|
|
type_i, dim_i = _get_ele_type_and_dim(cell) |
|
|
|
if type_i!=type_0: |
|
|
|
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." |
|
|
|
".".format(type_i, index, type_0)) |
|
|
|
if dim_0!=dim_i: |
|
|
|
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " |
|
|
|
"dimension:{}.".format(dim_i, index, dim_0)) |
|
|
|
self._cell_ndim = dim_0 |
|
|
|
self.dtype = type_0 |
|
|
|
except SetInputOrTargetException as e: |
|
|
|
e.index = index |
|
|
|
raise e |
|
|
|
|
|
|
|
def append(self, val:Any): |
|
|
|
""" |
|
|
|
:param val: 把该val append到fieldarray。 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if len(type_set) == 1: |
|
|
|
return type_set.pop() |
|
|
|
elif len(type_set) == 2: |
|
|
|
# 有多个basic type; 可能需要up-cast |
|
|
|
if float in type_set and int in type_set: |
|
|
|
# up-cast int to float |
|
|
|
return float |
|
|
|
else: |
|
|
|
# str 跟 int 或者 float 混在一起 |
|
|
|
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) |
|
|
|
if (self._is_target or self._is_input) and self._ignore_type is False: |
|
|
|
type_, dim_ = _get_ele_type_and_dim(val) |
|
|
|
if self.dtype!=type_: |
|
|
|
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " |
|
|
|
f"previous values(type:{self.dtype}).") |
|
|
|
if self._cell_ndim!=dim_: |
|
|
|
raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " |
|
|
|
f"previous values(dim:{self._cell_ndim}).") |
|
|
|
self.content.append(val) |
|
|
|
else: |
|
|
|
# str, int, float混在一起 |
|
|
|
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) |
|
|
|
|
|
|
|
def _1d_list_check(self, val): |
|
|
|
"""如果不是1D list就报错 |
|
|
|
""" |
|
|
|
type_set = set((type(obj) for obj in val)) |
|
|
|
if any(obj not in self.BASIC_TYPES for obj in type_set): |
|
|
|
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) |
|
|
|
self._basic_type_detection(type_set) |
|
|
|
# otherwise: _basic_type_detection will raise error |
|
|
|
return True |
|
|
|
|
|
|
|
def _2d_list_check(self, val): |
|
|
|
"""如果不是2D list 就报错 |
|
|
|
""" |
|
|
|
type_set = set(type(obj) for obj in val) |
|
|
|
if list(type_set) != [list]: |
|
|
|
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) |
|
|
|
inner_type_set = set() |
|
|
|
for l in val: |
|
|
|
for obj in l: |
|
|
|
inner_type_set.add(type(obj)) |
|
|
|
self._basic_type_detection(inner_type_set) |
|
|
|
return True |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _map_to_np_type(basic_type): |
|
|
|
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} |
|
|
|
return type_mapping[basic_type] |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) |
|
|
|
|
|
|
|
def append(self, val): |
|
|
|
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 |
|
|
|
的内容是匹配的。 |
|
|
|
|
|
|
|
:param Any val: 需要append的值。 |
|
|
|
""" |
|
|
|
if self.ignore_type is False: |
|
|
|
if isinstance(val, list): |
|
|
|
pass |
|
|
|
elif isinstance(val, tuple): # 确保最外层是list |
|
|
|
val = list(val) |
|
|
|
elif isinstance(val, np.ndarray): |
|
|
|
val = val.tolist() |
|
|
|
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) |
|
|
|
|
|
|
|
if self.is_input is True or self.is_target is True: |
|
|
|
if type(val) == list: |
|
|
|
if len(val) == 0: |
|
|
|
raise ValueError("Cannot append an empty list.") |
|
|
|
if self.content_dim == 2 and self._1d_list_check(val): |
|
|
|
# 1维list检查 |
|
|
|
pass |
|
|
|
elif self.content_dim == 3 and self._2d_list_check(val): |
|
|
|
# 2维list检查 |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) |
|
|
|
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: |
|
|
|
# scalar检查 |
|
|
|
if type(val) == float and self.pytype == int: |
|
|
|
self.pytype = float |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) |
|
|
|
self.content.append(val) |
|
|
|
|
|
|
|
self.content.append(val) |
|
|
|
|
|
|
|
def __getitem__(self, indices): |
|
|
|
return self.get(indices, pad=False) |
|
|
|
|
|
|
|
|
|
|
|
def __setitem__(self, idx, val): |
|
|
|
assert isinstance(idx, int) |
|
|
|
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 |
|
|
|
type_, dim_ = _get_ele_type_and_dim(val) |
|
|
|
if self.dtype!=type_: |
|
|
|
raise RuntimeError(f"Value(type:{type_}) are of different types with " |
|
|
|
f"other values(type:{self.dtype}).") |
|
|
|
if self._cell_ndim!=dim_: |
|
|
|
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " |
|
|
|
f"previous values(dim:{self._cell_ndim}).") |
|
|
|
self.content[idx] = val |
|
|
|
|
|
|
|
|
|
|
|
def get(self, indices, pad=True): |
|
|
|
""" |
|
|
|
根据给定的indices返回内容 |
|
|
@@ -257,14 +168,14 @@ class FieldArray(object): |
|
|
|
if isinstance(indices, int): |
|
|
|
return self.content[indices] |
|
|
|
if self.is_input is False and self.is_target is False: |
|
|
|
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) |
|
|
|
|
|
|
|
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) |
|
|
|
|
|
|
|
contents = [self.content[i] for i in indices] |
|
|
|
if self.padder is None or pad is False: |
|
|
|
return np.array(contents) |
|
|
|
else: |
|
|
|
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) |
|
|
|
|
|
|
|
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) |
|
|
|
|
|
|
|
def set_padder(self, padder): |
|
|
|
""" |
|
|
|
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 |
|
|
@@ -276,7 +187,7 @@ class FieldArray(object): |
|
|
|
self.padder = deepcopy(padder) |
|
|
|
else: |
|
|
|
self.padder = None |
|
|
|
|
|
|
|
|
|
|
|
def set_pad_val(self, pad_val): |
|
|
|
""" |
|
|
|
修改padder的pad_val. |
|
|
@@ -286,7 +197,7 @@ class FieldArray(object): |
|
|
|
if self.padder is not None: |
|
|
|
self.padder.set_pad_val(pad_val) |
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
""" |
|
|
|
Returns the size of FieldArray. |
|
|
@@ -294,7 +205,7 @@ class FieldArray(object): |
|
|
|
:return int length: |
|
|
|
""" |
|
|
|
return len(self.content) |
|
|
|
|
|
|
|
|
|
|
|
def to(self, other): |
|
|
|
""" |
|
|
|
将other的属性复制给本FieldArray(other必须为FieldArray类型). |
|
|
@@ -303,22 +214,63 @@ class FieldArray(object): |
|
|
|
:param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 |
|
|
|
:return: :class:`~fastNLP.FieldArray` |
|
|
|
""" |
|
|
|
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) |
|
|
|
|
|
|
|
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) |
|
|
|
|
|
|
|
self.ignore_type = other.ignore_type |
|
|
|
self.is_input = other.is_input |
|
|
|
self.is_target = other.is_target |
|
|
|
self.padder = other.padder |
|
|
|
self.ignore_type = other.ignore_type |
|
|
|
|
|
|
|
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
def _is_iterable(content): |
|
|
|
def _get_ele_type_and_dim(cell:Any, dim=0): |
|
|
|
""" |
|
|
|
识别cell的类别与dimension的数量 |
|
|
|
|
|
|
|
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html |
|
|
|
:param cell: |
|
|
|
:param dim: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if isinstance(cell, (str, Number, np.bool_)): |
|
|
|
return type(cell), dim |
|
|
|
elif isinstance(cell, list): |
|
|
|
dim += 1 |
|
|
|
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] |
|
|
|
types = set([i for i,j in res]) |
|
|
|
dims = set([j for i,j in res]) |
|
|
|
if len(types)>1: |
|
|
|
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) |
|
|
|
if len(dims)>1: |
|
|
|
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) |
|
|
|
return types.pop(), dims.pop() |
|
|
|
elif isinstance(cell, torch.Tensor): |
|
|
|
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 |
|
|
|
elif isinstance(cell, np.ndarray): |
|
|
|
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 |
|
|
|
return cell.dtype.type, cell.ndim + dim |
|
|
|
# 否则需要继续往下iterate |
|
|
|
dim += 1 |
|
|
|
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] |
|
|
|
types = set([i for i,j in res]) |
|
|
|
dims = set([j for i,j in res]) |
|
|
|
if len(types)>1: |
|
|
|
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) |
|
|
|
if len(dims)>1: |
|
|
|
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) |
|
|
|
return types.pop(), dims.pop() |
|
|
|
else: # 包含tuple, set, dict以及其它的类型 |
|
|
|
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") |
|
|
|
|
|
|
|
|
|
|
|
def _is_iterable(value): |
|
|
|
# 检查是否是iterable的, duck typing |
|
|
|
try: |
|
|
|
_ = (e for e in content) |
|
|
|
except TypeError: |
|
|
|
iter(value) |
|
|
|
return True |
|
|
|
except BaseException as e: |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
class Padder: |
|
|
@@ -327,32 +279,35 @@ class Padder: |
|
|
|
|
|
|
|
所有padder都需要继承这个类,并覆盖__call__方法。 |
|
|
|
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 |
|
|
|
|
|
|
|
|
|
|
|
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): |
|
|
|
传入的是List内容。假设有以下的DataSet。 |
|
|
|
|
|
|
|
|
|
|
|
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 |
|
|
|
deepcopy一份。 |
|
|
|
:param str, field_name: field的名称。 |
|
|
|
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 |
|
|
|
:return: np.array([padded_element]) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, pad_val=0, **kwargs): |
|
|
|
self.pad_val = pad_val |
|
|
|
|
|
|
|
|
|
|
|
def set_pad_val(self, pad_val): |
|
|
|
self.pad_val = pad_val |
|
|
|
|
|
|
|
def __call__(self, contents, field_name, field_ele_dtype): |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def __call__(self, contents, field_name, field_ele_dtype, dim:int): |
|
|
|
""" |
|
|
|
传入的是List内容。假设有以下的DataSet。 |
|
|
|
|
|
|
|
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 |
|
|
|
deepcopy一份。 |
|
|
|
:param str, field_name: field的名称。 |
|
|
|
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 |
|
|
|
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, |
|
|
|
该这个值为None。 |
|
|
|
:param dim: 这个field的维度。当ignore_type为True时,该值为None |
|
|
|
:return: np.array([padded_element]) |
|
|
|
|
|
|
|
Example:: |
|
|
@@ -394,50 +349,87 @@ class AutoPadder(Padder): |
|
|
|
根据contents的数据自动判定是否需要做padding。 |
|
|
|
|
|
|
|
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 |
|
|
|
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad |
|
|
|
型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad |
|
|
|
|
|
|
|
2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 |
|
|
|
|
|
|
|
2 如果元素类型为(np.int64, np.float64), |
|
|
|
2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding |
|
|
|
|
|
|
|
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding |
|
|
|
2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 |
|
|
|
|
|
|
|
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 |
|
|
|
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad |
|
|
|
2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 |
|
|
|
:class: fastNLP.EngChar2DPadder. |
|
|
|
|
|
|
|
2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 |
|
|
|
的情况。 |
|
|
|
|
|
|
|
3 其它情况不进行处理,返回一个np.array类型。 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, pad_val=0): |
|
|
|
""" |
|
|
|
:param pad_val: int, padding的位置使用该index |
|
|
|
""" |
|
|
|
super().__init__(pad_val=pad_val) |
|
|
|
|
|
|
|
def _is_two_dimension(self, contents): |
|
|
|
""" |
|
|
|
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 |
|
|
|
:param contents: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
value = contents[0] |
|
|
|
if isinstance(value, (np.ndarray, list)): |
|
|
|
value = value[0] |
|
|
|
if isinstance(value, (np.ndarray, list)): |
|
|
|
return False |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def __call__(self, contents, field_name, field_ele_dtype): |
|
|
|
|
|
|
|
if not _is_iterable(contents[0]): |
|
|
|
array = np.array([content for content in contents], dtype=field_ele_dtype) |
|
|
|
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): |
|
|
|
max_len = max([len(content) for content in contents]) |
|
|
|
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) |
|
|
|
for i, content in enumerate(contents): |
|
|
|
array[i][:len(content)] = content |
|
|
|
elif field_ele_dtype is None: |
|
|
|
array = np.array(contents) # 当ignore_type=True时,直接返回contents |
|
|
|
else: # should only be str |
|
|
|
array = np.array([content for content in contents]) |
|
|
|
return array |
|
|
|
|
|
|
|
def __call__(self, contents, field_name, field_ele_dtype, dim): |
|
|
|
if field_ele_dtype: |
|
|
|
if dim>3: |
|
|
|
return np.array(contents) |
|
|
|
if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str): |
|
|
|
if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool): |
|
|
|
if dim==0: |
|
|
|
array = np.array(contents, dtype=field_ele_dtype) |
|
|
|
elif dim==1: |
|
|
|
max_len = max(map(len, contents)) |
|
|
|
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) |
|
|
|
for i, content_i in enumerate(contents): |
|
|
|
array[i, :len(content_i)] = content_i |
|
|
|
elif dim==2: |
|
|
|
max_len = max(map(len, contents)) |
|
|
|
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for |
|
|
|
content_i in contents]) |
|
|
|
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) |
|
|
|
for i, content_i in enumerate(contents): |
|
|
|
for j, content_ii in enumerate(content_i): |
|
|
|
array[i, j, :len(content_ii)] = content_ii |
|
|
|
else: |
|
|
|
shape = np.shape(contents) |
|
|
|
if len(shape)==4: # 说明各dimension是相同的大小 |
|
|
|
array = np.array(contents, dtype=field_ele_dtype) |
|
|
|
else: |
|
|
|
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
return array |
|
|
|
return np.array(contents) |
|
|
|
elif str(field_ele_dtype).startswith('torch'): |
|
|
|
if dim==0: |
|
|
|
tensor = torch.tensor(contents).to(field_ele_dtype) |
|
|
|
elif dim==1: |
|
|
|
max_len = max(map(len, contents)) |
|
|
|
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) |
|
|
|
for i, content_i in enumerate(contents): |
|
|
|
tensor[i, :len(content_i)] = torch.tensor(content_i) |
|
|
|
elif dim==2: |
|
|
|
max_len = max(map(len, contents)) |
|
|
|
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for |
|
|
|
content_i in contents]) |
|
|
|
tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, |
|
|
|
dtype=field_ele_dtype) |
|
|
|
for i, content_i in enumerate(contents): |
|
|
|
for j, content_ii in enumerate(content_i): |
|
|
|
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) |
|
|
|
else: |
|
|
|
shapes = set([np.shape(content_i) for content_i in contents]) |
|
|
|
if len(shapes)>1: |
|
|
|
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
shape = shapes.pop() |
|
|
|
if len(shape)==3: |
|
|
|
tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype) |
|
|
|
for i, content_i in enumerate(contents): |
|
|
|
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) |
|
|
|
else: |
|
|
|
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") |
|
|
|
return tensor |
|
|
|
else: |
|
|
|
return np.array(contents) # 不进行任何操作 |
|
|
|
else: |
|
|
|
return np.array(contents) |
|
|
|
|
|
|
|
|
|
|
|
class EngChar2DPadder(Padder): |
|
|
@@ -463,7 +455,7 @@ class EngChar2DPadder(Padder): |
|
|
|
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, pad_val=0, pad_length=0): |
|
|
|
""" |
|
|
|
:param pad_val: int, pad的位置使用该index |
|
|
@@ -471,32 +463,10 @@ class EngChar2DPadder(Padder): |
|
|
|
都pad或截取到该长度. |
|
|
|
""" |
|
|
|
super().__init__(pad_val=pad_val) |
|
|
|
|
|
|
|
|
|
|
|
self.pad_length = pad_length |
|
|
|
|
|
|
|
def _exactly_three_dims(self, contents, field_name): |
|
|
|
""" |
|
|
|
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character |
|
|
|
:param contents: |
|
|
|
:param field_name: str |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not isinstance(contents, list): |
|
|
|
raise TypeError("contents should be a list, not {}.".format(type(contents))) |
|
|
|
value = contents[0] |
|
|
|
try: |
|
|
|
value = value[0] |
|
|
|
except: |
|
|
|
raise ValueError("Field:{} only has one dimension.".format(field_name)) |
|
|
|
try: |
|
|
|
value = value[0] |
|
|
|
except: |
|
|
|
raise ValueError("Field:{} only has two dimensions.".format(field_name)) |
|
|
|
|
|
|
|
if _is_iterable(value): |
|
|
|
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) |
|
|
|
|
|
|
|
def __call__(self, contents, field_name, field_ele_dtype): |
|
|
|
|
|
|
|
def __call__(self, contents, field_name, field_ele_dtype, dim): |
|
|
|
""" |
|
|
|
期望输入类似于 |
|
|
|
[ |
|
|
@@ -510,11 +480,11 @@ class EngChar2DPadder(Padder): |
|
|
|
:param field_ele_dtype |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if field_ele_dtype not in (np.int64, np.float64): |
|
|
|
if field_ele_dtype not in (np.int64, np.float64, int, float): |
|
|
|
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( |
|
|
|
field_name, field_ele_dtype |
|
|
|
)) |
|
|
|
self._exactly_three_dims(contents, field_name) |
|
|
|
assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." |
|
|
|
if self.pad_length < 1: |
|
|
|
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) |
|
|
|
else: |
|
|
@@ -522,12 +492,12 @@ class EngChar2DPadder(Padder): |
|
|
|
max_sent_length = max(len(word_lst) for word_lst in contents) |
|
|
|
batch_size = len(contents) |
|
|
|
dtype = type(contents[0][0][0]) |
|
|
|
|
|
|
|
|
|
|
|
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, |
|
|
|
dtype=dtype) |
|
|
|
for b_idx, word_lst in enumerate(contents): |
|
|
|
for c_idx, char_lst in enumerate(word_lst): |
|
|
|
chars = char_lst[:max_char_length] |
|
|
|
padded_array[b_idx, c_idx, :len(chars)] = chars |
|
|
|
|
|
|
|
|
|
|
|
return padded_array |