Browse Source

修复fieldarray中padder潜在bug

tags/v0.4.10
yh 5 years ago
parent
commit
3d7cfb3598
1 changed files with 14 additions and 4 deletions
  1. +14
    -4
      fastNLP/core/fieldarray.py

+ 14
- 4
fastNLP/core/fieldarray.py View File

@@ -1,5 +1,5 @@
import numpy as np import numpy as np
from copy import deepcopy


class PadderBase: class PadderBase:
""" """
@@ -98,11 +98,16 @@ class FieldArray(object):
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. :param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
:param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input. :param bool is_input: If True, this FieldArray is used to the model input.
:param PadderBase padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding)
:param PadderBase padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过
fieldarray.set_pad_val()。
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型
则会进行padding;(3)其它情况不进行padder。
假设需要对English word中character进行padding,则需要使用其他的padder。
或ignore_type为True但是需要进行padding。
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False) :param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False)
""" """


def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0), ignore_type=False):
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False):
"""DataSet在初始化时会有两类方法对FieldArray操作: """DataSet在初始化时会有两类方法对FieldArray操作:
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -139,6 +144,11 @@ class FieldArray(object):


self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐
self.content_dim = None # 表示content是多少维的list self.content_dim = None # 表示content是多少维的list
if padder is None:
padder = AutoPadder(pad_val=0)
else:
assert isinstance(padder, PadderBase), "padder must be of type PadderBase."
padder = deepcopy(padder)
self.set_padder(padder) self.set_padder(padder)
self.ignore_type = ignore_type self.ignore_type = ignore_type


@@ -354,7 +364,7 @@ class FieldArray(object):
""" """
if padder is not None: if padder is not None:
assert isinstance(padder, PadderBase), "padder must be of type PadderBase." assert isinstance(padder, PadderBase), "padder must be of type PadderBase."
self.padder = padder
self.padder = deepcopy(padder)


def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
""" """


Loading…
Cancel
Save