| @@ -22,7 +22,8 @@ | |||
| namespace ge { | |||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | |||
| bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type); | |||
| bool IsContainOpType(const std::string &cfg_line, std::string &op_type); | |||
| graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||
| graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||
| } // namespace | |||
| @@ -32,18 +32,24 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list, const st | |||
| size_t list_size = invalid_list.size(); | |||
| err_msg << "config file contains " << list_size; | |||
| if (list_size == 1) { | |||
| err_msg << " operator not in the graph, op name:"; | |||
| err_msg << " operator not in the graph, "; | |||
| } else { | |||
| err_msg << " operators not in the graph, op names:"; | |||
| err_msg << " operators not in the graph, "; | |||
| } | |||
| std::string cft_type; | |||
| for (size_t i = 0; i < list_size; i++) { | |||
| if (i == kMaxOpsNum) { | |||
| err_msg << ".."; | |||
| break; | |||
| } | |||
| err_msg << invalid_list[i]; | |||
| if (i != list_size - 1) { | |||
| bool istype = IsContainOpType(invalid_list[i], cft_type); | |||
| if (!istype) { | |||
| err_msg << "op name:"; | |||
| } else { | |||
| err_msg << "op type:"; | |||
| } | |||
| err_msg << cft_type; | |||
| if (i != (list_size - 1)) { | |||
| err_msg << " "; | |||
| } | |||
| } | |||
| @@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::string op_name; | |||
| std::string op_name, op_type; | |||
| std::vector<std::string> invalid_list; | |||
| while (std::getline(ifs, op_name)) { | |||
| if (op_name.empty()) { | |||
| @@ -80,13 +86,20 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||
| } | |||
| op_name = StringUtils::Trim(op_name); | |||
| bool is_find = false; | |||
| for (auto &node_ptr : graph->GetDirectNode()) { | |||
| bool is_type = IsContainOpType(op_name, op_type); | |||
| for (auto &node_ptr : graph->GetAllNodes()) { | |||
| auto op_desc = node_ptr->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { | |||
| is_find = true; | |||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||
| if (is_type) { | |||
| if (IsOpTypeEqual(node_ptr, op_type)) { | |||
| is_find = true; | |||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||
| } | |||
| } else { | |||
| if (op_desc->GetName() == op_name || IsOriginalOpFind(op_desc, op_name)) { | |||
| is_find = true; | |||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||
| } | |||
| } | |||
| } | |||
| if (!is_find) { | |||
| @@ -16,9 +16,12 @@ | |||
| #include "attr_options.h" | |||
| #include <vector> | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/common/omg_util.h" | |||
| namespace ge { | |||
| namespace { | |||
| const std::string CFG_PRE_OPTYPE = "OpType::"; | |||
| } | |||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||
| std::vector<std::string> original_op_names; | |||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||
| @@ -33,4 +36,47 @@ bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||
| return false; | |||
| } | |||
| bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type) { | |||
| if (node == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "param node is nullptr, check invalid"); | |||
| GELOGE(FAILED, "[Check][Param] node is nullptr"); | |||
| return false; | |||
| } | |||
| auto op_desc = node->GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| REPORT_INNER_ERROR("E19999", "param node's op desc is nullptr, check invalid"); | |||
| GELOGE(FAILED, "[Check][Param] node op desc is nullptr"); | |||
| return false; | |||
| } | |||
| if (op_type != op_desc->GetType()) { | |||
| return false; | |||
| } | |||
| std::string origin_type; | |||
| auto ret = GetOriginalType(node, origin_type); | |||
| if (ret != SUCCESS) { | |||
| GELOGW("[Get][OriginalType] from op:%s failed.", node->GetName().c_str()); | |||
| return false; | |||
| } | |||
| if (op_type != origin_type) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool IsContainOpType(const std::string &cfg_line, std::string &op_type) { | |||
| op_type = cfg_line; | |||
| size_t pos = op_type.find(CFG_PRE_OPTYPE); | |||
| if (pos != std::string::npos) { | |||
| if (pos == 0) { | |||
| op_type = cfg_line.substr(CFG_PRE_OPTYPE.length()); | |||
| return true; | |||
| } else { | |||
| GELOGW("[Check][Param] %s must be at zero pos of %s", CFG_PRE_OPTYPE.c_str(), cfg_line.c_str()); | |||
| } | |||
| return false; | |||
| } | |||
| GELOGW("[Check][Param] %s not contain optype", cfg_line.c_str()); | |||
| return false; | |||
| } | |||
| } // namespace ge | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdio.h> | |||
| #include <gtest/gtest.h> | |||
| #include "ir_build/option_utils.h" | |||
| #include "graph/testcase/ge_graph/graph_builder_utils.h" | |||
| @@ -371,4 +371,18 @@ TEST(UtestIrBuild, check_modify_mixlist_param) { | |||
| auto ret = aclgrphBuildModel(graph, build_options, model); | |||
| EXPECT_EQ(ret, GRAPH_PARAM_INVALID); | |||
| } | |||
| TEST(UtestIrBuild, check_cfg_optype_param) { | |||
| Graph graph = BuildIrGraph1(); | |||
| FILE *fp = fopen("./keep.txt", "w+"); | |||
| if (fp) { | |||
| fprintf(fp, "Test\n"); | |||
| fprintf(fp, "OpType::Mul\n"); | |||
| fprintf(fp, "Optype::Sub\n"); | |||
| fclose(fp); | |||
| } | |||
| auto ret = aclgrphSetOpAttr(graph, ATTR_TYPE_KEEP_DTYPE, "./keep.txt"); | |||
| (void)remove("./keep.txt"); | |||
| EXPECT_EQ(ret, GRAPH_PARAM_INVALID); | |||
| } | |||