From: @xu_anyue Reviewed-by: Signed-off-by:pull/13703/MERGE
| @@ -45,7 +45,7 @@ int GatherFp16CPUKernel::Init() { | |||
| reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); | |||
| Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); | |||
| } | |||
| (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *(reinterpret_cast<int *>(in_tensors_.at(2)->data_c())); | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| @@ -15,7 +15,9 @@ | |||
| */ | |||
| #include "src/runtime/kernel/npu/matmul_npu.h" | |||
| #include <memory> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -24,6 +26,11 @@ using mindspore::schema::PrimitiveType_MatMul; | |||
| namespace mindspore::kernel { | |||
| int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter) { | |||
| if (inputs.size() == 3) { | |||
| if (inputs[2]->shape().size() != 1) { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -33,7 +40,33 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con | |||
| op_->set_input_x1(*npu_inputs[0]); | |||
| op_->set_input_x2(*npu_inputs[1]); | |||
| if (npu_inputs.size() == 3) { | |||
| op_->set_input_bias(*npu_inputs[2]); | |||
| matmul_parameter_->has_bias_ = true; | |||
| add_op_ = new (std::nothrow) hiai::op::Add(name_ + "_add"); | |||
| if (add_op_ == nullptr) { | |||
| MS_LOG(ERROR) << "new add op failed."; | |||
| return RET_ERROR; | |||
| } | |||
| add_op_->set_input_x1(*op_); | |||
| auto bias_shape = inputs[2]->shape(); | |||
| auto bias_tensor = std::make_shared<ge::Tensor>(); | |||
| if (bias_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new bias_tensor failed."; | |||
| return RET_ERROR; | |||
| } | |||
| ge::TensorDesc bias_tensor_desc(lite::ConverterToNPUShape({1, bias_shape[0], 1, 1}), ge::FORMAT_NCHW, | |||
| lite::ConverterToNPUDataType(inputs[2]->data_type())); | |||
| if (outputs[0]->shape().size() == 2) { | |||
| bias_tensor_desc.SetShape(lite::ConverterToNPUShape({1, bias_shape[0]})); | |||
| } | |||
| bias_tensor->SetTensorDesc(bias_tensor_desc); | |||
| bias_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size()); | |||
| bias_ = new (std::nothrow) hiai::op::Const(name_ + "_bias"); | |||
| if (bias_ == nullptr) { | |||
| MS_LOG(ERROR) << "new bias const failed."; | |||
| return RET_ERROR; | |||
| } | |||
| bias_->set_attr_value(bias_tensor); | |||
| add_op_->set_input_x2(*bias_); | |||
| } | |||
| op_->set_attr_transpose_x1(matmul_parameter_->a_transpose_); | |||
| @@ -41,13 +74,26 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con | |||
| return RET_OK; | |||
| } | |||
| ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; } | |||
| ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { | |||
| if (matmul_parameter_->has_bias_) { | |||
| return add_op_; | |||
| } | |||
| return op_; | |||
| } | |||
| MatMulNPUKernel::~MatMulNPUKernel() { | |||
| if (op_ != nullptr) { | |||
| delete op_; | |||
| op_ = nullptr; | |||
| } | |||
| if (add_op_ != nullptr) { | |||
| delete add_op_; | |||
| add_op_ = nullptr; | |||
| } | |||
| if (bias_ != nullptr) { | |||
| delete bias_; | |||
| bias_ = nullptr; | |||
| } | |||
| } | |||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>) | |||
| @@ -39,6 +39,8 @@ class MatMulNPUKernel : public NPUKernel { | |||
| private: | |||
| hiai::op::MatMul *op_ = nullptr; | |||
| hiai::op::Add *add_op_ = nullptr; | |||
| hiai::op::Const *bias_ = nullptr; | |||
| MatMulParameter *matmul_parameter_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -11,12 +11,12 @@ STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_C_FLAGS "$ | |||
| STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | |||
| if(ENABLE_CONVERTER) | |||
| set(CCSRC_SRC | |||
| ## ccsrc | |||
| ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc | |||
| ${CCSRC_DIR}/backend/optimizer/common/visit.cc | |||
| ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc | |||
| ) | |||
| set(CCSRC_SRC | |||
| ## ccsrc | |||
| ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc | |||
| ${CCSRC_DIR}/backend/optimizer/common/visit.cc | |||
| ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc | |||
| ) | |||
| else() | |||
| set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc) | |||
| add_compile_definitions(USE_ANDROID_LOG) | |||
| @@ -38,10 +38,10 @@ file(GLOB KERNEL_OP_SRC | |||
| file(GLOB KERNEL_OP_TRAIN_SRC | |||
| ${LITE_DIR}/nnacl/fp32_grad/*.c | |||
| ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc | |||
| ) | |||
| ) | |||
| if(SUPPORT_TRAIN) | |||
| list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) | |||
| list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) | |||
| endif() | |||
| if(PLATFORM_ARM64) | |||
| # assembly | |||
| @@ -114,9 +114,9 @@ if(SUPPORT_GPU STREQUAL vulkan) | |||
| endif() | |||
| if(PLATFORM_ARM32 OR PLATFORM_ARM64) | |||
| if(ENABLE_CONVERTER) | |||
| set(BUILD_MINDDATA "off") | |||
| endif() | |||
| if(ENABLE_CONVERTER) | |||
| set(BUILD_MINDDATA "off") | |||
| endif() | |||
| endif() | |||
| ### runtime framework | |||
| add_definitions(-DENABLE_V0) | |||
| @@ -189,19 +189,19 @@ if(ENABLE_MINDRT) | |||
| include_directories(${CORE_DIR}/mindrt/) | |||
| include_directories(${CORE_DIR}/mindrt/src/) | |||
| set(TEST_LITE_SRC ${TEST_LITE_SRC} | |||
| ${LITE_DIR}/src/lite_mindrt.cc | |||
| ${LITE_DIR}/src/mindrt_executor.cc | |||
| ${CORE_DIR}/mindrt/src/litebus.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actor.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actormgr.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actorpolicy.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actorthread.cc | |||
| ${CORE_DIR}/mindrt/src/actor/aid.cc | |||
| ${CORE_DIR}/mindrt/src/async/async.cc | |||
| ${CORE_DIR}/mindrt/src/async/future.cc | |||
| ${CORE_DIR}/mindrt/src/async/uuid_base.cc | |||
| ${CORE_DIR}/mindrt/src/async/uuid_generator.cc | |||
| ) | |||
| ${LITE_DIR}/src/lite_mindrt.cc | |||
| ${LITE_DIR}/src/mindrt_executor.cc | |||
| ${CORE_DIR}/mindrt/src/litebus.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actor.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actormgr.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actorpolicy.cc | |||
| ${CORE_DIR}/mindrt/src/actor/actorthread.cc | |||
| ${CORE_DIR}/mindrt/src/actor/aid.cc | |||
| ${CORE_DIR}/mindrt/src/async/async.cc | |||
| ${CORE_DIR}/mindrt/src/async/future.cc | |||
| ${CORE_DIR}/mindrt/src/async/uuid_base.cc | |||
| ${CORE_DIR}/mindrt/src/async/uuid_generator.cc | |||
| ) | |||
| endif() | |||
| @@ -242,6 +242,7 @@ if(ENABLE_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | |||
| @@ -286,16 +287,16 @@ else() | |||
| endif() | |||
| ### test src | |||
| file(GLOB_RECURSE TEST_CASE_KERNEL_SRC | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc | |||
| ${TEST_DIR}/ut/nnacl/infer/*.cc | |||
| ) | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc | |||
| ${TEST_DIR}/ut/nnacl/infer/*.cc | |||
| ) | |||
| file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc | |||
| ) | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc | |||
| ) | |||
| set(TEST_SRC | |||
| ${TEST_LITE_SRC} | |||
| @@ -306,7 +307,7 @@ set(TEST_SRC | |||
| ${TEST_DIR}/ut/src/infer_test.cc | |||
| ${TEST_DIR}/ut/src/utils_test.cc | |||
| ${TEST_DIR}/ut/src/scheduler_test.cc | |||
| ) | |||
| ) | |||
| if(ENABLE_CONVERTER) | |||
| set(TEST_SRC | |||
| @@ -358,7 +359,7 @@ endif() | |||
| if(ENABLE_FP16 AND SUPPORT_TRAIN) | |||
| file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc) | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc) | |||
| list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD}) | |||
| endif() | |||
| @@ -52,6 +52,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/tf_lstm_cell_fusion.cc | |||
| ../optimizer/fusion/tf_bidirection_gru_fusion.cc | |||
| ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | |||
| ../optimizer/fusion/matmul_add_fusion.cc | |||
| ../optimizer/graph/weight_format_transform_pass.cc | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| ../optimizer/graph/clip_convert_activation_pass.cc | |||
| @@ -35,6 +35,7 @@ | |||
| #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" | |||
| #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | |||
| #include "tools/optimizer/fusion/matmul_add_fusion.h" | |||
| #include "tools/optimizer/graph/primitive_adjust_pass.h" | |||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | |||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | |||
| @@ -107,6 +108,9 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||
| fusion_pm->AddPass(remove_unused_transpose_pass); | |||
| } | |||
| fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | |||
| if (!config->trainModel) { | |||
| fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>()); | |||
| } | |||
| optimizer->AddPassManager(fusion_pm); | |||
| return RET_OK; | |||
| } | |||
| @@ -1,48 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/onnx/onnx_gemm_parser.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/make_tuple.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||
| auto prim = std::make_unique<ops::MakeTuple>(); | |||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul"); | |||
| if (node_parser == nullptr) { | |||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||
| return nullptr; | |||
| } | |||
| auto *matmul_primitive = node_parser->Parse(onnx_graph, onnx_node); | |||
| prim->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive)); | |||
| node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd"); | |||
| if (node_parser == nullptr) { | |||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||
| return nullptr; | |||
| } | |||
| auto *bias_add_primitive = node_parser->Parse(onnx_graph, onnx_node); | |||
| prim->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive)); | |||
| return prim.release(); | |||
| } | |||
| OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,34 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class OnnxGemmParser : public OnnxNodeParser { | |||
| public: | |||
| OnnxGemmParser() : OnnxNodeParser("Gemm") {} | |||
| ~OnnxGemmParser() override = default; | |||
| ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||
| @@ -46,5 +46,6 @@ ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, con | |||
| } | |||
| OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | |||
| OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxMatmulParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -44,7 +44,6 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = { | |||
| {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, | |||
| {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; | |||
| std::set<std::string> SPECIAL_NODE = {"Gemm"}; | |||
| FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type) { | |||
| NoSupportOp::GetInstance()->SetFmkType("ONNX"); | |||
| @@ -215,11 +214,6 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F | |||
| MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; | |||
| continue; | |||
| } | |||
| if (IsSpecialOnnxNode(onnx_node)) { | |||
| auto status_node = ConvertSpecialOnnxNode(onnx_node, anf_graph, anf_nodes_map, primitive_c); | |||
| status = status == RET_OK ? status_node : status; | |||
| continue; | |||
| } | |||
| // build CNode | |||
| status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name); | |||
| if (status != RET_OK) { | |||
| @@ -1023,117 +1017,6 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf | |||
| return status; | |||
| } | |||
| STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| ops::PrimitiveC *primitive_c) { | |||
| if (primitive_c == nullptr || anf_graph == nullptr) { | |||
| MS_LOG(ERROR) << "imitive_c is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| STATUS status = RET_OK; | |||
| if (onnx_node.op_type() == "Gemm") { | |||
| status = ConvertOnnxGemmNode(onnx_node, anf_graph, anf_nodes_map, primitive_c); | |||
| } else { | |||
| MS_LOG(ERROR) << "the node is not special node."; | |||
| status = RET_ERROR; | |||
| } | |||
| delete primitive_c; | |||
| return status; | |||
| } | |||
| STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| ops::PrimitiveC *primitive_c) { | |||
| if (primitive_c == nullptr || anf_graph == nullptr) { | |||
| MS_LOG(ERROR) << "parameter has nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (onnx_node.op_type() != "Gemm") { | |||
| MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type(); | |||
| return RET_ERROR; | |||
| } | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "MatMul"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert gemm node failed."; | |||
| return status; | |||
| } | |||
| status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "BiasAdd"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "convert gemm node failed."; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| ops::PrimitiveC *primitive_c, const std::string &name) { | |||
| if (primitive_c == nullptr || anf_graph == nullptr) { | |||
| MS_LOG(ERROR) << "parameter has nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value = primitive_c->GetAttr(name); | |||
| primitive_c->EraseAttr(name); | |||
| if (value == nullptr) { | |||
| MS_LOG(ERROR) << "op parse failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto prim_ptr = value->cast<std::shared_ptr<ops::PrimitiveC>>(); | |||
| if (prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "primitive parse failed."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto type_ptr = TypeIdToType(kTypeUnknown); | |||
| std::vector<int64_t> shape_vector; | |||
| std::vector<AnfNodePtr> op_inputs; | |||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||
| auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | |||
| if (name == "MatMul") { | |||
| for (int i = 0; i < 2; ++i) { | |||
| if (anf_nodes_map->find(onnx_node.input(i)) == anf_nodes_map->end()) { | |||
| MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; | |||
| return RET_ERROR; | |||
| } else { | |||
| op_inputs.push_back(anf_nodes_map->at(onnx_node.input(i))); | |||
| quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(i)); | |||
| } | |||
| } | |||
| quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1)); | |||
| auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs); | |||
| if (new_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "new cnode error"; | |||
| return RET_ERROR; | |||
| } | |||
| new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); | |||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||
| anf_nodes_map->emplace("Gemm_MatMul_" + onnx_node.output(0), new_cnode); | |||
| } else { | |||
| if (anf_nodes_map->find("Gemm_MatMul_" + onnx_node.output(0)) == anf_nodes_map->end() || | |||
| anf_nodes_map->find(onnx_node.input(2)) == anf_nodes_map->end()) { | |||
| MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; | |||
| return RET_ERROR; | |||
| } | |||
| op_inputs.push_back(anf_nodes_map->at("Gemm_MatMul_" + onnx_node.output(0))); | |||
| op_inputs.push_back(anf_nodes_map->at(onnx_node.input(2))); | |||
| quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1)); | |||
| quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(2)); | |||
| quant_params_holder->AddOutputQuantParam(quant_params_holder_origin->output_quant_params().front()); | |||
| auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs); | |||
| if (new_cnode == nullptr) { | |||
| MS_LOG(ERROR) << "new cnode error"; | |||
| return RET_ERROR; | |||
| } | |||
| new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); | |||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||
| anf_nodes_map->emplace(onnx_node.output(0), new_cnode); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) { | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "value is nullptr."; | |||
| @@ -1281,10 +1164,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_t | |||
| return RET_OK; | |||
| } | |||
| bool OnnxModelParser::IsSpecialOnnxNode(const onnx::NodeProto &onnx_node) { | |||
| return SPECIAL_NODE.find(onnx_node.op_type()) != SPECIAL_NODE.end(); | |||
| } | |||
| TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { | |||
| auto iter = TYPE_MAP.find(onnx_type); | |||
| if (iter == TYPE_MAP.end()) { | |||
| @@ -69,21 +69,11 @@ class OnnxModelParser : public ModelParser { | |||
| ops::PrimitiveC *primitive_c, std::string loop_name); | |||
| static STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode); | |||
| static STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| ops::PrimitiveC *primitive_c); | |||
| static STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| ops::PrimitiveC *primitive_c); | |||
| static STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| ops::PrimitiveC *primitive_c, const std::string &name); | |||
| STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); | |||
| STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); | |||
| STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | |||
| STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); | |||
| static bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node); | |||
| STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node, | |||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||
| const std::string &root_node_name); | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/matmul_add_fusion.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t AddInputSize = 3; | |||
| constexpr size_t MatMulInputSize = 3; | |||
| bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(index != nullptr); | |||
| if (cnode->size() != AddInputSize) { | |||
| return false; | |||
| } | |||
| size_t matmul_index = 0; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) { | |||
| auto matmul_cnode = cnode->input(i)->cast<CNodePtr>(); | |||
| if (matmul_cnode->size() > MatMulInputSize) { | |||
| continue; | |||
| } | |||
| matmul_index = i; | |||
| break; | |||
| } | |||
| } | |||
| if (matmul_index == 0) { | |||
| return false; | |||
| } | |||
| *index = matmul_index; | |||
| return true; | |||
| } | |||
| } // namespace | |||
| bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nulltr); | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNode>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) { | |||
| continue; | |||
| } | |||
| size_t index = 0; | |||
| if (!CheckAndGetMatMulIndex(cnode, &index)) { | |||
| continue; | |||
| } | |||
| auto matmul_cnode = cnode->input(index)->cast<CNodePtr>(); | |||
| auto bias_node = cnode->input(AddInputSize - index); | |||
| if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) { | |||
| continue; | |||
| } | |||
| matmul_cnode->add_input(bias_node); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| matmul_cnode->set_fullname_with_scope(node->fullname_with_scope()); | |||
| manager->Replace(node, matmul_cnode); | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class MatMulAddFusion : public Pass { | |||
| public: | |||
| MatMulAddFusion() : Pass("matmul_add_fusion") {} | |||
| ~MatMulAddFusion() override = default; | |||
| bool Run(const FuncGraphPtr &func_graph) override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ | |||