diff --git a/fastNLP/core/drivers/jittor_driver/utils.py b/fastNLP/core/drivers/jittor_driver/utils.py index 046603d0..c75526df 100644 --- a/fastNLP/core/drivers/jittor_driver/utils.py +++ b/fastNLP/core/drivers/jittor_driver/utils.py @@ -1,14 +1,61 @@ import inspect +import os +import random from copy import deepcopy from typing import Union +import numpy as np + from fastNLP.core.dataloaders import JittorDataLoader from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +from fastNLP.envs.utils import get_global_seed +from fastNLP.envs import ( + get_global_rank, + FASTNLP_BACKEND_LAUNCH, + FASTNLP_GLOBAL_SEED, +) +from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: + import jittor as jt from jittor.dataset import Dataset -__all__ = [] +__all__ = [ + "jittor_seed_everything", +] + +def jittor_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int: + r""" + 为 **jittor**、**numpy**、**python.random** 伪随机数生成器设置种子。 + + :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。 + :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。 + 当设置为 ``True`` 时,**FastNLP** 会将种子加上当前的 ``global_rank``。 + """ + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min + + if seed is None: + if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1": + seed = 42 + else: + seed = get_global_seed() + logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.") + if not isinstance(seed, int): + seed = int(seed) + + if not (min_seed_value <= seed <= max_seed_value): + logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.") + seed %= max_seed_value + + os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}" + if add_global_rank_to_seed: + seed += get_global_rank() + + random.seed(seed) + np.random.seed(seed) + jt.set_global_seed(seed) + return seed def replace_batch_sampler(dataloader, batch_sampler): raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, " diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 342ae8f2..98c07495 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -71,7 +71,6 @@ from .paddle_driver import PaddleDriver from .fleet_launcher import FleetLauncher from .utils import ( _FleetWrappingModel, - reset_seed, replace_sampler, replace_batch_sampler, ) @@ -238,7 +237,6 @@ class PaddleFleetDriver(PaddleDriver): # dist.get_world_size() 只能在初始化之后进行调用; self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) - reset_seed() logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") if not parallel_helper._is_parallel_ctx_initialized(): fleet.init(self.role_maker, self.is_collective, self.strategy) diff --git a/fastNLP/core/drivers/paddle_driver/fleet_launcher.py b/fastNLP/core/drivers/paddle_driver/fleet_launcher.py index 4df795ef..f6d99333 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet_launcher.py +++ b/fastNLP/core/drivers/paddle_driver/fleet_launcher.py @@ -15,7 +15,6 @@ from fastNLP.envs.env import ( from fastNLP.core.utils import get_paddle_device_id from .utils import ( find_free_ports, - reset_seed, ) __all__ = [] @@ -62,7 +61,6 @@ class FleetLauncher: # 设置环境变量 self.global_envs = self.get_global_env() self.open_subprocess() - reset_seed() def open_subprocess(self): """ diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index a3fde3af..4527f1ed 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -13,7 +13,6 @@ from fastNLP.core.drivers.driver import Driver from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device from fastNLP.core.utils.paddle_utils import _convert_data_device from fastNLP.envs import ( - FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_GLOBAL_RANK, @@ -396,7 +395,7 @@ class PaddleDriver(Driver): random.seed(stdlib_seed) def set_deterministic_dataloader(self, dataloader): - if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: + if dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 1a324c97..b1815fbd 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -9,8 +9,13 @@ from contextlib import ExitStack, closing from typing import Dict, Optional from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +from fastNLP.envs.utils import get_global_seed +from fastNLP.envs import ( + get_global_rank, + FASTNLP_BACKEND_LAUNCH, + FASTNLP_GLOBAL_SEED, +) from fastNLP.core.utils import auto_param_call, paddle_to -from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS from fastNLP.core.log import logger @@ -28,64 +33,40 @@ __all__ = [ "paddle_seed_everything", ] -def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: - return random.randint(min_seed_value, max_seed_value) - -def paddle_seed_everything(seed: Optional[int], workers: bool = False) -> int: +def paddle_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int: r""" 为 **paddle**、**numpy**、**python.random** 伪随机数生成器设置种子。 - :param seed: 全局随机状态的整数值种子。如果为 ``None``,将从环境变量 ``FASTNLP_GLOBAL_SEED`` 中读取种子或随机选择; - :param workers: 如果为 ``True`` ,则会设置环境变量 ``FASTNLP_SEED_WORKERS`` 。该环境变量会在 :class:`~fastNLP.core.Trainer` - 中配置 ``dataloader`` 时用于设置 ``worker_init_fn`` 。如果用户已经为 ``dataloader`` 提供了 ``worker_init_fn`` ,则设置 - 此参数将没有影响; + :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。 + :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。 + 当设置为 ``True`` 时,**FastNLP** 会将种子加上当前的 ``global_rank``。 """ - max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min if seed is None: - env_seed = os.environ.get("GLOBAL_SEED") - if env_seed is None: - seed = _select_seed_randomly(min_seed_value, max_seed_value) - # rank_zero_warn(f"No seed found, seed set to {seed}") + if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1": + seed = 42 else: - try: - seed = int(env_seed) - except ValueError: - seed = _select_seed_randomly(min_seed_value, max_seed_value) - # rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") - elif not isinstance(seed, int): + seed = get_global_seed() + logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.") + if not isinstance(seed, int): seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): - logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for " - "you.") + logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.") + seed %= max_seed_value - # rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") - seed = _select_seed_randomly(min_seed_value, max_seed_value) + os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}" + if add_global_rank_to_seed: + seed += get_global_rank() - # using `log.info` instead of `rank_zero_info`, - # so users can verify the seed is properly set in distributed training. - # log.info(f"Global seed set to {seed}") - os.environ[FASTNLP_GLOBAL_SEED] = str(seed) random.seed(seed) np.random.seed(seed) # paddle的seed函数会自行判断是否在gpu环境,如果在的话会设置gpu的种子 paddle.seed(seed) - os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" return seed -def reset_seed() -> None: - """ - ``fleet`` 会开启多个进程,因此当用户在脚本中指定 ``seed_everything`` 时,在开启多个脚本后,会在每个脚本内重新 - 进行随机数的设置; - """ - seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) - workers = os.environ.get(FASTNLP_SEED_WORKERS, "0") - if seed is not None: - paddle_seed_everything(int(seed), workers=bool(int(workers))) - class _FleetWrappingModel(Layer): """ 参考 :class:`fastNLP.core.drivers.torch_driver.utils._DDPWrappingModel` , **PaddlePaddle** 的分布式训练也需要用 :class:`paddle.nn.DataParallel` 进行包装,采用和 diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 364c3a0b..43c6bc36 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -148,7 +148,6 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import ( _DDPWrappingModel, - reset_seed, replace_sampler, replace_batch_sampler ) @@ -339,7 +338,6 @@ class TorchDDPDriver(TorchDriver): # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; self.world_size = int(os.environ.get("WORLD_SIZE")) self.global_rank = int(os.environ.get("RANK")) - reset_seed() logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") if not dist.is_initialized(): @@ -359,7 +357,6 @@ class TorchDDPDriver(TorchDriver): self.world_size = len(self.parallel_device) self.open_subprocess() self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的; - reset_seed() dist.init_process_group( backend="nccl", rank=self.global_rank, world_size=self.world_size ) diff --git a/fastNLP/core/drivers/torch_driver/fairscale.py b/fastNLP/core/drivers/torch_driver/fairscale.py index ece78f5e..02dda6a6 100644 --- a/fastNLP/core/drivers/torch_driver/fairscale.py +++ b/fastNLP/core/drivers/torch_driver/fairscale.py @@ -18,7 +18,7 @@ if _NEED_IMPORT_FAIRSCALE: from fairscale.nn.wrap import auto_wrap, enable_wrap, default_auto_wrap_policy from ...log import logger -from .utils import reset_seed, _DDPWrappingModel +from .utils import _DDPWrappingModel from .ddp import TorchDDPDriver from .torch_driver import TorchDriver @@ -114,7 +114,6 @@ class FairScaleDriver(TorchDDPDriver): # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; self.world_size = int(os.environ.get("WORLD_SIZE")) self.global_rank = int(os.environ.get("RANK")) - reset_seed() logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") if not dist.is_initialized(): @@ -129,7 +128,6 @@ class FairScaleDriver(TorchDDPDriver): self.world_size = len(self.parallel_device) self.open_subprocess() self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的; - reset_seed() dist.init_process_group( backend="nccl", rank=self.global_rank, world_size=self.world_size ) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 841e6614..93f607d6 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -4,6 +4,8 @@ from functools import partial import numpy as np import random from dataclasses import dataclass + +from py import process from fastNLP.envs.imports import _NEED_IMPORT_TORCH from pathlib import Path if _NEED_IMPORT_TORCH: @@ -28,7 +30,7 @@ from fastNLP.core.drivers.driver import Driver from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call -from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME +from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler @@ -370,7 +372,7 @@ class TorchDriver(Driver): random.seed(stdlib_seed) def set_deterministic_dataloader(self, dataloader: "DataLoader"): - if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: + if dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(self.worker_init_function, rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))) diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 14f5b9f3..2d13a8e8 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -8,7 +8,15 @@ import numpy as np import inspect from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs.utils import get_global_seed +from fastNLP.envs import ( + get_global_rank, + FASTNLP_BACKEND_LAUNCH, + FASTNLP_GLOBAL_SEED, +) from fastNLP.core.samplers import re_instantiate_sampler +from fastNLP.core.utils import auto_param_call +from fastNLP.core.log import logger if _NEED_IMPORT_TORCH: import torch @@ -25,64 +33,41 @@ __all__ = [ 'optimizer_state_to_device' ] -from fastNLP.core.utils import auto_param_call -from fastNLP.envs import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS -from fastNLP.core.log import logger - - -def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: - return random.randint(min_seed_value, max_seed_value) - - -def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: +def torch_seed_everything(seed: int = None, add_global_rank_to_seed: bool = True) -> int: r""" - 为伪随机数生成器设置种子的函数:pytorch、numpy、python.random 另外, - 设置以下环境变量: + 为 **torch**、**numpy**、**python.random** 伪随机数生成器设置种子。 - :param seed: 全局随机状态的整数值种子。如果为“无”,将从 "FASTNLP_GLOBAL_SEED" 环境变量中读取种子或随机选择。 - :param workers: 如果设置为“True”,将正确配置所有传递给带有“worker_init_fn”的培训师。如果用户已经提供了这样的功能对于他们的数据加载器, - 设置此参数将没有影响; + :param seed: 全局随机状态的整数值种子。如果为 ``None`` 则会根据时间戳生成一个种子。 + :param add_global_rank_to_seed: 在分布式训练中,是否在不同 **rank** 中使用不同的随机数。 + 当设置为 ``True`` 时,**FastNLP** 会将种子加上当前的 ``global_rank``。 """ max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min if seed is None: - env_seed = os.environ.get(FASTNLP_GLOBAL_SEED) - if env_seed is None: - seed = _select_seed_randomly(min_seed_value, max_seed_value) + if os.getenv(FASTNLP_BACKEND_LAUNCH) == "1": + seed = 42 else: - try: - seed = int(env_seed) - except ValueError: - seed = _select_seed_randomly(min_seed_value, max_seed_value) - # rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") - elif not isinstance(seed, int): + seed = get_global_seed() + logger.info(f"'FASTNLP_GLOBAL_SEED' is set to {seed} automatically.") + if not isinstance(seed, int): seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): - logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") + logger.rank_zero_warning("Your seed value is too big or too small for numpy, we will choose a random seed for you.") + seed %= max_seed_value - seed = _select_seed_randomly(min_seed_value, max_seed_value) + os.environ[FASTNLP_GLOBAL_SEED] = f"{seed}" + if add_global_rank_to_seed: + seed += get_global_rank() random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" return seed -def reset_seed() -> None: - r""" - 这个函数主要是给 ddp 用的,因为 ddp 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 - 进行随机数的设置; - """ - seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) - workers = os.environ.get(FASTNLP_SEED_WORKERS, "0") - if seed is not None: - torch_seed_everything(int(seed), workers=bool(int(workers))) - - class ForwardState(IntEnum): TRAIN = 0 VALIDATE = 1 diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index edb8a67f..f522f997 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -4,14 +4,16 @@ __all__ = [ "RandomBatchSampler" ] +import os import math from copy import deepcopy -from typing import Dict, Union, List, Sequence +from typing import Dict, Union, List from itertools import chain import numpy as np from fastNLP.core.dataset import DataSet +from fastNLP.envs.utils import get_global_seed from fastNLP.core.log import logger from .utils import create_array from abc import abstractmethod @@ -169,7 +171,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, - drop_last: bool = False, seed: int = 0, **kwargs): + drop_last: bool = False, seed: int = None, **kwargs): super().__init__() self.dataset = dataset @@ -177,7 +179,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last - self.seed = seed + self.seed = get_global_seed() if seed is None else seed self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 @@ -396,7 +398,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, - shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): + shuffle: bool = True, drop_last: bool = False, seed: int = None, **kwargs): super().__init__() if isinstance(dataset, DataSet) and isinstance(length, str): length = dataset.get_field(length).content @@ -421,7 +423,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.num_batch_per_bucket = num_batch_per_bucket self.shuffle = shuffle self.drop_last = drop_last - self.seed = seed + self.seed = get_global_seed() if seed is None else seed self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index fe38a808..dc396851 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -12,6 +12,7 @@ import numpy as np from fastNLP.core.log import logger from fastNLP.core.dataset import DataSet +from fastNLP.envs.utils import get_global_seed class ReproducibleSampler: @@ -65,11 +66,11 @@ class RandomSampler(ReproducibleSampler): :param seed: 随机数种子。 :param kwargs: 用户不需要使用,fastNLP 内部使用 """ - def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): + def __init__(self, dataset, shuffle: bool = True, seed: int = None, **kwargs): super(RandomSampler, self).__init__() self.dataset = dataset self.shuffle = shuffle - self.seed = seed + self.seed = get_global_seed() if seed is None else seed self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index e94215a6..22207274 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -7,6 +7,7 @@ __all__ = [ from typing import List, Union from fastNLP.core.dataset import DataSet +from fastNLP.envs.utils import get_global_seed import numpy as np @@ -27,10 +28,10 @@ class UnrepeatedRandomSampler(UnrepeatedSampler): :param seed: 设置的随机数种子 :param kwargs: fastNLP 保留使用 """ - def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): + def __init__(self, dataset, shuffle: bool = False, seed: int = None, **kwargs): self.dataset = dataset self.shuffle = shuffle - self.seed = seed + self.seed = get_global_seed() if seed is None else seed # 多卡的相关的参数 self.num_replicas = kwargs.get('num_replicas', 1) diff --git a/fastNLP/envs/env.py b/fastNLP/envs/env.py index 9cc05a02..12dcc392 100644 --- a/fastNLP/envs/env.py +++ b/fastNLP/envs/env.py @@ -30,9 +30,6 @@ FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" # FASTNLP_GLOBAL_SEED 用于每个子进程随机数种子的正确设置; FASTNLP_GLOBAL_SEED = "FASTNLP_GLOBAL_SEED" -# FASTNLP_SEED_WORKERS 用于 pytorch dataloader work_init_fn 的正确的设置; -FASTNLP_SEED_WORKERS = "FASTNLP_SEED_WORKERS" - # 用于设置 fastNLP 使用的 backend 框架 FASTNLP_BACKEND = 'FASTNLP_BACKEND' diff --git a/fastNLP/envs/utils.py b/fastNLP/envs/utils.py index 541bfba7..ff5663f2 100644 --- a/fastNLP/envs/utils.py +++ b/fastNLP/envs/utils.py @@ -1,3 +1,6 @@ +import os +import math +import time from importlib.util import find_spec from typing import Callable import importlib @@ -79,4 +82,13 @@ def get_gpu_count() -> int: # 经分割后还要除去头部和尾部的换行符 return len(lines.split(b"\n")) - 2 except: - return -1 \ No newline at end of file + return -1 + +def get_global_seed(): + seed = os.getenv("FASTNLP_GLOBAL_SEED", None) + if seed is not None: + return int(seed) + seed = int(math.modf(time.time())[0] * 1000000) + os.environ["FASTNLP_GLOBAL_SEED"] = f"{seed}" + + return seed