@@ -4,8 +4,8 @@ from functools import reduce | |||
from fastNLP.core.callbacks.callback_events import Events, Filter | |||
class TestFilter: | |||
class TestFilter: | |||
def test_params_check(self): | |||
# 顺利通过 | |||
_filter1 = Filter(every=10) | |||
@@ -80,35 +80,6 @@ class TestFilter: | |||
_res.append(cu_res) | |||
assert _res == [9] | |||
def test_filter_fn(self): | |||
from torch.optim import SGD | |||
from torch.utils.data import DataLoader | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||
optimizer = SGD(model.parameters(), lr=0.0001) | |||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||
def filter_fn(filter, trainer): | |||
if trainer.__heihei_test__ == 10: | |||
return True | |||
return False | |||
@Filter(filter_fn=filter_fn) | |||
def _fn(trainer, data): | |||
return data | |||
_res = [] | |||
for i in range(100): | |||
trainer.__heihei_test__ = i | |||
cu_res = _fn(trainer, i) | |||
if cu_res is not None: | |||
_res.append(cu_res) | |||
assert _res == [10] | |||
def test_extract_filter_from_fn(self): | |||
@Filter(every=10) | |||
@@ -155,3 +126,119 @@ class TestFilter: | |||
assert _res == [w - 1 for w in range(60, 101, 10)] | |||
@pytest.mark.torch | |||
def test_filter_fn_torch(): | |||
from torch.optim import SGD | |||
from torch.utils.data import DataLoader | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||
optimizer = SGD(model.parameters(), lr=0.0001) | |||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||
def filter_fn(filter, trainer): | |||
if trainer.__heihei_test__ == 10: | |||
return True | |||
return False | |||
@Filter(filter_fn=filter_fn) | |||
def _fn(trainer, data): | |||
return data | |||
_res = [] | |||
for i in range(100): | |||
trainer.__heihei_test__ = i | |||
cu_res = _fn(trainer, i) | |||
if cu_res is not None: | |||
_res.append(cu_res) | |||
assert _res == [10] | |||
class TestCallbackEvents: | |||
def test_every(self): | |||
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; | |||
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; | |||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||
def _fn(data): | |||
return data | |||
_res = [] | |||
for i in range(100): | |||
cu_res = _fn(i) | |||
if cu_res is not None: | |||
_res.append(cu_res) | |||
assert _res == list(range(100)) | |||
event_state = Events.on_train_begin(every=10) | |||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||
def _fn(data): | |||
return data | |||
_res = [] | |||
for i in range(100): | |||
cu_res = _fn(i) | |||
if cu_res is not None: | |||
_res.append(cu_res) | |||
assert _res == [w - 1 for w in range(10, 101, 10)] | |||
def test_once(self): | |||
event_state = Events.on_train_begin(once=10) | |||
@Filter(once=event_state.once) | |||
def _fn(data): | |||
return data | |||
_res = [] | |||
for i in range(100): | |||
cu_res = _fn(i) | |||
if cu_res is not None: | |||
_res.append(cu_res) | |||
assert _res == [9] | |||
@pytest.mark.torch | |||
def test_callback_events_torch(): | |||
from torch.optim import SGD | |||
from torch.utils.data import DataLoader | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||
optimizer = SGD(model.parameters(), lr=0.0001) | |||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||
def filter_fn(filter, trainer): | |||
if trainer.__heihei_test__ == 10: | |||
return True | |||
return False | |||
event_state = Events.on_train_begin(filter_fn=filter_fn) | |||
@Filter(filter_fn=event_state.filter_fn) | |||
def _fn(trainer, data): | |||
return data | |||
_res = [] | |||
for i in range(100): | |||
trainer.__heihei_test__ = i | |||
cu_res = _fn(trainer, i) | |||
if cu_res is not None: | |||
_res.append(cu_res) | |||
assert _res == [10] | |||
@@ -221,124 +221,6 @@ def test_trainer_event_trigger_2( | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | |||
@pytest.mark.torch | |||
@magic_argv_env_context | |||
def test_trainer_event_trigger_3( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
n_epochs=2, | |||
): | |||
@Trainer.on(Events.on_after_trainer_initialized) | |||
def on_after_trainer_initialized(trainer, driver): | |||
print("on_after_trainer_initialized") | |||
@Trainer.on(Events.on_sanity_check_begin) | |||
def on_sanity_check_begin(trainer): | |||
print("on_sanity_check_begin") | |||
@Trainer.on(Events.on_sanity_check_end) | |||
def on_sanity_check_end(trainer, sanity_check_res): | |||
print("on_sanity_check_end") | |||
@Trainer.on(Events.on_train_begin) | |||
def on_train_begin(trainer): | |||
print("on_train_begin") | |||
@Trainer.on(Events.on_train_end) | |||
def on_train_end(trainer): | |||
print("on_train_end") | |||
@Trainer.on(Events.on_train_epoch_begin) | |||
def on_train_epoch_begin(trainer): | |||
if trainer.cur_epoch_idx >= 1: | |||
# 触发 on_exception; | |||
raise Exception | |||
print("on_train_epoch_begin") | |||
@Trainer.on(Events.on_train_epoch_end) | |||
def on_train_epoch_end(trainer): | |||
print("on_train_epoch_end") | |||
@Trainer.on(Events.on_fetch_data_begin) | |||
def on_fetch_data_begin(trainer): | |||
print("on_fetch_data_begin") | |||
@Trainer.on(Events.on_fetch_data_end) | |||
def on_fetch_data_end(trainer): | |||
print("on_fetch_data_end") | |||
@Trainer.on(Events.on_train_batch_begin) | |||
def on_train_batch_begin(trainer, batch, indices=None): | |||
print("on_train_batch_begin") | |||
@Trainer.on(Events.on_train_batch_end) | |||
def on_train_batch_end(trainer): | |||
print("on_train_batch_end") | |||
@Trainer.on(Events.on_exception) | |||
def on_exception(trainer, exception): | |||
print("on_exception") | |||
@Trainer.on(Events.on_before_backward) | |||
def on_before_backward(trainer, outputs): | |||
print("on_before_backward") | |||
@Trainer.on(Events.on_after_backward) | |||
def on_after_backward(trainer): | |||
print("on_after_backward") | |||
@Trainer.on(Events.on_before_optimizers_step) | |||
def on_before_optimizers_step(trainer, optimizers): | |||
print("on_before_optimizers_step") | |||
@Trainer.on(Events.on_after_optimizers_step) | |||
def on_after_optimizers_step(trainer, optimizers): | |||
print("on_after_optimizers_step") | |||
@Trainer.on(Events.on_before_zero_grad) | |||
def on_before_zero_grad(trainer, optimizers): | |||
print("on_before_zero_grad") | |||
@Trainer.on(Events.on_after_zero_grad) | |||
def on_after_zero_grad(trainer, optimizers): | |||
print("on_after_zero_grad") | |||
@Trainer.on(Events.on_evaluate_begin) | |||
def on_evaluate_begin(trainer): | |||
print("on_evaluate_begin") | |||
@Trainer.on(Events.on_evaluate_end) | |||
def on_evaluate_end(trainer, results): | |||
print("on_evaluate_end") | |||
with pytest.raises(Exception): | |||
with Capturing() as output: | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=n_epochs, | |||
) | |||
trainer.run() | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
for name, member in Events.__members__.items(): | |||
assert member.value in output[0] | |||