@@ -4,8 +4,6 @@ from types import DynamicClassAttribute | |||||
from functools import wraps | from functools import wraps | ||||
import fastNLP | |||||
__all__ = [ | __all__ = [ | ||||
'Events', | 'Events', | ||||
'EventsList', | 'EventsList', | ||||
@@ -16,7 +16,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | ||||
def _get_backend(): | |||||
def _get_backend() -> str: | |||||
""" | """ | ||||
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | ||||
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | ||||
@@ -57,7 +57,7 @@ def _get_backend(): | |||||
else: | else: | ||||
break | break | ||||
if len(catch_backend): | if len(catch_backend): | ||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | return catch_backend[0] | ||||
# 方式 (2) | # 方式 (2) | ||||
@@ -66,7 +66,7 @@ def _get_backend(): | |||||
if catch_backend: | if catch_backend: | ||||
break | break | ||||
if len(catch_backend): | if len(catch_backend): | ||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | return catch_backend[0] | ||||
return 'numpy' | return 'numpy' | ||||
@@ -80,7 +80,7 @@ class Collator: | |||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | ||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | 的数据返回一定是 list 。 | ||||
""" | """ | ||||
self.unpack_batch_func = None | self.unpack_batch_func = None | ||||
@@ -144,15 +144,18 @@ class Collator: | |||||
for key in unpack_batch.keys(): | for key in unpack_batch.keys(): | ||||
if key not in self.input_fields and key not in self.ignore_fields: | if key not in self.input_fields and key not in self.ignore_fields: | ||||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | ||||
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': | |||||
self.input_fields[key]['backend'] = self.backend | |||||
for field_name, setting in self.input_fields.items(): | for field_name, setting in self.input_fields.items(): | ||||
pad_fn = setting.get('pad_fn', None) | pad_fn = setting.get('pad_fn', None) | ||||
if callable(pad_fn): | if callable(pad_fn): | ||||
padder = pad_fn | padder = pad_fn | ||||
else: | else: | ||||
backend = self.backend if setting['backend'] == 'auto' else setting['backend'] | |||||
batch_field = unpack_batch.get(field_name) | batch_field = unpack_batch.get(field_name) | ||||
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | ||||
dtype=setting['dtype'], backend=setting['backend'], | |||||
dtype=setting['dtype'], backend=backend, | |||||
field_name=field_name) | field_name=field_name) | ||||
self.padders[field_name] = padder | self.padders[field_name] = padder | ||||
if self.batch_data_type == 'l': | if self.batch_data_type == 'l': | ||||
@@ -13,7 +13,6 @@ if _NEED_IMPORT_PADDLE: | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@dataclass | @dataclass | ||||
@@ -100,17 +100,16 @@ def model_and_optimizers(request): | |||||
# 测试一下普通的情况; | # 测试一下普通的情况; | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | |||||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | |||||
evaluate_every, | evaluate_every, | ||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -172,7 +171,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_validate_every( | def test_trainer_validate_every( | ||||
@@ -184,9 +183,7 @@ def test_trainer_validate_every( | |||||
def validate_every(trainer): | def validate_every(trainer): | ||||
if trainer.global_forward_batches % 10 == 0: | if trainer.global_forward_batches % 10 == 0: | ||||
print(trainer) | |||||
print("\nfastNLP test validate every.\n") | print("\nfastNLP test validate every.\n") | ||||
print(trainer.global_forward_batches) | |||||
return True | return True | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -30,12 +30,12 @@ def recover_logger(fn): | |||||
return wrapper | return wrapper | ||||
def magic_argv_env_context(fn=None, timeout=600): | |||||
def magic_argv_env_context(fn=None, timeout=300): | |||||
""" | """ | ||||
用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; | 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; | ||||
会丢掉 pytest 中的 arg 参数。 | 会丢掉 pytest 中的 arg 参数。 | ||||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; | |||||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 5 分钟,单位为秒; | |||||
:return: | :return: | ||||
""" | """ | ||||
# 说明是通过 @magic_argv_env_context(timeout=600) 调用; | # 说明是通过 @magic_argv_env_context(timeout=600) 调用; | ||||