Browse Source

!1041 set weight compress for original operators

From: @li-lei0106
Reviewed-by: @ni100die,@xchu42,@j00107162
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
e4d892cbda
14 changed files with 225 additions and 88 deletions
  1. +6
    -0
      ge/CMakeLists.txt
  2. +3
    -0
      ge/ge_inference.mk
  3. +3
    -0
      ge/ge_runner.mk
  4. +8
    -5
      ge/ir_build/attr_options/attr_options.h
  5. +11
    -25
      ge/ir_build/attr_options/keep_dtype_option.cc
  6. +36
    -0
      ge/ir_build/attr_options/utils.cc
  7. +64
    -0
      ge/ir_build/attr_options/weight_compress_option.cc
  8. +50
    -0
      ge/ir_build/ge_ir_build.cc
  9. +0
    -1
      ge/offline/CMakeLists.txt
  10. +19
    -5
      ge/offline/main.cc
  11. +0
    -3
      ge/offline/module.mk
  12. +0
    -42
      ge/session/omg.cc
  13. +22
    -7
      inc/external/ge/ge_ir_build.h
  14. +3
    -0
      tests/ut/ge/CMakeLists.txt

+ 6
- 0
ge/CMakeLists.txt View File

@@ -388,6 +388,9 @@ set(TRAIN_SRC_LIST
"client/ge_api.cc"
"analyzer/analyzer.cc"
"ir_build/ge_ir_build.cc"
"ir_build/attr_options/utils.cc"
"ir_build/attr_options/keep_dtype_option.cc"
"ir_build/attr_options/weight_compress_option.cc"
"ir_build/atc_ir_common.cc"
"graph/build/memory/memory_assigner.cc"
"graph/build/memory/graph_mem_assigner.cc"
@@ -640,6 +643,9 @@ set(INFER_SRC_LIST
"graph/load/model_manager/task_info/super_kernel/super_kernel.cc"
"hybrid/hybrid_davinci_model_stub.cc"
"ir_build/ge_ir_build.cc"
"ir_build/attr_options/utils.cc"
"ir_build/attr_options/keep_dtype_option.cc"
"ir_build/attr_options/weight_compress_option.cc"
"ir_build/atc_ir_common.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc"
"graph/preprocess/insert_op/util_insert_aipp_op.cc"


+ 3
- 0
ge/ge_inference.mk View File

@@ -70,6 +70,9 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \

BUILER_SRC_FILES := \
ir_build/ge_ir_build.cc \
ir_build/attr_options/utils.cc \
ir_build/attr_options/keep_dtype_option.cc \
ir_build/attr_options/weight_compress_option.cc \
ir_build/atc_ir_common.cc \

ANALYZER_SRC_FILES:= \


+ 3
- 0
ge/ge_runner.mk View File

@@ -311,6 +311,9 @@ LIBGE_LOCAL_SRC_FILES := \
executor/ge_executor.cc \
analyzer/analyzer.cc \
ir_build/ge_ir_build.cc \
ir_build/attr_options/utils.cc \
ir_build/attr_options/keep_dtype_option.cc \
ir_build/attr_options/weight_compress_option.cc \
ir_build/atc_ir_common.cc \

LIBCLIENT_LOCAL_SRC_FILES := \


ge/offline/keep_dtype_option.h → ge/ir_build/attr_options/attr_options.h View File

@@ -13,14 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef KEEP_DTYPE_OPTION_H_
#define KEEP_DTYPE_OPTION_H_
#ifndef ATTR_OPTIONS_H_
#define ATTR_OPTIONS_H_
#include <string>
#include "graph/compute_graph.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/ge_error_codes.h"
namespace ge {
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype);
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name);
graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path);
} // namespace
#endif // KEEP_DTYPE_OPTION_H_
#endif // ATTR_OPTIONS_H_

ge/offline/keep_dtype_option.cc → ge/ir_build/attr_options/keep_dtype_option.cc View File

@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "keep_dtype_option.h"
#include "attr_options.h"
#include <fstream>
#include <iostream>
#include <sstream>
@@ -26,20 +26,6 @@ namespace ge {
namespace {
const size_t kMaxOpsNum = 10;
} // namespace
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
std::vector<std::string> original_op_names;
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
return false;
}
for (auto &origin_name : original_op_names) {
if (origin_name == op_name) {
return true;
}
}
return false;
}
void KeepDtypeReportError(const std::vector<std::string> &invalid_list) {
std::stringstream err_msg;
@@ -67,20 +53,20 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list) {
GELOGE(FAILED, "%s", err_msg.str().c_str());
}
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) {
graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) {
GE_CHECK_NOTNULL(graph);
if (keep_dtype.empty()) {
return SUCCESS;
if (cfg_path.empty()) {
return GRAPH_SUCCESS;
}
std::string real_path = RealPath(keep_dtype.c_str());
std::string real_path = RealPath(cfg_path.c_str());
if (real_path.empty()) {
GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str());
return PARAM_INVALID;
GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str());
return GRAPH_PARAM_INVALID;
}
std::ifstream ifs(real_path);
if (!ifs.is_open()) {
GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str());
return FAILED;
GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str());
return GRAPH_FAILED;
}
std::string op_name;
@@ -108,9 +94,9 @@ Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep
if (!invalid_list.empty()) {
KeepDtypeReportError(invalid_list);
return PARAM_INVALID;
return GRAPH_PARAM_INVALID;
}
return SUCCESS;
return GRAPH_SUCCESS;
}
} // namespace ge

+ 36
- 0
ge/ir_build/attr_options/utils.cc View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "attr_options.h"
#include <vector>
#include "graph/debug/ge_attr_define.h"
#include "common/util/error_manager/error_manager.h"
namespace ge {
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
std::vector<std::string> original_op_names;
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
return false;
}
for (auto &origin_name : original_op_names) {
if (origin_name == op_name) {
return true;
}
}
return false;
}
} // namespace ge

+ 64
- 0
ge/ir_build/attr_options/weight_compress_option.cc View File

@@ -0,0 +1,64 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "attr_options.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include <vector>
#include "graph/debug/ge_attr_define.h"
#include "framework/common/util.h"
#include "common/util/error_manager/error_manager.h"
namespace ge {
graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const string &cfg_path) {
GE_CHECK_NOTNULL(graph);
if (cfg_path.empty()) {
return GRAPH_SUCCESS;
}
std::string real_path = RealPath(cfg_path.c_str());
if (real_path.empty()) {
GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str());
return GRAPH_PARAM_INVALID;
}
std::ifstream ifs(real_path);
if (!ifs.is_open()) {
GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str());
return GRAPH_FAILED;
}
std::string compress_nodes;
ifs >> compress_nodes;
ifs.close();
GELOGI("Compress weight of nodes: %s", compress_nodes.c_str());
vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';');
for (size_t i = 0; i < compress_node_vec.size(); ++i) {
for (auto &node_ptr : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node_ptr);
auto op_desc = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((op_desc->GetName() == compress_node_vec[i]) || IsOriginalOpFind(op_desc, compress_node_vec[i])) {
if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) {
GELOGE(GRAPH_FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str());
return GRAPH_FAILED;
}
}
}
}
return GRAPH_SUCCESS;
}
} // namespace ge

+ 50
- 0
ge/ir_build/ge_ir_build.cc View File

@@ -39,6 +39,7 @@
#include "inc/pass_manager.h"
#include "graph/passes/net_output_pass.h"
#include "graph/passes/data_pass.h"
#include "ir_build/attr_options/attr_options.h"

using std::string;
using namespace std;
@@ -52,8 +53,28 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default";
const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize";
const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0";
const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false";
const std::string KEEP_DTYPE_OPTION = "keep_dtype";
const std::string kInputShape = "input_shape";
const std::string kInputFormat = "input_format";

/**
* @name SetOpAttrFun
* @brief set attribute for operators in the configuration file
* @param graph [IN/OUT] compute graph
* @param cfg_path [IN] the config file path
* @return graphStatus
*/
typedef graphStatus (*SetOpAttrFun)(ComputeGraphPtr &graph, const std::string &cfg_path);

const std::map<aclgrphAttrType, SetOpAttrFun> kAttrTypeFuncMap = {
{ATTR_TYPE_KEEP_DTYPE, KeepDtypeFunc},
{ATTR_TYPE_WEIGHT_COMPRESS, WeightCompressFunc}
};

const std::map<aclgrphAttrType, std::string> kAttrTypeToStringMap = {
{ATTR_TYPE_KEEP_DTYPE, KEEP_DTYPE_OPTION},
{ATTR_TYPE_WEIGHT_COMPRESS, ge::ir_option::COMPRESS_WEIGHT_CONF}
};
} // namespace

static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) {
@@ -703,4 +724,33 @@ graphStatus aclgrphGenerateForOp(const AscendString &op_type, const vector<Tenso
return GRAPH_SUCCESS;
}

static std::string AttrTypeToSerialString(aclgrphAttrType attr_type) {
auto it = kAttrTypeToStringMap.find(attr_type);
if (it != kAttrTypeToStringMap.end()) {
return it->second;
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
{"AttrTypeToSerialString", "attr_type[" + std::to_string(attr_type) + "] is not support"});
GELOGE(GRAPH_FAILED, "AttrTypeToSerialString: attr_type not support %u", attr_type);
return "UNDEFINED";
}
}

graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path) {
auto compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
if (cfg_path == nullptr) {
return GRAPH_SUCCESS;
}

auto iter = kAttrTypeFuncMap.find(attr_type);
if (iter == kAttrTypeFuncMap.end()) {
GELOGE(GRAPH_FAILED, "attr type: %s is not support", AttrTypeToSerialString(attr_type).c_str());
return GRAPH_FAILED;
}

std::string path = cfg_path;
return iter->second(compute_graph, path);
}

} // namespace ge

+ 0
- 1
ge/offline/CMakeLists.txt View File

@@ -10,7 +10,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
set(SRC_LIST
"main.cc"
"single_op_parser.cc"
"keep_dtype_option.cc"
"../session/omg.cc"
"../ir_build/atc_ir_common.cc"
)


+ 19
- 5
ge/offline/main.cc View File

@@ -43,7 +43,7 @@
#include "parser/common/register_tbe.h"
#include "register/op_registry.h"
#include "single_op_parser.h"
#include "keep_dtype_option.h"
#include "external/ge/ge_ir_build.h"

using domi::BuildMode;
using domi::OpRegistrationData;
@@ -913,6 +913,22 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s
return ret;
}

static Status SetAttrOptions(ge::Graph &graph) {
if (!FLAGS_keep_dtype.empty()) {
if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_KEEP_DTYPE, FLAGS_keep_dtype.c_str()) != ge::GRAPH_SUCCESS) {
return ge::FAILED;
}
}
if (!FLAGS_compress_weight_conf.empty()) {
if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_WEIGHT_COMPRESS, FLAGS_compress_weight_conf.c_str())
!= ge::GRAPH_SUCCESS) {
return ge::FAILED;
}
}

return ge::SUCCESS;
}

domi::Status GenerateModel(std::map<string, string> &options, std::string output) {
ge::GeGenerator ge_generator;
ge::Status geRet = ge::SUCCESS;
@@ -969,7 +985,6 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes));
atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout));
atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout));
atc_params.insert(std::pair<string, string>("compress_weight_conf", FLAGS_compress_weight_conf));
atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type));
atc_params.insert(std::pair<string, string>("output", output));

@@ -1003,11 +1018,10 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output
}
}

Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype);
if (ret != SUCCESS) {
if (SetAttrOptions(graph) != ge::SUCCESS) {
(void)ge_generator.Finalize();
(void)ge::GELib::GetInstance()->Finalize();
return ret;
return domi::FAILED;
}

geRet = ge_generator.GenerateOfflineModel(graph, output, inputs);


+ 0
- 3
ge/offline/module.mk View File

@@ -10,7 +10,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg

LOCAL_SRC_FILES := \
main.cc \
keep_dtype_option.cc \
single_op_parser.cc \
../session/omg.cc \
../ir_build/atc_ir_common.cc \
@@ -64,7 +63,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg

LOCAL_SRC_FILES := \
main.cc \
keep_dtype_option.cc \
single_op_parser.cc \
../session/omg.cc \
../ir_build/atc_ir_common.cc \
@@ -118,7 +116,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg

LOCAL_SRC_FILES := \
main.cc \
keep_dtype_option.cc \
single_op_parser.cc \
../session/omg.cc \
../ir_build/atc_ir_common.cc \


+ 0
- 42
ge/session/omg.cc View File

@@ -193,44 +193,6 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in
return SUCCESS;
}

static Status SetWeightCompressNodes(const ComputeGraphPtr &graph, const string &compress_weight_conf) {
GE_CHECK_NOTNULL(graph);
if (compress_weight_conf.empty()) {
return SUCCESS;
}
std::string real_path = RealPath(compress_weight_conf.c_str());
if (real_path.empty()) {
GELOGE(PARAM_INVALID, "Can not get real path for %s.", compress_weight_conf.c_str());
return PARAM_INVALID;
}
std::ifstream ifs(real_path);
if (!ifs.is_open()) {
GELOGE(domi::FAILED, "Open file %s failed", compress_weight_conf.c_str());
return domi::FAILED;
}

std::string compress_nodes;
ifs >> compress_nodes;
ifs.close();
GELOGI("Compress weight of nodes: %s", compress_nodes.c_str());

vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';');
for (size_t i = 0; i < compress_node_vec.size(); ++i) {
ge::NodePtr node = graph->FindNode(compress_node_vec[i]);
if (node == nullptr) {
GELOGW("node %s is not in graph", compress_node_vec[i].c_str());
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) {
GELOGE(domi::FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str());
return domi::FAILED;
}
}
return SUCCESS;
}

static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) {
if (is_output_fp16.empty()) {
return SUCCESS;
@@ -800,10 +762,6 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri

GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode));

std::string compress_weight_conf;
ParseAtcParms(atc_params, "compress_weight_conf", compress_weight_conf);
GE_RETURN_IF_ERROR(SetWeightCompressNodes(compute_graph, compress_weight_conf));

// Verify the contents of the op_name_map
if (op_conf != nullptr && *op_conf != '\0') {
GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf),


+ 22
- 7
inc/external/ge/ge_ir_build.h View File

@@ -50,6 +50,8 @@ struct ModelBufferData {
uint64_t length;
};

enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS };

/**
* @ingroup AscendCL
* @brief build model.Notice the model is stored in buffer
@@ -80,13 +82,16 @@ GE_FUNC_VISIBILITY void aclgrphBuildFinalize();
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<AscendString, AscendString> &,
ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options,
ModelBufferData &model);
ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &,
const std::map<AscendString, AscendString> &,
ModelBufferData &))
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<std::string, std::string> &build_options,
ModelBufferData &model);

GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options,
ModelBufferData &model);
GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph,
const std::map<AscendString, AscendString> &build_options,
ModelBufferData &model);

/**
* @ingroup AscendCL
@@ -138,7 +143,17 @@ GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const ch
* @retval OtherValues Failure
*/
GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs,
const std::vector<TensorDesc> &outputs, Graph &graph);
const std::vector<TensorDesc> &outputs, Graph &graph);

/**
* @name aclgrphSetOpAttr
* @brief set attribute for operators in the configuration file
* @param graph [IN/OUT] compute graph
* @param attr_type [In] attribute type
* @param cfg_path [IN] the config file path
* @return graphStatus
*/
GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path);

}; // namespace ge
#endif // INC_EXTERNAL_GE_IR_BUILD_H_

+ 3
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -289,6 +289,9 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc"
"${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc"
"${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc"
"${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc"
"${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc"
"${GE_CODE_DIR}/ge/ir_build/attr_options/weight_compress_option.cc"
"${GE_CODE_DIR}/ge/graph/build/label_allocator.cc"
"${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc"
"${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc"


Loading…
Cancel
Save