Browse Source

bug fix for new_collator

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
296e7e9f2b
1 changed files with 8 additions and 5 deletions
  1. +8
    -5
      fastNLP/core/collators/new_collator.py

+ 8
- 5
fastNLP/core/collators/new_collator.py View File

@@ -16,7 +16,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None]
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend


def _get_backend():
def _get_backend() -> str:
"""
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个:
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是
@@ -57,7 +57,7 @@ def _get_backend():
else:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.")
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.")
return catch_backend[0]

# 方式 (2)
@@ -66,7 +66,7 @@ def _get_backend():
if catch_backend:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.")
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
return catch_backend[0]

return 'numpy'
@@ -80,7 +80,7 @@ class Collator:
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad
的数据返回一定是 list 。
"""
self.unpack_batch_func = None
@@ -144,15 +144,18 @@ class Collator:
for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto':
self.input_fields[key]['backend'] = self.backend

for field_name, setting in self.input_fields.items():
pad_fn = setting.get('pad_fn', None)
if callable(pad_fn):
padder = pad_fn
else:
backend = self.backend if setting['backend'] == 'auto' else setting['backend']
batch_field = unpack_batch.get(field_name)
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
dtype=setting['dtype'], backend=setting['backend'],
dtype=setting['dtype'], backend=backend,
field_name=field_name)
self.padders[field_name] = padder
if self.batch_data_type == 'l':


Loading…
Cancel
Save