| @@ -292,7 +292,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ||||
| } | } | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>()); | ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddInputToOutput>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertTranspose>()); | ir_fusion_pm->AddPass(std::make_shared<InsertTranspose>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ||||
| @@ -0,0 +1,126 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/mindir/optimizer_unify_output.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t kFtrlOutputNum = 3; | |||||
| constexpr size_t kMomentumOutputNum = 2; | |||||
| constexpr size_t kRMSPropOutputNum = 3; | |||||
| constexpr size_t kCenteredRMSPropOutputNum = 4; | |||||
| CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode_ptr = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| auto abstract = cnode_ptr->abstract(); | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| if (AnfAlgo::HasNodeAttr("optim_output_passed", cnode_ptr) && abstract->isa<abstract::AbstractTuple>()) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfAlgo::SetNodeAttr("optim_output_passed", MakeValue(true), cnode_ptr); | |||||
| std::vector<AbstractBasePtr> abstract_list; | |||||
| for (size_t i = 0; i < output_size; i++) { | |||||
| abstract_list.push_back(abstract->Clone()); | |||||
| } | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| cnode_ptr->set_abstract(abstract_tuple); | |||||
| auto index = NewValueNode(static_cast<int64_t>(0)); | |||||
| auto get_item = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index}); | |||||
| MS_EXCEPTION_IF_NULL(get_item); | |||||
| get_item->set_abstract(abstract->Clone()); | |||||
| return get_item; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef FtrlUnifyOutput::DefinePattern() const { | |||||
| VarPtr var = std::make_shared<Var>(); | |||||
| VarPtr accum = std::make_shared<Var>(); | |||||
| VarPtr linear = std::make_shared<Var>(); | |||||
| VarPtr grad = std::make_shared<Var>(); | |||||
| VarPtr lr = std::make_shared<Var>(); | |||||
| VarPtr l1 = std::make_shared<Var>(); | |||||
| VarPtr l2 = std::make_shared<Var>(); | |||||
| VarPtr lr_power = std::make_shared<Var>(); | |||||
| VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr FtrlUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| return ProcessOutput(graph, node, kFtrlOutputNum); | |||||
| } | |||||
| const BaseRef MomentumUnifyOutput::DefinePattern() const { | |||||
| VarPtr var = std::make_shared<Var>(); | |||||
| VarPtr accum = std::make_shared<Var>(); | |||||
| VarPtr lr = std::make_shared<Var>(); | |||||
| VarPtr grad = std::make_shared<Var>(); | |||||
| VarPtr momentum = std::make_shared<Var>(); | |||||
| VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr MomentumUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| return ProcessOutput(graph, node, kMomentumOutputNum); | |||||
| } | |||||
| const BaseRef RMSPropUnifyOutput::DefinePattern() const { | |||||
| VarPtr inputs = std::make_shared<SeqVar>(); | |||||
| VectorRef pattern({prim::kPrimApplyRMSProp, inputs}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr RMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| return ProcessOutput(graph, node, kRMSPropOutputNum); | |||||
| } | |||||
| const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const { | |||||
| VarPtr var = std::make_shared<Var>(); | |||||
| VarPtr mg = std::make_shared<Var>(); | |||||
| VarPtr ms = std::make_shared<Var>(); | |||||
| VarPtr mom = std::make_shared<Var>(); | |||||
| VarPtr grad = std::make_shared<Var>(); | |||||
| VarPtr lr = std::make_shared<Var>(); | |||||
| VarPtr rho = std::make_shared<Var>(); | |||||
| VarPtr momentum = std::make_shared<Var>(); | |||||
| VarPtr epsilon = std::make_shared<Var>(); | |||||
| VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr CenteredRMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| return ProcessOutput(graph, node, kCenteredRMSPropOutputNum); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class FtrlUnifyOutput : public PatternProcessPass { | |||||
| public: | |||||
| explicit FtrlUnifyOutput(bool multigraph = true) : PatternProcessPass("ftrl_unify_output", multigraph) {} | |||||
| ~FtrlUnifyOutput() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| class MomentumUnifyOutput : public PatternProcessPass { | |||||
| public: | |||||
| explicit MomentumUnifyOutput(bool multigraph = true) : PatternProcessPass("momentum_unify_output", multigraph) {} | |||||
| ~MomentumUnifyOutput() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| class CenteredRMSPropUnifyOutput : public PatternProcessPass { | |||||
| public: | |||||
| explicit CenteredRMSPropUnifyOutput(bool multigraph = true) | |||||
| : PatternProcessPass("centered_rmsprop_unify_output", multigraph) {} | |||||
| ~CenteredRMSPropUnifyOutput() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| class RMSPropUnifyOutput : public PatternProcessPass { | |||||
| public: | |||||
| explicit RMSPropUnifyOutput(bool multigraph = true) : PatternProcessPass("rmsprop_unify_output", multigraph) {} | |||||
| ~RMSPropUnifyOutput() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ | |||||
| @@ -38,6 +38,7 @@ | |||||
| #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" | #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" | ||||
| #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" | #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" | ||||
| #include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" | #include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" | ||||
| #include "backend/optimizer/ascend/mindir/optimizer_unify_output.h" | |||||
| #include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h" | #include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h" | ||||
| #include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h" | #include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h" | ||||
| #include "runtime/device/kernel_adjust.h" | #include "runtime/device/kernel_adjust.h" | ||||
| @@ -217,6 +218,10 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { | |||||
| unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>()); | unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>()); | ||||
| unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>()); | unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>()); | ||||
| unify_mindir_pm->AddPass(std::make_shared<opt::SliceGradUnifyMindIR>()); | unify_mindir_pm->AddPass(std::make_shared<opt::SliceGradUnifyMindIR>()); | ||||
| unify_mindir_pm->AddPass(std::make_shared<opt::FtrlUnifyOutput>()); | |||||
| unify_mindir_pm->AddPass(std::make_shared<opt::MomentumUnifyOutput>()); | |||||
| unify_mindir_pm->AddPass(std::make_shared<opt::RMSPropUnifyOutput>()); | |||||
| unify_mindir_pm->AddPass(std::make_shared<opt::CenteredRMSPropUnifyOutput>()); | |||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -143,13 +143,13 @@ ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<b | |||||
| OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}}; | OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}}; | ||||
| REG_ADPT_DESC(SparseApplyFtrlD, kNameSparseApplyFtrlD, ADPT_DESC(SparseApplyFtrlD)) | REG_ADPT_DESC(SparseApplyFtrlD, kNameSparseApplyFtrlD, ADPT_DESC(SparseApplyFtrlD)) | ||||
| // ApplyFtrlD | |||||
| INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, | |||||
| {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, | |||||
| {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; | |||||
| ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}}; | |||||
| REG_ADPT_DESC(ApplyFtrlD, kNameApplyFtrl, ADPT_DESC(ApplyFtrlD)) | |||||
| // ApplyFtrl | |||||
| INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, | |||||
| {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, | |||||
| {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; | |||||
| ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}}; | |||||
| REG_ADPT_DESC(ApplyFtrl, kNameApplyFtrl, ADPT_DESC(ApplyFtrl)) | |||||
| // ApplyRMSPropD | // ApplyRMSPropD | ||||
| INPUT_MAP(ApplyRMSPropD) = { | INPUT_MAP(ApplyRMSPropD) = { | ||||
| @@ -161,12 +161,11 @@ ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool | |||||
| OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; | OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; | ||||
| REG_ADPT_DESC(ApplyRMSPropD, kNameApplyRMSProp, ADPT_DESC(ApplyRMSPropD)) | REG_ADPT_DESC(ApplyRMSPropD, kNameApplyRMSProp, ADPT_DESC(ApplyRMSPropD)) | ||||
| // ApplyCenteredRMSPropD | |||||
| INPUT_MAP(ApplyCenteredRMSPropD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, | |||||
| {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, | |||||
| {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; | |||||
| ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyCenteredRMSPropD) = { | |||||
| {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; | |||||
| REG_ADPT_DESC(ApplyCenteredRMSPropD, kNameApplyCenteredRMSProp, ADPT_DESC(ApplyCenteredRMSPropD)) | |||||
| // ApplyCenteredRMSProp | |||||
| INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, | |||||
| {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, | |||||
| {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; | |||||
| ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; | |||||
| REG_ADPT_DESC(ApplyCenteredRMSProp, kNameApplyCenteredRMSProp, ADPT_DESC(ApplyCenteredRMSProp)) | |||||
| } // namespace mindspore::transform | } // namespace mindspore::transform | ||||
| @@ -62,8 +62,8 @@ DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) | |||||
| DECLARE_OP_ADAPTER(LarsV2Update) | DECLARE_OP_ADAPTER(LarsV2Update) | ||||
| DECLARE_OP_USE_OUTPUT(LarsV2Update) | DECLARE_OP_USE_OUTPUT(LarsV2Update) | ||||
| DECLARE_OP_ADAPTER(ApplyFtrlD) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyFtrlD) | |||||
| DECLARE_OP_ADAPTER(ApplyFtrl) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyFtrl) | |||||
| DECLARE_OP_ADAPTER(SparseApplyFtrlD) | DECLARE_OP_ADAPTER(SparseApplyFtrlD) | ||||
| DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) | DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) | ||||
| @@ -72,7 +72,7 @@ DECLARE_OP_ADAPTER(ApplyRMSPropD) | |||||
| DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) | DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) | ||||
| DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) | DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) | ||||
| DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) | |||||
| DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) | |||||
| } // namespace mindspore::transform | } // namespace mindspore::transform | ||||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_TRAINING_OPS_DECLARE_H_ | #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_TRAINING_OPS_DECLARE_H_ | ||||
| @@ -239,6 +239,7 @@ inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = | |||||
| std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits"); | ||||
| inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum"); | ||||
| inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum"); | ||||
| inline const PrimitivePtr kPrimApplyFtrl = std::make_shared<Primitive>("ApplyFtrl"); | |||||
| inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm"); | ||||
| inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn"); | inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("Lrn"); | ||||
| inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | ||||
| @@ -452,7 +453,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_ | |||||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | ||||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | ||||
| // Other primitve not used by backend but used in core; | |||||
| // Other primitive not used by backend but used in core; | |||||
| inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | ||||
| inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | ||||
| @@ -308,7 +308,6 @@ def _concat_grad_uniform(input_shapes, input_nums): | |||||
| def get_bprop_concat(self): | def get_bprop_concat(self): | ||||
| """Generate bprop for Concat""" | """Generate bprop for Concat""" | ||||
| axis = self.axis | axis = self.axis | ||||
| is_ascend = context.get_context('device_target') == "Ascend" | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = () | dx = () | ||||
| @@ -318,7 +317,7 @@ def get_bprop_concat(self): | |||||
| for i in range(input_nums): | for i in range(input_nums): | ||||
| input_shapes = input_shapes + (shape_op(x[i]),) | input_shapes = input_shapes + (shape_op(x[i]),) | ||||
| is_uniform = _concat_grad_uniform(input_shapes, input_nums) | is_uniform = _concat_grad_uniform(input_shapes, input_nums) | ||||
| if is_uniform and is_ascend: | |||||
| if is_uniform: | |||||
| dx = P.Split(axis, input_nums)(dout) | dx = P.Split(axis, input_nums)(dout) | ||||
| else: | else: | ||||
| for i in range(input_nums): | for i in range(input_nums): | ||||
| @@ -2413,12 +2413,8 @@ class ApplyMomentum(PrimitiveWithInfer): | |||||
| validator.check_value_type('gradient_scale', gradient_scale, [float], self.name) | validator.check_value_type('gradient_scale', gradient_scale, [float], self.name) | ||||
| self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], | self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], | ||||
| outputs=['output']) | outputs=['output']) | ||||
| self.is_tbe = context.get_context("device_target") == "Ascend" | |||||
| self.is_ge = context.get_context("enable_ge") | |||||
| def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): | def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): | ||||
| if not self.is_ge and self.is_tbe: | |||||
| return v_shape, v_shape | |||||
| return v_shape | return v_shape | ||||
| def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): | def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): | ||||
| @@ -2429,9 +2425,7 @@ class ApplyMomentum(PrimitiveWithInfer): | |||||
| validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name) | ||||
| validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name) | ||||
| validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name) | ||||
| if not self.is_ge and self.is_tbe: | |||||
| return g_dtype, g_dtype | |||||
| return g_dtype | |||||
| return v_dtype | |||||
| class SmoothL1Loss(PrimitiveWithInfer): | class SmoothL1Loss(PrimitiveWithInfer): | ||||
| @@ -2763,9 +2757,8 @@ class ApplyRMSProp(PrimitiveWithInfer): | |||||
| >>> momentum = 1e-10 | >>> momentum = 1e-10 | ||||
| >>> epsilon = 0.001 | >>> epsilon = 0.001 | ||||
| >>> output = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon) | >>> output = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon) | ||||
| >>> print(output) | |||||
| (Tensor(shape=[], dtype=Float32, value= 0.100112), Tensor(shape=[], dtype=Float32, value= 4), | |||||
| Tensor(shape=[], dtype=Float32, value= 0.899888)) | |||||
| >>> output | |||||
| Tensor(shape=[], dtype=Float32, value= 0.100112) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -2773,16 +2766,12 @@ class ApplyRMSProp(PrimitiveWithInfer): | |||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad', | self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad', | ||||
| 'rho', 'momentum', 'epsilon'], outputs=['output']) | 'rho', 'momentum', 'epsilon'], outputs=['output']) | ||||
| self.is_ge = context.get_context("enable_ge") | |||||
| self.is_d = context.get_context("device_target") == "Ascend" | |||||
| def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape, | def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape, | ||||
| momentum_shape, epsilon_shape): | momentum_shape, epsilon_shape): | ||||
| validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) | ||||
| validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) | ||||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | ||||
| if not self.is_ge and self.is_d: | |||||
| return var_shape, var_shape, var_shape | |||||
| return var_shape | return var_shape | ||||
| def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype, | def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype, | ||||
| @@ -2795,8 +2784,6 @@ class ApplyRMSProp(PrimitiveWithInfer): | |||||
| validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name) | validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name) | ||||
| args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype} | args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype} | ||||
| validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) | validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) | ||||
| if not self.is_ge and self.is_d: | |||||
| return var_dtype, var_dtype, var_dtype | |||||
| return var_dtype | return var_dtype | ||||
| def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): | def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): | ||||
| @@ -2867,22 +2854,15 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||||
| >>> epsilon = 0.05 | >>> epsilon = 0.05 | ||||
| >>> output = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad, | >>> output = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad, | ||||
| ... learning_rate, decay, momentum, epsilon) | ... learning_rate, decay, momentum, epsilon) | ||||
| >>> print(output) | |||||
| (Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| >>> output | |||||
| Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| [[-2.00000000e+00, -5.02492237e+00], | [[-2.00000000e+00, -5.02492237e+00], | ||||
| [-8.04984474e+00, -1.10747662e+01]]), Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| [[ 0.00000000e+00, 1.00000000e+00], | |||||
| [ 2.00000000e+00, 3.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| [[ 0.00000000e+00, 1.00000000e+00], | |||||
| [ 4.00000000e+00, 9.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| [[ 0.00000000e+00, 4.02492237e+00], | |||||
| [ 8.04984474e+00, 1.20747662e+01]])) | |||||
| [-8.04984474e+00, -1.10747662e+01]]) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=False): | def __init__(self, use_locking=False): | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| self.is_ascend = context.get_context("device_target") == "Ascend" | |||||
| def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, | def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, | ||||
| learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): | learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): | ||||
| @@ -2890,8 +2870,6 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||||
| validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) | ||||
| validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) | ||||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | ||||
| if self.is_ascend: | |||||
| return var_shape, mean_gradient_shape, mean_square_shape, moment_shape | |||||
| return var_shape | return var_shape | ||||
| def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, | def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, | ||||
| @@ -2905,8 +2883,6 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||||
| validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name) | validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name) | ||||
| args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} | args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} | ||||
| validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) | validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) | ||||
| if self.is_ascend: | |||||
| return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype | |||||
| return var_dtype | return var_dtype | ||||
| @@ -6176,15 +6152,8 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type. | Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type. | ||||
| Outputs: | Outputs: | ||||
| There are three outputs for Ascend environment. | |||||
| - **var** (Tensor) - represents the updated `var`. | |||||
| - **accum** (Tensor) - represents the updated `accum`. | |||||
| - **linear** (Tensor) - represents the updated `linear`. | |||||
| There is only one output for GPU environment. | |||||
| - **var** (Tensor) - This value is always zero and the input parameters has been updated in-place. | |||||
| - **var** (Tensor) - represents the updated `var`. As the input parameters has been updated in-place, this | |||||
| value is always zero when the platforms is GPU. | |||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | ``Ascend`` ``GPU`` | ||||
| @@ -6217,26 +6186,10 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| >>> net = ApplyFtrlNet() | >>> net = ApplyFtrlNet() | ||||
| >>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32) | >>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32) | ||||
| >>> output = net(input_x) | >>> output = net(input_x) | ||||
| >>> is_tbe = context.get_context("device_target") == "Ascend" | |||||
| >>> if is_tbe: | |||||
| ... print(output) | |||||
| (Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| >>> output | |||||
| Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| [[ 4.61418092e-01, 5.30964255e-01], | [[ 4.61418092e-01, 5.30964255e-01], | ||||
| [ 2.68715084e-01, 3.82065028e-01]]), Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| [[ 1.64236546e+01, 9.64589405e+00], | |||||
| [ 1.43758726e+00, 9.89177322e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= | |||||
| [[-1.86994812e+03, -1.64906018e+03], | |||||
| [-3.22187836e+02, -1.20163989e+03]])) | |||||
| ... else: | |||||
| ... print(net.var.asnumpy()) | |||||
| [[0.4614181 0.5309642 ] | |||||
| [0.2687151 0.38206503]] | |||||
| ... print(net.accum.asnumpy()) | |||||
| [[16.423655 9.645894 ] | |||||
| [ 1.4375873 9.891773 ]] | |||||
| ... print(net.linear.asnumpy()) | |||||
| [[-1869.9479 -1649.0599] | |||||
| [ -322.1879 -1201.6399]] | |||||
| [ 2.68715084e-01, 3.82065028e-01]]) | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -6244,14 +6197,11 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], | self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], | ||||
| outputs=['output']) | outputs=['output']) | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| self.is_tbe = context.get_context("device_target") == "Ascend" | |||||
| def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, | def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, | ||||
| lr_power_shape): | lr_power_shape): | ||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | ||||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | ||||
| if self.is_tbe: | |||||
| return var_shape, var_shape, var_shape | |||||
| return var_shape | return var_shape | ||||
| def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): | def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): | ||||
| @@ -6263,8 +6213,6 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name) | ||||
| validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name) | ||||
| validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name) | validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name) | ||||
| if self.is_tbe: | |||||
| return var_type, var_type, var_type | |||||
| return var_type | return var_type | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -48,7 +48,8 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum"; | |||||
| auto node_name = AnfAlgo::GetCNodeName(node->cast<CNodePtr>()); | |||||
| return node_name == "ApplyMomentum" || node_name == "AssignAdd"; | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -103,9 +104,9 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond3) { | |||||
| get_py_fun_.SetDoResolve(true); | get_py_fun_.SetDoResolve(true); | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before"); | ||||
| ASSERT_TRUE(g != nullptr); | ASSERT_TRUE(g != nullptr); | ||||
| std::vector<int64_t> shp_x{1, 64, 112, 112}; | |||||
| std::vector<int64_t> shp_x{3, 2}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | ||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | auto kg = GetKernelGraph(g, args_spec_list); | ||||
| EXPECT_NE(kg, nullptr); | EXPECT_NE(kg, nullptr); | ||||
| @@ -1,74 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TestHWAddInputToOutput : public BackendCommon { | |||||
| public: | |||||
| TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {} | |||||
| ~TestHWAddInputToOutput() override = default; | |||||
| public: | |||||
| UT::PyFuncGraphFetcher getPyFun_; | |||||
| }; | |||||
| class MockOpFinder : public OpFinder { | |||||
| public: | |||||
| MockOpFinder() = default; | |||||
| ~MockOpFinder() override = default; | |||||
| int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) override { return 2; } | |||||
| }; | |||||
| TEST_F(TestHWAddInputToOutput, test_add_input_to_output) { | |||||
| FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before"); | |||||
| EXPECT_NE(g, nullptr); | |||||
| std::vector<int64_t> shp{2, 32, 224, 224}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||||
| AbstractBasePtrList args_spec_list; | |||||
| for (size_t i = 0; i < 5; ++i) { | |||||
| args_spec_list.push_back(x_abstract); | |||||
| } | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| auto ret = kg->get_return(); | |||||
| EXPECT_NE(ret, nullptr); | |||||
| auto make_tuple = ret->input(1); | |||||
| EXPECT_NE(make_tuple, nullptr); | |||||
| auto momentum = make_tuple->cast<CNodePtr>()->input(1); | |||||
| EXPECT_NE(momentum, nullptr); | |||||
| EXPECT_NE(momentum->abstract(), nullptr); | |||||
| EXPECT_FALSE(momentum->abstract()->isa<abstract::AbstractTuple>()); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::AddInputToOutput>(); | |||||
| pass->op_finder_ = std::make_shared<MockOpFinder>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| (void)optimizer->Optimize(kg); | |||||
| EXPECT_TRUE(momentum->abstract()->isa<abstract::AbstractTuple>()); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,39 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore.ops import operations as P | |||||
| ApplyMomentum = P.ApplyMomentum() | |||||
| class FnDict: | |||||
| def __init__(self): | |||||
| self.fnDict = {} | |||||
| def __call__(self, fn): | |||||
| self.fnDict[fn.__name__] = fn | |||||
| def __getitem__(self, name): | |||||
| return self.fnDict[name] | |||||
| def test_add_input_to_output(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(input0, input1, input2, input3, input4): | |||||
| return ApplyMomentum(input0, input1, input2, input3, input4) | |||||
| return fns[tag] | |||||
| @@ -22,7 +22,7 @@ broadcast = P.Broadcast(1) | |||||
| memcpy_async = Primitive('memcpy_async') | memcpy_async = Primitive('memcpy_async') | ||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| tuple_getitem = Primitive(Constants.kTupleGetItem) | tuple_getitem = Primitive(Constants.kTupleGetItem) | ||||
| apply_momentun = P.ApplyMomentum() | |||||
| assign_add = P.AssignAdd() | |||||
| control_depend = P.ControlDepend() | control_depend = P.ControlDepend() | ||||
| relu = P.ReLU() | relu = P.ReLU() | ||||
| @@ -84,14 +84,14 @@ def test_insert_memcpy_async_for_hccl_op_cond3(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(a, b, c, d, e): | |||||
| res = apply_momentun(a, b, c, d, e) | |||||
| def before(a, b): | |||||
| res = assign_add(a, b) | |||||
| res = all_reduce(res) | res = all_reduce(res) | ||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(a, b, c, d, e): | |||||
| res = apply_momentun(a, b, c, d, e) | |||||
| def after(a, b): | |||||
| res = assign_add(a, b) | |||||
| res = memcpy_async(res) | res = memcpy_async(res) | ||||
| res = all_reduce(res) | res = all_reduce(res) | ||||
| return make_tuple(res) | return make_tuple(res) | ||||
| @@ -48,6 +48,6 @@ def test_momentum_lossscale_fusion(tag): | |||||
| @fns | @fns | ||||
| def after(input0, input1, input2, input3, input4): | def after(input0, input1, input2, input3, input4): | ||||
| return make_tuple(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant)) | |||||
| return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant), 0)) | |||||
| return fns[tag] | return fns[tag] | ||||