Browse Source

small bug

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
9b0a30c8fb
2 changed files with 3 additions and 4 deletions
  1. +1
    -2
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +2
    -2
      tests/core/drivers/paddle_driver/test_dist_utils.py

+ 1
- 2
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -1,7 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Union, Optional, Dict from typing import Union, Optional, Dict
from contextlib import nullcontext
from dataclasses import dataclass from dataclasses import dataclass


from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils import apply_to_collection, nullcontext
from fastNLP.envs import ( from fastNLP.envs import (
FASTNLP_MODEL_FILENAME, FASTNLP_MODEL_FILENAME,
FASTNLP_CHECKPOINT_FILENAME, FASTNLP_CHECKPOINT_FILENAME,


+ 2
- 2
tests/core/drivers/paddle_driver/test_dist_utils.py View File

@@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast:


@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
devices = [0,1,2]
devices = [0,1]
output_from_new_proc = "all" output_from_new_proc = "all"


launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc) launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc)
@@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast:
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1, 2]))
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_fastnlp_paddle_broadcast_object(self, src_rank): def test_fastnlp_paddle_broadcast_object(self, src_rank):
if self.local_rank == src_rank: if self.local_rank == src_rank:
obj = { obj = {


Loading…
Cancel
Save