diff --git a/ge/ir_build/attr_options/attr_options.h b/ge/ir_build/attr_options/attr_options.h index b1b794c0..9ea2b9a1 100644 --- a/ge/ir_build/attr_options/attr_options.h +++ b/ge/ir_build/attr_options/attr_options.h @@ -18,11 +18,12 @@ #include #include "graph/compute_graph.h" -#include "external/graph/ge_error_codes.h" +#include "graph/ge_error_codes.h" namespace ge { bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); - +bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type); +bool IsContainOpType(const std::string &cfg_line, std::string &op_type); graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); } // namespace diff --git a/ge/ir_build/attr_options/keep_dtype_option.cc b/ge/ir_build/attr_options/keep_dtype_option.cc index 9da08cc0..88f238c0 100644 --- a/ge/ir_build/attr_options/keep_dtype_option.cc +++ b/ge/ir_build/attr_options/keep_dtype_option.cc @@ -32,18 +32,24 @@ void KeepDtypeReportError(const std::vector &invalid_list, const st size_t list_size = invalid_list.size(); err_msg << "config file contains " << list_size; if (list_size == 1) { - err_msg << " operator not in the graph, op name:"; + err_msg << " operator not in the graph, "; } else { - err_msg << " operators not in the graph, op names:"; + err_msg << " operators not in the graph, "; } - + std::string cft_type; for (size_t i = 0; i < list_size; i++) { if (i == kMaxOpsNum) { err_msg << ".."; break; } - err_msg << invalid_list[i]; - if (i != list_size - 1) { + bool istype = IsContainOpType(invalid_list[i], cft_type); + if (!istype) { + err_msg << "op name:"; + } else { + err_msg << "op type:"; + } + err_msg << cft_type; + if (i != (list_size - 1)) { err_msg << " "; } } @@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { return GRAPH_FAILED; } - std::string op_name; + std::string op_name, op_type; std::vector invalid_list; while (std::getline(ifs, op_name)) { if (op_name.empty()) { @@ -80,13 +86,20 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { } op_name = StringUtils::Trim(op_name); bool is_find = false; - for (auto &node_ptr : graph->GetDirectNode()) { + bool is_type = IsContainOpType(op_name, op_type); + for (auto &node_ptr : graph->GetAllNodes()) { auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - - if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { - is_find = true; - (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); + if (is_type) { + if (IsOpTypeEqual(node_ptr, op_type)) { + is_find = true; + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); + } + } else { + if (op_desc->GetName() == op_name || IsOriginalOpFind(op_desc, op_name)) { + is_find = true; + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); + } } } if (!is_find) { diff --git a/ge/ir_build/attr_options/utils.cc b/ge/ir_build/attr_options/utils.cc index ed63ffe3..5398c220 100644 --- a/ge/ir_build/attr_options/utils.cc +++ b/ge/ir_build/attr_options/utils.cc @@ -16,9 +16,12 @@ #include "ir_build/attr_options/attr_options.h" #include #include "graph/debug/ge_attr_define.h" -#include "common/util/error_manager/error_manager.h" - +#include "framework/common/debug/ge_log.h" +#include "graph/common/omg_util.h" namespace ge { + namespace { + const std::string CFG_PRE_OPTYPE = "OpType::"; +} bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { std::vector original_op_names; if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { @@ -33,4 +36,36 @@ bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { return false; } + +bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type) { + if (op_type != node->GetOpDesc()->GetType()) { + return false; + } + std::string origin_type; + auto ret = GetOriginalType(node, origin_type); + if (ret != SUCCESS) { + GELOGW("[Get][OriginalType] from op:%s failed.", node->GetName().c_str()); + return false; + } + if (op_type != origin_type) { + return false; + } + return true; +} + +bool IsContainOpType(const std::string &cfg_line, std::string &op_type) { + op_type = cfg_line; + size_t pos = op_type.find(CFG_PRE_OPTYPE); + if (pos != std::string::npos) { + if (pos == 0) { + op_type = cfg_line.substr(CFG_PRE_OPTYPE.length()); + return true; + } else { + GELOGW("[Check][Param] %s must be at zero pos of %s", CFG_PRE_OPTYPE.c_str(), cfg_line.c_str()); + } + return false; + } + GELOGW("[Check][Param] %s not contain optype", cfg_line.c_str()); + return false; +} } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc index 047c9e1d..197c9300 100644 --- a/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc +++ b/tests/ut/ge/graph_ir/ge_ir_build_unittest.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#include #include #include "ir_build/option_utils.h" #include "graph/testcase/ge_graph/graph_builder_utils.h" @@ -21,7 +21,7 @@ #include "graph/utils/graph_utils.h" #include "ge/ge_ir_build.h" #include "graph/ops_stub.h" - +#include "ge/ir_build/attr_options/attr_options.h" #define protected public #define private public @@ -70,6 +70,22 @@ static ComputeGraphPtr BuildComputeGraph() { return builder.GetGraph(); } +static ComputeGraphPtr BuildComputeGraph1() { + auto builder = ut::GraphBuilder("test"); + auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); + auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); + auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1); + auto node1 = builder.AddNode("addd", "Mul", 2, 1); + auto node2 = builder.AddNode("ffm", "FrameworkOp", 2, 1); + auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + builder.AddDataEdge(data1, 0, addn1, 0); + builder.AddDataEdge(data2, 0, addn1, 1); + builder.AddDataEdge(addn1, 0,netoutput, 0); + + return builder.GetGraph(); +} + // data not set attr index; // but becasue of op proto, register attr index. so all data index is zero; static Graph BuildIrGraph() { @@ -89,10 +105,12 @@ static Graph BuildIrGraph1() { auto data1 = op::Data("data1").set_attr_index(0); auto data2 = op::Data("data2").set_attr_index(1); auto data3 = op::Data("data3"); - std::vector inputs {data1, data2, data3}; + auto data4 = op::Data("Test"); + std::vector inputs {data1, data2, data3, data4}; std::vector outputs; Graph graph("test_graph"); + graph.AddNodeByOp(Operator("gg", "Mul")); graph.SetInputs(inputs).SetOutputs(outputs); return graph; } @@ -373,9 +391,16 @@ TEST(UtestIrBuild, check_modify_mixlist_param) { EXPECT_EQ(ret, GRAPH_PARAM_INVALID); } -TEST(UtestIrCommon, check_dynamic_imagesize_input_shape_valid_format_empty) { - std::map> shape_map; - std::string dynamic_image_size = ""; - bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, "123", dynamic_image_size); - EXPECT_EQ(ret, false); +TEST(UtestIrBuild, atc_cfg_optype_param) { + ComputeGraphPtr graph = BuildComputeGraph1(); + FILE *fp = fopen("./keep.txt", "w+"); + if (fp) { + fprintf(fp, "Test\n"); + fprintf(fp, "OpType::Mul\n"); + fprintf(fp, "Optype::Sub\n"); + fclose(fp); + } + auto ret = KeepDtypeFunc(graph, "./keep.txt"); + (void)remove("./keep.txt"); + EXPECT_EQ(ret, GRAPH_PARAM_INVALID); } \ No newline at end of file