|
@@ -14,6 +14,7 @@ from tests.helpers.utils import magic_argv_env_context |
|
|
from fastNLP.envs.distributed import rank_zero_rm |
|
|
from fastNLP.envs.distributed import rank_zero_rm |
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 |
|
|
from tests.helpers.datasets.torch_data import TorchArgMaxDataset |
|
|
from tests.helpers.datasets.torch_data import TorchArgMaxDataset |
|
|
|
|
|
from tests.helpers.utils import Capturing |
|
|
from torchmetrics import Accuracy |
|
|
from torchmetrics import Accuracy |
|
|
from fastNLP.core.log import logger |
|
|
from fastNLP.core.log import logger |
|
|
|
|
|
|
|
@@ -428,6 +429,78 @@ def test_trainer_checkpoint_callback_1( |
|
|
dist.destroy_process_group() |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
|
|
def test_load_state(model_and_optimizers): |
|
|
|
|
|
try: |
|
|
|
|
|
path = Path.cwd().joinpath(f"test_model_checkpoint") |
|
|
|
|
|
path.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
from fastNLP import Event, Callback |
|
|
|
|
|
@Trainer.on(Event.on_before_backward(every=3), marker='all') |
|
|
|
|
|
def print_outputs(*args): |
|
|
|
|
|
print("????") |
|
|
|
|
|
|
|
|
|
|
|
class StateCallback(Callback): |
|
|
|
|
|
def __init__(self, name): |
|
|
|
|
|
self.name = name |
|
|
|
|
|
|
|
|
|
|
|
def on_save_checkpoint(self, trainer): |
|
|
|
|
|
return {'name': self.name} |
|
|
|
|
|
|
|
|
|
|
|
def on_load_checkpoint(self, trainer, states): |
|
|
|
|
|
self.name = states['name'] |
|
|
|
|
|
|
|
|
|
|
|
def on_train_end(self, trainer): |
|
|
|
|
|
print(self.name) |
|
|
|
|
|
|
|
|
|
|
|
callbacks = [StateCallback('old_callback1'), StateCallback('old_callback2'), |
|
|
|
|
|
CheckpointCallback(folder=path, every_n_epochs=1, save_object='trainer')] |
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
|
|
|
model=model_and_optimizers.model, |
|
|
|
|
|
driver='torch', |
|
|
|
|
|
device='cpu', |
|
|
|
|
|
optimizers=model_and_optimizers.optimizers, |
|
|
|
|
|
train_dataloader=model_and_optimizers.train_dataloader, |
|
|
|
|
|
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, |
|
|
|
|
|
input_mapping=model_and_optimizers.input_mapping, |
|
|
|
|
|
output_mapping=model_and_optimizers.output_mapping, |
|
|
|
|
|
metrics=model_and_optimizers.metrics, |
|
|
|
|
|
n_epochs=3, |
|
|
|
|
|
callbacks=callbacks, |
|
|
|
|
|
output_from_new_proc="all" |
|
|
|
|
|
) |
|
|
|
|
|
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) |
|
|
|
|
|
|
|
|
|
|
|
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} |
|
|
|
|
|
epoch_2_path = all_saved_model_paths['trainer-epoch_2'] |
|
|
|
|
|
|
|
|
|
|
|
callbacks = [StateCallback('new_callback1'), StateCallback('new_callback2')] |
|
|
|
|
|
trainer = Trainer( |
|
|
|
|
|
model=model_and_optimizers.model, |
|
|
|
|
|
driver='torch', |
|
|
|
|
|
device='cpu', |
|
|
|
|
|
optimizers=model_and_optimizers.optimizers, |
|
|
|
|
|
train_dataloader=model_and_optimizers.train_dataloader, |
|
|
|
|
|
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, |
|
|
|
|
|
input_mapping=model_and_optimizers.input_mapping, |
|
|
|
|
|
output_mapping=model_and_optimizers.output_mapping, |
|
|
|
|
|
metrics=model_and_optimizers.metrics, |
|
|
|
|
|
n_epochs=3, |
|
|
|
|
|
callbacks=callbacks, |
|
|
|
|
|
output_from_new_proc="all" |
|
|
|
|
|
) |
|
|
|
|
|
trainer.load(folder=epoch_2_path) |
|
|
|
|
|
with Capturing() as output: |
|
|
|
|
|
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) |
|
|
|
|
|
|
|
|
|
|
|
assert 'old_callback1' in output[0] |
|
|
|
|
|
assert 'new_callback2' in output[0] |
|
|
|
|
|
assert output[0].count('???')==1 |
|
|
|
|
|
|
|
|
|
|
|
finally: |
|
|
|
|
|
rank_zero_rm(path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
@pytest.mark.torch |
|
|
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; |
|
|
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; |
|
|
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) |
|
|
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) |
|
|