Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
6d889c8624
4 changed files with 25 additions and 7 deletions
  1. +1
    -2
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +18
    -2
      fastNLP/core/samplers/reproducible_batch_sampler.py
  3. +4
    -1
      fastNLP/core/samplers/reproducible_sampler.py
  4. +2
    -2
      tests/core/drivers/paddle_driver/test_dist_utils.py

+ 1
- 2
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -1,7 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Union, Optional, Dict from typing import Union, Optional, Dict
from contextlib import nullcontext
from dataclasses import dataclass from dataclasses import dataclass


from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@@ -9,7 +8,7 @@ from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils import apply_to_collection, nullcontext
from fastNLP.envs import ( from fastNLP.envs import (
FASTNLP_MODEL_FILENAME, FASTNLP_MODEL_FILENAME,
FASTNLP_CHECKPOINT_FILENAME, FASTNLP_CHECKPOINT_FILENAME,


+ 18
- 2
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -318,7 +318,15 @@ class RandomBatchSampler(ReproducibleBatchSampler):


@property @property
def num_samples(self): def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
"""
返回样本的总数

:return:
"""
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len


def __len__(self)->int: def __len__(self)->int:
""" """
@@ -473,7 +481,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler):


@property @property
def num_samples(self): def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
"""
返回样本的总数

:return:
"""
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len


def __len__(self)->int: def __len__(self)->int:
""" """


+ 4
- 1
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -222,7 +222,10 @@ class RandomSampler(ReproducibleSampler):


:return: :return:
""" """
return getattr(self.dataset, 'total_len', len(self.dataset))
total_len = getattr(self.dataset, 'total_len', None)
if not isinstance(total_len, int):
total_len = len(self.dataset)
return total_len


class SequentialSampler(RandomSampler): class SequentialSampler(RandomSampler):
""" """


+ 2
- 2
tests/core/drivers/paddle_driver/test_dist_utils.py View File

@@ -84,7 +84,7 @@ class TestAllGatherAndBroadCast:


@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
devices = [0,1,2]
devices = [0,1]
output_from_new_proc = "all" output_from_new_proc = "all"


launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc) launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc)
@@ -150,7 +150,7 @@ class TestAllGatherAndBroadCast:
dist.barrier() dist.barrier()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1, 2]))
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_fastnlp_paddle_broadcast_object(self, src_rank): def test_fastnlp_paddle_broadcast_object(self, src_rank):
if self.local_rank == src_rank: if self.local_rank == src_rank:
obj = { obj = {


Loading…
Cancel
Save