@@ -4,8 +4,8 @@ from functools import reduce | |||||
from fastNLP.core.callbacks.callback_event import Event, Filter | from fastNLP.core.callbacks.callback_event import Event, Filter | ||||
class TestFilter: | |||||
class TestFilter: | |||||
def test_params_check(self): | def test_params_check(self): | ||||
# 顺利通过 | # 顺利通过 | ||||
_filter1 = Filter(every=10) | _filter1 = Filter(every=10) | ||||
@@ -80,35 +80,6 @@ class TestFilter: | |||||
_res.append(cu_res) | _res.append(cu_res) | ||||
assert _res == [9] | assert _res == [9] | ||||
def test_filter_fn(self): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
@Filter(filter_fn=filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
def test_extract_filter_from_fn(self): | def test_extract_filter_from_fn(self): | ||||
@Filter(every=10) | @Filter(every=10) | ||||
@@ -155,3 +126,119 @@ class TestFilter: | |||||
assert _res == [w - 1 for w in range(60, 101, 10)] | assert _res == [w - 1 for w in range(60, 101, 10)] | ||||
@pytest.mark.torch | |||||
def test_filter_fn_torch(): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
@Filter(filter_fn=filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
class TestCallbackEvents: | |||||
def test_every(self): | |||||
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; | |||||
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == list(range(100)) | |||||
event_state = Events.on_train_begin(every=10) | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(10, 101, 10)] | |||||
def test_once(self): | |||||
event_state = Events.on_train_begin(once=10) | |||||
@Filter(once=event_state.once) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [9] | |||||
@pytest.mark.torch | |||||
def test_callback_events_torch(): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
event_state = Events.on_train_begin(filter_fn=filter_fn) | |||||
@Filter(filter_fn=event_state.filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
@@ -62,10 +62,9 @@ def model_and_optimizers(): | |||||
return trainer_params | return trainer_params | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger_1( | def test_trainer_event_trigger_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -102,125 +101,10 @@ def test_trainer_event_trigger_1( | |||||
if isinstance(v, staticmethod): | if isinstance(v, staticmethod): | ||||
assert k in output[0] | assert k in output[0] | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@magic_argv_env_context | |||||
def test_trainer_event_trigger_2( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=2, | |||||
): | |||||
@Trainer.on(Event.on_after_trainer_initialized()) | |||||
def on_after_trainer_initialized(trainer, driver): | |||||
print("on_after_trainer_initialized") | |||||
@Trainer.on(Event.on_sanity_check_begin()) | |||||
def on_sanity_check_begin(trainer): | |||||
print("on_sanity_check_begin") | |||||
@Trainer.on(Event.on_sanity_check_end()) | |||||
def on_sanity_check_end(trainer, sanity_check_res): | |||||
print("on_sanity_check_end") | |||||
@Trainer.on(Event.on_train_begin()) | |||||
def on_train_begin(trainer): | |||||
print("on_train_begin") | |||||
@Trainer.on(Event.on_train_end()) | |||||
def on_train_end(trainer): | |||||
print("on_train_end") | |||||
@Trainer.on(Event.on_train_epoch_begin()) | |||||
def on_train_epoch_begin(trainer): | |||||
if trainer.cur_epoch_idx >= 1: | |||||
# 触发 on_exception; | |||||
raise Exception | |||||
print("on_train_epoch_begin") | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def on_train_epoch_end(trainer): | |||||
print("on_train_epoch_end") | |||||
@Trainer.on(Event.on_fetch_data_begin()) | |||||
def on_fetch_data_begin(trainer): | |||||
print("on_fetch_data_begin") | |||||
@Trainer.on(Event.on_fetch_data_end()) | |||||
def on_fetch_data_end(trainer): | |||||
print("on_fetch_data_end") | |||||
@Trainer.on(Event.on_train_batch_begin()) | |||||
def on_train_batch_begin(trainer, batch, indices=None): | |||||
print("on_train_batch_begin") | |||||
@Trainer.on(Event.on_train_batch_end()) | |||||
def on_train_batch_end(trainer): | |||||
print("on_train_batch_end") | |||||
@Trainer.on(Event.on_exception()) | |||||
def on_exception(trainer, exception): | |||||
print("on_exception") | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def on_before_backward(trainer, outputs): | |||||
print("on_before_backward") | |||||
@Trainer.on(Event.on_after_backward()) | |||||
def on_after_backward(trainer): | |||||
print("on_after_backward") | |||||
@Trainer.on(Event.on_before_optimizers_step()) | |||||
def on_before_optimizers_step(trainer, optimizers): | |||||
print("on_before_optimizers_step") | |||||
@Trainer.on(Event.on_after_optimizers_step()) | |||||
def on_after_optimizers_step(trainer, optimizers): | |||||
print("on_after_optimizers_step") | |||||
@Trainer.on(Event.on_before_zero_grad()) | |||||
def on_before_zero_grad(trainer, optimizers): | |||||
print("on_before_zero_grad") | |||||
@Trainer.on(Event.on_after_zero_grad()) | |||||
def on_after_zero_grad(trainer, optimizers): | |||||
print("on_after_zero_grad") | |||||
@Trainer.on(Event.on_evaluate_begin()) | |||||
def on_evaluate_begin(trainer): | |||||
print("on_evaluate_begin") | |||||
@Trainer.on(Event.on_evaluate_end()) | |||||
def on_evaluate_end(trainer, results): | |||||
print("on_evaluate_end") | |||||
with pytest.raises(Exception): | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
) | |||||
trainer.run() | |||||
Event_attrs = Event.__dict__ | |||||
for k, v in Event_attrs.items(): | |||||
if isinstance(v, staticmethod): | |||||
assert k in output[0] | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger_3( | |||||
def test_trainer_event_trigger_2( | |||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
@@ -327,7 +211,6 @@ def test_trainer_event_trigger_3( | |||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
Event_attrs = Event.__dict__ | Event_attrs = Event.__dict__ | ||||
for k, v in Event_attrs.items(): | for k, v in Event_attrs.items(): | ||||
if isinstance(v, staticmethod): | if isinstance(v, staticmethod): | ||||