Browse Source

fix some importing bugs

tags/v0.4.10
ChenXin 5 years ago
parent
commit
d0354d8e28
3 changed files with 108 additions and 90 deletions
  1. +4
    -1
      fastNLP/__init__.py
  2. +1
    -1
      fastNLP/core/__init__.py
  3. +103
    -88
      fastNLP/core/field.py

+ 4
- 1
fastNLP/__init__.py View File

@@ -14,6 +14,7 @@ __all__ = [
"Instance", "Instance",
"FieldArray", "FieldArray",
"DataSetIter", "DataSetIter",
"BatchIter", "BatchIter",
"TorchLoaderIter", "TorchLoaderIter",
@@ -31,6 +32,7 @@ __all__ = [
"TensorboardCallback", "TensorboardCallback",
"LRScheduler", "LRScheduler",
"ControlC", "ControlC",
"LRFinder",
"Padder", "Padder",
"AutoPadder", "AutoPadder",
@@ -43,7 +45,8 @@ __all__ = [
"Optimizer", "Optimizer",
"SGD", "SGD",
"Adam", "Adam",
"AdamW",

"Sampler", "Sampler",
"SequentialSampler", "SequentialSampler",
"BucketSampler", "BucketSampler",


+ 1
- 1
fastNLP/core/__init__.py View File

@@ -22,7 +22,7 @@ from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
from .instance import Instance from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric
from .optimizer import Optimizer, SGD, Adam
from .optimizer import Optimizer, SGD, Adam, AdamW
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester from .tester import Tester
from .trainer import Trainer from .trainer import Trainer


+ 103
- 88
fastNLP/core/field.py View File

@@ -1,4 +1,8 @@

__all__ = [
"Padder",
"AutoPadder",
"EngChar2DPadder",
]


from numbers import Number from numbers import Number
import torch import torch
@@ -9,24 +13,27 @@ from copy import deepcopy
from collections import Counter from collections import Counter
from .utils import _is_iterable from .utils import _is_iterable



class SetInputOrTargetException(Exception): class SetInputOrTargetException(Exception):
def __init__(self, msg, index=None, field_name=None): def __init__(self, msg, index=None, field_name=None):
super().__init__(msg) super().__init__(msg)
self.msg = msg self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了 self.index = index # 标示在哪个数据遭遇到问题了
self.field_name = field_name # 标示当前field的名称
self.field_name = field_name # 标示当前field的名称



class AppendToTargetOrInputException(Exception): class AppendToTargetOrInputException(Exception):
def __init__(self, msg, index=None, field_name=None): def __init__(self, msg, index=None, field_name=None):
super().__init__(msg) super().__init__(msg)
self.msg = msg self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了 self.index = index # 标示在哪个数据遭遇到问题了
self.field_name = field_name # 标示当前field的名称
self.field_name = field_name # 标示当前field的名称



class FieldArray: class FieldArray:
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False, def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False,
use_1st_ins_infer_dim_type=True): use_1st_ins_infer_dim_type=True):
if len(content)==0:
if len(content) == 0:
raise RuntimeError("Empty fieldarray is not allowed.") raise RuntimeError("Empty fieldarray is not allowed.")
_content = content _content = content
try: try:
@@ -43,34 +50,34 @@ class FieldArray:
self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self._is_input = False self._is_input = False
self._is_target = False self._is_target = False
if is_input: if is_input:
self.is_input = is_input self.is_input = is_input
if is_target: if is_target:
self.is_target = is_target self.is_target = is_target
if padder is None: if padder is None:
padder = AutoPadder(pad_val=0) padder = AutoPadder(pad_val=0)
else: else:
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder."
padder = deepcopy(padder) padder = deepcopy(padder)
self.set_padder(padder) self.set_padder(padder)
@property @property
def ignore_type(self): def ignore_type(self):
return self._ignore_type return self._ignore_type
@ignore_type.setter @ignore_type.setter
def ignore_type(self, value): def ignore_type(self, value):
if value: if value:
self._cell_ndim = None self._cell_ndim = None
self.dtype = None self.dtype = None
self._ignore_type = value self._ignore_type = value
@property @property
def is_input(self): def is_input(self):
return self._is_input return self._is_input
@is_input.setter @is_input.setter
def is_input(self, value): def is_input(self, value):
""" """
@@ -85,11 +92,11 @@ class FieldArray:
self.dtype = None self.dtype = None
self._cell_ndim = None self._cell_ndim = None
self._is_input = value self._is_input = value
@property @property
def is_target(self): def is_target(self):
return self._is_target return self._is_target
@is_target.setter @is_target.setter
def is_target(self, value): def is_target(self, value):
""" """
@@ -103,7 +110,7 @@ class FieldArray:
self.dtype = None self.dtype = None
self._cell_ndim = None self._cell_ndim = None
self._is_target = value self._is_target = value
def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True):
""" """
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有
@@ -120,35 +127,37 @@ class FieldArray:
for cell in self.content[1:]: for cell in self.content[1:]:
index += 1 index += 1
type_i, dim_i = _get_ele_type_and_dim(cell) type_i, dim_i = _get_ele_type_and_dim(cell)
if type_i!=type_0:
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}."
".".format(type_i, index, type_0))
if dim_0!=dim_i:
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with "
"dimension:{}.".format(dim_i, index, dim_0))
if type_i != type_0:
raise SetInputOrTargetException(
"Type:{} in index {} is different from the first element with type:{}."
".".format(type_i, index, type_0))
if dim_0 != dim_i:
raise SetInputOrTargetException(
"Dimension:{} in index {} is different from the first element with "
"dimension:{}.".format(dim_i, index, dim_0))
self._cell_ndim = dim_0 self._cell_ndim = dim_0
self.dtype = type_0 self.dtype = type_0
except SetInputOrTargetException as e: except SetInputOrTargetException as e:
e.index = index e.index = index
raise e raise e
def append(self, val:Any):
def append(self, val: Any):
""" """
:param val: 把该val append到fieldarray。 :param val: 把该val append到fieldarray。
:return: :return:
""" """
if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type: if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type:
type_, dim_ = _get_ele_type_and_dim(val) type_, dim_ = _get_ele_type_and_dim(val)
if self.dtype!=type_:
if self.dtype != type_:
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with "
f"previous values(type:{self.dtype}).") f"previous values(type:{self.dtype}).")
if self._cell_ndim!=dim_:
if self._cell_ndim != dim_:
raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with "
f"previous values(dim:{self._cell_ndim}).") f"previous values(dim:{self._cell_ndim}).")
self.content.append(val) self.content.append(val)
else: else:
self.content.append(val) self.content.append(val)
def pop(self, index): def pop(self, index):
""" """
删除该field中index处的元素 删除该field中index处的元素
@@ -156,22 +165,22 @@ class FieldArray:
:return: :return:
""" """
self.content.pop(index) self.content.pop(index)
def __getitem__(self, indices): def __getitem__(self, indices):
return self.get(indices, pad=False) return self.get(indices, pad=False)
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
assert isinstance(idx, int) assert isinstance(idx, int)
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型
type_, dim_ = _get_ele_type_and_dim(val) type_, dim_ = _get_ele_type_and_dim(val)
if self.dtype!=type_:
if self.dtype != type_:
raise RuntimeError(f"Value(type:{type_}) are of different types with " raise RuntimeError(f"Value(type:{type_}) are of different types with "
f"other values(type:{self.dtype}).")
if self._cell_ndim!=dim_:
f"other values(type:{self.dtype}).")
if self._cell_ndim != dim_:
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with "
f"previous values(dim:{self._cell_ndim}).")
f"previous values(dim:{self._cell_ndim}).")
self.content[idx] = val self.content[idx] = val
def get(self, indices, pad=True): def get(self, indices, pad=True):
""" """
根据给定的indices返回内容 根据给定的indices返回内容
@@ -184,16 +193,16 @@ class FieldArray:
return self.content[indices] return self.content[indices]
if self.is_input is False and self.is_target is False: if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name))
contents = [self.content[i] for i in indices] contents = [self.content[i] for i in indices]
if self.padder is None or pad is False: if self.padder is None or pad is False:
return np.array(contents) return np.array(contents)
else: else:
return self.pad(contents) return self.pad(contents)
def pad(self, contents): def pad(self, contents):
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
def set_padder(self, padder): def set_padder(self, padder):
""" """
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。
@@ -205,7 +214,7 @@ class FieldArray:
self.padder = deepcopy(padder) self.padder = deepcopy(padder)
else: else:
self.padder = None self.padder = None
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
""" """
修改padder的pad_val. 修改padder的pad_val.
@@ -215,7 +224,7 @@ class FieldArray:
if self.padder is not None: if self.padder is not None:
self.padder.set_pad_val(pad_val) self.padder.set_pad_val(pad_val)
return self return self
def __len__(self): def __len__(self):
""" """
Returns the size of FieldArray. Returns the size of FieldArray.
@@ -223,7 +232,7 @@ class FieldArray:
:return int length: :return int length:
""" """
return len(self.content) return len(self.content)
def to(self, other): def to(self, other):
""" """
将other的属性复制给本FieldArray(other必须为FieldArray类型). 将other的属性复制给本FieldArray(other必须为FieldArray类型).
@@ -233,15 +242,15 @@ class FieldArray:
:return: :class:`~fastNLP.FieldArray` :return: :class:`~fastNLP.FieldArray`
""" """
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other))
self.ignore_type = other.ignore_type self.ignore_type = other.ignore_type
self.is_input = other.is_input self.is_input = other.is_input
self.is_target = other.is_target self.is_target = other.is_target
self.padder = other.padder self.padder = other.padder
return self return self
def split(self, sep:str=None, inplace:bool=True):
def split(self, sep: str = None, inplace: bool = True):
""" """
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值


@@ -257,8 +266,8 @@ class FieldArray:
print(f"Exception happens when process value in index {index}.") print(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
def int(self, inplace:bool=True):
def int(self, inplace: bool = True):
""" """
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), 将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) (2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
@@ -277,7 +286,7 @@ class FieldArray:
print(f"Exception happens when process value in index {index}.") print(f"Exception happens when process value in index {index}.")
print(e) print(e)
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
def float(self, inplace=True): def float(self, inplace=True):
""" """
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), 将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -297,7 +306,7 @@ class FieldArray:
print(f"Exception happens when process value in index {index}.") print(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
def bool(self, inplace=True): def bool(self, inplace=True):
""" """
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), 将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -316,9 +325,9 @@ class FieldArray:
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.") print(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
def lower(self, inplace=True): def lower(self, inplace=True):
""" """
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -338,7 +347,7 @@ class FieldArray:
print(f"Exception happens when process value in index {index}.") print(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
def upper(self, inplace=True): def upper(self, inplace=True):
""" """
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), 将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -358,7 +367,7 @@ class FieldArray:
print(f"Exception happens when process value in index {index}.") print(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
def value_count(self): def value_count(self):
""" """
返回该field下不同value的数量。多用于统计label数量 返回该field下不同value的数量。多用于统计label数量
@@ -366,17 +375,18 @@ class FieldArray:
:return: Counter, key是label,value是出现次数 :return: Counter, key是label,value是出现次数
""" """
count = Counter() count = Counter()
def cum(cell): def cum(cell):
if _is_iterable(cell) and not isinstance(cell, str): if _is_iterable(cell) and not isinstance(cell, str):
for cell_ in cell: for cell_ in cell:
cum(cell_) cum(cell_)
else: else:
count[cell] += 1 count[cell] += 1
for cell in self.content: for cell in self.content:
cum(cell) cum(cell)
return count return count
def _after_process(self, new_contents, inplace): def _after_process(self, new_contents, inplace):
""" """
当调用处理函数之后,决定是否要替换field。 当调用处理函数之后,决定是否要替换field。
@@ -398,7 +408,7 @@ class FieldArray:
return new_contents return new_contents




def _get_ele_type_and_dim(cell:Any, dim=0):
def _get_ele_type_and_dim(cell: Any, dim=0):
""" """
识别cell的类别与dimension的数量 识别cell的类别与dimension的数量


@@ -414,13 +424,13 @@ def _get_ele_type_and_dim(cell:Any, dim=0):
elif isinstance(cell, list): elif isinstance(cell, list):
dim += 1 dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i,j in res])
dims = set([j for i,j in res])
if len(types)>1:
types = set([i for i, j in res])
dims = set([j for i, j in res])
if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
elif len(types)==0:
elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.") raise SetInputOrTargetException("Empty value encountered.")
if len(dims)>1:
if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop() return types.pop(), dims.pop()
elif isinstance(cell, torch.Tensor): elif isinstance(cell, torch.Tensor):
@@ -431,16 +441,16 @@ def _get_ele_type_and_dim(cell:Any, dim=0):
# 否则需要继续往下iterate # 否则需要继续往下iterate
dim += 1 dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i,j in res])
dims = set([j for i,j in res])
if len(types)>1:
types = set([i for i, j in res])
dims = set([j for i, j in res])
if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
elif len(types)==0:
elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.") raise SetInputOrTargetException("Empty value encountered.")
if len(dims)>1:
if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop() return types.pop(), dims.pop()
else: # 包含tuple, set, dict以及其它的类型
else: # 包含tuple, set, dict以及其它的类型
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")




@@ -462,15 +472,15 @@ class Padder:
:return: np.array([padded_element]) :return: np.array([padded_element])


""" """
def __init__(self, pad_val=0, **kwargs): def __init__(self, pad_val=0, **kwargs):
self.pad_val = pad_val self.pad_val = pad_val
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
self.pad_val = pad_val self.pad_val = pad_val
@abstractmethod @abstractmethod
def __call__(self, contents, field_name, field_ele_dtype, dim:int):
def __call__(self, contents, field_name, field_ele_dtype, dim: int):
""" """
传入的是List内容。假设有以下的DataSet。 传入的是List内容。假设有以下的DataSet。


@@ -537,23 +547,24 @@ class AutoPadder(Padder):


3 其它情况不进行处理,返回一个np.array类型。 3 其它情况不进行处理,返回一个np.array类型。
""" """
def __init__(self, pad_val=0): def __init__(self, pad_val=0):
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
def __call__(self, contents, field_name, field_ele_dtype, dim): def __call__(self, contents, field_name, field_ele_dtype, dim):
if field_ele_dtype: if field_ele_dtype:
if dim>3:
if dim > 3:
return np.array(contents) return np.array(contents)
if isinstance(field_ele_dtype, type) and \ if isinstance(field_ele_dtype, type) and \
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): (issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)):
if dim==0:
if dim == 0:
array = np.array(contents, dtype=field_ele_dtype) array = np.array(contents, dtype=field_ele_dtype)
elif dim==1:
elif dim == 1:
max_len = max(map(len, contents)) max_len = max(map(len, contents))
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents): for i, content_i in enumerate(contents):
array[i, :len(content_i)] = content_i array[i, :len(content_i)] = content_i
elif dim==2:
elif dim == 2:
max_len = max(map(len, contents)) max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents]) content_i in contents])
@@ -563,20 +574,21 @@ class AutoPadder(Padder):
array[i, j, :len(content_ii)] = content_ii array[i, j, :len(content_ii)] = content_ii
else: else:
shape = np.shape(contents) shape = np.shape(contents)
if len(shape)==4: # 说明各dimension是相同的大小
if len(shape) == 4: # 说明各dimension是相同的大小
array = np.array(contents, dtype=field_ele_dtype) array = np.array(contents, dtype=field_ele_dtype)
else: else:
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return array return array
elif str(field_ele_dtype).startswith('torch'): elif str(field_ele_dtype).startswith('torch'):
if dim==0:
if dim == 0:
tensor = torch.tensor(contents).to(field_ele_dtype) tensor = torch.tensor(contents).to(field_ele_dtype)
elif dim==1:
elif dim == 1:
max_len = max(map(len, contents)) max_len = max(map(len, contents))
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents): for i, content_i in enumerate(contents):
tensor[i, :len(content_i)] = torch.tensor(content_i) tensor[i, :len(content_i)] = torch.tensor(content_i)
elif dim==2:
elif dim == 2:
max_len = max(map(len, contents)) max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents]) content_i in contents])
@@ -587,15 +599,18 @@ class AutoPadder(Padder):
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) tensor[i, j, :len(content_ii)] = torch.tensor(content_ii)
else: else:
shapes = set([np.shape(content_i) for content_i in contents]) shapes = set([np.shape(content_i) for content_i in contents])
if len(shapes)>1:
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
if len(shapes) > 1:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
shape = shapes.pop() shape = shapes.pop()
if len(shape)==3:
tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype)
if len(shape) == 3:
tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val,
dtype=field_ele_dtype)
for i, content_i in enumerate(contents): for i, content_i in enumerate(contents):
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype)
else: else:
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return tensor return tensor
else: else:
return np.array(contents) # 不进行任何操作 return np.array(contents) # 不进行任何操作
@@ -626,7 +641,7 @@ class EngChar2DPadder(Padder):
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder


""" """
def __init__(self, pad_val=0, pad_length=0): def __init__(self, pad_val=0, pad_length=0):
""" """
:param pad_val: int, pad的位置使用该index :param pad_val: int, pad的位置使用该index
@@ -634,9 +649,9 @@ class EngChar2DPadder(Padder):
都pad或截取到该长度. 都pad或截取到该长度.
""" """
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
self.pad_length = pad_length self.pad_length = pad_length
def __call__(self, contents, field_name, field_ele_dtype, dim): def __call__(self, contents, field_name, field_ele_dtype, dim):
""" """
期望输入类似于 期望输入类似于
@@ -655,7 +670,7 @@ class EngChar2DPadder(Padder):
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format(
field_name, field_ele_dtype field_name, field_ele_dtype
)) ))
assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions."
assert dim == 2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions."
if self.pad_length < 1: if self.pad_length < 1:
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents])
else: else:
@@ -663,12 +678,12 @@ class EngChar2DPadder(Padder):
max_sent_length = max(len(word_lst) for word_lst in contents) max_sent_length = max(len(word_lst) for word_lst in contents)
batch_size = len(contents) batch_size = len(contents)
dtype = type(contents[0][0][0]) dtype = type(contents[0][0][0])
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val,
dtype=dtype) dtype=dtype)
for b_idx, word_lst in enumerate(contents): for b_idx, word_lst in enumerate(contents):
for c_idx, char_lst in enumerate(word_lst): for c_idx, char_lst in enumerate(word_lst):
chars = char_lst[:max_char_length] chars = char_lst[:max_char_length]
padded_array[b_idx, c_idx, :len(chars)] = chars padded_array[b_idx, c_idx, :len(chars)] = chars
return padded_array return padded_array

Loading…
Cancel
Save