@@ -388,6 +388,9 @@ set(TRAIN_SRC_LIST | |||||
"client/ge_api.cc" | "client/ge_api.cc" | ||||
"analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
"ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
"ir_build/attr_options/utils.cc" | |||||
"ir_build/attr_options/keep_dtype_option.cc" | |||||
"ir_build/attr_options/weight_compress_option.cc" | |||||
"ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
"graph/build/memory/memory_assigner.cc" | "graph/build/memory/memory_assigner.cc" | ||||
"graph/build/memory/graph_mem_assigner.cc" | "graph/build/memory/graph_mem_assigner.cc" | ||||
@@ -641,6 +644,9 @@ set(INFER_SRC_LIST | |||||
"graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | "graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | ||||
"hybrid/hybrid_davinci_model_stub.cc" | "hybrid/hybrid_davinci_model_stub.cc" | ||||
"ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
"ir_build/attr_options/utils.cc" | |||||
"ir_build/attr_options/keep_dtype_option.cc" | |||||
"ir_build/attr_options/weight_compress_option.cc" | |||||
"ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
@@ -70,6 +70,9 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ | |||||
BUILER_SRC_FILES := \ | BUILER_SRC_FILES := \ | ||||
ir_build/ge_ir_build.cc \ | ir_build/ge_ir_build.cc \ | ||||
ir_build/attr_options/utils.cc \ | |||||
ir_build/attr_options/keep_dtype_option.cc \ | |||||
ir_build/attr_options/weight_compress_option.cc \ | |||||
ir_build/atc_ir_common.cc \ | ir_build/atc_ir_common.cc \ | ||||
ANALYZER_SRC_FILES:= \ | ANALYZER_SRC_FILES:= \ | ||||
@@ -312,6 +312,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
executor/ge_executor.cc \ | executor/ge_executor.cc \ | ||||
analyzer/analyzer.cc \ | analyzer/analyzer.cc \ | ||||
ir_build/ge_ir_build.cc \ | ir_build/ge_ir_build.cc \ | ||||
ir_build/attr_options/utils.cc \ | |||||
ir_build/attr_options/keep_dtype_option.cc \ | |||||
ir_build/attr_options/weight_compress_option.cc \ | |||||
ir_build/atc_ir_common.cc \ | ir_build/atc_ir_common.cc \ | ||||
LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
@@ -13,14 +13,17 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef KEEP_DTYPE_OPTION_H_ | |||||
#define KEEP_DTYPE_OPTION_H_ | |||||
#ifndef ATTR_OPTIONS_H_ | |||||
#define ATTR_OPTIONS_H_ | |||||
#include <string> | #include <string> | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | |||||
#include "graph/ge_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype); | |||||
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | |||||
graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
} // namespace | } // namespace | ||||
#endif // KEEP_DTYPE_OPTION_H_ | |||||
#endif // ATTR_OPTIONS_H_ |
@@ -13,7 +13,7 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "keep_dtype_option.h" | |||||
#include "attr_options.h" | |||||
#include <fstream> | #include <fstream> | ||||
#include <iostream> | #include <iostream> | ||||
#include <sstream> | #include <sstream> | ||||
@@ -26,20 +26,6 @@ namespace ge { | |||||
namespace { | namespace { | ||||
const size_t kMaxOpsNum = 10; | const size_t kMaxOpsNum = 10; | ||||
} // namespace | } // namespace | ||||
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
std::vector<std::string> original_op_names; | |||||
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
return false; | |||||
} | |||||
for (auto &origin_name : original_op_names) { | |||||
if (origin_name == op_name) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | ||||
std::stringstream err_msg; | std::stringstream err_msg; | ||||
@@ -67,20 +53,20 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | |||||
GELOGE(FAILED, "%s", err_msg.str().c_str()); | GELOGE(FAILED, "%s", err_msg.str().c_str()); | ||||
} | } | ||||
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) { | |||||
graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
if (keep_dtype.empty()) { | |||||
return SUCCESS; | |||||
if (cfg_path.empty()) { | |||||
return GRAPH_SUCCESS; | |||||
} | } | ||||
std::string real_path = RealPath(keep_dtype.c_str()); | |||||
std::string real_path = RealPath(cfg_path.c_str()); | |||||
if (real_path.empty()) { | if (real_path.empty()) { | ||||
GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str()); | |||||
return PARAM_INVALID; | |||||
GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | } | ||||
std::ifstream ifs(real_path); | std::ifstream ifs(real_path); | ||||
if (!ifs.is_open()) { | if (!ifs.is_open()) { | ||||
GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str()); | |||||
return FAILED; | |||||
GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | } | ||||
std::string op_name; | std::string op_name; | ||||
@@ -108,9 +94,9 @@ Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep | |||||
if (!invalid_list.empty()) { | if (!invalid_list.empty()) { | ||||
KeepDtypeReportError(invalid_list); | KeepDtypeReportError(invalid_list); | ||||
return PARAM_INVALID; | |||||
return GRAPH_PARAM_INVALID; | |||||
} | } | ||||
return SUCCESS; | |||||
return GRAPH_SUCCESS; | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "attr_options.h" | |||||
#include <vector> | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "common/util/error_manager/error_manager.h" | |||||
namespace ge { | |||||
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
std::vector<std::string> original_op_names; | |||||
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
return false; | |||||
} | |||||
for (auto &origin_name : original_op_names) { | |||||
if (origin_name == op_name) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,64 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "attr_options.h" | |||||
#include <fstream> | |||||
#include <iostream> | |||||
#include <sstream> | |||||
#include <vector> | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "framework/common/util.h" | |||||
#include "common/util/error_manager/error_manager.h" | |||||
namespace ge { | |||||
graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const string &cfg_path) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
if (cfg_path.empty()) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
std::string real_path = RealPath(cfg_path.c_str()); | |||||
if (real_path.empty()) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str()); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
std::ifstream ifs(real_path); | |||||
if (!ifs.is_open()) { | |||||
GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
std::string compress_nodes; | |||||
ifs >> compress_nodes; | |||||
ifs.close(); | |||||
GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||||
vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||||
for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||||
for (auto &node_ptr : graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node_ptr); | |||||
auto op_desc = node_ptr->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if ((op_desc->GetName() == compress_node_vec[i]) || IsOriginalOpFind(op_desc, compress_node_vec[i])) { | |||||
if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||||
GELOGE(GRAPH_FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -39,6 +39,7 @@ | |||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
#include "graph/passes/net_output_pass.h" | #include "graph/passes/net_output_pass.h" | ||||
#include "graph/passes/data_pass.h" | #include "graph/passes/data_pass.h" | ||||
#include "ir_build/attr_options/attr_options.h" | |||||
using std::string; | using std::string; | ||||
using namespace std; | using namespace std; | ||||
@@ -52,8 +53,28 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default"; | |||||
const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | ||||
const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | ||||
const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | ||||
const std::string KEEP_DTYPE_OPTION = "keep_dtype"; | |||||
const std::string kInputShape = "input_shape"; | const std::string kInputShape = "input_shape"; | ||||
const std::string kInputFormat = "input_format"; | const std::string kInputFormat = "input_format"; | ||||
/** | |||||
* @name SetOpAttrFun | |||||
* @brief set attribute for operators in the configuration file | |||||
* @param graph [IN/OUT] compute graph | |||||
* @param cfg_path [IN] the config file path | |||||
* @return graphStatus | |||||
*/ | |||||
typedef graphStatus (*SetOpAttrFun)(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
const std::map<aclgrphAttrType, SetOpAttrFun> kAttrTypeFuncMap = { | |||||
{ATTR_TYPE_KEEP_DTYPE, KeepDtypeFunc}, | |||||
{ATTR_TYPE_WEIGHT_COMPRESS, WeightCompressFunc} | |||||
}; | |||||
const std::map<aclgrphAttrType, std::string> kAttrTypeToStringMap = { | |||||
{ATTR_TYPE_KEEP_DTYPE, KEEP_DTYPE_OPTION}, | |||||
{ATTR_TYPE_WEIGHT_COMPRESS, ge::ir_option::COMPRESS_WEIGHT_CONF} | |||||
}; | |||||
} // namespace | } // namespace | ||||
static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | ||||
@@ -703,4 +724,33 @@ graphStatus aclgrphGenerateForOp(const AscendString &op_type, const vector<Tenso | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
static std::string AttrTypeToSerialString(aclgrphAttrType attr_type) { | |||||
auto it = kAttrTypeToStringMap.find(attr_type); | |||||
if (it != kAttrTypeToStringMap.end()) { | |||||
return it->second; | |||||
} else { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, | |||||
{"AttrTypeToSerialString", "attr_type[" + std::to_string(attr_type) + "] is not support"}); | |||||
GELOGE(GRAPH_FAILED, "AttrTypeToSerialString: attr_type not support %u", attr_type); | |||||
return "UNDEFINED"; | |||||
} | |||||
} | |||||
graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path) { | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
if (cfg_path == nullptr) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
auto iter = kAttrTypeFuncMap.find(attr_type); | |||||
if (iter == kAttrTypeFuncMap.end()) { | |||||
GELOGE(GRAPH_FAILED, "attr type: %s is not support", AttrTypeToSerialString(attr_type).c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
std::string path = cfg_path; | |||||
return iter->second(compute_graph, path); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -10,7 +10,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
set(SRC_LIST | set(SRC_LIST | ||||
"main.cc" | "main.cc" | ||||
"single_op_parser.cc" | "single_op_parser.cc" | ||||
"keep_dtype_option.cc" | |||||
"../session/omg.cc" | "../session/omg.cc" | ||||
"../ir_build/atc_ir_common.cc" | "../ir_build/atc_ir_common.cc" | ||||
) | ) | ||||
@@ -43,7 +43,7 @@ | |||||
#include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
#include "single_op_parser.h" | #include "single_op_parser.h" | ||||
#include "keep_dtype_option.h" | |||||
#include "external/ge/ge_ir_build.h" | |||||
using domi::BuildMode; | using domi::BuildMode; | ||||
using domi::OpRegistrationData; | using domi::OpRegistrationData; | ||||
@@ -913,6 +913,22 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s | |||||
return ret; | return ret; | ||||
} | } | ||||
static Status SetAttrOptions(ge::Graph &graph) { | |||||
if (!FLAGS_keep_dtype.empty()) { | |||||
if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_KEEP_DTYPE, FLAGS_keep_dtype.c_str()) != ge::GRAPH_SUCCESS) { | |||||
return ge::FAILED; | |||||
} | |||||
} | |||||
if (!FLAGS_compress_weight_conf.empty()) { | |||||
if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_WEIGHT_COMPRESS, FLAGS_compress_weight_conf.c_str()) | |||||
!= ge::GRAPH_SUCCESS) { | |||||
return ge::FAILED; | |||||
} | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
domi::Status GenerateModel(std::map<string, string> &options, std::string output) { | domi::Status GenerateModel(std::map<string, string> &options, std::string output) { | ||||
ge::GeGenerator ge_generator; | ge::GeGenerator ge_generator; | ||||
ge::Status geRet = ge::SUCCESS; | ge::Status geRet = ge::SUCCESS; | ||||
@@ -969,7 +985,6 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes)); | atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes)); | ||||
atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); | atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); | ||||
atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); | atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); | ||||
atc_params.insert(std::pair<string, string>("compress_weight_conf", FLAGS_compress_weight_conf)); | |||||
atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); | atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); | ||||
atc_params.insert(std::pair<string, string>("output", output)); | atc_params.insert(std::pair<string, string>("output", output)); | ||||
@@ -1003,11 +1018,10 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
} | } | ||||
} | } | ||||
Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype); | |||||
if (ret != SUCCESS) { | |||||
if (SetAttrOptions(graph) != ge::SUCCESS) { | |||||
(void)ge_generator.Finalize(); | (void)ge_generator.Finalize(); | ||||
(void)ge::GELib::GetInstance()->Finalize(); | (void)ge::GELib::GetInstance()->Finalize(); | ||||
return ret; | |||||
return domi::FAILED; | |||||
} | } | ||||
geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | ||||
@@ -10,7 +10,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
main.cc \ | main.cc \ | ||||
keep_dtype_option.cc \ | |||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
@@ -64,7 +63,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
main.cc \ | main.cc \ | ||||
keep_dtype_option.cc \ | |||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
@@ -118,7 +116,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
main.cc \ | main.cc \ | ||||
keep_dtype_option.cc \ | |||||
single_op_parser.cc \ | single_op_parser.cc \ | ||||
../session/omg.cc \ | ../session/omg.cc \ | ||||
../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
@@ -193,44 +193,6 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
static Status SetWeightCompressNodes(const ComputeGraphPtr &graph, const string &compress_weight_conf) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
if (compress_weight_conf.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
std::string real_path = RealPath(compress_weight_conf.c_str()); | |||||
if (real_path.empty()) { | |||||
GELOGE(PARAM_INVALID, "Can not get real path for %s.", compress_weight_conf.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
std::ifstream ifs(real_path); | |||||
if (!ifs.is_open()) { | |||||
GELOGE(domi::FAILED, "Open file %s failed", compress_weight_conf.c_str()); | |||||
return domi::FAILED; | |||||
} | |||||
std::string compress_nodes; | |||||
ifs >> compress_nodes; | |||||
ifs.close(); | |||||
GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||||
vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||||
for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||||
ge::NodePtr node = graph->FindNode(compress_node_vec[i]); | |||||
if (node == nullptr) { | |||||
GELOGW("node %s is not in graph", compress_node_vec[i].c_str()); | |||||
continue; | |||||
} | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||||
GELOGE(domi::FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str()); | |||||
return domi::FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { | static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { | ||||
if (is_output_fp16.empty()) { | if (is_output_fp16.empty()) { | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -800,10 +762,6 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri | |||||
GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); | GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); | ||||
std::string compress_weight_conf; | |||||
ParseAtcParms(atc_params, "compress_weight_conf", compress_weight_conf); | |||||
GE_RETURN_IF_ERROR(SetWeightCompressNodes(compute_graph, compress_weight_conf)); | |||||
// Verify the contents of the op_name_map | // Verify the contents of the op_name_map | ||||
if (op_conf != nullptr && *op_conf != '\0') { | if (op_conf != nullptr && *op_conf != '\0') { | ||||
GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf), | GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf), | ||||
@@ -50,6 +50,8 @@ struct ModelBufferData { | |||||
uint64_t length; | uint64_t length; | ||||
}; | }; | ||||
enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS }; | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
* @brief build model.Notice the model is stored in buffer | * @brief build model.Notice the model is stored in buffer | ||||
@@ -80,13 +82,16 @@ GE_FUNC_VISIBILITY void aclgrphBuildFinalize(); | |||||
* @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<AscendString, AscendString> &, | |||||
ModelBufferData &)) | |||||
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | |||||
ModelBufferData &model); | |||||
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, | |||||
const std::map<AscendString, AscendString> &, | |||||
ModelBufferData &)) | |||||
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, | |||||
const std::map<std::string, std::string> &build_options, | |||||
ModelBufferData &model); | |||||
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options, | |||||
ModelBufferData &model); | |||||
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, | |||||
const std::map<AscendString, AscendString> &build_options, | |||||
ModelBufferData &model); | |||||
/** | /** | ||||
* @ingroup AscendCL | * @ingroup AscendCL | ||||
@@ -138,7 +143,17 @@ GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const ch | |||||
* @retval OtherValues Failure | * @retval OtherValues Failure | ||||
*/ | */ | ||||
GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs, | GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs, | ||||
const std::vector<TensorDesc> &outputs, Graph &graph); | |||||
const std::vector<TensorDesc> &outputs, Graph &graph); | |||||
/** | |||||
* @name aclgrphSetOpAttr | |||||
* @brief set attribute for operators in the configuration file | |||||
* @param graph [IN/OUT] compute graph | |||||
* @param attr_type [In] attribute type | |||||
* @param cfg_path [IN] the config file path | |||||
* @return graphStatus | |||||
*/ | |||||
GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path); | |||||
}; // namespace ge | }; // namespace ge | ||||
#endif // INC_EXTERNAL_GE_IR_BUILD_H_ | #endif // INC_EXTERNAL_GE_IR_BUILD_H_ |
@@ -274,6 +274,9 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | ||||
"${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" | "${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" | ||||
"${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | ||||
"${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc" | |||||
"${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" | |||||
"${GE_CODE_DIR}/ge/ir_build/attr_options/weight_compress_option.cc" | |||||
"${GE_CODE_DIR}/ge/graph/build/label_allocator.cc" | "${GE_CODE_DIR}/ge/graph/build/label_allocator.cc" | ||||
"${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | ||||
"${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | ||||