|
|
@@ -16,7 +16,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] |
|
|
|
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend |
|
|
|
|
|
|
|
|
|
|
|
def _get_backend(): |
|
|
|
def _get_backend() -> str: |
|
|
|
""" |
|
|
|
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: |
|
|
|
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 |
|
|
@@ -57,7 +57,7 @@ def _get_backend(): |
|
|
|
else: |
|
|
|
break |
|
|
|
if len(catch_backend): |
|
|
|
logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") |
|
|
|
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") |
|
|
|
return catch_backend[0] |
|
|
|
|
|
|
|
# 方式 (2) |
|
|
@@ -66,7 +66,7 @@ def _get_backend(): |
|
|
|
if catch_backend: |
|
|
|
break |
|
|
|
if len(catch_backend): |
|
|
|
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") |
|
|
|
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") |
|
|
|
return catch_backend[0] |
|
|
|
|
|
|
|
return 'numpy' |
|
|
@@ -80,7 +80,7 @@ class Collator: |
|
|
|
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 |
|
|
|
|
|
|
|
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 |
|
|
|
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad |
|
|
|
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad |
|
|
|
的数据返回一定是 list 。 |
|
|
|
""" |
|
|
|
self.unpack_batch_func = None |
|
|
@@ -144,15 +144,18 @@ class Collator: |
|
|
|
for key in unpack_batch.keys(): |
|
|
|
if key not in self.input_fields and key not in self.ignore_fields: |
|
|
|
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} |
|
|
|
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': |
|
|
|
self.input_fields[key]['backend'] = self.backend |
|
|
|
|
|
|
|
for field_name, setting in self.input_fields.items(): |
|
|
|
pad_fn = setting.get('pad_fn', None) |
|
|
|
if callable(pad_fn): |
|
|
|
padder = pad_fn |
|
|
|
else: |
|
|
|
backend = self.backend if setting['backend'] == 'auto' else setting['backend'] |
|
|
|
batch_field = unpack_batch.get(field_name) |
|
|
|
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], |
|
|
|
dtype=setting['dtype'], backend=setting['backend'], |
|
|
|
dtype=setting['dtype'], backend=backend, |
|
|
|
field_name=field_name) |
|
|
|
self.padders[field_name] = padder |
|
|
|
if self.batch_data_type == 'l': |
|
|
|