|
|
|
@@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra |
|
|
|
allreduce (Primitive): The communication operator for gradients. |
|
|
|
allreduce_filter (bool): When it is true, allreduce would apply. |
|
|
|
grad (Tensor): The gradient tensor before operation. |
|
|
|
ps_parameter(Bool): Use parameter server or not. |
|
|
|
ps_parameter (bool): Use parameter server or not. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Tensor, the gradient tensor after operation. |
|
|
|
""" |
|
|
|
if not ps_parameter and allreduce_filter: |
|
|
|
if ps_parameter: |
|
|
|
return grad |
|
|
|
|
|
|
|
if allreduce_filter: |
|
|
|
grad = allreduce(grad) |
|
|
|
if mean: |
|
|
|
degree = F.scalar_cast(degree, F.dtype(grad)) |
|
|
|
@@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra |
|
|
|
return grad |
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") |
|
|
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): |
|
|
|
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") |
|
|
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): |
|
|
|
""" |
|
|
|
Apply allgather on gradient instead of allreduce for sparse feature. |
|
|
|
Allgather is a communication operation used for distributed deep learning. |
|
|
|
@@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce |
|
|
|
allreduce (Primitive): The communication operator for gradients. |
|
|
|
allreduce_filter (bool): When it is true, allgather would apply. |
|
|
|
grad (tuple): The indices, gradient tensor and tensor_shape before operation. |
|
|
|
ps_parameter (bool): Use parameter server or not. |
|
|
|
|
|
|
|
Returns: |
|
|
|
IndexedSlices, the gradient after operation. |
|
|
|
""" |
|
|
|
if ps_parameter: |
|
|
|
return grad |
|
|
|
|
|
|
|
if allreduce_filter: |
|
|
|
indices = allgather(grad.indices()) |
|
|
|
dout = allgather(grad.values()) |
|
|
|
|