Browse Source

1. 修复DataSet.delete_instance的bug; 2. FieldArray中支持只使用第一个instance推断dimension和type,节省时间

tags/v0.4.10
yh 5 years ago
parent
commit
c19499e60a
1 changed files with 26 additions and 14 deletions
  1. +26
    -14
      fastNLP/core/field.py

+ 26
- 14
fastNLP/core/field.py View File

@@ -23,7 +23,8 @@ class AppendToTargetOrInputException(Exception):
self.field_name = field_name # 标示当前field的名称

class FieldArray:
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False):
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False,
use_1st_ins_infer_dim_type=True):
if len(content)==0:
raise RuntimeError("Empty fieldarray is not allowed.")
_content = content
@@ -38,6 +39,7 @@ class FieldArray:
# 根据input的情况设置input,target等
self._cell_ndim = None # 多少维度
self.dtype = None # 最内层的element都是什么类型的
self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self._is_input = False
self._is_target = False

@@ -77,7 +79,7 @@ class FieldArray:
if value is True and \
self._is_target is False and \
self._ignore_type is False:
self._check_dtype_and_ndim()
self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
if value is False and self._is_target is False:
self.dtype = None
self._cell_ndim = None
@@ -95,32 +97,34 @@ class FieldArray:
if value is True and \
self._is_input is False and \
self._ignore_type is False:
self._check_dtype_and_ndim()
self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
if value is False and self._is_input is False:
self.dtype = None
self._cell_ndim = None
self._is_target = value

def _check_dtype_and_ndim(self):
def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True):
"""
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有
通过将直接报错.

:param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim
:return:
"""
cell_0 = self.content[0]
index = 0
try:
type_0, dim_0 = _get_ele_type_and_dim(cell_0)
for cell in self.content[1:]:
index += 1
type_i, dim_i = _get_ele_type_and_dim(cell)
if type_i!=type_0:
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}."
".".format(type_i, index, type_0))
if dim_0!=dim_i:
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with "
"dimension:{}.".format(dim_i, index, dim_0))
if not only_check_1st_ins_dim_type:
for cell in self.content[1:]:
index += 1
type_i, dim_i = _get_ele_type_and_dim(cell)
if type_i!=type_0:
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}."
".".format(type_i, index, type_0))
if dim_0!=dim_i:
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with "
"dimension:{}.".format(dim_i, index, dim_0))
self._cell_ndim = dim_0
self.dtype = type_0
except SetInputOrTargetException as e:
@@ -132,7 +136,7 @@ class FieldArray:
:param val: 把该val append到fieldarray。
:return:
"""
if (self._is_target or self._is_input) and self._ignore_type is False:
if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type:
type_, dim_ = _get_ele_type_and_dim(val)
if self.dtype!=type_:
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with "
@@ -144,6 +148,14 @@ class FieldArray:
else:
self.content.append(val)

def pop(self, index):
"""
删除该field中index处的元素
:param int index: 从0开始的数据下标。
:return:
"""
self.content.pop(index)

def __getitem__(self, indices):
return self.get(indices, pad=False)



Loading…
Cancel
Save