|
- import functools
-
- from torch.nn.parallel.data_parallel import DataParallel
-
- __all__ = [
- 'CallbackContext',
- 'execute_replication_callbacks',
- 'DataParallelWithCallback',
- 'patch_replication_callback'
- ]
-
-
- class CallbackContext(object):
- pass
-
-
- def execute_replication_callbacks(modules):
- """
- Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
- The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
- Note that, as all modules are isomorphism, we assign each sub-module with a context
- (shared among multiple copies of this module on different devices).
- Through this context, different copies can share some information.
- We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
- of any slave copies.
- """
- master_copy = modules[0]
- nr_modules = len(list(master_copy.modules()))
- ctxs = [CallbackContext() for _ in range(nr_modules)]
-
- for i, module in enumerate(modules):
- for j, m in enumerate(module.modules()):
- if hasattr(m, '__data_parallel_replicate__'):
- m.__data_parallel_replicate__(ctxs[j], i)
-
-
- class DataParallelWithCallback(DataParallel):
- """
- Data Parallel with a replication callback.
- An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
- original `replicate` function.
- The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
- Examples:
- > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
- > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
- # sync_bn.__data_parallel_replicate__ will be invoked.
- """
-
- def replicate(self, module, device_ids):
- modules = super(DataParallelWithCallback,
- self).replicate(module, device_ids)
- execute_replication_callbacks(modules)
- return modules
-
-
- def patch_replication_callback(data_parallel):
- """
- Monkey-patch an existing `DataParallel` object. Add the replication callback.
- Useful when you have customized `DataParallel` implementation.
- Examples:
- > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
- > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
- > patch_replication_callback(sync_bn)
- # this is equivalent to
- > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
- > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
- """
-
- assert isinstance(data_parallel, DataParallel)
-
- old_replicate = data_parallel.replicate
-
- @functools.wraps(old_replicate)
- def new_replicate(module, device_ids):
- modules = old_replicate(module, device_ids)
- execute_replication_callbacks(modules)
- return modules
-
- data_parallel.replicate = new_replicate
|