You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

distributed.py 24 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. # -*- coding: utf-8 -*-
  2. import json
  3. import os
  4. import time
  5. import tensorflow as tf
  6. from tensorflow.python.training import session_run_hook
  7. from tensorlayer import logging
  8. from tensorlayer.decorators import deprecated
  9. from tensorlayer.lazy_imports import LazyImport
  10. hvd = LazyImport('horovod.tensorflow')
  11. __all__ = ['TaskSpecDef', 'TaskSpec', 'DistributedSession', 'StopAtTimeHook', 'LoadCheckpoint', 'Trainer']
  12. class Trainer(object):
  13. """Trainer for neural networks in a distributed environment.
  14. TensorLayer Trainer is a high-level training interface built on top of TensorFlow MonitoredSession and
  15. `Horovod <https://github.com/uber/horovod>`__. It transparently scales the training of a TensorLayer model
  16. from a single GPU to multiple GPUs that be placed on different machines in a single cluster.
  17. To run the trainer, you will need to install Horovod on your machine. Check the installation script at
  18. `tensorlayer/scripts/download_and_install_openmpi3_ubuntu.sh`
  19. The minimal inputs to the Trainer include (1) a training dataset defined using the TensorFlow DataSet API,
  20. and (2) a model build function given the inputs of the training dataset, and returns the neural network
  21. to train, the loss function to minimize, and the names of the tensor to log during training, and (3)
  22. an optimizer and its arguments.
  23. The default parameter choices of Trainer is inspired by the Facebook paper:
  24. `Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour <https://arxiv.org/abs/1706.02677>`__
  25. Parameters
  26. ----------
  27. training_dataset : class TensorFlow ``DataSet``
  28. The training dataset which zips samples and labels. The trainer automatically
  29. shards the training dataset based on the number of GPUs.
  30. build_training_func : function
  31. A function that builds the training operator. It takes the training dataset as an input,
  32. and returns the neural network, the loss function and a dictionary that maps
  33. string tags to tensors to log during training.
  34. optimizer : class TensorFlow ``Optimizer``
  35. The loss function optimizer. The trainer automatically linearly scale the learning rate based on
  36. the number of GPUs.
  37. optimizer_args : dict
  38. The optimizer argument dictionary. It must contain a `learning_rate` field in type of float.
  39. Note that the learning rate is linearly scaled according to the number of GPU by default.
  40. You can disable it using the option `scaling_learning_rate`
  41. batch_size : int
  42. The training mini-batch size (i.e., number of samples per batch).
  43. prefetch_size: int or None
  44. The dataset prefetch buffer size. Set this parameter to overlap the GPU training and data preparation
  45. if the data preparation is heavy.
  46. checkpoint_dir : None or str
  47. The path to the TensorFlow model checkpoint. Note that only one trainer master would checkpoints its model.
  48. If None, checkpoint is disabled.
  49. log_step_size : int
  50. The trainer logs training information every N mini-batches (i.e., step size).
  51. validation_dataset: None or class TensorFlow ``DataSet``
  52. The optional validation dataset that zips samples and labels. Note that
  53. only the trainer master needs to the validation often.
  54. build_validation_func: None or function
  55. The function that builds the validation operator. It returns the validation neural network (which
  56. share the weights of the training network) and a custom number of validation metrics.
  57. scaling_learning_rate: Boolean
  58. Linearly scale the learning rate by the number of GPUs. Default is True.
  59. This `linear scaling rule` is generally effective and is highly recommended by the practioners.
  60. Check `Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour <https://arxiv.org/abs/1706.02677>`__
  61. max_iteration: int
  62. The maximum iteration (i.e., mini-batch) to train.
  63. The default is `math.inf`. You can set it to a small number to end the training earlier. This is
  64. usually set for testing purpose.
  65. Attributes
  66. ----------
  67. training_network : class TensorLayer ``Layer``
  68. The training model.
  69. session : class TensorFlow ``MonitoredTrainingSession``
  70. The training session tha the Trainer wraps.
  71. global_step : int
  72. The number of training mini-batch by far.
  73. validation_metrics : list of tuples
  74. The validation metrics that zips the validation metric property and the average value.
  75. Examples
  76. --------
  77. See `tutorial_mnist_distributed_trainer.py
  78. <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_distributed_trainer.py>`__.
  79. """
  80. def __init__(
  81. self, training_dataset, build_training_func, optimizer, optimizer_args, batch_size=32, prefetch_size=None,
  82. checkpoint_dir=None, scaling_learning_rate=True, log_step_size=1, validation_dataset=None,
  83. build_validation_func=None, max_iteration=float('inf')
  84. ):
  85. # Initialize Horovod.
  86. hvd.init()
  87. self.is_master = hvd.rank() == 0
  88. self._last_global_step = 0
  89. if prefetch_size is None:
  90. prefetch_size = batch_size
  91. # Define the loss for validation dataset
  92. if validation_dataset:
  93. validation_dataset = validation_dataset.shard(num_shards=hvd.size(), index=hvd.rank()).batch(batch_size)
  94. validation_dataset.prefetch(buffer_size=prefetch_size)
  95. self._validation_iterator = validation_dataset.make_initializable_iterator()
  96. next_example, next_label = self._validation_iterator.get_next()
  97. _, self._validation_metrics = build_validation_func(next_example, next_label)
  98. if not isinstance(self._validation_metrics, list):
  99. self._validation_metrics = list(self._validation_metrics)
  100. else:
  101. self._validation_iterator = None
  102. self._validation_metrics = None
  103. # Get the shard of the dataset based on my local rank
  104. training_dataset = training_dataset.shard(num_shards=hvd.size(), index=hvd.rank()).batch(batch_size)
  105. training_dataset.prefetch(buffer_size=prefetch_size)
  106. training_iterator = training_dataset.make_one_shot_iterator()
  107. self._training_network, loss, log_tensors = build_training_func(*training_iterator.get_next())
  108. # Adjust learning rate based on number of GPUs.
  109. lr = optimizer_args['learning_rate']
  110. optimizer_args['learning_rate'] = lr * hvd.size() if scaling_learning_rate else lr
  111. opt = optimizer(**optimizer_args)
  112. # Add Horovod Distributed Optimizer.
  113. opt = hvd.DistributedOptimizer(opt)
  114. self._global_step = tf.train.get_or_create_global_step()
  115. if isinstance(log_tensors, list):
  116. log_tensors.append(self._global_step)
  117. else:
  118. log_tensors['global_step'] = self._global_step
  119. self._train_op = opt.minimize(loss, global_step=self._global_step)
  120. hooks = [
  121. # Horovod: BroadcastGlobalVariablesHook broadcasts initial variable states
  122. # from rank 0 to all other processes. This is necessary to ensure consistent
  123. # initialization of all workers when training is started with random weights
  124. # or restored from a checkpoint.
  125. hvd.BroadcastGlobalVariablesHook(0),
  126. # Horovod: adjust number of steps based on number of GPUs.
  127. tf.train.StopAtStepHook(last_step=max_iteration // hvd.size()),
  128. tf.train.LoggingTensorHook(tensors=log_tensors, every_n_iter=log_step_size),
  129. ]
  130. # Pin GPU to be used to process local rank (one GPU per process)
  131. config = tf.ConfigProto()
  132. config.gpu_options.allow_growth = True
  133. config.gpu_options.visible_device_list = str(hvd.local_rank())
  134. # Save checkpoints only on worker 0 to prevent other workers from
  135. # corrupting them.
  136. checkpoint_dir = checkpoint_dir if self.is_master else None
  137. # The MonitoredTrainingSession takes care of session initialization,
  138. # restoring from a checkpoint, saving to a checkpoint, and closing when done
  139. # or an error occurs.
  140. self._sess = tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir, hooks=hooks, config=config)
  141. @property
  142. def global_step(self):
  143. if self._sess.should_stop():
  144. return self._last_global_step
  145. self._last_global_step = self._sess.run(self._global_step)
  146. return self._last_global_step
  147. @property
  148. def session(self):
  149. return self._sess
  150. @property
  151. def training_network(self):
  152. return self._training_network
  153. @property
  154. def validation_metrics(self):
  155. """A helper function to compute validation related metrics"""
  156. if (self._validation_iterator is None) or (self._validation_metrics is None):
  157. raise AttributeError('Validation is not setup.')
  158. n = 0.0
  159. metric_sums = [0.0] * len(self._validation_metrics)
  160. self._sess.run(self._validation_iterator.initializer)
  161. while True:
  162. try:
  163. metrics = self._sess.run(self._validation_metrics)
  164. for i, m in enumerate(metrics):
  165. metric_sums[i] += m
  166. n += 1.0
  167. except tf.errors.OutOfRangeError:
  168. break
  169. for i, m in enumerate(metric_sums):
  170. metric_sums[i] = metric_sums[i] / n
  171. return zip(self._validation_metrics, metric_sums)
  172. def train_on_batch(self):
  173. """Train a mini-batch."""
  174. self._sess.run(self._train_op)
  175. def train_and_validate_to_end(self, validate_step_size=50):
  176. """A helper function that shows how to train and validate a model at the same time.
  177. Parameters
  178. ----------
  179. validate_step_size : int
  180. Validate the training network every N steps.
  181. """
  182. while not self._sess.should_stop():
  183. self.train_on_batch() # Run a training step synchronously.
  184. if self.global_step % validate_step_size == 0:
  185. # logging.info("Average loss for validation dataset: %s" % self.get_validation_metrics())
  186. log_str = 'step: %d, ' % self.global_step
  187. for n, m in self.validation_metrics:
  188. log_str += '%s: %f, ' % (n.name, m)
  189. logging.info(log_str)
  190. @deprecated(date="2018-10-30", instructions="Using the TensorLayer distributed trainer.")
  191. class TaskSpecDef(object):
  192. """Specification for a distributed task.
  193. It contains the job name, index of the task,
  194. the parameter servers and the worker servers. If you want to use the last worker
  195. for continuous evaluation you can call the method `use_last_worker_as_evaluator`
  196. which returns a new :class:`TaskSpecDef` object without the last worker in the
  197. cluster specification.
  198. Parameters
  199. ----------
  200. task_type : str
  201. Task type. One of `master`, `worker` or `ps`.
  202. index : int
  203. The zero-based index of the task. Distributed training jobs will have a single
  204. master task, one or more parameter servers, and one or more workers.
  205. trial : int
  206. The identifier of the trial being run.
  207. ps_hosts : str OR list of str
  208. A string with a coma separate list of hosts for the parameter servers
  209. or a list of hosts.
  210. worker_hosts : str OR list of str
  211. A string with a coma separate list of hosts for the worker servers
  212. or a list of hosts.
  213. master : str
  214. A string with the master hosts
  215. Notes
  216. ----------
  217. master might not be included in TF_CONFIG and can be None. The shard_index is adjusted
  218. in any case to assign 0 to master and >= 1 to workers.
  219. This implementation doesn't support sparse arrays in the `TF_CONFIG` variable as the
  220. official TensorFlow documentation shows, as it is not a supported by the json
  221. definition.
  222. References
  223. ----------
  224. - `ML-engine trainer considerations <https://cloud.google.com/ml-engine/docs/trainer-considerations#use_tf_config>`__
  225. """
  226. def __init__(self, task_type='master', index=0, trial=None, ps_hosts=None, worker_hosts=None, master=None):
  227. self.type = task_type
  228. self._index = int(index)
  229. self._cluster_spec = None
  230. self.num_workers = 1
  231. self.num_ps = 0
  232. self.shard_index = int(index)
  233. self._master = True
  234. self.trial = trial
  235. self.ps_hosts = ps_hosts
  236. self.worker_hosts = worker_hosts
  237. self.master = master
  238. self._server = None
  239. if ps_hosts and worker_hosts:
  240. self.ps_hosts = ps_hosts if isinstance(ps_hosts, list) else ps_hosts.split(',')
  241. self.num_ps = len(self.ps_hosts)
  242. self.worker_hosts = worker_hosts if isinstance(worker_hosts, list) else worker_hosts.split(',')
  243. if master is not None and len(master) > 0:
  244. self._cluster_spec = tf.train.ClusterSpec(
  245. {
  246. 'ps': self.ps_hosts,
  247. 'worker': self.worker_hosts,
  248. 'master': master
  249. }
  250. )
  251. # master is a worker too
  252. self.num_workers = len(self.worker_hosts) + 1
  253. if self.type == 'worker':
  254. self.shard_index = self._index + 1
  255. self._master = self.type == 'master'
  256. else:
  257. self._cluster_spec = tf.train.ClusterSpec({'ps': self.ps_hosts, 'worker': self.worker_hosts})
  258. self.num_workers = len(self.worker_hosts)
  259. if self.type == 'worker':
  260. self.shard_index = self._index
  261. self._master = self.type == 'worker' and self._index == 0
  262. def is_ps(self):
  263. """Returns true if this server is a parameter server"""
  264. return self.type == 'ps'
  265. def is_worker(self):
  266. """Returns true if this server is a worker server"""
  267. return self.type == 'worker'
  268. def is_master(self):
  269. """Returns true if this server is the master server"""
  270. return self._master
  271. def is_evaluator(self):
  272. """Returns true if this server is the evaluator server"""
  273. return self.type == 'worker' and self.num_workers == self._index
  274. def device_fn(self):
  275. """Returns the function with the specification to create the graph in this server"""
  276. current_device = '/job:{}/task:{}'.format(self.type, self._index)
  277. ps_devices = '/job:ps'
  278. return tf.train.replica_device_setter(
  279. ps_device=ps_devices, worker_device=current_device, cluster=self._cluster_spec
  280. )
  281. def create_server(self):
  282. if self._server is None and self.ps_hosts and self.worker_hosts and not self.is_evaluator():
  283. # create server and join if it is a parameter server
  284. self._server = tf.train.Server(self._cluster_spec, job_name=self.type, task_index=self._index)
  285. if self.is_ps():
  286. self._server.join()
  287. def target(self):
  288. if self._server is None:
  289. self.create_server()
  290. if self._server is not None:
  291. return self._server.target
  292. else:
  293. return None
  294. def use_last_worker_as_evaluator(self):
  295. """Returns a new :class:`TaskSpecDef` where the last worker has been removed from
  296. the list of worker_hosts, so it is not used for training anymore. You can call
  297. is_evaluator to know whether this server is the evaluator one or not.
  298. In case there is only one server for training this method raises an exception, as
  299. you cannot use any server for evaluation.
  300. """
  301. if self.num_workers <= 1:
  302. raise Exception('You need more than one worker instance to use one as evaluator')
  303. return TaskSpecDef(
  304. task_type=self.type, index=self._index, trial=self.trial, ps_hosts=self.ps_hosts,
  305. worker_hosts=self.worker_hosts[:-1], master=self.master
  306. )
  307. @deprecated(date="2018-10-30", instructions="Using the TensorLayer distributed trainer.")
  308. def create_task_spec_def():
  309. """Returns the a :class:`TaskSpecDef` based on the environment variables for distributed training.
  310. References
  311. ----------
  312. - `ML-engine trainer considerations <https://cloud.google.com/ml-engine/docs/trainer-considerations#use_tf_config>`__
  313. - `TensorPort Distributed Computing <https://www.tensorport.com/documentation/code-details/>`__
  314. """
  315. if 'TF_CONFIG' in os.environ:
  316. # TF_CONFIG is used in ML-engine
  317. env = json.loads(os.environ.get('TF_CONFIG', '{}'))
  318. task_data = env.get('task', None) or {'type': 'master', 'index': 0}
  319. cluster_data = env.get('cluster', None) or {'ps': None, 'worker': None, 'master': None}
  320. return TaskSpecDef(
  321. task_type=task_data['type'], index=task_data['index'],
  322. trial=task_data['trial'] if 'trial' in task_data else None, ps_hosts=cluster_data['ps'],
  323. worker_hosts=cluster_data['worker'], master=cluster_data['master'] if 'master' in cluster_data else None
  324. )
  325. elif 'JOB_NAME' in os.environ:
  326. # JOB_NAME, TASK_INDEX, PS_HOSTS, WORKER_HOSTS and MASTER_HOST are used in TensorPort
  327. return TaskSpecDef(
  328. task_type=os.environ['JOB_NAME'], index=os.environ['TASK_INDEX'], ps_hosts=os.environ.get('PS_HOSTS', None),
  329. worker_hosts=os.environ.get('WORKER_HOSTS', None), master=os.environ.get('MASTER_HOST', None)
  330. )
  331. else:
  332. raise Exception('You need to setup TF_CONFIG or JOB_NAME to define the task.')
  333. @deprecated(date="2018-10-30", instructions="Using the TensorLayer distributed trainer.")
  334. def create_distributed_session(
  335. task_spec=None, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600,
  336. save_summaries_steps=object(), save_summaries_secs=object(), config=None, stop_grace_period_secs=120,
  337. log_step_count_steps=100
  338. ):
  339. """Creates a distributed session.
  340. It calls `MonitoredTrainingSession` to create a :class:`MonitoredSession` for distributed training.
  341. Parameters
  342. ----------
  343. task_spec : :class:`TaskSpecDef`.
  344. The task spec definition from create_task_spec_def()
  345. checkpoint_dir : str.
  346. Optional path to a directory where to restore variables.
  347. scaffold : ``Scaffold``
  348. A `Scaffold` used for gathering or building supportive ops.
  349. If not specified, a default one is created. It's used to finalize the graph.
  350. hooks : list of ``SessionRunHook`` objects.
  351. Optional
  352. chief_only_hooks : list of ``SessionRunHook`` objects.
  353. Activate these hooks if `is_chief==True`, ignore otherwise.
  354. save_checkpoint_secs : int
  355. The frequency, in seconds, that a checkpoint is saved
  356. using a default checkpoint saver. If `save_checkpoint_secs` is set to
  357. `None`, then the default checkpoint saver isn't used.
  358. save_summaries_steps : int
  359. The frequency, in number of global steps, that the
  360. summaries are written to disk using a default summary saver. If both
  361. `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
  362. the default summary saver isn't used. Default 100.
  363. save_summaries_secs : int
  364. The frequency, in secs, that the summaries are written
  365. to disk using a default summary saver. If both `save_summaries_steps` and
  366. `save_summaries_secs` are set to `None`, then the default summary saver
  367. isn't used. Default not enabled.
  368. config : ``tf.ConfigProto``
  369. an instance of `tf.ConfigProto` proto used to configure the session.
  370. It's the `config` argument of constructor of `tf.Session`.
  371. stop_grace_period_secs : int
  372. Number of seconds given to threads to stop after
  373. `close()` has been called.
  374. log_step_count_steps : int
  375. The frequency, in number of global steps, that the
  376. global step/sec is logged.
  377. Examples
  378. --------
  379. A simple example for distributed training where all the workers use the same dataset:
  380. >>> task_spec = TaskSpec()
  381. >>> with tf.device(task_spec.device_fn()):
  382. >>> tensors = create_graph()
  383. >>> with tl.DistributedSession(task_spec=task_spec,
  384. ... checkpoint_dir='/tmp/ckpt') as session:
  385. >>> while not session.should_stop():
  386. >>> session.run(tensors)
  387. An example where the dataset is shared among the workers
  388. (see https://www.tensorflow.org/programmers_guide/datasets):
  389. >>> task_spec = TaskSpec()
  390. >>> # dataset is a :class:`tf.data.Dataset` with the raw data
  391. >>> dataset = create_dataset()
  392. >>> if task_spec is not None:
  393. >>> dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
  394. >>> # shuffle or apply a map function to the new sharded dataset, for example:
  395. >>> dataset = dataset.shuffle(buffer_size=10000)
  396. >>> dataset = dataset.batch(batch_size)
  397. >>> dataset = dataset.repeat(num_epochs)
  398. >>> # create the iterator for the dataset and the input tensor
  399. >>> iterator = dataset.make_one_shot_iterator()
  400. >>> next_element = iterator.get_next()
  401. >>> with tf.device(task_spec.device_fn()):
  402. >>> # next_element is the input for the graph
  403. >>> tensors = create_graph(next_element)
  404. >>> with tl.DistributedSession(task_spec=task_spec,
  405. ... checkpoint_dir='/tmp/ckpt') as session:
  406. >>> while not session.should_stop():
  407. >>> session.run(tensors)
  408. References
  409. ----------
  410. - `MonitoredTrainingSession <https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession>`__
  411. """
  412. target = task_spec.target() if task_spec is not None else None
  413. is_chief = task_spec.is_master() if task_spec is not None else True
  414. return tf.train.MonitoredTrainingSession(
  415. master=target, is_chief=is_chief, checkpoint_dir=checkpoint_dir, scaffold=scaffold,
  416. save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps,
  417. save_summaries_secs=save_summaries_secs, log_step_count_steps=log_step_count_steps,
  418. stop_grace_period_secs=stop_grace_period_secs, config=config, hooks=hooks, chief_only_hooks=chief_only_hooks
  419. )
  420. @deprecated(date="2018-10-30", instructions="Using the TensorLayer distributed trainer.")
  421. class StopAtTimeHook(session_run_hook.SessionRunHook):
  422. """Hook that requests stop after a specified time.
  423. Parameters
  424. ----------
  425. time_running: int
  426. Maximum time running in seconds
  427. """
  428. def __init__(self, time_running):
  429. self._time_running = time_running
  430. self._end_time = 0
  431. def begin(self):
  432. self._end_time = time.time() + self._time_running
  433. def after_run(self, run_context, run_values):
  434. if time.time() > self._end_time:
  435. run_context.request_stop()
  436. @deprecated(date="2018-10-30", instructions="Using the TensorLayer distributed trainer.")
  437. class LoadCheckpoint(session_run_hook.SessionRunHook):
  438. """Hook that loads a checkpoint after the session is created.
  439. >>> from tensorflow.python.ops import variables as tf_variables
  440. >>> from tensorflow.python.training.monitored_session import SingularMonitoredSession
  441. >>>
  442. >>> tensors = create_graph()
  443. >>> saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
  444. >>> checkpoint_hook = LoadCheckpoint(saver, my_checkpoint_file)
  445. >>> with tf.SingularMonitoredSession(hooks=[checkpoint_hook]) as session:
  446. >>> while not session.should_stop():
  447. >>> session.run(tensors)
  448. """
  449. def __init__(self, saver, checkpoint):
  450. self._saver = saver
  451. self._checkpoint = checkpoint
  452. self._loaded = False
  453. def after_create_session(self, session, coord):
  454. if not self._loaded:
  455. self._loaded = True
  456. self._saver.restore(self._checkpoint)
  457. # Alias
  458. TaskSpec = create_task_spec_def
  459. DistributedSession = create_distributed_session

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.