You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_launch_server.py 1.1 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import tensorflow as tf
  3. import multiprocessing
  4. import signal
  5. import json
  6. import argparse
  7. def pop_env():
  8. for k in ['https_proxy', 'http_proxy']:
  9. if k in os.environ:
  10. os.environ.pop(k)
  11. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  12. pop_env()
  13. def start_server(cluster, task_id):
  14. server = tf.train.Server(cluster, job_name='ps', task_index=task_id)
  15. server.join()
  16. def main():
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument(
  19. "--config", type=str, default='./settings/tf_dist_s4_w2.json', help="config file path")
  20. parser.add_argument("--id", type=int, required=True)
  21. args = parser.parse_args()
  22. raw_config = args.config
  23. config = json.load(open(raw_config))
  24. cluster = tf.train.ClusterSpec(config)
  25. global proc
  26. proc = multiprocessing.Process(
  27. target=start_server, args=[cluster, args.id, ])
  28. proc.start()
  29. signal.signal(signal.SIGINT, signal_handler)
  30. proc.join()
  31. def signal_handler(signal, frame):
  32. print("SIGINT signal caught, stop Training")
  33. global proc
  34. proc.kill()
  35. exit(0)
  36. if __name__ == '__main__':
  37. main()