Browse Source

add ps filter

tags/v0.6.0-beta
jinyaohui 5 years ago
parent
commit
e97bf5b8ec
1 changed files with 11 additions and 4 deletions
  1. +11
    -4
      mindspore/nn/wrap/grad_reducer.py

+ 11
- 4
mindspore/nn/wrap/grad_reducer.py View File

@@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
allreduce (Primitive): The communication operator for gradients. allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allreduce would apply. allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation. grad (Tensor): The gradient tensor before operation.
ps_parameter(Bool): Use parameter server or not.
ps_parameter (bool): Use parameter server or not.


Returns: Returns:
Tensor, the gradient tensor after operation. Tensor, the gradient tensor after operation.
""" """
if not ps_parameter and allreduce_filter:
if ps_parameter:
return grad

if allreduce_filter:
grad = allreduce(grad) grad = allreduce(grad)
if mean: if mean:
degree = F.scalar_cast(degree, F.dtype(grad)) degree = F.scalar_cast(degree, F.dtype(grad))
@@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
return grad return grad




@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
""" """
Apply allgather on gradient instead of allreduce for sparse feature. Apply allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning. Allgather is a communication operation used for distributed deep learning.
@@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
allreduce (Primitive): The communication operator for gradients. allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allgather would apply. allreduce_filter (bool): When it is true, allgather would apply.
grad (tuple): The indices, gradient tensor and tensor_shape before operation. grad (tuple): The indices, gradient tensor and tensor_shape before operation.
ps_parameter (bool): Use parameter server or not.


Returns: Returns:
IndexedSlices, the gradient after operation. IndexedSlices, the gradient after operation.
""" """
if ps_parameter:
return grad

if allreduce_filter: if allreduce_filter:
indices = allgather(grad.indices()) indices = allgather(grad.indices())
dout = allgather(grad.values()) dout = allgather(grad.values())


Loading…
Cancel
Save