Browse Source

small bug

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
9b0a30c8fb
2 changed files with 3 additions and 4 deletions
  1. +1
    -2
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +2
    -2
      tests/core/drivers/paddle_driver/test_dist_utils.py

+ 1
- 2
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -1,7 +1,6 @@
import os
from pathlib import Path
from typing import Union, Optional, Dict
from contextlib import nullcontext
from dataclasses import dataclass

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils import apply_to_collection, nullcontext
from fastNLP.envs import (
FASTNLP_MODEL_FILENAME,
FASTNLP_CHECKPOINT_FILENAME,


+ 2
- 2
tests/core/drivers/paddle_driver/test_dist_utils.py View File

@@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast:

@classmethod
def setup_class(cls):
devices = [0,1,2]
devices = [0,1]
output_from_new_proc = "all"

launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc)
@@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast:
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1, 2]))
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_fastnlp_paddle_broadcast_object(self, src_rank):
if self.local_rank == src_rank:
obj = {


Loading…
Cancel
Save