Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
c39ed15c00
2 changed files with 77 additions and 12 deletions
  1. +75
    -10
      fastNLP/core/collators/new_collator.py
  2. +2
    -2
      fastNLP/core/collators/padders/get_padder.py

+ 75
- 10
fastNLP/core/collators/new_collator.py View File

@@ -1,4 +1,7 @@
from typing import List, Union, Dict, Callable, Sequence, Mapping from typing import List, Union, Dict, Callable, Sequence, Mapping
import os
import sys
import inspect


from fastNLP.core.log import logger from fastNLP.core.log import logger
from .padders.get_padder import get_padder from .padders.get_padder import get_padder
@@ -9,18 +12,76 @@ from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch
pack_batch_sequence pack_batch_sequence


sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None]
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None]
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend


def _get_backend():
"""
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个:
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是
某个 backend 的 dataloader 。
(2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。

如果都没有找到则返回 numpy 。
:return:
"""
def _check_module(module):
"""
检查该 module 是否含有 某个 backend 的特征

:param module: module 对象
:return:
"""
catch_backend = []
try:
file = module.__file__
for backend in CHECK_BACKEND:
if f'{os.sep}site-packages{os.sep}{backend}' in file:
catch_backend = [backend, file]
except:
pass
return catch_backend

currentframe = inspect.currentframe()
# 方式(1)
catch_backend = []
for i in range(100):
currentframe = currentframe.f_back
if currentframe is not None:
module = inspect.getmodule(currentframe)
if module is not None:
catch_backend = _check_module(module)
if len(catch_backend): # 主要捕获到一个就结束吧
break
else:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.")
return catch_backend[0]

# 方式 (2)
for key, module in sys.modules.items():
catch_backend = _check_module(module)
if catch_backend:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.")
return catch_backend[0]

return 'numpy'




class Collator: class Collator:
def __init__(self, backend='torch'):
def __init__(self, backend='auto'):
""" """
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。


:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。
若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad
的数据返回一定是 list 。
""" """
self.unpack_batch_func = None self.unpack_batch_func = None
self.pack_batch_func = None self.pack_batch_func = None
@@ -77,6 +138,9 @@ class Collator:


pad_batch = {} pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder if len(self.padders)==0: # 第一次运行,准备 padder
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。
self.backend = _get_backend()

for key in unpack_batch.keys(): for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields: if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}
@@ -100,7 +164,7 @@ class Collator:


return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型


def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto',
pad_fn:Callable=None) -> "Collator": pad_fn:Callable=None) -> "Collator":
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
@@ -110,10 +174,11 @@ class Collator:
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。 形式,输出将被直接作为结果输出。
@@ -154,8 +219,8 @@ class Collator:
""" """
设置可以 pad 的 field 默认 pad 为什么类型的 tensor 设置可以 pad 的 field 默认 pad 为什么类型的 tensor


:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None],
若为 None ,则不进行 padding
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None],
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend
:return: :return:
""" """
assert backend in SUPPORTED_BACKENDS assert backend in SUPPORTED_BACKENDS


+ 2
- 2
fastNLP/core/collators/padders/get_padder.py View File

@@ -27,7 +27,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
:param field_name: 方便报错的。 :param field_name: 方便报错的。
:return: :return:
""" """
logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field))
logger.debug(f"The content in the field:`{field_name}` is:\n"+str(batch_field))
if pad_val is None: if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
return NullPadder() return NullPadder()
@@ -112,7 +112,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return NullPadder() return NullPadder()


except DtypeError as e: except DtypeError as e:
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \
"information please set logger's level to DEBUG." "information please set logger's level to DEBUG."
if must_pad: if must_pad:
raise type(e)(msg=msg) raise type(e)(msg=msg)


Loading…
Cancel
Save