|
@@ -4,8 +4,8 @@ from functools import reduce |
|
|
from fastNLP.core.callbacks.callback_events import Events, Filter |
|
|
from fastNLP.core.callbacks.callback_events import Events, Filter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestFilter: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestFilter: |
|
|
def test_params_check(self): |
|
|
def test_params_check(self): |
|
|
# 顺利通过 |
|
|
# 顺利通过 |
|
|
_filter1 = Filter(every=10) |
|
|
_filter1 = Filter(every=10) |
|
@@ -80,35 +80,6 @@ class TestFilter: |
|
|
_res.append(cu_res) |
|
|
_res.append(cu_res) |
|
|
assert _res == [9] |
|
|
assert _res == [9] |
|
|
|
|
|
|
|
|
def test_filter_fn(self): |
|
|
|
|
|
from torch.optim import SGD |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from fastNLP.core.controllers.trainer import Trainer |
|
|
|
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification |
|
|
|
|
|
|
|
|
|
|
|
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) |
|
|
|
|
|
optimizer = SGD(model.parameters(), lr=0.0001) |
|
|
|
|
|
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) |
|
|
|
|
|
dataloader = DataLoader(dataset=dataset, batch_size=4) |
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) |
|
|
|
|
|
def filter_fn(filter, trainer): |
|
|
|
|
|
if trainer.__heihei_test__ == 10: |
|
|
|
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
@Filter(filter_fn=filter_fn) |
|
|
|
|
|
def _fn(trainer, data): |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
_res = [] |
|
|
|
|
|
for i in range(100): |
|
|
|
|
|
trainer.__heihei_test__ = i |
|
|
|
|
|
cu_res = _fn(trainer, i) |
|
|
|
|
|
if cu_res is not None: |
|
|
|
|
|
_res.append(cu_res) |
|
|
|
|
|
assert _res == [10] |
|
|
|
|
|
|
|
|
|
|
|
def test_extract_filter_from_fn(self): |
|
|
def test_extract_filter_from_fn(self): |
|
|
@Filter(every=10) |
|
|
@Filter(every=10) |
|
@@ -155,3 +126,119 @@ class TestFilter: |
|
|
assert _res == [w - 1 for w in range(60, 101, 10)] |
|
|
assert _res == [w - 1 for w in range(60, 101, 10)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
|
|
def test_filter_fn_torch(): |
|
|
|
|
|
from torch.optim import SGD |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from fastNLP.core.controllers.trainer import Trainer |
|
|
|
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification |
|
|
|
|
|
|
|
|
|
|
|
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) |
|
|
|
|
|
optimizer = SGD(model.parameters(), lr=0.0001) |
|
|
|
|
|
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) |
|
|
|
|
|
dataloader = DataLoader(dataset=dataset, batch_size=4) |
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) |
|
|
|
|
|
def filter_fn(filter, trainer): |
|
|
|
|
|
if trainer.__heihei_test__ == 10: |
|
|
|
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
@Filter(filter_fn=filter_fn) |
|
|
|
|
|
def _fn(trainer, data): |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
_res = [] |
|
|
|
|
|
for i in range(100): |
|
|
|
|
|
trainer.__heihei_test__ = i |
|
|
|
|
|
cu_res = _fn(trainer, i) |
|
|
|
|
|
if cu_res is not None: |
|
|
|
|
|
_res.append(cu_res) |
|
|
|
|
|
assert _res == [10] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCallbackEvents: |
|
|
|
|
|
def test_every(self): |
|
|
|
|
|
|
|
|
|
|
|
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; |
|
|
|
|
|
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; |
|
|
|
|
|
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) |
|
|
|
|
|
def _fn(data): |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
_res = [] |
|
|
|
|
|
for i in range(100): |
|
|
|
|
|
cu_res = _fn(i) |
|
|
|
|
|
if cu_res is not None: |
|
|
|
|
|
_res.append(cu_res) |
|
|
|
|
|
assert _res == list(range(100)) |
|
|
|
|
|
|
|
|
|
|
|
event_state = Events.on_train_begin(every=10) |
|
|
|
|
|
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) |
|
|
|
|
|
def _fn(data): |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
_res = [] |
|
|
|
|
|
for i in range(100): |
|
|
|
|
|
cu_res = _fn(i) |
|
|
|
|
|
if cu_res is not None: |
|
|
|
|
|
_res.append(cu_res) |
|
|
|
|
|
assert _res == [w - 1 for w in range(10, 101, 10)] |
|
|
|
|
|
|
|
|
|
|
|
def test_once(self): |
|
|
|
|
|
event_state = Events.on_train_begin(once=10) |
|
|
|
|
|
|
|
|
|
|
|
@Filter(once=event_state.once) |
|
|
|
|
|
def _fn(data): |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
_res = [] |
|
|
|
|
|
for i in range(100): |
|
|
|
|
|
cu_res = _fn(i) |
|
|
|
|
|
if cu_res is not None: |
|
|
|
|
|
_res.append(cu_res) |
|
|
|
|
|
assert _res == [9] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
|
|
def test_callback_events_torch(): |
|
|
|
|
|
from torch.optim import SGD |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from fastNLP.core.controllers.trainer import Trainer |
|
|
|
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification |
|
|
|
|
|
|
|
|
|
|
|
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) |
|
|
|
|
|
optimizer = SGD(model.parameters(), lr=0.0001) |
|
|
|
|
|
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) |
|
|
|
|
|
dataloader = DataLoader(dataset=dataset, batch_size=4) |
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) |
|
|
|
|
|
def filter_fn(filter, trainer): |
|
|
|
|
|
if trainer.__heihei_test__ == 10: |
|
|
|
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
event_state = Events.on_train_begin(filter_fn=filter_fn) |
|
|
|
|
|
|
|
|
|
|
|
@Filter(filter_fn=event_state.filter_fn) |
|
|
|
|
|
def _fn(trainer, data): |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
_res = [] |
|
|
|
|
|
for i in range(100): |
|
|
|
|
|
trainer.__heihei_test__ = i |
|
|
|
|
|
cu_res = _fn(trainer, i) |
|
|
|
|
|
if cu_res is not None: |
|
|
|
|
|
_res.append(cu_res) |
|
|
|
|
|
assert _res == [10] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|