Browse Source

简化paddle trainer的单卡测试例

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
b97962b8dd
1 changed files with 27 additions and 79 deletions
  1. +27
    -79
      tests/core/controllers/test_trainer_paddle.py

+ 27
- 79
tests/core/controllers/test_trainer_paddle.py View File

@@ -1,19 +1,20 @@
import pytest
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
from typing import Any
from dataclasses import dataclass

from paddle.optimizer import Adam
from paddle.io import DataLoader

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK

from paddle.optimizer import Adam
from paddle.io import DataLoader


from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context

@@ -48,64 +49,31 @@ class TrainerParameters:
output_mapping: Any = None
metrics: Any = None

# @pytest.fixture(params=[0], autouse=True)
# def model_and_optimizers(request):
# """
# 初始化单卡模式的模型和优化器
# """
# trainer_params = TrainerParameters()
# print(paddle.device.get_device())

# if request.param == 0:
# trainer_params.model = PaddleNormalModel_Classification(
# num_labels=MNISTTrainPaddleConfig.num_labels,
# feature_dimension=MNISTTrainPaddleConfig.feature_dimension
# )
# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
# train_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST("train"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# val_dataloader = DataLoader(
# dataset=PaddleDataset_MNIST(mode="test"),
# batch_size=MNISTTrainPaddleConfig.batch_size,
# shuffle=True
# )
# trainer_params.train_dataloader = train_dataloader
# trainer_params.validate_dataloaders = val_dataloader
# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
# trainer_params.metrics = {"acc": Accuracy()}

# return trainer_params


@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
@magic_argv_env_context
def test_trainer_paddle(
# model_and_optimizers: TrainerParameters,
driver,
device,
callbacks,
n_epochs=15,
n_epochs=2,
):
trainer_params = TrainerParameters()

trainer_params.model = PaddleNormalModel_Classification(
trainer_params.model = PaddleNormalModel_Classification_1(
num_labels=MNISTTrainPaddleConfig.num_labels,
feature_dimension=MNISTTrainPaddleConfig.feature_dimension
)
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader(
dataset=PaddleDataset_MNIST("train"),
dataset=PaddleRandomMaxDataset(6400, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleDataset_MNIST(mode="test"),
dataset=PaddleRandomMaxDataset(1000, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
shuffle=True
)
@@ -113,39 +81,19 @@ def test_trainer_paddle(
trainer_params.validate_dataloaders = val_dataloader
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as exc:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
assert exc.value.code == 0
return
else:
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()
trainer = Trainer(
model=trainer_params.model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,

n_epochs=n_epochs,
callbacks=callbacks,
)
trainer.run()

Loading…
Cancel
Save