| @@ -57,12 +57,15 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra | |||||
| allreduce (Primitive): The communication operator for gradients. | allreduce (Primitive): The communication operator for gradients. | ||||
| allreduce_filter (bool): When it is true, allreduce would apply. | allreduce_filter (bool): When it is true, allreduce would apply. | ||||
| grad (Tensor): The gradient tensor before operation. | grad (Tensor): The gradient tensor before operation. | ||||
| ps_parameter(Bool): Use parameter server or not. | |||||
| ps_parameter (bool): Use parameter server or not. | |||||
| Returns: | Returns: | ||||
| Tensor, the gradient tensor after operation. | Tensor, the gradient tensor after operation. | ||||
| """ | """ | ||||
| if not ps_parameter and allreduce_filter: | |||||
| if ps_parameter: | |||||
| return grad | |||||
| if allreduce_filter: | |||||
| grad = allreduce(grad) | grad = allreduce(grad) | ||||
| if mean: | if mean: | ||||
| degree = F.scalar_cast(degree, F.dtype(grad)) | degree = F.scalar_cast(degree, F.dtype(grad)) | ||||
| @@ -73,8 +76,8 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra | |||||
| return grad | return grad | ||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") | |||||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): | |||||
| @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") | |||||
| def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): | |||||
| """ | """ | ||||
| Apply allgather on gradient instead of allreduce for sparse feature. | Apply allgather on gradient instead of allreduce for sparse feature. | ||||
| Allgather is a communication operation used for distributed deep learning. | Allgather is a communication operation used for distributed deep learning. | ||||
| @@ -86,10 +89,14 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce | |||||
| allreduce (Primitive): The communication operator for gradients. | allreduce (Primitive): The communication operator for gradients. | ||||
| allreduce_filter (bool): When it is true, allgather would apply. | allreduce_filter (bool): When it is true, allgather would apply. | ||||
| grad (tuple): The indices, gradient tensor and tensor_shape before operation. | grad (tuple): The indices, gradient tensor and tensor_shape before operation. | ||||
| ps_parameter (bool): Use parameter server or not. | |||||
| Returns: | Returns: | ||||
| IndexedSlices, the gradient after operation. | IndexedSlices, the gradient after operation. | ||||
| """ | """ | ||||
| if ps_parameter: | |||||
| return grad | |||||
| if allreduce_filter: | if allreduce_filter: | ||||
| indices = allgather(grad.indices()) | indices = allgather(grad.indices()) | ||||
| dout = allgather(grad.values()) | dout = allgather(grad.values()) | ||||