From b97962b8ddd3e80dd8c4c95955e36c380c846e81 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Thu, 14 Apr 2022 16:04:56 +0000 Subject: [PATCH] =?UTF-8?q?=E7=AE=80=E5=8C=96paddle=20trainer=E7=9A=84?= =?UTF-8?q?=E5=8D=95=E5=8D=A1=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/controllers/test_trainer_paddle.py | 106 +++++------------- 1 file changed, 27 insertions(+), 79 deletions(-) diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index ed102c99..03ffd1ca 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -1,19 +1,20 @@ import pytest import os +os.environ["FASTNLP_BACKEND"] = "paddle" from typing import Any from dataclasses import dataclass -from paddle.optimizer import Adam -from paddle.io import DataLoader - from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.progress_callback import RichCallback from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK +from paddle.optimizer import Adam +from paddle.io import DataLoader + -from tests.helpers.models.paddle_model import PaddleNormalModel_Classification -from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context @@ -48,64 +49,31 @@ class TrainerParameters: output_mapping: Any = None metrics: Any = None -# @pytest.fixture(params=[0], autouse=True) -# def model_and_optimizers(request): -# """ -# 初始化单卡模式的模型和优化器 -# """ -# trainer_params = TrainerParameters() -# print(paddle.device.get_device()) - -# if request.param == 0: -# trainer_params.model = PaddleNormalModel_Classification( -# num_labels=MNISTTrainPaddleConfig.num_labels, -# feature_dimension=MNISTTrainPaddleConfig.feature_dimension -# ) -# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) -# train_dataloader = DataLoader( -# dataset=PaddleDataset_MNIST("train"), -# batch_size=MNISTTrainPaddleConfig.batch_size, -# shuffle=True -# ) -# val_dataloader = DataLoader( -# dataset=PaddleDataset_MNIST(mode="test"), -# batch_size=MNISTTrainPaddleConfig.batch_size, -# shuffle=True -# ) -# trainer_params.train_dataloader = train_dataloader -# trainer_params.validate_dataloaders = val_dataloader -# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every -# trainer_params.metrics = {"acc": Accuracy()} - -# return trainer_params - - -@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) +@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) @magic_argv_env_context def test_trainer_paddle( - # model_and_optimizers: TrainerParameters, driver, device, callbacks, - n_epochs=15, + n_epochs=2, ): trainer_params = TrainerParameters() - trainer_params.model = PaddleNormalModel_Classification( + trainer_params.model = PaddleNormalModel_Classification_1( num_labels=MNISTTrainPaddleConfig.num_labels, feature_dimension=MNISTTrainPaddleConfig.feature_dimension ) trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( - dataset=PaddleDataset_MNIST("train"), + dataset=PaddleRandomMaxDataset(6400, 10), batch_size=MNISTTrainPaddleConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( - dataset=PaddleDataset_MNIST(mode="test"), + dataset=PaddleRandomMaxDataset(1000, 10), batch_size=MNISTTrainPaddleConfig.batch_size, shuffle=True ) @@ -113,39 +81,19 @@ def test_trainer_paddle( trainer_params.validate_dataloaders = val_dataloader trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every trainer_params.metrics = {"acc": Accuracy(backend="paddle")} - if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as exc: - trainer = Trainer( - model=trainer_params.model, - driver=driver, - device=device, - optimizers=trainer_params.optimizers, - train_dataloader=trainer_params.train_dataloader, - validate_dataloaders=trainer_params.validate_dataloaders, - validate_every=trainer_params.validate_every, - input_mapping=trainer_params.input_mapping, - output_mapping=trainer_params.output_mapping, - metrics=trainer_params.metrics, - - n_epochs=n_epochs, - callbacks=callbacks, - ) - assert exc.value.code == 0 - return - else: - trainer = Trainer( - model=trainer_params.model, - driver=driver, - device=device, - optimizers=trainer_params.optimizers, - train_dataloader=trainer_params.train_dataloader, - validate_dataloaders=trainer_params.validate_dataloaders, - validate_every=trainer_params.validate_every, - input_mapping=trainer_params.input_mapping, - output_mapping=trainer_params.output_mapping, - metrics=trainer_params.metrics, - - n_epochs=n_epochs, - callbacks=callbacks, - ) - trainer.run() \ No newline at end of file + trainer = Trainer( + model=trainer_params.model, + driver=driver, + device=device, + optimizers=trainer_params.optimizers, + train_dataloader=trainer_params.train_dataloader, + validate_dataloaders=trainer_params.validate_dataloaders, + validate_every=trainer_params.validate_every, + input_mapping=trainer_params.input_mapping, + output_mapping=trainer_params.output_mapping, + metrics=trainer_params.metrics, + + n_epochs=n_epochs, + callbacks=callbacks, + ) + trainer.run() \ No newline at end of file