|
- import argparse
- import yaml
- import os
- import signal
- import multiprocessing
- import subprocess
- import paramiko
- import socket
- import psutil
- import hetu as ht
-
- _procs = []
-
-
- def signal_handler(signal, frame):
- print("SIGINT signal caught, stop Training")
- for proc in _procs:
- proc.kill()
- global executor_shell
- executor_shell.kill()
- exit(0)
-
-
- def start_sched():
- os.environ["DMLC_ROLE"] = "scheduler"
- ht.scheduler_init()
- ht.scheduler_finish()
-
-
- def start_server():
- os.environ["DMLC_ROLE"] = "server"
- ht.server_init()
- ht.server_finish()
-
-
- def start_remote_server(host, local_server_num, identify_file):
- ssh_directory = os.path.expanduser('~/.ssh') if identify_file == '' else os.path.dirname(
- os.path.abspath(os.path.expanduser(identify_file)))
- ssh = paramiko.SSHClient()
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- private = paramiko.RSAKey.from_private_key_file(
- os.path.join(ssh_directory, 'id_rsa'))
- config = paramiko.config.SSHConfig.from_path(
- os.path.join(ssh_directory, 'config'))
- conf = config.lookup(host)
- ssh.connect(hostname=conf['hostname'], port=conf['port'],
- username=conf['user'], pkey=private)
- sftp = ssh.open_sftp()
- sftp.put('/tmp/temp_hetu_config.yml',
- '/tmp/temp_hetu_config.yml', confirm=True)
- sftp.close()
- stdin, stdout, stderr = ssh.exec_command(
- 'python -m hetu.launcher /tmp/temp_hetu_config.yml -n %d' % local_server_num)
- stdout = stdout.read().decode()
- stderr = stderr.read().decode()
- if stdout:
- print('From remote %s stdout:\n %s' % (host, stdout.strip()))
- if stderr:
- print('From remote %s stderr:\n %s' % (host, stderr.strip()))
- ssh.close()
-
-
- def get_available_port(localhost):
- ports = set()
- for conn in psutil.net_connections():
- la = conn.laddr
- ra = conn.raddr
- if len(la) == 2 and la.ip in (localhost, '127.0.0.1'):
- ports.add(la.port)
- if len(ra) == 2 and ra.ip in (localhost, '127.0.0.1'):
- ports.add(ra.port)
- for p in range(13100, 13200):
- if p not in ports:
- return p
-
-
- def get_nic_names(local_address, remote_hostnames, identify_file):
- # get local interface
- nics = dict()
- for iface, addrs in psutil.net_if_addrs().items():
- for addr in addrs:
- if addr.family == socket.AF_INET:
- nics[addr.address] = iface
- local_nic = nics[local_address]
-
- # get remote interfaces
- command_prefix = "\"from socket import AF_INET;\nfrom psutil import net_if_addrs;\n" +\
- "nics = dict();\nfor iface, addrs in net_if_addrs().items():\n for addr in addrs:" +\
- "\n if addr.family == AF_INET:\n nics[addr.address] = iface;\n"
- ssh_directory = os.path.expanduser('~/.ssh') if identify_file == '' else os.path.dirname(
- os.path.abspath(os.path.expanduser(identify_file)))
- remote_nics = set()
- for hostname in remote_hostnames:
- ssh = paramiko.SSHClient()
- ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- private = paramiko.RSAKey.from_private_key_file(
- os.path.join(ssh_directory, 'id_rsa'))
- config = paramiko.config.SSHConfig.from_path(
- os.path.join(ssh_directory, 'config'))
- conf = config.lookup(hostname)
- command = command_prefix + "print(nics[\'%s\'])\"" % (conf['hostname'])
- ssh.connect(hostname=conf['hostname'], port=conf['port'],
- username=conf['user'], pkey=private)
- stdin, stdout, stderr = ssh.exec_command('python -c %s' % command)
- stdout = stdout.read().decode()
- stderr = stderr.read().decode()
- remote_nics.add(stdout.strip())
- if stderr:
- print('From remote %s stderr:\n %s' % (hostname, stderr.strip()))
- ssh.close()
-
- remote_nics.add(local_nic)
- return list(remote_nics)
-
-
- def get_subnet(local_address, remote_hostnames, identify_file=''):
- ssh_directory = os.path.expanduser('~/.ssh') if identify_file == '' else os.path.dirname(
- os.path.abspath(os.path.expanduser(identify_file)))
- config = paramiko.config.SSHConfig.from_path(
- os.path.join(ssh_directory, 'config'))
- remote_address = [config.lookup(hostname)['hostname']
- for hostname in remote_hostnames]
- remote_address.append(local_address)
- address_pool = set()
- for addr in remote_address:
- binary_repr = int(''.join([format(int(part), '08b')
- for part in addr.split('.')]), 2)
- address_pool.add(format(binary_repr+1, '032b'))
- address_pool.add(format(binary_repr-1, '032b'))
- address_pool = list(address_pool)
- longestCommonPrefix = 0
- for item in zip(*address_pool):
- if len(set(item)) > 1:
- break
- longestCommonPrefix += 1
- if longestCommonPrefix > 30:
- longestCommonPrefix = 30
- assert longestCommonPrefix >= 16, 'Hosts not in the same subnet!'
- commonAddress = address_pool[0][:longestCommonPrefix] + \
- '0' * (32 - longestCommonPrefix)
- parts = [commonAddress[:8], commonAddress[8:16],
- commonAddress[16:24], commonAddress[24:]]
- subnet = '.'.join([str(int(part, 2))
- for part in parts]) + '/%d' % longestCommonPrefix
- return subnet
-
-
- def main():
- signal.signal(signal.SIGINT, signal_handler)
- parser = argparse.ArgumentParser()
- parser.add_argument('-c', '--config', required=True,
- help='Configuration file.')
- parser.add_argument('-i', '--identify', default='',
- help='SSH identify file.')
- parser.add_argument('command', nargs=argparse.REMAINDER,
- help='Command to be executed.')
- args = parser.parse_args()
- settings = yaml.load(open(args.config).read(), Loader=yaml.FullLoader)
- attributes = set(['host', 'servers', 'workers', 'chief'])
- hosts = []
- servers, workers = {}, {}
- chief = None
- chief_address = socket.gethostbyname(socket.gethostname())
- port = get_available_port(chief_address)
- for node in settings['nodes']:
- assert set(node.keys(
- )) <= attributes, 'Attributes of nodes invalid, %s / %s.' % (set(node.keys()), attributes)
- hosts.append(node['host'])
- if node.get('servers', 0):
- servers[node['host']] = node['servers']
- if node.get('workers', 0):
- workers[node['host']] = node['workers']
- if node.get('chief', False):
- assert chief is None, 'There should be only one chief.'
- chief = node['host']
- assert chief, 'There should be one chief.'
- num_servers = sum(servers.values())
- num_workers = sum(workers.values())
- enable_PS = (num_servers > 0)
- print('Cluster: {')
- print(' Chief: %s,' % chief)
- print(' Servers(%d): %s,' % (num_servers, servers))
- print(' Workers(%d): %s,' % (num_workers, workers))
- print('}')
- if enable_PS:
- os.environ['DMLC_PS_ROOT_URI'] = chief_address
- os.environ['DMLC_PS_ROOT_PORT'] = str(port)
- os.environ['DMLC_PS_VAN_TYPE'] = 'p3'
- os.environ['DMLC_NUM_SERVER'] = str(num_servers)
- os.environ['DMLC_NUM_WORKER'] = str(num_workers)
-
- global executor_shell
- if len(hosts) == 1:
- # single machine
- # TODO: add hostdress validation check
- if enable_PS:
- proc = multiprocessing.Process(target=start_sched)
- _procs.append(proc)
- for i in range(num_servers):
- proc = multiprocessing.Process(target=start_server)
- _procs.append(proc)
- for proc in _procs:
- proc.start()
- mpi_command = 'mpirun --allow-run-as-root --tag-output -np %d %s' % (
- num_workers, ' '.join(args.command))
- env = dict(os.environ)
- if enable_PS:
- env["DMLC_ROLE"] = "worker"
- executor_shell = subprocess.Popen(
- mpi_command, shell=True, env=env, stdout=None, stderr=None)
- for proc in _procs:
- proc.join()
- executor_shell.wait()
- else:
- # multi machines
-
- #! nic names not used currently, use subnets instead; nccl_socket_name please specified in /etc/bash.bashrc
- #! nic methods cannot support different nic name on different machines
- # nics = get_nic_names(chief_address, set(hosts) - {chief}, args.identify)
- # joined_nics = ','.join(nics)
- subnet = get_subnet(chief_address, set(hosts) - {chief}, args.identify)
- if enable_PS:
- with open('/tmp/temp_hetu_config.yml', 'w') as fw:
- yaml.dump({'shared': {'DMLC_PS_ROOT_URI': chief_address, 'DMLC_PS_ROOT_PORT': port,
- 'DMLC_NUM_WORKER': num_workers, 'DMLC_NUM_SERVER': num_servers, 'DMLC_PS_VAN_TYPE': 'p3'}}, fw)
- proc = multiprocessing.Process(target=start_sched)
- _procs.append(proc)
- for node in hosts:
- if node == chief:
- for i in range(servers.get(node, 0)):
- proc = multiprocessing.Process(target=start_server)
- _procs.append(proc)
- else:
- if servers.get(node, 0):
- proc = multiprocessing.Process(target=start_remote_server, args=[
- node, servers[node], args.identify])
- _procs.append(proc)
- for proc in _procs:
- proc.start()
- basic_args = '--allow-run-as-root --tag-output'
- hosts_in_command = ','.join(
- ['%s:%d' % (node, nworkers) for node, nworkers in workers.items()])
- mpi_ssh_args = '' if args.identify == '' else '-bootstrap=ssh -bootstrap-exec-args -i %s' % args.identify
- tcp_intf_arg = '-mca btl_tcp_if_include %s' % subnet
- # tcp_intf_arg = '-mca btl_tcp_if_include %s' % joined_nics
- # nccl_socket_intf_arg = '-x NCCL_SOCKET_IFNAME=%s' % joined_nics
- env_list = '-x DMLC_PS_ROOT_URI=%s -x DMLC_PS_ROOT_PORT=%s -x DMLC_PS_VAN_TYPE=p3 -x DMLC_NUM_SERVER=%s -x DMLC_NUM_WORKER=%s -x DMLC_ROLE=worker' %\
- (chief_address, str(port), str(num_servers),
- str(num_workers)) if enable_PS else ''
- mpi_command = (
- 'mpirun {basic_args} '
- '--host {hosts} '
- '{mpi_ssh_args} '
- '{tcp_intf_arg} '
- # '{nccl_socket_intf_arg} '
- '{env} '
- '{command}'
- .format(basic_args=basic_args,
- hosts=hosts_in_command,
- mpi_ssh_args=mpi_ssh_args,
- tcp_intf_arg=tcp_intf_arg,
- # nccl_socket_intf_arg=nccl_socket_intf_arg,
- env=env_list,
- command=' '.join(args.command))
- )
- executor_shell = subprocess.Popen(
- mpi_command, shell=True, stdout=None, stderr=None)
- for proc in _procs:
- proc.join()
- executor_shell.wait()
-
-
- if __name__ == '__main__':
- #! need to modify /etc/bash.bashrc on other machines for:
- # * specify NCCL_SOCKET_IFNAME
- # * specify PATH for mpirun support
- # * activate conda environment
- # * specify PYTHONPATH for hetu support
- #! ssh process to other machines for server CANNOT receive SIGINT from Ctrl+C on this machine, please kill on other machines
- main()
|