|
|
@@ -138,32 +138,12 @@ def set_env(global_seed=None): |
|
|
|
backend = os.environ.get(FASTNLP_BACKEND, '') |
|
|
|
if backend == 'paddle': |
|
|
|
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." |
|
|
|
if os.environ.get(FASTNLP_GLOBAL_SEED, None) is not None: |
|
|
|
seed_paddle_global_seed(int(os.environ.get(FASTNLP_GLOBAL_SEED))) |
|
|
|
|
|
|
|
if backend == 'jittor': |
|
|
|
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." |
|
|
|
if os.environ.get(FASTNLP_GLOBAL_SEED, None) is not None: |
|
|
|
seed_jittor_global_seed(int(os.environ.get(FASTNLP_GLOBAL_SEED))) |
|
|
|
|
|
|
|
if backend == 'torch': |
|
|
|
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." |
|
|
|
if os.environ.get(FASTNLP_GLOBAL_SEED, None) is not None: |
|
|
|
seed_torch_global_seed(int(os.environ.get(FASTNLP_GLOBAL_SEED))) |
|
|
|
|
|
|
|
|
|
|
|
def seed_torch_global_seed(global_seed): |
|
|
|
# @yxg |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def seed_paddle_global_seed(global_seed): |
|
|
|
# @xsh |
|
|
|
pass |
|
|
|
|
|
|
|
def seed_jittor_global_seed(global_seed): |
|
|
|
# @xsh |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def dump_fastnlp_backend(default:bool = False, backend=None): |
|
|
|