|
|
@@ -19,18 +19,20 @@ for folder in list(folders[::-1]): |
|
|
|
path = os.sep.join(folders) |
|
|
|
sys.path.extend([path, os.path.join(path, 'fastNLP')]) |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from torch.optim import SGD |
|
|
|
import torch.distributed as dist |
|
|
|
from dataclasses import dataclass |
|
|
|
from torchmetrics import Accuracy |
|
|
|
|
|
|
|
from fastNLP.core.controllers.trainer import Trainer |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH |
|
|
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 |
|
|
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification |
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from torch.optim import SGD |
|
|
|
import torch.distributed as dist |
|
|
|
from torchmetrics import Accuracy |
|
|
|
|
|
|
|
@dataclass |
|
|
|
class NormalClassificationTrainTorchConfig: |
|
|
|