Browse Source

merge

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
c373105674
2 changed files with 119 additions and 149 deletions
  1. +117
    -30
      tests/core/callbacks/test_callback_event.py
  2. +2
    -119
      tests/core/controllers/test_trainer_event_trigger.py

tests/core/callbacks/test_callback_events.py → tests/core/callbacks/test_callback_event.py View File

@@ -4,8 +4,8 @@ from functools import reduce
from fastNLP.core.callbacks.callback_event import Event, Filter


class TestFilter:

class TestFilter:
def test_params_check(self):
# 顺利通过
_filter1 = Filter(every=10)
@@ -80,35 +80,6 @@ class TestFilter:
_res.append(cu_res)
assert _res == [9]

def test_filter_fn(self):
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification

model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)

trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False

@Filter(filter_fn=filter_fn)
def _fn(trainer, data):
return data

_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]

def test_extract_filter_from_fn(self):
@Filter(every=10)
@@ -155,3 +126,119 @@ class TestFilter:
assert _res == [w - 1 for w in range(60, 101, 10)]


@pytest.mark.torch
def test_filter_fn_torch():
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification

model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)

trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False

@Filter(filter_fn=filter_fn)
def _fn(trainer, data):
return data

_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]


class TestCallbackEvents:
def test_every(self):

# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1;
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == list(range(100))

event_state = Events.on_train_begin(every=10)
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(10, 101, 10)]

def test_once(self):
event_state = Events.on_train_begin(once=10)

@Filter(once=event_state.once)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [9]


@pytest.mark.torch
def test_callback_events_torch():
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification

model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)

trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False

event_state = Events.on_train_begin(filter_fn=filter_fn)

@Filter(filter_fn=event_state.filter_fn)
def _fn(trainer, data):
return data

_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]










+ 2
- 119
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -62,10 +62,9 @@ def model_and_optimizers():

return trainer_params

@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_1(
model_and_optimizers: TrainerParameters,
@@ -102,125 +101,10 @@ def test_trainer_event_trigger_1(
if isinstance(v, staticmethod):
assert k in output[0]

@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_2(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):

@Trainer.on(Event.on_after_trainer_initialized())
def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized")

@Trainer.on(Event.on_sanity_check_begin())
def on_sanity_check_begin(trainer):
print("on_sanity_check_begin")

@Trainer.on(Event.on_sanity_check_end())
def on_sanity_check_end(trainer, sanity_check_res):
print("on_sanity_check_end")

@Trainer.on(Event.on_train_begin())
def on_train_begin(trainer):
print("on_train_begin")

@Trainer.on(Event.on_train_end())
def on_train_end(trainer):
print("on_train_end")

@Trainer.on(Event.on_train_epoch_begin())
def on_train_epoch_begin(trainer):
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")

@Trainer.on(Event.on_train_epoch_end())
def on_train_epoch_end(trainer):
print("on_train_epoch_end")

@Trainer.on(Event.on_fetch_data_begin())
def on_fetch_data_begin(trainer):
print("on_fetch_data_begin")

@Trainer.on(Event.on_fetch_data_end())
def on_fetch_data_end(trainer):
print("on_fetch_data_end")

@Trainer.on(Event.on_train_batch_begin())
def on_train_batch_begin(trainer, batch, indices=None):
print("on_train_batch_begin")

@Trainer.on(Event.on_train_batch_end())
def on_train_batch_end(trainer):
print("on_train_batch_end")

@Trainer.on(Event.on_exception())
def on_exception(trainer, exception):
print("on_exception")

@Trainer.on(Event.on_before_backward())
def on_before_backward(trainer, outputs):
print("on_before_backward")

@Trainer.on(Event.on_after_backward())
def on_after_backward(trainer):
print("on_after_backward")

@Trainer.on(Event.on_before_optimizers_step())
def on_before_optimizers_step(trainer, optimizers):
print("on_before_optimizers_step")

@Trainer.on(Event.on_after_optimizers_step())
def on_after_optimizers_step(trainer, optimizers):
print("on_after_optimizers_step")

@Trainer.on(Event.on_before_zero_grad())
def on_before_zero_grad(trainer, optimizers):
print("on_before_zero_grad")

@Trainer.on(Event.on_after_zero_grad())
def on_after_zero_grad(trainer, optimizers):
print("on_after_zero_grad")

@Trainer.on(Event.on_evaluate_begin())
def on_evaluate_begin(trainer):
print("on_evaluate_begin")

@Trainer.on(Event.on_evaluate_end())
def on_evaluate_end(trainer, results):
print("on_evaluate_end")

with pytest.raises(Exception):
with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()
Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]

@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_3(
def test_trainer_event_trigger_2(
model_and_optimizers: TrainerParameters,
driver,
device,
@@ -327,7 +211,6 @@ def test_trainer_event_trigger_3(
)

trainer.run()

Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):


Loading…
Cancel
Save