You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

runner.py 12 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import argparse
  2. import yaml
  3. import os
  4. import signal
  5. import multiprocessing
  6. import subprocess
  7. import paramiko
  8. import socket
  9. import psutil
  10. import hetu as ht
  11. _procs = []
  12. def signal_handler(signal, frame):
  13. print("SIGINT signal caught, stop Training")
  14. for proc in _procs:
  15. proc.kill()
  16. global executor_shell
  17. executor_shell.kill()
  18. exit(0)
  19. def start_sched():
  20. os.environ["DMLC_ROLE"] = "scheduler"
  21. ht.scheduler_init()
  22. ht.scheduler_finish()
  23. def start_server():
  24. os.environ["DMLC_ROLE"] = "server"
  25. ht.server_init()
  26. ht.server_finish()
  27. def start_remote_server(host, local_server_num, identify_file):
  28. ssh_directory = os.path.expanduser('~/.ssh') if identify_file == '' else os.path.dirname(
  29. os.path.abspath(os.path.expanduser(identify_file)))
  30. ssh = paramiko.SSHClient()
  31. ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  32. private = paramiko.RSAKey.from_private_key_file(
  33. os.path.join(ssh_directory, 'id_rsa'))
  34. config = paramiko.config.SSHConfig.from_path(
  35. os.path.join(ssh_directory, 'config'))
  36. conf = config.lookup(host)
  37. ssh.connect(hostname=conf['hostname'], port=conf['port'],
  38. username=conf['user'], pkey=private)
  39. sftp = ssh.open_sftp()
  40. sftp.put('/tmp/temp_hetu_config.yml',
  41. '/tmp/temp_hetu_config.yml', confirm=True)
  42. sftp.close()
  43. stdin, stdout, stderr = ssh.exec_command(
  44. 'python -m hetu.launcher /tmp/temp_hetu_config.yml -n %d' % local_server_num)
  45. stdout = stdout.read().decode()
  46. stderr = stderr.read().decode()
  47. if stdout:
  48. print('From remote %s stdout:\n %s' % (host, stdout.strip()))
  49. if stderr:
  50. print('From remote %s stderr:\n %s' % (host, stderr.strip()))
  51. ssh.close()
  52. def get_available_port(localhost):
  53. ports = set()
  54. for conn in psutil.net_connections():
  55. la = conn.laddr
  56. ra = conn.raddr
  57. if len(la) == 2 and la.ip in (localhost, '127.0.0.1'):
  58. ports.add(la.port)
  59. if len(ra) == 2 and ra.ip in (localhost, '127.0.0.1'):
  60. ports.add(ra.port)
  61. for p in range(13100, 13200):
  62. if p not in ports:
  63. return p
  64. def get_nic_names(local_address, remote_hostnames, identify_file):
  65. # get local interface
  66. nics = dict()
  67. for iface, addrs in psutil.net_if_addrs().items():
  68. for addr in addrs:
  69. if addr.family == socket.AF_INET:
  70. nics[addr.address] = iface
  71. local_nic = nics[local_address]
  72. # get remote interfaces
  73. command_prefix = "\"from socket import AF_INET;\nfrom psutil import net_if_addrs;\n" +\
  74. "nics = dict();\nfor iface, addrs in net_if_addrs().items():\n for addr in addrs:" +\
  75. "\n if addr.family == AF_INET:\n nics[addr.address] = iface;\n"
  76. ssh_directory = os.path.expanduser('~/.ssh') if identify_file == '' else os.path.dirname(
  77. os.path.abspath(os.path.expanduser(identify_file)))
  78. remote_nics = set()
  79. for hostname in remote_hostnames:
  80. ssh = paramiko.SSHClient()
  81. ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  82. private = paramiko.RSAKey.from_private_key_file(
  83. os.path.join(ssh_directory, 'id_rsa'))
  84. config = paramiko.config.SSHConfig.from_path(
  85. os.path.join(ssh_directory, 'config'))
  86. conf = config.lookup(hostname)
  87. command = command_prefix + "print(nics[\'%s\'])\"" % (conf['hostname'])
  88. ssh.connect(hostname=conf['hostname'], port=conf['port'],
  89. username=conf['user'], pkey=private)
  90. stdin, stdout, stderr = ssh.exec_command('python -c %s' % command)
  91. stdout = stdout.read().decode()
  92. stderr = stderr.read().decode()
  93. remote_nics.add(stdout.strip())
  94. if stderr:
  95. print('From remote %s stderr:\n %s' % (hostname, stderr.strip()))
  96. ssh.close()
  97. remote_nics.add(local_nic)
  98. return list(remote_nics)
  99. def get_subnet(local_address, remote_hostnames, identify_file=''):
  100. ssh_directory = os.path.expanduser('~/.ssh') if identify_file == '' else os.path.dirname(
  101. os.path.abspath(os.path.expanduser(identify_file)))
  102. config = paramiko.config.SSHConfig.from_path(
  103. os.path.join(ssh_directory, 'config'))
  104. remote_address = [config.lookup(hostname)['hostname']
  105. for hostname in remote_hostnames]
  106. remote_address.append(local_address)
  107. address_pool = set()
  108. for addr in remote_address:
  109. binary_repr = int(''.join([format(int(part), '08b')
  110. for part in addr.split('.')]), 2)
  111. address_pool.add(format(binary_repr+1, '032b'))
  112. address_pool.add(format(binary_repr-1, '032b'))
  113. address_pool = list(address_pool)
  114. longestCommonPrefix = 0
  115. for item in zip(*address_pool):
  116. if len(set(item)) > 1:
  117. break
  118. longestCommonPrefix += 1
  119. if longestCommonPrefix > 30:
  120. longestCommonPrefix = 30
  121. assert longestCommonPrefix >= 16, 'Hosts not in the same subnet!'
  122. commonAddress = address_pool[0][:longestCommonPrefix] + \
  123. '0' * (32 - longestCommonPrefix)
  124. parts = [commonAddress[:8], commonAddress[8:16],
  125. commonAddress[16:24], commonAddress[24:]]
  126. subnet = '.'.join([str(int(part, 2))
  127. for part in parts]) + '/%d' % longestCommonPrefix
  128. return subnet
  129. def main():
  130. signal.signal(signal.SIGINT, signal_handler)
  131. parser = argparse.ArgumentParser()
  132. parser.add_argument('-c', '--config', required=True,
  133. help='Configuration file.')
  134. parser.add_argument('-i', '--identify', default='',
  135. help='SSH identify file.')
  136. parser.add_argument('command', nargs=argparse.REMAINDER,
  137. help='Command to be executed.')
  138. args = parser.parse_args()
  139. settings = yaml.load(open(args.config).read(), Loader=yaml.FullLoader)
  140. attributes = set(['host', 'servers', 'workers', 'chief'])
  141. hosts = []
  142. servers, workers = {}, {}
  143. chief = None
  144. chief_address = socket.gethostbyname(socket.gethostname())
  145. port = get_available_port(chief_address)
  146. for node in settings['nodes']:
  147. assert set(node.keys(
  148. )) <= attributes, 'Attributes of nodes invalid, %s / %s.' % (set(node.keys()), attributes)
  149. hosts.append(node['host'])
  150. if node.get('servers', 0):
  151. servers[node['host']] = node['servers']
  152. if node.get('workers', 0):
  153. workers[node['host']] = node['workers']
  154. if node.get('chief', False):
  155. assert chief is None, 'There should be only one chief.'
  156. chief = node['host']
  157. assert chief, 'There should be one chief.'
  158. num_servers = sum(servers.values())
  159. num_workers = sum(workers.values())
  160. enable_PS = (num_servers > 0)
  161. print('Cluster: {')
  162. print(' Chief: %s,' % chief)
  163. print(' Servers(%d): %s,' % (num_servers, servers))
  164. print(' Workers(%d): %s,' % (num_workers, workers))
  165. print('}')
  166. if enable_PS:
  167. os.environ['DMLC_PS_ROOT_URI'] = chief_address
  168. os.environ['DMLC_PS_ROOT_PORT'] = str(port)
  169. os.environ['DMLC_PS_VAN_TYPE'] = 'p3'
  170. os.environ['DMLC_NUM_SERVER'] = str(num_servers)
  171. os.environ['DMLC_NUM_WORKER'] = str(num_workers)
  172. global executor_shell
  173. if len(hosts) == 1:
  174. # single machine
  175. # TODO: add hostdress validation check
  176. if enable_PS:
  177. proc = multiprocessing.Process(target=start_sched)
  178. _procs.append(proc)
  179. for i in range(num_servers):
  180. proc = multiprocessing.Process(target=start_server)
  181. _procs.append(proc)
  182. for proc in _procs:
  183. proc.start()
  184. mpi_command = 'mpirun --allow-run-as-root --tag-output -np %d %s' % (
  185. num_workers, ' '.join(args.command))
  186. env = dict(os.environ)
  187. if enable_PS:
  188. env["DMLC_ROLE"] = "worker"
  189. executor_shell = subprocess.Popen(
  190. mpi_command, shell=True, env=env, stdout=None, stderr=None)
  191. for proc in _procs:
  192. proc.join()
  193. executor_shell.wait()
  194. else:
  195. # multi machines
  196. #! nic names not used currently, use subnets instead; nccl_socket_name please specified in /etc/bash.bashrc
  197. #! nic methods cannot support different nic name on different machines
  198. # nics = get_nic_names(chief_address, set(hosts) - {chief}, args.identify)
  199. # joined_nics = ','.join(nics)
  200. subnet = get_subnet(chief_address, set(hosts) - {chief}, args.identify)
  201. if enable_PS:
  202. with open('/tmp/temp_hetu_config.yml', 'w') as fw:
  203. yaml.dump({'shared': {'DMLC_PS_ROOT_URI': chief_address, 'DMLC_PS_ROOT_PORT': port,
  204. 'DMLC_NUM_WORKER': num_workers, 'DMLC_NUM_SERVER': num_servers, 'DMLC_PS_VAN_TYPE': 'p3'}}, fw)
  205. proc = multiprocessing.Process(target=start_sched)
  206. _procs.append(proc)
  207. for node in hosts:
  208. if node == chief:
  209. for i in range(servers.get(node, 0)):
  210. proc = multiprocessing.Process(target=start_server)
  211. _procs.append(proc)
  212. else:
  213. if servers.get(node, 0):
  214. proc = multiprocessing.Process(target=start_remote_server, args=[
  215. node, servers[node], args.identify])
  216. _procs.append(proc)
  217. for proc in _procs:
  218. proc.start()
  219. basic_args = '--allow-run-as-root --tag-output'
  220. hosts_in_command = ','.join(
  221. ['%s:%d' % (node, nworkers) for node, nworkers in workers.items()])
  222. mpi_ssh_args = '' if args.identify == '' else '-bootstrap=ssh -bootstrap-exec-args -i %s' % args.identify
  223. tcp_intf_arg = '-mca btl_tcp_if_include %s' % subnet
  224. # tcp_intf_arg = '-mca btl_tcp_if_include %s' % joined_nics
  225. # nccl_socket_intf_arg = '-x NCCL_SOCKET_IFNAME=%s' % joined_nics
  226. env_list = '-x DMLC_PS_ROOT_URI=%s -x DMLC_PS_ROOT_PORT=%s -x DMLC_PS_VAN_TYPE=p3 -x DMLC_NUM_SERVER=%s -x DMLC_NUM_WORKER=%s -x DMLC_ROLE=worker' %\
  227. (chief_address, str(port), str(num_servers),
  228. str(num_workers)) if enable_PS else ''
  229. mpi_command = (
  230. 'mpirun {basic_args} '
  231. '--host {hosts} '
  232. '{mpi_ssh_args} '
  233. '{tcp_intf_arg} '
  234. # '{nccl_socket_intf_arg} '
  235. '{env} '
  236. '{command}'
  237. .format(basic_args=basic_args,
  238. hosts=hosts_in_command,
  239. mpi_ssh_args=mpi_ssh_args,
  240. tcp_intf_arg=tcp_intf_arg,
  241. # nccl_socket_intf_arg=nccl_socket_intf_arg,
  242. env=env_list,
  243. command=' '.join(args.command))
  244. )
  245. executor_shell = subprocess.Popen(
  246. mpi_command, shell=True, stdout=None, stderr=None)
  247. for proc in _procs:
  248. proc.join()
  249. executor_shell.wait()
  250. if __name__ == '__main__':
  251. #! need to modify /etc/bash.bashrc on other machines for:
  252. # * specify NCCL_SOCKET_IFNAME
  253. # * specify PATH for mpirun support
  254. # * activate conda environment
  255. # * specify PYTHONPATH for hetu support
  256. #! ssh process to other machines for server CANNOT receive SIGINT from Ctrl+C on this machine, please kill on other machines
  257. main()