| @@ -18,11 +18,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| namespace ge { | namespace ge { | ||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | ||||
| bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type); | |||||
| bool IsContainOpType(const std::string &cfg_line, std::string &op_type); | |||||
| graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | ||||
| graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | ||||
| } // namespace | } // namespace | ||||
| @@ -32,18 +32,24 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list, const st | |||||
| size_t list_size = invalid_list.size(); | size_t list_size = invalid_list.size(); | ||||
| err_msg << "config file contains " << list_size; | err_msg << "config file contains " << list_size; | ||||
| if (list_size == 1) { | if (list_size == 1) { | ||||
| err_msg << " operator not in the graph, op name:"; | |||||
| err_msg << " operator not in the graph, "; | |||||
| } else { | } else { | ||||
| err_msg << " operators not in the graph, op names:"; | |||||
| err_msg << " operators not in the graph, "; | |||||
| } | } | ||||
| std::string cft_type; | |||||
| for (size_t i = 0; i < list_size; i++) { | for (size_t i = 0; i < list_size; i++) { | ||||
| if (i == kMaxOpsNum) { | if (i == kMaxOpsNum) { | ||||
| err_msg << ".."; | err_msg << ".."; | ||||
| break; | break; | ||||
| } | } | ||||
| err_msg << invalid_list[i]; | |||||
| if (i != list_size - 1) { | |||||
| bool istype = IsContainOpType(invalid_list[i], cft_type); | |||||
| if (!istype) { | |||||
| err_msg << "op name:"; | |||||
| } else { | |||||
| err_msg << "op type:"; | |||||
| } | |||||
| err_msg << cft_type; | |||||
| if (i != (list_size - 1)) { | |||||
| err_msg << " "; | err_msg << " "; | ||||
| } | } | ||||
| } | } | ||||
| @@ -72,7 +78,7 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| std::string op_name; | |||||
| std::string op_name, op_type; | |||||
| std::vector<std::string> invalid_list; | std::vector<std::string> invalid_list; | ||||
| while (std::getline(ifs, op_name)) { | while (std::getline(ifs, op_name)) { | ||||
| if (op_name.empty()) { | if (op_name.empty()) { | ||||
| @@ -80,13 +86,20 @@ graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||||
| } | } | ||||
| op_name = StringUtils::Trim(op_name); | op_name = StringUtils::Trim(op_name); | ||||
| bool is_find = false; | bool is_find = false; | ||||
| for (auto &node_ptr : graph->GetDirectNode()) { | |||||
| bool is_type = IsContainOpType(op_name, op_type); | |||||
| for (auto &node_ptr : graph->GetAllNodes()) { | |||||
| auto op_desc = node_ptr->GetOpDesc(); | auto op_desc = node_ptr->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { | |||||
| is_find = true; | |||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||||
| if (is_type) { | |||||
| if (IsOpTypeEqual(node_ptr, op_type)) { | |||||
| is_find = true; | |||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||||
| } | |||||
| } else { | |||||
| if (op_desc->GetName() == op_name || IsOriginalOpFind(op_desc, op_name)) { | |||||
| is_find = true; | |||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (!is_find) { | if (!is_find) { | ||||
| @@ -16,9 +16,12 @@ | |||||
| #include "ir_build/attr_options/attr_options.h" | #include "ir_build/attr_options/attr_options.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/common/omg_util.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::string CFG_PRE_OPTYPE = "OpType::"; | |||||
| } | |||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | ||||
| std::vector<std::string> original_op_names; | std::vector<std::string> original_op_names; | ||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | ||||
| @@ -33,4 +36,36 @@ bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool IsOpTypeEqual(const ge::NodePtr &node, const std::string &op_type) { | |||||
| if (op_type != node->GetOpDesc()->GetType()) { | |||||
| return false; | |||||
| } | |||||
| std::string origin_type; | |||||
| auto ret = GetOriginalType(node, origin_type); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGW("[Get][OriginalType] from op:%s failed.", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (op_type != origin_type) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool IsContainOpType(const std::string &cfg_line, std::string &op_type) { | |||||
| op_type = cfg_line; | |||||
| size_t pos = op_type.find(CFG_PRE_OPTYPE); | |||||
| if (pos != std::string::npos) { | |||||
| if (pos == 0) { | |||||
| op_type = cfg_line.substr(CFG_PRE_OPTYPE.length()); | |||||
| return true; | |||||
| } else { | |||||
| GELOGW("[Check][Param] %s must be at zero pos of %s", CFG_PRE_OPTYPE.c_str(), cfg_line.c_str()); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| GELOGW("[Check][Param] %s not contain optype", cfg_line.c_str()); | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <stdio.h> | |||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include "ir_build/option_utils.h" | #include "ir_build/option_utils.h" | ||||
| #include "graph/testcase/ge_graph/graph_builder_utils.h" | #include "graph/testcase/ge_graph/graph_builder_utils.h" | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "ge/ge_ir_build.h" | #include "ge/ge_ir_build.h" | ||||
| #include "graph/ops_stub.h" | #include "graph/ops_stub.h" | ||||
| #include "ge/ir_build/attr_options/attr_options.h" | |||||
| #define protected public | #define protected public | ||||
| #define private public | #define private public | ||||
| @@ -70,6 +70,22 @@ static ComputeGraphPtr BuildComputeGraph() { | |||||
| return builder.GetGraph(); | return builder.GetGraph(); | ||||
| } | } | ||||
| static ComputeGraphPtr BuildComputeGraph1() { | |||||
| auto builder = ut::GraphBuilder("test"); | |||||
| auto data1 = builder.AddNode("input1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2, 3}); | |||||
| auto data2 = builder.AddNode("input2", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {4, 10}); | |||||
| auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1); | |||||
| auto node1 = builder.AddNode("addd", "Mul", 2, 1); | |||||
| auto node2 = builder.AddNode("ffm", "FrameworkOp", 2, 1); | |||||
| auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
| builder.AddDataEdge(data1, 0, addn1, 0); | |||||
| builder.AddDataEdge(data2, 0, addn1, 1); | |||||
| builder.AddDataEdge(addn1, 0,netoutput, 0); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| // data not set attr index; | // data not set attr index; | ||||
| // but becasue of op proto, register attr index. so all data index is zero; | // but becasue of op proto, register attr index. so all data index is zero; | ||||
| static Graph BuildIrGraph() { | static Graph BuildIrGraph() { | ||||
| @@ -89,10 +105,12 @@ static Graph BuildIrGraph1() { | |||||
| auto data1 = op::Data("data1").set_attr_index(0); | auto data1 = op::Data("data1").set_attr_index(0); | ||||
| auto data2 = op::Data("data2").set_attr_index(1); | auto data2 = op::Data("data2").set_attr_index(1); | ||||
| auto data3 = op::Data("data3"); | auto data3 = op::Data("data3"); | ||||
| std::vector<Operator> inputs {data1, data2, data3}; | |||||
| auto data4 = op::Data("Test"); | |||||
| std::vector<Operator> inputs {data1, data2, data3, data4}; | |||||
| std::vector<Operator> outputs; | std::vector<Operator> outputs; | ||||
| Graph graph("test_graph"); | Graph graph("test_graph"); | ||||
| graph.AddNodeByOp(Operator("gg", "Mul")); | |||||
| graph.SetInputs(inputs).SetOutputs(outputs); | graph.SetInputs(inputs).SetOutputs(outputs); | ||||
| return graph; | return graph; | ||||
| } | } | ||||
| @@ -373,9 +391,16 @@ TEST(UtestIrBuild, check_modify_mixlist_param) { | |||||
| EXPECT_EQ(ret, GRAPH_PARAM_INVALID); | EXPECT_EQ(ret, GRAPH_PARAM_INVALID); | ||||
| } | } | ||||
| TEST(UtestIrCommon, check_dynamic_imagesize_input_shape_valid_format_empty) { | |||||
| std::map<std::string, std::vector<int64_t>> shape_map; | |||||
| std::string dynamic_image_size = ""; | |||||
| bool ret = CheckDynamicImagesizeInputShapeValid(shape_map, "123", dynamic_image_size); | |||||
| EXPECT_EQ(ret, false); | |||||
| TEST(UtestIrBuild, atc_cfg_optype_param) { | |||||
| ComputeGraphPtr graph = BuildComputeGraph1(); | |||||
| FILE *fp = fopen("./keep.txt", "w+"); | |||||
| if (fp) { | |||||
| fprintf(fp, "Test\n"); | |||||
| fprintf(fp, "OpType::Mul\n"); | |||||
| fprintf(fp, "Optype::Sub\n"); | |||||
| fclose(fp); | |||||
| } | |||||
| auto ret = KeepDtypeFunc(graph, "./keep.txt"); | |||||
| (void)remove("./keep.txt"); | |||||
| EXPECT_EQ(ret, GRAPH_PARAM_INVALID); | |||||
| } | } | ||||