diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 3f3691e3..7252398c 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -4,8 +4,6 @@ from types import DynamicClassAttribute from functools import wraps -import fastNLP - __all__ = [ 'Events', 'EventsList', diff --git a/fastNLP/core/collators/new_collator.py b/fastNLP/core/collators/new_collator.py index 1d8636e3..cee713f2 100644 --- a/fastNLP/core/collators/new_collator.py +++ b/fastNLP/core/collators/new_collator.py @@ -16,7 +16,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend -def _get_backend(): +def _get_backend() -> str: """ 当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: (1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 @@ -57,7 +57,7 @@ def _get_backend(): else: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") + logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") return catch_backend[0] # 方式 (2) @@ -66,7 +66,7 @@ def _get_backend(): if catch_backend: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") + logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") return catch_backend[0] return 'numpy' @@ -80,7 +80,7 @@ class Collator: 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 - 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad + 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 """ self.unpack_batch_func = None @@ -144,15 +144,18 @@ class Collator: for key in unpack_batch.keys(): if key not in self.input_fields and key not in self.ignore_fields: self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} + elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': + self.input_fields[key]['backend'] = self.backend for field_name, setting in self.input_fields.items(): pad_fn = setting.get('pad_fn', None) if callable(pad_fn): padder = pad_fn else: + backend = self.backend if setting['backend'] == 'auto' else setting['backend'] batch_field = unpack_batch.get(field_name) padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], - dtype=setting['dtype'], backend=setting['backend'], + dtype=setting['dtype'], backend=backend, field_name=field_name) self.padders[field_name] = padder if self.batch_data_type == 'l': diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index bc5a590c..3cf850c3 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -13,7 +13,6 @@ if _NEED_IMPORT_PADDLE: from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset -from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context @dataclass diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index d8dd7d73..891626b5 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -100,17 +100,16 @@ def model_and_optimizers(request): # 测试一下普通的情况; @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) -@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) @magic_argv_env_context def test_trainer_torch_with_evaluator( model_and_optimizers: TrainerParameters, driver, device, - callbacks, evaluate_every, n_epochs=10, ): + callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -172,7 +171,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( if dist.is_initialized(): dist.destroy_process_group() - +@pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_validate_every( @@ -184,9 +183,7 @@ def test_trainer_validate_every( def validate_every(trainer): if trainer.global_forward_batches % 10 == 0: - print(trainer) print("\nfastNLP test validate every.\n") - print(trainer.global_forward_batches) return True trainer = Trainer( diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 7e02ca0d..463f144d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -30,12 +30,12 @@ def recover_logger(fn): return wrapper -def magic_argv_env_context(fn=None, timeout=600): +def magic_argv_env_context(fn=None, timeout=300): """ 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; 会丢掉 pytest 中的 arg 参数。 - :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; + :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 5 分钟,单位为秒; :return: """ # 说明是通过 @magic_argv_env_context(timeout=600) 调用;