From: @xu_anyue Reviewed-by: Signed-off-by:pull/13703/MERGE
| @@ -45,7 +45,7 @@ int GatherFp16CPUKernel::Init() { | |||||
| reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); | reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); | ||||
| Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); | Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); | ||||
| } | } | ||||
| (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *(reinterpret_cast<int *>(in_tensors_.at(2)->data_c())); | |||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -15,7 +15,9 @@ | |||||
| */ | */ | ||||
| #include "src/runtime/kernel/npu/matmul_npu.h" | #include "src/runtime/kernel/npu/matmul_npu.h" | ||||
| #include <memory> | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/agent/npu/npu_converter_utils.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kNPU; | using mindspore::kernel::KERNEL_ARCH::kNPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -24,6 +26,11 @@ using mindspore::schema::PrimitiveType_MatMul; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | ||||
| OpParameter *opParameter) { | OpParameter *opParameter) { | ||||
| if (inputs.size() == 3) { | |||||
| if (inputs[2]->shape().size() != 1) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -33,7 +40,33 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con | |||||
| op_->set_input_x1(*npu_inputs[0]); | op_->set_input_x1(*npu_inputs[0]); | ||||
| op_->set_input_x2(*npu_inputs[1]); | op_->set_input_x2(*npu_inputs[1]); | ||||
| if (npu_inputs.size() == 3) { | if (npu_inputs.size() == 3) { | ||||
| op_->set_input_bias(*npu_inputs[2]); | |||||
| matmul_parameter_->has_bias_ = true; | |||||
| add_op_ = new (std::nothrow) hiai::op::Add(name_ + "_add"); | |||||
| if (add_op_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new add op failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| add_op_->set_input_x1(*op_); | |||||
| auto bias_shape = inputs[2]->shape(); | |||||
| auto bias_tensor = std::make_shared<ge::Tensor>(); | |||||
| if (bias_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new bias_tensor failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ge::TensorDesc bias_tensor_desc(lite::ConverterToNPUShape({1, bias_shape[0], 1, 1}), ge::FORMAT_NCHW, | |||||
| lite::ConverterToNPUDataType(inputs[2]->data_type())); | |||||
| if (outputs[0]->shape().size() == 2) { | |||||
| bias_tensor_desc.SetShape(lite::ConverterToNPUShape({1, bias_shape[0]})); | |||||
| } | |||||
| bias_tensor->SetTensorDesc(bias_tensor_desc); | |||||
| bias_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size()); | |||||
| bias_ = new (std::nothrow) hiai::op::Const(name_ + "_bias"); | |||||
| if (bias_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new bias const failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| bias_->set_attr_value(bias_tensor); | |||||
| add_op_->set_input_x2(*bias_); | |||||
| } | } | ||||
| op_->set_attr_transpose_x1(matmul_parameter_->a_transpose_); | op_->set_attr_transpose_x1(matmul_parameter_->a_transpose_); | ||||
| @@ -41,13 +74,26 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; } | |||||
| ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { | |||||
| if (matmul_parameter_->has_bias_) { | |||||
| return add_op_; | |||||
| } | |||||
| return op_; | |||||
| } | |||||
| MatMulNPUKernel::~MatMulNPUKernel() { | MatMulNPUKernel::~MatMulNPUKernel() { | ||||
| if (op_ != nullptr) { | if (op_ != nullptr) { | ||||
| delete op_; | delete op_; | ||||
| op_ = nullptr; | op_ = nullptr; | ||||
| } | } | ||||
| if (add_op_ != nullptr) { | |||||
| delete add_op_; | |||||
| add_op_ = nullptr; | |||||
| } | |||||
| if (bias_ != nullptr) { | |||||
| delete bias_; | |||||
| bias_ = nullptr; | |||||
| } | |||||
| } | } | ||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>) | REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>) | ||||
| @@ -39,6 +39,8 @@ class MatMulNPUKernel : public NPUKernel { | |||||
| private: | private: | ||||
| hiai::op::MatMul *op_ = nullptr; | hiai::op::MatMul *op_ = nullptr; | ||||
| hiai::op::Add *add_op_ = nullptr; | |||||
| hiai::op::Const *bias_ = nullptr; | |||||
| MatMulParameter *matmul_parameter_; | MatMulParameter *matmul_parameter_; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -11,12 +11,12 @@ STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_C_FLAGS "$ | |||||
| STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") | ||||
| if(ENABLE_CONVERTER) | if(ENABLE_CONVERTER) | ||||
| set(CCSRC_SRC | |||||
| ## ccsrc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/visit.cc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc | |||||
| ) | |||||
| set(CCSRC_SRC | |||||
| ## ccsrc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/visit.cc | |||||
| ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc | |||||
| ) | |||||
| else() | else() | ||||
| set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc) | set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc) | ||||
| add_compile_definitions(USE_ANDROID_LOG) | add_compile_definitions(USE_ANDROID_LOG) | ||||
| @@ -38,10 +38,10 @@ file(GLOB KERNEL_OP_SRC | |||||
| file(GLOB KERNEL_OP_TRAIN_SRC | file(GLOB KERNEL_OP_TRAIN_SRC | ||||
| ${LITE_DIR}/nnacl/fp32_grad/*.c | ${LITE_DIR}/nnacl/fp32_grad/*.c | ||||
| ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc | ${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc | ||||
| ) | |||||
| ) | |||||
| if(SUPPORT_TRAIN) | if(SUPPORT_TRAIN) | ||||
| list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) | |||||
| list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) | |||||
| endif() | endif() | ||||
| if(PLATFORM_ARM64) | if(PLATFORM_ARM64) | ||||
| # assembly | # assembly | ||||
| @@ -114,9 +114,9 @@ if(SUPPORT_GPU STREQUAL vulkan) | |||||
| endif() | endif() | ||||
| if(PLATFORM_ARM32 OR PLATFORM_ARM64) | if(PLATFORM_ARM32 OR PLATFORM_ARM64) | ||||
| if(ENABLE_CONVERTER) | |||||
| set(BUILD_MINDDATA "off") | |||||
| endif() | |||||
| if(ENABLE_CONVERTER) | |||||
| set(BUILD_MINDDATA "off") | |||||
| endif() | |||||
| endif() | endif() | ||||
| ### runtime framework | ### runtime framework | ||||
| add_definitions(-DENABLE_V0) | add_definitions(-DENABLE_V0) | ||||
| @@ -189,19 +189,19 @@ if(ENABLE_MINDRT) | |||||
| include_directories(${CORE_DIR}/mindrt/) | include_directories(${CORE_DIR}/mindrt/) | ||||
| include_directories(${CORE_DIR}/mindrt/src/) | include_directories(${CORE_DIR}/mindrt/src/) | ||||
| set(TEST_LITE_SRC ${TEST_LITE_SRC} | set(TEST_LITE_SRC ${TEST_LITE_SRC} | ||||
| ${LITE_DIR}/src/lite_mindrt.cc | |||||
| ${LITE_DIR}/src/mindrt_executor.cc | |||||
| ${CORE_DIR}/mindrt/src/litebus.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actor.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actormgr.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actorpolicy.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actorthread.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/aid.cc | |||||
| ${CORE_DIR}/mindrt/src/async/async.cc | |||||
| ${CORE_DIR}/mindrt/src/async/future.cc | |||||
| ${CORE_DIR}/mindrt/src/async/uuid_base.cc | |||||
| ${CORE_DIR}/mindrt/src/async/uuid_generator.cc | |||||
| ) | |||||
| ${LITE_DIR}/src/lite_mindrt.cc | |||||
| ${LITE_DIR}/src/mindrt_executor.cc | |||||
| ${CORE_DIR}/mindrt/src/litebus.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actor.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actormgr.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actorpolicy.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/actorthread.cc | |||||
| ${CORE_DIR}/mindrt/src/actor/aid.cc | |||||
| ${CORE_DIR}/mindrt/src/async/async.cc | |||||
| ${CORE_DIR}/mindrt/src/async/future.cc | |||||
| ${CORE_DIR}/mindrt/src/async/uuid_base.cc | |||||
| ${CORE_DIR}/mindrt/src/async/uuid_generator.cc | |||||
| ) | |||||
| endif() | endif() | ||||
| @@ -242,6 +242,7 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc | ||||
| @@ -286,16 +287,16 @@ else() | |||||
| endif() | endif() | ||||
| ### test src | ### test src | ||||
| file(GLOB_RECURSE TEST_CASE_KERNEL_SRC | file(GLOB_RECURSE TEST_CASE_KERNEL_SRC | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc | |||||
| ${TEST_DIR}/ut/nnacl/infer/*.cc | |||||
| ) | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc | |||||
| ${TEST_DIR}/ut/nnacl/infer/*.cc | |||||
| ) | |||||
| file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC | file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc | |||||
| ) | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc | |||||
| ) | |||||
| set(TEST_SRC | set(TEST_SRC | ||||
| ${TEST_LITE_SRC} | ${TEST_LITE_SRC} | ||||
| @@ -306,7 +307,7 @@ set(TEST_SRC | |||||
| ${TEST_DIR}/ut/src/infer_test.cc | ${TEST_DIR}/ut/src/infer_test.cc | ||||
| ${TEST_DIR}/ut/src/utils_test.cc | ${TEST_DIR}/ut/src/utils_test.cc | ||||
| ${TEST_DIR}/ut/src/scheduler_test.cc | ${TEST_DIR}/ut/src/scheduler_test.cc | ||||
| ) | |||||
| ) | |||||
| if(ENABLE_CONVERTER) | if(ENABLE_CONVERTER) | ||||
| set(TEST_SRC | set(TEST_SRC | ||||
| @@ -358,7 +359,7 @@ endif() | |||||
| if(ENABLE_FP16 AND SUPPORT_TRAIN) | if(ENABLE_FP16 AND SUPPORT_TRAIN) | ||||
| file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD | file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc) | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc) | |||||
| list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD}) | list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD}) | ||||
| endif() | endif() | ||||
| @@ -52,6 +52,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/tf_lstm_cell_fusion.cc | ../optimizer/fusion/tf_lstm_cell_fusion.cc | ||||
| ../optimizer/fusion/tf_bidirection_gru_fusion.cc | ../optimizer/fusion/tf_bidirection_gru_fusion.cc | ||||
| ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ||||
| ../optimizer/fusion/matmul_add_fusion.cc | |||||
| ../optimizer/graph/weight_format_transform_pass.cc | ../optimizer/graph/weight_format_transform_pass.cc | ||||
| ../optimizer/graph/weight_format_hardcode_pass.cc | ../optimizer/graph/weight_format_hardcode_pass.cc | ||||
| ../optimizer/graph/clip_convert_activation_pass.cc | ../optimizer/graph/clip_convert_activation_pass.cc | ||||
| @@ -35,6 +35,7 @@ | |||||
| #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" | ||||
| #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" | #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" | ||||
| #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | #include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h" | ||||
| #include "tools/optimizer/fusion/matmul_add_fusion.h" | |||||
| #include "tools/optimizer/graph/primitive_adjust_pass.h" | #include "tools/optimizer/graph/primitive_adjust_pass.h" | ||||
| #include "tools/optimizer/graph/mindir_adjust_pass.h" | #include "tools/optimizer/graph/mindir_adjust_pass.h" | ||||
| #include "tools/optimizer/graph/redundant_op_remove_pass.h" | #include "tools/optimizer/graph/redundant_op_remove_pass.h" | ||||
| @@ -107,6 +108,9 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti | |||||
| fusion_pm->AddPass(remove_unused_transpose_pass); | fusion_pm->AddPass(remove_unused_transpose_pass); | ||||
| } | } | ||||
| fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); | ||||
| if (!config->trainModel) { | |||||
| fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>()); | |||||
| } | |||||
| optimizer->AddPassManager(fusion_pm); | optimizer->AddPassManager(fusion_pm); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/onnx/onnx_gemm_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "ops/make_tuple.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { | |||||
| auto prim = std::make_unique<ops::MakeTuple>(); | |||||
| auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul"); | |||||
| if (node_parser == nullptr) { | |||||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *matmul_primitive = node_parser->Parse(onnx_graph, onnx_node); | |||||
| prim->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive)); | |||||
| node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd"); | |||||
| if (node_parser == nullptr) { | |||||
| MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *bias_add_primitive = node_parser->Parse(onnx_graph, onnx_node); | |||||
| prim->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive)); | |||||
| return prim.release(); | |||||
| } | |||||
| OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,34 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser.h" | |||||
| #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class OnnxGemmParser : public OnnxNodeParser { | |||||
| public: | |||||
| OnnxGemmParser() : OnnxNodeParser("Gemm") {} | |||||
| ~OnnxGemmParser() override = default; | |||||
| ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H | |||||
| @@ -46,5 +46,6 @@ ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, con | |||||
| } | } | ||||
| OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser()); | ||||
| OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxMatmulParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,7 +44,6 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = { | |||||
| {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, | {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, | ||||
| {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; | {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; | ||||
| std::set<std::string> SPECIAL_NODE = {"Gemm"}; | |||||
| FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, | FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, | ||||
| const QuantType &quant_type) { | const QuantType &quant_type) { | ||||
| NoSupportOp::GetInstance()->SetFmkType("ONNX"); | NoSupportOp::GetInstance()->SetFmkType("ONNX"); | ||||
| @@ -215,11 +214,6 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F | |||||
| MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; | MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsSpecialOnnxNode(onnx_node)) { | |||||
| auto status_node = ConvertSpecialOnnxNode(onnx_node, anf_graph, anf_nodes_map, primitive_c); | |||||
| status = status == RET_OK ? status_node : status; | |||||
| continue; | |||||
| } | |||||
| // build CNode | // build CNode | ||||
| status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name); | status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -1023,117 +1017,6 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf | |||||
| return status; | return status; | ||||
| } | } | ||||
| STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, | |||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||||
| ops::PrimitiveC *primitive_c) { | |||||
| if (primitive_c == nullptr || anf_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "imitive_c is nullptr."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| STATUS status = RET_OK; | |||||
| if (onnx_node.op_type() == "Gemm") { | |||||
| status = ConvertOnnxGemmNode(onnx_node, anf_graph, anf_nodes_map, primitive_c); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "the node is not special node."; | |||||
| status = RET_ERROR; | |||||
| } | |||||
| delete primitive_c; | |||||
| return status; | |||||
| } | |||||
| STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, | |||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||||
| ops::PrimitiveC *primitive_c) { | |||||
| if (primitive_c == nullptr || anf_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter has nullptr."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (onnx_node.op_type() != "Gemm") { | |||||
| MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is nullptr."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "MatMul"); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert gemm node failed."; | |||||
| return status; | |||||
| } | |||||
| status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "BiasAdd"); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert gemm node failed."; | |||||
| return status; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph, | |||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||||
| ops::PrimitiveC *primitive_c, const std::string &name) { | |||||
| if (primitive_c == nullptr || anf_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "parameter has nullptr."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto value = primitive_c->GetAttr(name); | |||||
| primitive_c->EraseAttr(name); | |||||
| if (value == nullptr) { | |||||
| MS_LOG(ERROR) << "op parse failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto prim_ptr = value->cast<std::shared_ptr<ops::PrimitiveC>>(); | |||||
| if (prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive parse failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto type_ptr = TypeIdToType(kTypeUnknown); | |||||
| std::vector<int64_t> shape_vector; | |||||
| std::vector<AnfNodePtr> op_inputs; | |||||
| auto quant_params_holder = std::make_shared<QuantParamHolder>(); | |||||
| auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | |||||
| if (name == "MatMul") { | |||||
| for (int i = 0; i < 2; ++i) { | |||||
| if (anf_nodes_map->find(onnx_node.input(i)) == anf_nodes_map->end()) { | |||||
| MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; | |||||
| return RET_ERROR; | |||||
| } else { | |||||
| op_inputs.push_back(anf_nodes_map->at(onnx_node.input(i))); | |||||
| quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(i)); | |||||
| } | |||||
| } | |||||
| quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1)); | |||||
| auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs); | |||||
| if (new_cnode == nullptr) { | |||||
| MS_LOG(ERROR) << "new cnode error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); | |||||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| anf_nodes_map->emplace("Gemm_MatMul_" + onnx_node.output(0), new_cnode); | |||||
| } else { | |||||
| if (anf_nodes_map->find("Gemm_MatMul_" + onnx_node.output(0)) == anf_nodes_map->end() || | |||||
| anf_nodes_map->find(onnx_node.input(2)) == anf_nodes_map->end()) { | |||||
| MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| op_inputs.push_back(anf_nodes_map->at("Gemm_MatMul_" + onnx_node.output(0))); | |||||
| op_inputs.push_back(anf_nodes_map->at(onnx_node.input(2))); | |||||
| quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1)); | |||||
| quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(2)); | |||||
| quant_params_holder->AddOutputQuantParam(quant_params_holder_origin->output_quant_params().front()); | |||||
| auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs); | |||||
| if (new_cnode == nullptr) { | |||||
| MS_LOG(ERROR) << "new cnode error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); | |||||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| anf_nodes_map->emplace(onnx_node.output(0), new_cnode); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) { | STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) { | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| MS_LOG(ERROR) << "value is nullptr."; | MS_LOG(ERROR) << "value is nullptr."; | ||||
| @@ -1281,10 +1164,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_t | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| bool OnnxModelParser::IsSpecialOnnxNode(const onnx::NodeProto &onnx_node) { | |||||
| return SPECIAL_NODE.find(onnx_node.op_type()) != SPECIAL_NODE.end(); | |||||
| } | |||||
| TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { | TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) { | ||||
| auto iter = TYPE_MAP.find(onnx_type); | auto iter = TYPE_MAP.find(onnx_type); | ||||
| if (iter == TYPE_MAP.end()) { | if (iter == TYPE_MAP.end()) { | ||||
| @@ -69,21 +69,11 @@ class OnnxModelParser : public ModelParser { | |||||
| ops::PrimitiveC *primitive_c, std::string loop_name); | ops::PrimitiveC *primitive_c, std::string loop_name); | ||||
| static STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | static STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | ||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode); | std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode); | ||||
| static STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||||
| ops::PrimitiveC *primitive_c); | |||||
| static STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||||
| ops::PrimitiveC *primitive_c); | |||||
| static STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr, | |||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | |||||
| ops::PrimitiveC *primitive_c, const std::string &name); | |||||
| STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); | STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c); | ||||
| STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); | STATUS ParseQuantParam(const onnx::NodeProto &onnx_node); | ||||
| STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | ||||
| STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params); | ||||
| STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); | STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not); | ||||
| static bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node); | |||||
| STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node, | STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node, | ||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, | ||||
| const std::string &root_node_name); | const std::string &root_node_name); | ||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/matmul_add_fusion.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t AddInputSize = 3; | |||||
| constexpr size_t MatMulInputSize = 3; | |||||
| bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| MS_ASSERT(index != nullptr); | |||||
| if (cnode->size() != AddInputSize) { | |||||
| return false; | |||||
| } | |||||
| size_t matmul_index = 0; | |||||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||||
| if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) { | |||||
| auto matmul_cnode = cnode->input(i)->cast<CNodePtr>(); | |||||
| if (matmul_cnode->size() > MatMulInputSize) { | |||||
| continue; | |||||
| } | |||||
| matmul_index = i; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (matmul_index == 0) { | |||||
| return false; | |||||
| } | |||||
| *index = matmul_index; | |||||
| return true; | |||||
| } | |||||
| } // namespace | |||||
| bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_ASSERT(func_graph != nulltr); | |||||
| auto node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (!utils::isa<CNode>(node)) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) { | |||||
| continue; | |||||
| } | |||||
| size_t index = 0; | |||||
| if (!CheckAndGetMatMulIndex(cnode, &index)) { | |||||
| continue; | |||||
| } | |||||
| auto matmul_cnode = cnode->input(index)->cast<CNodePtr>(); | |||||
| auto bias_node = cnode->input(AddInputSize - index); | |||||
| if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) { | |||||
| continue; | |||||
| } | |||||
| matmul_cnode->add_input(bias_node); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_ASSERT(manager != nullptr); | |||||
| matmul_cnode->set_fullname_with_scope(node->fullname_with_scope()); | |||||
| manager->Replace(node, matmul_cnode); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "tools/converter/converter_context.h" | |||||
| #include "backend/optimizer/common/pass.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class MatMulAddFusion : public Pass { | |||||
| public: | |||||
| MatMulAddFusion() : Pass("matmul_add_fusion") {} | |||||
| ~MatMulAddFusion() override = default; | |||||
| bool Run(const FuncGraphPtr &func_graph) override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_ | |||||