diff --git a/fastNLP/core/collators/new_collator.py b/fastNLP/core/collators/new_collator.py index 9123a293..1d8636e3 100644 --- a/fastNLP/core/collators/new_collator.py +++ b/fastNLP/core/collators/new_collator.py @@ -1,4 +1,7 @@ from typing import List, Union, Dict, Callable, Sequence, Mapping +import os +import sys +import inspect from fastNLP.core.log import logger from .padders.get_padder import get_padder @@ -9,18 +12,76 @@ from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch pack_batch_sequence sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 -SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] +SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] +CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend + + +def _get_backend(): + """ + 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: + (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 + 某个 backend 的 dataloader 。 + (2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。 + + 如果都没有找到则返回 numpy 。 + :return: + """ + def _check_module(module): + """ + 检查该 module 是否含有 某个 backend 的特征 + + :param module: module 对象 + :return: + """ + catch_backend = [] + try: + file = module.__file__ + for backend in CHECK_BACKEND: + if f'{os.sep}site-packages{os.sep}{backend}' in file: + catch_backend = [backend, file] + except: + pass + return catch_backend + + currentframe = inspect.currentframe() + # 方式(1) + catch_backend = [] + for i in range(100): + currentframe = currentframe.f_back + if currentframe is not None: + module = inspect.getmodule(currentframe) + if module is not None: + catch_backend = _check_module(module) + if len(catch_backend): # 主要捕获到一个就结束吧 + break + else: + break + if len(catch_backend): + logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") + return catch_backend[0] + + # 方式 (2) + for key, module in sys.modules.items(): + catch_backend = _check_module(module) + if catch_backend: + break + if len(catch_backend): + logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") + return catch_backend[0] + + return 'numpy' class Collator: - def __init__(self, backend='torch'): + def __init__(self, backend='auto'): """ 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 - :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 - 若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 + 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad + 的数据返回一定是 list 。 """ self.unpack_batch_func = None self.pack_batch_func = None @@ -77,6 +138,9 @@ class Collator: pad_batch = {} if len(self.padders)==0: # 第一次运行,准备 padder + if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 + self.backend = _get_backend() + for key in unpack_batch.keys(): if key not in self.input_fields and key not in self.ignore_fields: self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} @@ -100,7 +164,7 @@ class Collator: return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 - def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', pad_fn:Callable=None) -> "Collator": """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 @@ -110,10 +174,11 @@ class Collator: 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, - paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch 形式,输出将被直接作为结果输出。 @@ -154,8 +219,8 @@ class Collator: """ 设置可以 pad 的 field 默认 pad 为什么类型的 tensor - :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], - 若为 None ,则不进行 padding 。 + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], + 若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 :return: """ assert backend in SUPPORTED_BACKENDS diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index ecef9fcf..ae32b7b8 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -27,7 +27,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> :param field_name: 方便报错的。 :return: """ - logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field)) + logger.debug(f"The content in the field:`{field_name}` is:\n"+str(batch_field)) if pad_val is None: logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") return NullPadder() @@ -112,7 +112,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return NullPadder() except DtypeError as e: - msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ + msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ "information please set logger's level to DEBUG." if must_pad: raise type(e)(msg=msg)