|
- """
- 这个文件测试用户以python -m paddle.distributed.launch 启动的情况
- 看看有没有用pytest执行的机会
- python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py
- """
- import os
- os.environ["FASTNLP_BACKEND"] = "paddle"
- import sys
- sys.path.append("../../../")
-
- from dataclasses import dataclass
-
- from fastNLP.core.controllers.trainer import Trainer
- from fastNLP.core.metrics.accuracy import Accuracy
- from fastNLP.core.callbacks.progress_callback import RichCallback
- from fastNLP.core.callbacks import Callback
-
- import paddle
- from paddle.optimizer import Adam
- from paddle.io import DataLoader
-
- from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
- from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
- from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback
-
- @dataclass
- class MNISTTrainFleetConfig:
- num_labels: int = 10
- feature_dimension: int = 10
-
- batch_size: int = 32
- shuffle: bool = True
- validate_every = -1
-
- def test_trainer_fleet(
- driver,
- device,
- callbacks,
- n_epochs,
- ):
- model = PaddleNormalModel_Classification_1(
- num_labels=MNISTTrainFleetConfig.num_labels,
- feature_dimension=MNISTTrainFleetConfig.feature_dimension
- )
- optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
-
- train_dataloader = DataLoader(
- dataset=PaddleRandomMaxDataset(6400, MNISTTrainFleetConfig.feature_dimension),
- batch_size=MNISTTrainFleetConfig.batch_size,
- shuffle=True
- )
- val_dataloader = DataLoader(
- dataset=PaddleRandomMaxDataset(1280, MNISTTrainFleetConfig.feature_dimension),
- batch_size=MNISTTrainFleetConfig.batch_size,
- shuffle=True
- )
- train_dataloader = train_dataloader
- validate_dataloaders = val_dataloader
- validate_every = MNISTTrainFleetConfig.validate_every
- metrics = {"acc": Accuracy()}
- trainer = Trainer(
- model=model,
- driver=driver,
- device=device,
- optimizers=optimizers,
- train_dataloader=train_dataloader,
- validate_dataloaders=validate_dataloaders,
- validate_every=validate_every,
- input_mapping=None,
- output_mapping=None,
- metrics=metrics,
-
- n_epochs=n_epochs,
- callbacks=callbacks,
- output_from_new_proc="logs",
- )
- trainer.run()
-
- if __name__ == "__main__":
- driver = "fleet"
- device = [0,2,3]
- # driver = "paddle"
- # device = 2
- callbacks = [
- # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
- RichCallback(5),
- ]
- test_trainer_fleet(
- driver=driver,
- device=device,
- callbacks=callbacks,
- n_epochs=5,
- )
|