Browse Source

简化paddle trainer的单卡测试例

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
b97962b8dd
1 changed files with 27 additions and 79 deletions
  1. +27
    -79
      tests/core/controllers/test_trainer_paddle.py

+ 27
- 79
tests/core/controllers/test_trainer_paddle.py View File

@@ -1,19 +1,20 @@
import pytest import pytest
import os import os
os.environ["FASTNLP_BACKEND"] = "paddle"
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass


from paddle.optimizer import Adam
from paddle.io import DataLoader

from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK


from paddle.optimizer import Adam
from paddle.io import DataLoader



from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context


@@ -48,64 +49,31 @@ class TrainerParameters:
output_mapping: Any = None output_mapping: Any = None
metrics: Any = None metrics: Any = None


# @pytest.fixture(params=[0], autouse=True)
# def model_and_optimizers(request):
# """
# 初始化单卡模式的模型和优化器
# """
# trainer_params = TrainerParameters()
# print(paddle.device.get_device())

# if request.param == 0:
# trainer_params.model = PaddleNormalModel_Classification(
# num_labels=MNISTTrainPaddleConfig.num_labels,
# feature_dimension=MNISTTrainPaddleConfig.feature_dimension
# )
# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
# train_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST("train"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# val_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST(mode="test"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# trainer_params.train_dataloader = train_dataloader
# trainer_params.validate_dataloaders = val_dataloader
# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
# trainer_params.metrics = {"acc": Accuracy()}

# return trainer_params


@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
@magic_argv_env_context @magic_argv_env_context
def test_trainer_paddle( def test_trainer_paddle(
# model_and_optimizers: TrainerParameters,
driver, driver,
device, device,
callbacks, callbacks,
n_epochs=15,
n_epochs=2,
): ):
trainer_params = TrainerParameters() trainer_params = TrainerParameters()


trainer_params.model = PaddleNormalModel_Classification(
trainer_params.model = PaddleNormalModel_Classification_1(
num_labels=MNISTTrainPaddleConfig.num_labels, num_labels=MNISTTrainPaddleConfig.num_labels,
feature_dimension=MNISTTrainPaddleConfig.feature_dimension feature_dimension=MNISTTrainPaddleConfig.feature_dimension
) )
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader( train_dataloader = DataLoader(
dataset=PaddleDataset_MNIST("train"),
dataset=PaddleRandomMaxDataset(6400, 10),
batch_size=MNISTTrainPaddleConfig.batch_size, batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True shuffle=True
) )
val_dataloader = DataLoader( val_dataloader = DataLoader(
dataset=PaddleDataset_MNIST(mode="test"),
dataset=PaddleRandomMaxDataset(1000, 10),
batch_size=MNISTTrainPaddleConfig.batch_size, batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True shuffle=True
) )
@@ -113,39 +81,19 @@ def test_trainer_paddle(
trainer_params.validate_dataloaders = val_dataloader trainer_params.validate_dataloaders = val_dataloader
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
trainer_params.metrics = {"acc": Accuracy(backend="paddle")} trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as exc:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
assert exc.value.code == 0
return
else:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()

Loading…
Cancel
Save