- import os
- import pytest
- from fastNLP.envs.distributed import rank_zero_call, all_rank_call_context
- from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context
- @rank_zero_call
- def write_something():
- print(os.environ.get('RANK', '0')*5, flush=True)
- def write_other_thing():
- print(os.environ.get('RANK', '0')*5, flush=True)
- class PaddleTest:
- # @x54-729
- def test_rank_zero_call(self):
- pass
- def test_all_rank_run(self):
- pass
- class JittorTest:
- # @x54-729
- def test_rank_zero_call(self):
- pass
- def test_all_rank_run(self):
- pass
- @pytest.mark.torch
- class TestTorch:
- @magic_argv_env_context
- def test_rank_zero_call(self):
- os.environ['MASTER_ADDR'] = ''
- os.environ['MASTER_PORT'] = '29500'
- if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
- os.environ['LOCAL_RANK'] = '0'
- os.environ['RANK'] = '0'
- os.environ['WORLD_SIZE'] = '2'
- re_run_current_cmd_for_torch(1, output_from_new_proc='all')
- with Capturing() as output:
- write_something()
- output = output[0]
- if os.environ['LOCAL_RANK'] == '0':
- assert '00000' in output and '11111' not in output
- else:
- assert '00000' not in output and '11111' not in output
- with Capturing() as output:
- rank_zero_call(write_other_thing)()
- output = output[0]
- if os.environ['LOCAL_RANK'] == '0':
- assert '00000' in output and '11111' not in output
- else:
- assert '00000' not in output and '11111' not in output
- @magic_argv_env_context
- def test_all_rank_run(self):
- os.environ['MASTER_ADDR'] = ''
- os.environ['MASTER_PORT'] = '29500'
- if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
- os.environ['LOCAL_RANK'] = '0'
- os.environ['RANK'] = '0'
- os.environ['WORLD_SIZE'] = '2'
- re_run_current_cmd_for_torch(1, output_from_new_proc='all')
- # torch.distributed.init_process_group(backend='nccl')
- # torch.distributed.barrier()
- with all_rank_call_context():
- with Capturing(no_del=True) as output:
- write_something()
- output = output[0]
- if os.environ['LOCAL_RANK'] == '0':
- assert '00000' in output
- else:
- assert '11111' in output
- with all_rank_call_context():
- with Capturing(no_del=True) as output:
- rank_zero_call(write_other_thing)()
- output = output[0]
- if os.environ['LOCAL_RANK'] == '0':
- assert '00000' in output
- else:
- assert '11111' in output