|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- import os
- import tensorflow as tf
- import multiprocessing
- import signal
- import json
- import argparse
-
-
- def pop_env():
- for k in ['https_proxy', 'http_proxy']:
- if k in os.environ:
- os.environ.pop(k)
- os.environ['CUDA_VISIBLE_DEVICES'] = ''
-
-
- pop_env()
-
-
- def start_server(cluster, task_id):
- server = tf.train.Server(cluster, job_name='ps', task_index=task_id)
- server.join()
-
-
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--config", type=str, default='./settings/tf_dist_s4_w2.json', help="config file path")
- parser.add_argument("--id", type=int, required=True)
- args = parser.parse_args()
- raw_config = args.config
- config = json.load(open(raw_config))
- cluster = tf.train.ClusterSpec(config)
- global proc
- proc = multiprocessing.Process(
- target=start_server, args=[cluster, args.id, ])
- proc.start()
- signal.signal(signal.SIGINT, signal_handler)
- proc.join()
-
-
- def signal_handler(signal, frame):
- print("SIGINT signal caught, stop Training")
- global proc
- proc.kill()
- exit(0)
-
-
- if __name__ == '__main__':
- main()
|