| @@ -299,6 +299,18 @@ REG_OP(ApplyMomentumCCE) | |||||
| .ATTR(use_locking, Bool, false) | .ATTR(use_locking, Bool, false) | ||||
| .OP_END_FACTORY_REG(ApplyMomentumCCE) | .OP_END_FACTORY_REG(ApplyMomentumCCE) | ||||
| REG_OP(ApplyMomentumD) | |||||
| .INPUT(var, TensorType::NumberType()) | |||||
| .INPUT(accum, TensorType::NumberType()) | |||||
| .INPUT(lr, TensorType::NumberType()) | |||||
| .INPUT(grad, TensorType::NumberType()) | |||||
| .INPUT(momentum, TensorType::NumberType()) | |||||
| .OUTPUT(var, TensorType::NumberType()) | |||||
| .OUTPUT(accum, TensorType::NumberType()) | |||||
| .ATTR(use_nesterov, Bool, false) | |||||
| .ATTR(use_locking, Bool, false) | |||||
| .OP_END_FACTORY_REG(ApplyMomentumD) | |||||
| /** | /** | ||||
| *@brief Updates "var" according to the AddSign update.\n | *@brief Updates "var" according to the AddSign update.\n | ||||
| * t-1 mean previous period. | * t-1 mean previous period. | ||||