diff --git a/tests/core/callbacks/test_callback_events.py b/tests/core/callbacks/test_callback_events.py index 8712b469..37f7047f 100644 --- a/tests/core/callbacks/test_callback_events.py +++ b/tests/core/callbacks/test_callback_events.py @@ -4,8 +4,8 @@ from functools import reduce from fastNLP.core.callbacks.callback_events import Events, Filter -class TestFilter: +class TestFilter: def test_params_check(self): # 顺利通过 _filter1 = Filter(every=10) @@ -80,35 +80,6 @@ class TestFilter: _res.append(cu_res) assert _res == [9] - def test_filter_fn(self): - from torch.optim import SGD - from torch.utils.data import DataLoader - from fastNLP.core.controllers.trainer import Trainer - from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 - from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification - - model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) - optimizer = SGD(model.parameters(), lr=0.0001) - dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) - dataloader = DataLoader(dataset=dataset, batch_size=4) - - trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) - def filter_fn(filter, trainer): - if trainer.__heihei_test__ == 10: - return True - return False - - @Filter(filter_fn=filter_fn) - def _fn(trainer, data): - return data - - _res = [] - for i in range(100): - trainer.__heihei_test__ = i - cu_res = _fn(trainer, i) - if cu_res is not None: - _res.append(cu_res) - assert _res == [10] def test_extract_filter_from_fn(self): @Filter(every=10) @@ -155,3 +126,119 @@ class TestFilter: assert _res == [w - 1 for w in range(60, 101, 10)] +@pytest.mark.torch +def test_filter_fn_torch(): + from torch.optim import SGD + from torch.utils.data import DataLoader + from fastNLP.core.controllers.trainer import Trainer + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + optimizer = SGD(model.parameters(), lr=0.0001) + dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) + dataloader = DataLoader(dataset=dataset, batch_size=4) + + trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) + def filter_fn(filter, trainer): + if trainer.__heihei_test__ == 10: + return True + return False + + @Filter(filter_fn=filter_fn) + def _fn(trainer, data): + return data + + _res = [] + for i in range(100): + trainer.__heihei_test__ = i + cu_res = _fn(trainer, i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [10] + + +class TestCallbackEvents: + def test_every(self): + + # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; + event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; + @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == list(range(100)) + + event_state = Events.on_train_begin(every=10) + @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [w - 1 for w in range(10, 101, 10)] + + def test_once(self): + event_state = Events.on_train_begin(once=10) + + @Filter(once=event_state.once) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [9] + + +@pytest.mark.torch +def test_callback_events_torch(): + from torch.optim import SGD + from torch.utils.data import DataLoader + from fastNLP.core.controllers.trainer import Trainer + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + optimizer = SGD(model.parameters(), lr=0.0001) + dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) + dataloader = DataLoader(dataset=dataset, batch_size=4) + + trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) + def filter_fn(filter, trainer): + if trainer.__heihei_test__ == 10: + return True + return False + + event_state = Events.on_train_begin(filter_fn=filter_fn) + + @Filter(filter_fn=event_state.filter_fn) + def _fn(trainer, data): + return data + + _res = [] + for i in range(100): + trainer.__heihei_test__ = i + cu_res = _fn(trainer, i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [10] + + + + + + + + + diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index fab07b3c..1a90a96d 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -221,124 +221,6 @@ def test_trainer_event_trigger_2( -@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) -@pytest.mark.torch -@magic_argv_env_context -def test_trainer_event_trigger_3( - model_and_optimizers: TrainerParameters, - driver, - device, - n_epochs=2, -): - - @Trainer.on(Events.on_after_trainer_initialized) - def on_after_trainer_initialized(trainer, driver): - print("on_after_trainer_initialized") - - @Trainer.on(Events.on_sanity_check_begin) - def on_sanity_check_begin(trainer): - print("on_sanity_check_begin") - - @Trainer.on(Events.on_sanity_check_end) - def on_sanity_check_end(trainer, sanity_check_res): - print("on_sanity_check_end") - - @Trainer.on(Events.on_train_begin) - def on_train_begin(trainer): - print("on_train_begin") - - @Trainer.on(Events.on_train_end) - def on_train_end(trainer): - print("on_train_end") - - @Trainer.on(Events.on_train_epoch_begin) - def on_train_epoch_begin(trainer): - if trainer.cur_epoch_idx >= 1: - # 触发 on_exception; - raise Exception - print("on_train_epoch_begin") - - @Trainer.on(Events.on_train_epoch_end) - def on_train_epoch_end(trainer): - print("on_train_epoch_end") - - @Trainer.on(Events.on_fetch_data_begin) - def on_fetch_data_begin(trainer): - print("on_fetch_data_begin") - - @Trainer.on(Events.on_fetch_data_end) - def on_fetch_data_end(trainer): - print("on_fetch_data_end") - - @Trainer.on(Events.on_train_batch_begin) - def on_train_batch_begin(trainer, batch, indices=None): - print("on_train_batch_begin") - - @Trainer.on(Events.on_train_batch_end) - def on_train_batch_end(trainer): - print("on_train_batch_end") - - @Trainer.on(Events.on_exception) - def on_exception(trainer, exception): - print("on_exception") - - @Trainer.on(Events.on_before_backward) - def on_before_backward(trainer, outputs): - print("on_before_backward") - - @Trainer.on(Events.on_after_backward) - def on_after_backward(trainer): - print("on_after_backward") - - @Trainer.on(Events.on_before_optimizers_step) - def on_before_optimizers_step(trainer, optimizers): - print("on_before_optimizers_step") - - @Trainer.on(Events.on_after_optimizers_step) - def on_after_optimizers_step(trainer, optimizers): - print("on_after_optimizers_step") - - @Trainer.on(Events.on_before_zero_grad) - def on_before_zero_grad(trainer, optimizers): - print("on_before_zero_grad") - - @Trainer.on(Events.on_after_zero_grad) - def on_after_zero_grad(trainer, optimizers): - print("on_after_zero_grad") - - @Trainer.on(Events.on_evaluate_begin) - def on_evaluate_begin(trainer): - print("on_evaluate_begin") - - @Trainer.on(Events.on_evaluate_end) - def on_evaluate_end(trainer, results): - print("on_evaluate_end") - - with pytest.raises(Exception): - with Capturing() as output: - trainer = Trainer( - model=model_and_optimizers.model, - driver=driver, - device=device, - optimizers=model_and_optimizers.optimizers, - train_dataloader=model_and_optimizers.train_dataloader, - evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, - input_mapping=model_and_optimizers.input_mapping, - output_mapping=model_and_optimizers.output_mapping, - metrics=model_and_optimizers.metrics, - - n_epochs=n_epochs, - ) - - trainer.run() - - if dist.is_initialized(): - dist.destroy_process_group() - - for name, member in Events.__members__.items(): - assert member.value in output[0] - -