|
- import os
- import pytest
- from dataclasses import dataclass
-
- from fastNLP.core.controllers.trainer import Trainer
- from fastNLP.core.metrics.accuracy import Accuracy
- from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW
-
- if _NEED_IMPORT_ONEFLOW:
- from oneflow.optim import Adam
- from oneflow.utils.data import DataLoader
-
-
- from tests.helpers.models.oneflow_model import OneflowNormalModel_Classification_1
- from tests.helpers.datasets.oneflow_data import OneflowArgMaxDataset
- from tests.helpers.utils import magic_argv_env_context
-
- @dataclass
- class TrainOneflowConfig:
- num_labels: int = 3
- feature_dimension: int = 3
-
- batch_size: int = 2
- shuffle: bool = True
- evaluate_every = 2
-
- @pytest.mark.parametrize("device", ["cpu", 1])
- @pytest.mark.parametrize("callbacks", [[]])
- @pytest.mark.oneflow
- @magic_argv_env_context
- def test_trainer_oneflow(
- device,
- callbacks,
- n_epochs=2,
- ):
- model = OneflowNormalModel_Classification_1(
- num_labels=TrainOneflowConfig.num_labels,
- feature_dimension=TrainOneflowConfig.feature_dimension
- )
- optimizers = Adam(params=model.parameters(), lr=0.0001)
- train_dataloader = DataLoader(
- dataset=OneflowArgMaxDataset(20, TrainOneflowConfig.feature_dimension),
- batch_size=TrainOneflowConfig.batch_size,
- shuffle=True
- )
- val_dataloader = DataLoader(
- dataset=OneflowArgMaxDataset(12, TrainOneflowConfig.feature_dimension),
- batch_size=TrainOneflowConfig.batch_size,
- shuffle=True
- )
- train_dataloader = train_dataloader
- evaluate_dataloaders = val_dataloader
- evaluate_every = TrainOneflowConfig.evaluate_every
- metrics = {"acc": Accuracy()}
- trainer = Trainer(
- model=model,
- driver="oneflow",
- device=device,
- optimizers=optimizers,
- train_dataloader=train_dataloader,
- evaluate_dataloaders=evaluate_dataloaders,
- evaluate_every=evaluate_every,
- input_mapping=None,
- output_mapping=None,
- metrics=metrics,
-
- n_epochs=n_epochs,
- callbacks=callbacks,
- )
- trainer.run()
|