You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

replicate.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import functools
  2. from torch.nn.parallel.data_parallel import DataParallel
  3. __all__ = [
  4. 'CallbackContext',
  5. 'execute_replication_callbacks',
  6. 'DataParallelWithCallback',
  7. 'patch_replication_callback'
  8. ]
  9. class CallbackContext(object):
  10. pass
  11. def execute_replication_callbacks(modules):
  12. """
  13. Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
  14. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
  15. Note that, as all modules are isomorphism, we assign each sub-module with a context
  16. (shared among multiple copies of this module on different devices).
  17. Through this context, different copies can share some information.
  18. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
  19. of any slave copies.
  20. """
  21. master_copy = modules[0]
  22. nr_modules = len(list(master_copy.modules()))
  23. ctxs = [CallbackContext() for _ in range(nr_modules)]
  24. for i, module in enumerate(modules):
  25. for j, m in enumerate(module.modules()):
  26. if hasattr(m, '__data_parallel_replicate__'):
  27. m.__data_parallel_replicate__(ctxs[j], i)
  28. class DataParallelWithCallback(DataParallel):
  29. """
  30. Data Parallel with a replication callback.
  31. An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
  32. original `replicate` function.
  33. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
  34. Examples:
  35. > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
  36. > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
  37. # sync_bn.__data_parallel_replicate__ will be invoked.
  38. """
  39. def replicate(self, module, device_ids):
  40. modules = super(DataParallelWithCallback,
  41. self).replicate(module, device_ids)
  42. execute_replication_callbacks(modules)
  43. return modules
  44. def patch_replication_callback(data_parallel):
  45. """
  46. Monkey-patch an existing `DataParallel` object. Add the replication callback.
  47. Useful when you have customized `DataParallel` implementation.
  48. Examples:
  49. > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
  50. > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
  51. > patch_replication_callback(sync_bn)
  52. # this is equivalent to
  53. > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
  54. > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
  55. """
  56. assert isinstance(data_parallel, DataParallel)
  57. old_replicate = data_parallel.replicate
  58. @functools.wraps(old_replicate)
  59. def new_replicate(module, device_ids):
  60. modules = old_replicate(module, device_ids)
  61. execute_replication_callbacks(modules)
  62. return modules
  63. data_parallel.replicate = new_replicate