Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
6d889c8624
4 changed files with 25 additions and 7 deletions
  1. +1
    -2
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +18
    -2
      fastNLP/core/samplers/reproducible_batch_sampler.py
  3. +4
    -1
      fastNLP/core/samplers/reproducible_sampler.py
  4. +2
    -2
      tests/core/drivers/paddle_driver/test_dist_utils.py

+ 1
- 2
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -1,7 +1,6 @@
import os
from pathlib import Path
from typing import Union, Optional, Dict
from contextlib import nullcontext
from dataclasses import dataclass

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils import apply_to_collection, nullcontext
from fastNLP.envs import (
FASTNLP_MODEL_FILENAME,
FASTNLP_CHECKPOINT_FILENAME,


+ 18
- 2
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler):

@property
def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
"""
返回样本的总数

:return:
"""
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len

def __len__(self)->int:
"""
@@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler):

@property
def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
"""
返回样本的总数

:return:
"""
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len

def __len__(self)->int:
"""


+ 4
- 1
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler):

:return:
"""
return getattr(self.dataset, 'total_len', len(self.dataset))
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len

class SequentialSampler(RandomSampler):
"""


+ 2
- 2
tests/core/drivers/paddle_driver/test_dist_utils.py View File

@@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast:

@classmethod
def setup_class(cls):
devices = [0,1,2]
devices = [0,1]
output_from_new_proc = "all"

launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc)
@@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast:
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1, 2]))
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_fastnlp_paddle_broadcast_object(self, src_rank):
if self.local_rank == src_rank:
obj = {


Loading…
Cancel
Save