| @@ -388,6 +388,9 @@ set(TRAIN_SRC_LIST | |||||
| "client/ge_api.cc" | "client/ge_api.cc" | ||||
| "analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
| "ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
| "ir_build/attr_options/utils.cc" | |||||
| "ir_build/attr_options/keep_dtype_option.cc" | |||||
| "ir_build/attr_options/weight_compress_option.cc" | |||||
| "ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
| "graph/build/memory/memory_assigner.cc" | "graph/build/memory/memory_assigner.cc" | ||||
| "graph/build/memory/graph_mem_assigner.cc" | "graph/build/memory/graph_mem_assigner.cc" | ||||
| @@ -641,6 +644,9 @@ set(INFER_SRC_LIST | |||||
| "graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | "graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | ||||
| "hybrid/hybrid_davinci_model_stub.cc" | "hybrid/hybrid_davinci_model_stub.cc" | ||||
| "ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
| "ir_build/attr_options/utils.cc" | |||||
| "ir_build/attr_options/keep_dtype_option.cc" | |||||
| "ir_build/attr_options/weight_compress_option.cc" | |||||
| "ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
| @@ -70,6 +70,9 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ | |||||
| BUILER_SRC_FILES := \ | BUILER_SRC_FILES := \ | ||||
| ir_build/ge_ir_build.cc \ | ir_build/ge_ir_build.cc \ | ||||
| ir_build/attr_options/utils.cc \ | |||||
| ir_build/attr_options/keep_dtype_option.cc \ | |||||
| ir_build/attr_options/weight_compress_option.cc \ | |||||
| ir_build/atc_ir_common.cc \ | ir_build/atc_ir_common.cc \ | ||||
| ANALYZER_SRC_FILES:= \ | ANALYZER_SRC_FILES:= \ | ||||
| @@ -312,6 +312,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| executor/ge_executor.cc \ | executor/ge_executor.cc \ | ||||
| analyzer/analyzer.cc \ | analyzer/analyzer.cc \ | ||||
| ir_build/ge_ir_build.cc \ | ir_build/ge_ir_build.cc \ | ||||
| ir_build/attr_options/utils.cc \ | |||||
| ir_build/attr_options/keep_dtype_option.cc \ | |||||
| ir_build/attr_options/weight_compress_option.cc \ | |||||
| ir_build/atc_ir_common.cc \ | ir_build/atc_ir_common.cc \ | ||||
| LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
| @@ -13,14 +13,17 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef KEEP_DTYPE_OPTION_H_ | |||||
| #define KEEP_DTYPE_OPTION_H_ | |||||
| #ifndef ATTR_OPTIONS_H_ | |||||
| #define ATTR_OPTIONS_H_ | |||||
| #include <string> | #include <string> | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| namespace ge { | namespace ge { | ||||
| Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype); | |||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | |||||
| graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
| graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
| } // namespace | } // namespace | ||||
| #endif // KEEP_DTYPE_OPTION_H_ | |||||
| #endif // ATTR_OPTIONS_H_ | |||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "keep_dtype_option.h" | |||||
| #include "attr_options.h" | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <sstream> | #include <sstream> | ||||
| @@ -26,20 +26,6 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const size_t kMaxOpsNum = 10; | const size_t kMaxOpsNum = 10; | ||||
| } // namespace | } // namespace | ||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
| std::vector<std::string> original_op_names; | |||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
| return false; | |||||
| } | |||||
| for (auto &origin_name : original_op_names) { | |||||
| if (origin_name == op_name) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | ||||
| std::stringstream err_msg; | std::stringstream err_msg; | ||||
| @@ -67,20 +53,20 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | |||||
| GELOGE(FAILED, "%s", err_msg.str().c_str()); | GELOGE(FAILED, "%s", err_msg.str().c_str()); | ||||
| } | } | ||||
| Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) { | |||||
| graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| if (keep_dtype.empty()) { | |||||
| return SUCCESS; | |||||
| if (cfg_path.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | } | ||||
| std::string real_path = RealPath(keep_dtype.c_str()); | |||||
| std::string real_path = RealPath(cfg_path.c_str()); | |||||
| if (real_path.empty()) { | if (real_path.empty()) { | ||||
| GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str()); | |||||
| return PARAM_INVALID; | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | } | ||||
| std::ifstream ifs(real_path); | std::ifstream ifs(real_path); | ||||
| if (!ifs.is_open()) { | if (!ifs.is_open()) { | ||||
| GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str()); | |||||
| return FAILED; | |||||
| GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | } | ||||
| std::string op_name; | std::string op_name; | ||||
| @@ -108,9 +94,9 @@ Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep | |||||
| if (!invalid_list.empty()) { | if (!invalid_list.empty()) { | ||||
| KeepDtypeReportError(invalid_list); | KeepDtypeReportError(invalid_list); | ||||
| return PARAM_INVALID; | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | } | ||||
| return SUCCESS; | |||||
| return GRAPH_SUCCESS; | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "attr_options.h" | |||||
| #include <vector> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| namespace ge { | |||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
| std::vector<std::string> original_op_names; | |||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
| return false; | |||||
| } | |||||
| for (auto &origin_name : original_op_names) { | |||||
| if (origin_name == op_name) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "attr_options.h" | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <sstream> | |||||
| #include <vector> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| namespace ge { | |||||
| graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const string &cfg_path) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| if (cfg_path.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::string real_path = RealPath(cfg_path.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| std::ifstream ifs(real_path); | |||||
| if (!ifs.is_open()) { | |||||
| GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string compress_nodes; | |||||
| ifs >> compress_nodes; | |||||
| ifs.close(); | |||||
| GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||||
| vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||||
| for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||||
| for (auto &node_ptr : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node_ptr); | |||||
| auto op_desc = node_ptr->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if ((op_desc->GetName() == compress_node_vec[i]) || IsOriginalOpFind(op_desc, compress_node_vec[i])) { | |||||
| if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||||
| GELOGE(GRAPH_FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -39,6 +39,7 @@ | |||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| #include "graph/passes/net_output_pass.h" | #include "graph/passes/net_output_pass.h" | ||||
| #include "graph/passes/data_pass.h" | #include "graph/passes/data_pass.h" | ||||
| #include "ir_build/attr_options/attr_options.h" | |||||
| using std::string; | using std::string; | ||||
| using namespace std; | using namespace std; | ||||
| @@ -52,8 +53,28 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default"; | |||||
| const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | ||||
| const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | ||||
| const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | ||||
| const std::string KEEP_DTYPE_OPTION = "keep_dtype"; | |||||
| const std::string kInputShape = "input_shape"; | const std::string kInputShape = "input_shape"; | ||||
| const std::string kInputFormat = "input_format"; | const std::string kInputFormat = "input_format"; | ||||
| /** | |||||
| * @name SetOpAttrFun | |||||
| * @brief set attribute for operators in the configuration file | |||||
| * @param graph [IN/OUT] compute graph | |||||
| * @param cfg_path [IN] the config file path | |||||
| * @return graphStatus | |||||
| */ | |||||
| typedef graphStatus (*SetOpAttrFun)(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
| const std::map<aclgrphAttrType, SetOpAttrFun> kAttrTypeFuncMap = { | |||||
| {ATTR_TYPE_KEEP_DTYPE, KeepDtypeFunc}, | |||||
| {ATTR_TYPE_WEIGHT_COMPRESS, WeightCompressFunc} | |||||
| }; | |||||
| const std::map<aclgrphAttrType, std::string> kAttrTypeToStringMap = { | |||||
| {ATTR_TYPE_KEEP_DTYPE, KEEP_DTYPE_OPTION}, | |||||
| {ATTR_TYPE_WEIGHT_COMPRESS, ge::ir_option::COMPRESS_WEIGHT_CONF} | |||||
| }; | |||||
| } // namespace | } // namespace | ||||
| static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | ||||
| @@ -703,4 +724,33 @@ graphStatus aclgrphGenerateForOp(const AscendString &op_type, const vector<Tenso | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| static std::string AttrTypeToSerialString(aclgrphAttrType attr_type) { | |||||
| auto it = kAttrTypeToStringMap.find(attr_type); | |||||
| if (it != kAttrTypeToStringMap.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, | |||||
| {"AttrTypeToSerialString", "attr_type[" + std::to_string(attr_type) + "] is not support"}); | |||||
| GELOGE(GRAPH_FAILED, "AttrTypeToSerialString: attr_type not support %u", attr_type); | |||||
| return "UNDEFINED"; | |||||
| } | |||||
| } | |||||
| graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path) { | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| if (cfg_path == nullptr) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto iter = kAttrTypeFuncMap.find(attr_type); | |||||
| if (iter == kAttrTypeFuncMap.end()) { | |||||
| GELOGE(GRAPH_FAILED, "attr type: %s is not support", AttrTypeToSerialString(attr_type).c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string path = cfg_path; | |||||
| return iter->second(compute_graph, path); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -10,7 +10,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
| set(SRC_LIST | set(SRC_LIST | ||||
| "main.cc" | "main.cc" | ||||
| "single_op_parser.cc" | "single_op_parser.cc" | ||||
| "keep_dtype_option.cc" | |||||
| "../session/omg.cc" | "../session/omg.cc" | ||||
| "../ir_build/atc_ir_common.cc" | "../ir_build/atc_ir_common.cc" | ||||
| ) | ) | ||||
| @@ -43,7 +43,7 @@ | |||||
| #include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "single_op_parser.h" | #include "single_op_parser.h" | ||||
| #include "keep_dtype_option.h" | |||||
| #include "external/ge/ge_ir_build.h" | |||||
| using domi::BuildMode; | using domi::BuildMode; | ||||
| using domi::OpRegistrationData; | using domi::OpRegistrationData; | ||||
| @@ -913,6 +913,22 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| static Status SetAttrOptions(ge::Graph &graph) { | |||||
| if (!FLAGS_keep_dtype.empty()) { | |||||
| if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_KEEP_DTYPE, FLAGS_keep_dtype.c_str()) != ge::GRAPH_SUCCESS) { | |||||
| return ge::FAILED; | |||||
| } | |||||
| } | |||||
| if (!FLAGS_compress_weight_conf.empty()) { | |||||
| if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_WEIGHT_COMPRESS, FLAGS_compress_weight_conf.c_str()) | |||||
| != ge::GRAPH_SUCCESS) { | |||||
| return ge::FAILED; | |||||
| } | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| domi::Status GenerateModel(std::map<string, string> &options, std::string output) { | domi::Status GenerateModel(std::map<string, string> &options, std::string output) { | ||||
| ge::GeGenerator ge_generator; | ge::GeGenerator ge_generator; | ||||
| ge::Status geRet = ge::SUCCESS; | ge::Status geRet = ge::SUCCESS; | ||||
| @@ -969,7 +985,6 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
| atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes)); | atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes)); | ||||
| atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); | atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); | ||||
| atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); | atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); | ||||
| atc_params.insert(std::pair<string, string>("compress_weight_conf", FLAGS_compress_weight_conf)); | |||||
| atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); | atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); | ||||
| atc_params.insert(std::pair<string, string>("output", output)); | atc_params.insert(std::pair<string, string>("output", output)); | ||||
| @@ -1003,11 +1018,10 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
| } | } | ||||
| } | } | ||||
| Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype); | |||||
| if (ret != SUCCESS) { | |||||
| if (SetAttrOptions(graph) != ge::SUCCESS) { | |||||
| (void)ge_generator.Finalize(); | (void)ge_generator.Finalize(); | ||||
| (void)ge::GELib::GetInstance()->Finalize(); | (void)ge::GELib::GetInstance()->Finalize(); | ||||
| return ret; | |||||
| return domi::FAILED; | |||||
| } | } | ||||
| geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | ||||
| @@ -10,7 +10,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| main.cc \ | main.cc \ | ||||
| keep_dtype_option.cc \ | |||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
| @@ -64,7 +63,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| main.cc \ | main.cc \ | ||||
| keep_dtype_option.cc \ | |||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
| @@ -118,7 +116,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| main.cc \ | main.cc \ | ||||
| keep_dtype_option.cc \ | |||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
| @@ -193,44 +193,6 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| static Status SetWeightCompressNodes(const ComputeGraphPtr &graph, const string &compress_weight_conf) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| if (compress_weight_conf.empty()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| std::string real_path = RealPath(compress_weight_conf.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(PARAM_INVALID, "Can not get real path for %s.", compress_weight_conf.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| std::ifstream ifs(real_path); | |||||
| if (!ifs.is_open()) { | |||||
| GELOGE(domi::FAILED, "Open file %s failed", compress_weight_conf.c_str()); | |||||
| return domi::FAILED; | |||||
| } | |||||
| std::string compress_nodes; | |||||
| ifs >> compress_nodes; | |||||
| ifs.close(); | |||||
| GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||||
| vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||||
| for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||||
| ge::NodePtr node = graph->FindNode(compress_node_vec[i]); | |||||
| if (node == nullptr) { | |||||
| GELOGW("node %s is not in graph", compress_node_vec[i].c_str()); | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||||
| GELOGE(domi::FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str()); | |||||
| return domi::FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { | static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { | ||||
| if (is_output_fp16.empty()) { | if (is_output_fp16.empty()) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -800,10 +762,6 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri | |||||
| GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); | GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); | ||||
| std::string compress_weight_conf; | |||||
| ParseAtcParms(atc_params, "compress_weight_conf", compress_weight_conf); | |||||
| GE_RETURN_IF_ERROR(SetWeightCompressNodes(compute_graph, compress_weight_conf)); | |||||
| // Verify the contents of the op_name_map | // Verify the contents of the op_name_map | ||||
| if (op_conf != nullptr && *op_conf != '\0') { | if (op_conf != nullptr && *op_conf != '\0') { | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf), | GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf), | ||||
| @@ -50,6 +50,8 @@ struct ModelBufferData { | |||||
| uint64_t length; | uint64_t length; | ||||
| }; | }; | ||||
| enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS }; | |||||
| /** | /** | ||||
| * @ingroup AscendCL | * @ingroup AscendCL | ||||
| * @brief build model.Notice the model is stored in buffer | * @brief build model.Notice the model is stored in buffer | ||||
| @@ -80,13 +82,16 @@ GE_FUNC_VISIBILITY void aclgrphBuildFinalize(); | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<AscendString, AscendString> &, | |||||
| ModelBufferData &)) | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | |||||
| ModelBufferData &model); | |||||
| ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, | |||||
| const std::map<AscendString, AscendString> &, | |||||
| ModelBufferData &)) | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, | |||||
| const std::map<std::string, std::string> &build_options, | |||||
| ModelBufferData &model); | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options, | |||||
| ModelBufferData &model); | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, | |||||
| const std::map<AscendString, AscendString> &build_options, | |||||
| ModelBufferData &model); | |||||
| /** | /** | ||||
| * @ingroup AscendCL | * @ingroup AscendCL | ||||
| @@ -138,7 +143,17 @@ GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const ch | |||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs, | GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs, | ||||
| const std::vector<TensorDesc> &outputs, Graph &graph); | |||||
| const std::vector<TensorDesc> &outputs, Graph &graph); | |||||
| /** | |||||
| * @name aclgrphSetOpAttr | |||||
| * @brief set attribute for operators in the configuration file | |||||
| * @param graph [IN/OUT] compute graph | |||||
| * @param attr_type [In] attribute type | |||||
| * @param cfg_path [IN] the config file path | |||||
| * @return graphStatus | |||||
| */ | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path); | |||||
| }; // namespace ge | }; // namespace ge | ||||
| #endif // INC_EXTERNAL_GE_IR_BUILD_H_ | #endif // INC_EXTERNAL_GE_IR_BUILD_H_ | ||||
| @@ -274,6 +274,9 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | ||||
| "${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" | "${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" | ||||
| "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | ||||
| "${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc" | |||||
| "${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" | |||||
| "${GE_CODE_DIR}/ge/ir_build/attr_options/weight_compress_option.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/build/label_allocator.cc" | "${GE_CODE_DIR}/ge/graph/build/label_allocator.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | ||||