|
- import pytest
- from functools import reduce
-
- from fastNLP.core.callbacks.callback_event import Event, Filter
-
-
-
- class TestFilter:
- def test_every_filter(self):
- # every = 10
- @Filter(every=10)
- def _fn(data):
- return data
-
- _res = []
- for i in range(100):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [w-1 for w in range(10, 101, 10)]
-
- # every = 1
- @Filter(every=1)
- def _fn(data):
- return data
-
- _res = []
- for i in range(100):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == list(range(100))
-
- def test_once_filter(self):
- # once = 10
- @Filter(once=10)
- def _fn(data):
- return data
-
- _res = []
- for i in range(100):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [9]
-
-
- def test_extract_filter_from_fn(self):
- @Filter(every=10)
- def _fn(data):
- return data
-
- _filter_num_called = []
- _filter_num_executed = []
- for i in range(100):
- cu_res = _fn(i)
- _filter = _fn.__fastNLP_filter__
- _filter_num_called.append(_filter.num_called)
- _filter_num_executed.append(_filter.num_executed)
- assert _filter_num_called == list(range(1, 101))
- assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]
-
- def _fn(data):
- return data
- assert not hasattr(_fn, "__fastNLP_filter__")
-
- def test_filter_state_dict(self):
- # every = 10
- @Filter(every=10)
- def _fn(data):
- return data
-
- _res = []
- for i in range(50):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [w - 1 for w in range(10, 51, 10)]
-
- # 保存状态
- state = _fn.__fastNLP_filter__.state_dict()
- # 加载状态
- _fn.__fastNLP_filter__.load_state_dict(state)
-
- _res = []
- for i in range(50, 100):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [w - 1 for w in range(60, 101, 10)]
-
-
- @pytest.mark.torch
- def test_filter_fn_torch():
- from torch.optim import SGD
- from torch.utils.data import DataLoader
- from fastNLP.core.controllers.trainer import Trainer
- from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
- from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
-
- model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
- optimizer = SGD(model.parameters(), lr=0.0001)
- dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
- dataloader = DataLoader(dataset=dataset, batch_size=4)
-
- trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
- def filter_fn(filter, trainer):
- if trainer.__heihei_test__ == 10:
- return True
- return False
-
- @Filter(filter_fn=filter_fn)
- def _fn(trainer, data):
- return data
-
- _res = []
- for i in range(100):
- trainer.__heihei_test__ = i
- cu_res = _fn(trainer, i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [10]
-
-
- class TestCallbackEvents:
- def test_every(self):
-
- # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
- event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1;
- @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
- def _fn(data):
- return data
-
- _res = []
- for i in range(100):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == list(range(100))
-
- event_state = Event.on_train_begin(every=10)
- @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
- def _fn(data):
- return data
-
- _res = []
- for i in range(100):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [w - 1 for w in range(10, 101, 10)]
-
- def test_once(self):
- event_state = Event.on_train_begin(once=10)
-
- @Filter(once=event_state.once)
- def _fn(data):
- return data
-
- _res = []
- for i in range(100):
- cu_res = _fn(i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [9]
-
-
- @pytest.mark.torch
- def test_callback_events_torch():
- from torch.optim import SGD
- from torch.utils.data import DataLoader
- from fastNLP.core.controllers.trainer import Trainer
- from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
- from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
-
- model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
- optimizer = SGD(model.parameters(), lr=0.0001)
- dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
- dataloader = DataLoader(dataset=dataset, batch_size=4)
-
- trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
- def filter_fn(filter, trainer):
- if trainer.__heihei_test__ == 10:
- return True
- return False
-
- event_state = Event.on_train_begin(filter_fn=filter_fn)
-
- @Filter(filter_fn=event_state.filter_fn)
- def _fn(trainer, data):
- return data
-
- _res = []
- for i in range(100):
- trainer.__heihei_test__ = i
- cu_res = _fn(trainer, i)
- if cu_res is not None:
- _res.append(cu_res)
- assert _res == [10]
-
-
-
-
-
-
-
-
|