|
- #!/usr/bin/env python
- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- import copy
- import os
- import pickle
- import shutil
- import socket
- import subprocess
- import sys
- import tarfile
- import tempfile
- import unittest
-
- import requests
- from datasets import Dataset
- from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE
-
- from modelscope.msdatasets import MsDataset
- from .torch_utils import _find_free_port
-
- TEST_LEVEL = 2
- TEST_LEVEL_STR = 'TEST_LEVEL'
-
-
- def test_level():
- global TEST_LEVEL
- if TEST_LEVEL_STR in os.environ:
- TEST_LEVEL = int(os.environ[TEST_LEVEL_STR])
-
- return TEST_LEVEL
-
-
- def require_tf(test_case):
- if not TF_AVAILABLE:
- test_case = unittest.skip('test requires TensorFlow')(test_case)
- return test_case
-
-
- def require_torch(test_case):
- if not TORCH_AVAILABLE:
- test_case = unittest.skip('test requires PyTorch')(test_case)
- return test_case
-
-
- def set_test_level(level: int):
- global TEST_LEVEL
- TEST_LEVEL = level
-
-
- def create_dummy_test_dataset(feat, label, num):
- return MsDataset.from_hf_dataset(
- Dataset.from_dict(dict(feat=[feat] * num, labels=[label] * num)))
-
-
- def download_and_untar(fpath, furl, dst) -> str:
- if not os.path.exists(fpath):
- r = requests.get(furl)
- with open(fpath, 'wb') as f:
- f.write(r.content)
-
- file_name = os.path.basename(fpath)
- root_dir = os.path.dirname(fpath)
- target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0]
- target_dir_path = os.path.join(root_dir, target_dir_name)
-
- # untar the file
- t = tarfile.open(fpath)
- t.extractall(path=dst)
-
- return target_dir_path
-
-
- _DIST_SCRIPT_TEMPLATE = """
- import ast
- import argparse
- import pickle
- import torch
- from torch import distributed as dist
- from modelscope.utils.torch_utils import get_dist_info
- import {}
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results')
- parser.add_argument('--save_file', type=str, help='save file')
- parser.add_argument('--local_rank', type=int, default=0)
- args = parser.parse_args()
-
-
- def main():
- results = {}.{}({}) # module.func(params)
- if args.save_all_ranks:
- save_file = args.save_file + str(dist.get_rank())
- with open(save_file, 'wb') as f:
- pickle.dump(results, f)
- else:
- rank, _ = get_dist_info()
- if rank == 0:
- with open(args.save_file, 'wb') as f:
- pickle.dump(results, f)
-
-
- if __name__ == '__main__':
- main()
- """
-
-
- class DistributedTestCase(unittest.TestCase):
- """Distributed TestCase for test function with distributed mode.
- Examples:
- import torch
- from torch import distributed as dist
- from modelscope.utils.torch_utils import init_dist
-
- def _test_func(*args, **kwargs):
- init_dist(launcher='pytorch')
- rank = dist.get_rank()
- if rank == 0:
- value = torch.tensor(1.0).cuda()
- else:
- value = torch.tensor(2.0).cuda()
- dist.all_reduce(value)
- return value.cpu().numpy()
-
- class DistTest(DistributedTestCase):
- def test_function_dist(self):
- args = () # args should be python builtin type
- kwargs = {} # kwargs should be python builtin type
- self.start(
- _test_func,
- num_gpus=2,
- assert_callback=lambda x: self.assertEqual(x, 3.0),
- *args,
- **kwargs,
- )
- """
-
- def _start(self,
- dist_start_cmd,
- func,
- num_gpus,
- assert_callback=None,
- save_all_ranks=False,
- *args,
- **kwargs):
- script_path = func.__code__.co_filename
- script_dir, script_name = os.path.split(script_path)
- script_name = os.path.splitext(script_name)[0]
- func_name = func.__qualname__
-
- func_params = []
- for arg in args:
- if isinstance(arg, str):
- arg = ('\'{}\''.format(arg))
- func_params.append(str(arg))
-
- for k, v in kwargs.items():
- if isinstance(v, str):
- v = ('\'{}\''.format(v))
- func_params.append('{}={}'.format(k, v))
-
- func_params = ','.join(func_params).strip(',')
-
- tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name
- tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name
-
- with open(tmp_run_file, 'w') as f:
- print('save temporary run file to : {}'.format(tmp_run_file))
- print('save results to : {}'.format(tmp_res_file))
- run_file_content = _DIST_SCRIPT_TEMPLATE.format(
- script_name, script_name, func_name, func_params)
- f.write(run_file_content)
-
- tmp_res_files = []
- if save_all_ranks:
- for i in range(num_gpus):
- tmp_res_files.append(tmp_res_file + str(i))
- else:
- tmp_res_files = [tmp_res_file]
- self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files)
-
- tmp_env = copy.deepcopy(os.environ)
- tmp_env['PYTHONPATH'] = ':'.join(
- (tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':')
- script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks,
- tmp_res_file)
- script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params)
- print('script command: %s' % script_cmd)
- res = subprocess.call(script_cmd, shell=True, env=tmp_env)
-
- script_res = []
- for res_file in tmp_res_files:
- with open(res_file, 'rb') as f:
- script_res.append(pickle.load(f))
- if not save_all_ranks:
- script_res = script_res[0]
-
- if assert_callback:
- assert_callback(script_res)
-
- self.assertEqual(
- res,
- 0,
- msg='The test function ``{}`` in ``{}`` run failed!'.format(
- func_name, script_name))
-
- return script_res
-
- def start(self,
- func,
- num_gpus,
- assert_callback=None,
- save_all_ranks=False,
- *args,
- **kwargs):
- ip = socket.gethostbyname(socket.gethostname())
- dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d --master_addr=\'%s\' --master_port=%s' % (
- sys.executable, num_gpus, ip, _find_free_port())
-
- return self._start(
- dist_start_cmd=dist_start_cmd,
- func=func,
- num_gpus=num_gpus,
- assert_callback=assert_callback,
- save_all_ranks=save_all_ranks,
- *args,
- **kwargs)
-
- def clean_tmp(self, tmp_file_list):
- for file in tmp_file_list:
- if os.path.exists(file):
- if os.path.isdir(file):
- shutil.rmtree(file)
- else:
- os.remove(file)
|