| @@ -16,7 +16,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
| CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | ||||
| def _get_backend(): | |||||
| def _get_backend() -> str: | |||||
| """ | """ | ||||
| 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | ||||
| (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | ||||
| @@ -57,7 +57,7 @@ def _get_backend(): | |||||
| else: | else: | ||||
| break | break | ||||
| if len(catch_backend): | if len(catch_backend): | ||||
| logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") | |||||
| logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") | |||||
| return catch_backend[0] | return catch_backend[0] | ||||
| # 方式 (2) | # 方式 (2) | ||||
| @@ -66,7 +66,7 @@ def _get_backend(): | |||||
| if catch_backend: | if catch_backend: | ||||
| break | break | ||||
| if len(catch_backend): | if len(catch_backend): | ||||
| logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") | |||||
| logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") | |||||
| return catch_backend[0] | return catch_backend[0] | ||||
| return 'numpy' | return 'numpy' | ||||
| @@ -80,7 +80,7 @@ class Collator: | |||||
| 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | ||||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | ||||
| 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad | |||||
| 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||||
| 的数据返回一定是 list 。 | 的数据返回一定是 list 。 | ||||
| """ | """ | ||||
| self.unpack_batch_func = None | self.unpack_batch_func = None | ||||
| @@ -144,15 +144,18 @@ class Collator: | |||||
| for key in unpack_batch.keys(): | for key in unpack_batch.keys(): | ||||
| if key not in self.input_fields and key not in self.ignore_fields: | if key not in self.input_fields and key not in self.ignore_fields: | ||||
| self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | ||||
| elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': | |||||
| self.input_fields[key]['backend'] = self.backend | |||||
| for field_name, setting in self.input_fields.items(): | for field_name, setting in self.input_fields.items(): | ||||
| pad_fn = setting.get('pad_fn', None) | pad_fn = setting.get('pad_fn', None) | ||||
| if callable(pad_fn): | if callable(pad_fn): | ||||
| padder = pad_fn | padder = pad_fn | ||||
| else: | else: | ||||
| backend = self.backend if setting['backend'] == 'auto' else setting['backend'] | |||||
| batch_field = unpack_batch.get(field_name) | batch_field = unpack_batch.get(field_name) | ||||
| padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | ||||
| dtype=setting['dtype'], backend=setting['backend'], | |||||
| dtype=setting['dtype'], backend=backend, | |||||
| field_name=field_name) | field_name=field_name) | ||||
| self.padders[field_name] = padder | self.padders[field_name] = padder | ||||
| if self.batch_data_type == 'l': | if self.batch_data_type == 'l': | ||||