@@ -217,7 +217,7 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { | |||||
std::string unsupported_reason; | std::string unsupported_reason; | ||||
// It will be replaced by engine' checksupport | // It will be replaced by engine' checksupport | ||||
uint64_t start_time = GetCurrentTimestamp(); | uint64_t start_time = GetCurrentTimestamp(); | ||||
if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||||
if (kernel_info_store->second->CheckSupported(node_ptr, unsupported_reason)) { | |||||
checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; | checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; | ||||
op_desc->SetOpEngineName(it.engine); | op_desc->SetOpEngineName(it.engine); | ||||
op_desc->SetOpKernelLibName(kernel_name); | op_desc->SetOpKernelLibName(kernel_name); | ||||
@@ -66,7 +66,8 @@ bool ContainsDynamicInpus(const ge::OpDesc &op_desc) { | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) { | |||||
static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) { | |||||
const OpDescPtr &op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | ||||
if (engine_type == ENGINE_SYS) { | if (engine_type == ENGINE_SYS) { | ||||
GELOGI("CheckEngineType: use default engine."); | GELOGI("CheckEngineType: use default engine."); | ||||
@@ -123,7 +124,7 @@ static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engi | |||||
auto kernel_info_store = kernel_map.find(kernel_name); | auto kernel_info_store = kernel_map.find(kernel_name); | ||||
if (kernel_info_store != kernel_map.end()) { | if (kernel_info_store != kernel_map.end()) { | ||||
std::string unsupported_reason; | std::string unsupported_reason; | ||||
if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||||
if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) { | |||||
op_desc->SetOpEngineName(op_engine_name); | op_desc->SetOpEngineName(op_engine_name); | ||||
op_desc->SetOpKernelLibName(kernel_name); | op_desc->SetOpKernelLibName(kernel_name); | ||||
GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), | GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), | ||||
@@ -692,22 +693,23 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); | OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); | ||||
GE_CHECK_NOTNULL(op_desc_tmp); | GE_CHECK_NOTNULL(op_desc_tmp); | ||||
// 1. check engine type when compile online | |||||
// 1. Create ComputeGraph. | |||||
string name = ge::CurrentTimeInStr() + "_" + model_file_name; | |||||
Graph graph; | |||||
GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "make graph fail."); | |||||
// 2. check engine type when compile online | |||||
if (model_file_name == kFileNameSuffix) { | if (model_file_name == kFileNameSuffix) { | ||||
Status ret = CheckEngineTypeSupport(op_desc, engine_type); | |||||
auto comp_graph = GraphUtils::GetComputeGraph(graph); | |||||
GE_CHECK_NOTNULL(comp_graph); | |||||
auto node = comp_graph->FindNode(op_desc->GetName()); | |||||
Status ret = CheckEngineTypeSupport(node, engine_type); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "check engine type failed."); | GELOGE(ret, "check engine type failed."); | ||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
// 2. Create ComputeGraph. | |||||
string name = ge::CurrentTimeInStr() + "_" + model_file_name; | |||||
Graph graph; | |||||
if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "make graph fail."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
GELOGI("ATC parser success in single op build."); | GELOGI("ATC parser success in single op build."); | ||||
GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
@@ -167,7 +167,7 @@ bool CastTranslatePass::IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans | |||||
trans_op_outdesc->SetDataType(cast_out_datatype); | trans_op_outdesc->SetDataType(cast_out_datatype); | ||||
} | } | ||||
if (!TranslateCheckAccuracySupported(trans_op_desc)) { | |||||
if (!TranslateCheckAccuracySupported(trans_node)) { | |||||
if (is_src_cast) { | if (is_src_cast) { | ||||
trans_op_desc->MutableInputDesc(0)->SetDataType(trans_in_datatype); | trans_op_desc->MutableInputDesc(0)->SetDataType(trans_in_datatype); | ||||
} else { | } else { | ||||
@@ -271,7 +271,8 @@ Status CastTranslatePass::FuseDstNTranslates(NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc) { | |||||
bool CastTranslatePass::TranslateCheckAccuracySupported(NodePtr &node) { | |||||
const OpDescPtr &op_desc = node->GetOpDesc(); | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | ||||
GELOGW("GE is not initialized or is finalized."); | GELOGW("GE is not initialized or is finalized."); | ||||
@@ -293,7 +294,7 @@ bool CastTranslatePass::TranslateCheckAccuracySupported(const OpDescPtr &op_desc | |||||
auto kernel_info_store = kernel_map.find(kernel_name); | auto kernel_info_store = kernel_map.find(kernel_name); | ||||
if (kernel_info_store != kernel_map.end()) { | if (kernel_info_store != kernel_map.end()) { | ||||
if (kernel_info_store->second != nullptr && | if (kernel_info_store->second != nullptr && | ||||
kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) { | |||||
kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason)) { | |||||
return true; | return true; | ||||
} | } | ||||
} | } | ||||
@@ -35,7 +35,7 @@ class CastTranslatePass : public BaseNodePass { | |||||
bool IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans_node, bool &is_src_cast); | bool IsOpSupportedOptimize(NodePtr &cast_node, NodePtr &trans_node, bool &is_src_cast); | ||||
bool CheckOpSupportOptimize(NodePtr &node, bool &is_src_cast); | bool CheckOpSupportOptimize(NodePtr &node, bool &is_src_cast); | ||||
Status FuseDstNTranslates(NodePtr &node); | Status FuseDstNTranslates(NodePtr &node); | ||||
bool TranslateCheckAccuracySupported(const OpDescPtr &op_desc); | |||||
bool TranslateCheckAccuracySupported(NodePtr &node); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_CAST_TRANSLATE_PASS_H_ | #endif // GE_GRAPH_PASSES_CAST_TRANSLATE_PASS_H_ |
@@ -110,7 +110,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||||
return ge::GE_GRAPH_PARAM_NULLPTR; | return ge::GE_GRAPH_PARAM_NULLPTR; | ||||
} | } | ||||
// begin accuracy supported check | // begin accuracy supported check | ||||
if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { | |||||
if (!CheckAccuracySupport(kernel_info, instance, node)) { | |||||
// if check accuracy support failed , try to go to other engine. | // if check accuracy support failed , try to go to other engine. | ||||
GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", | GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", | ||||
op_desc->GetName().c_str()); | op_desc->GetName().c_str()); | ||||
@@ -123,7 +123,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||||
continue; | continue; | ||||
} | } | ||||
OpsKernelInfoStorePtr tmp_kernel_info = it->second; | OpsKernelInfoStorePtr tmp_kernel_info = it->second; | ||||
if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { | |||||
if (CheckAccuracySupport(tmp_kernel_info, instance, node)) { | |||||
kernel_lib_name = tmp_kernel_name; | kernel_lib_name = tmp_kernel_name; | ||||
GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), | GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), | ||||
node->GetName().c_str(), op_desc->GetType().c_str()); | node->GetName().c_str(), op_desc->GetType().c_str()); | ||||
@@ -138,14 +138,9 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: | |||||
} | } | ||||
bool CompileNodesPass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, | bool CompileNodesPass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, | ||||
const std::shared_ptr<GELib> instance, OpDescPtr &op_desc) { | |||||
auto ge_desc = MakeShared<ge::OpDescPtr>(op_desc); | |||||
if (ge_desc == nullptr) { | |||||
GELOGE(GE_GRAPH_MEMORY_ALLOC_FAILED, "Fail to malloc op desc."); | |||||
return false; | |||||
} | |||||
const std::shared_ptr<GELib> instance, const NodePtr &node) { | |||||
string reason; | string reason; | ||||
if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { | |||||
if (!(kernel_info->CheckAccuracySupported(node, reason, true))) { | |||||
return false; | return false; | ||||
} | } | ||||
return true; | return true; | ||||
@@ -39,7 +39,7 @@ class CompileNodesPass : public GraphPass { | |||||
private: | private: | ||||
graphStatus GetSupportedKernel(const NodePtr &node, const std::shared_ptr<GELib> instance, string &kernel_lib_name); | graphStatus GetSupportedKernel(const NodePtr &node, const std::shared_ptr<GELib> instance, string &kernel_lib_name); | ||||
bool CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, const std::shared_ptr<GELib> instance, | bool CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, const std::shared_ptr<GELib> instance, | ||||
OpDescPtr &op_desc); | |||||
const NodePtr &node); | |||||
graphStatus CompileNodes(const std::shared_ptr<GELib> instance, | graphStatus CompileNodes(const std::shared_ptr<GELib> instance, | ||||
std::unordered_map<string, vector<NodePtr>> &kernel_to_compile_nodes); | std::unordered_map<string, vector<NodePtr>> &kernel_to_compile_nodes); | ||||
}; | }; | ||||
@@ -86,7 +86,7 @@ Status TransposeTransDataPass::Run(NodePtr &node) { | |||||
if (CheckOneInAndOneOutDataAnchor(out_node)) { | if (CheckOneInAndOneOutDataAnchor(out_node)) { | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (!FusionIfNeed(op_desc, out_op_desc)) { | |||||
if (!FusionIfNeed(op_desc, out_node)) { | |||||
continue; | continue; | ||||
} | } | ||||
CopyInputEdges(node, out_node); | CopyInputEdges(node, out_node); | ||||
@@ -152,7 +152,8 @@ Status TransposeTransDataPass::RemoveTranspose(NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc) { | |||||
bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, NodePtr &node) { | |||||
auto transdata_op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
GE_CHECK_NOTNULL(transdata_op_desc); | GE_CHECK_NOTNULL(transdata_op_desc); | ||||
auto out_input_desc = transdata_op_desc->MutableInputDesc(0); | auto out_input_desc = transdata_op_desc->MutableInputDesc(0); | ||||
@@ -187,7 +188,7 @@ bool TransposeTransDataPass::FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transda | |||||
out_input_desc->SetFormat(src_format); | out_input_desc->SetFormat(src_format); | ||||
out_input_desc->SetShape(src_shape); | out_input_desc->SetShape(src_shape); | ||||
if (!TransDataCheckAccuracySupported(transdata_op_desc)) { | |||||
if (!TransDataCheckAccuracySupported(node)) { | |||||
out_input_desc->SetFormat(out_input_format); | out_input_desc->SetFormat(out_input_format); | ||||
out_input_desc->SetShape(out_input_shape); | out_input_desc->SetShape(out_input_shape); | ||||
return false; | return false; | ||||
@@ -224,7 +225,8 @@ void TransposeTransDataPass::CopyInputEdges(NodePtr &origin_node, NodePtr &new_n | |||||
GraphUtils::CopyInCtrlEdges(origin_node, new_node) != GRAPH_SUCCESS, GELOGW("Copy in ctrl edges failed"); return); | GraphUtils::CopyInCtrlEdges(origin_node, new_node) != GRAPH_SUCCESS, GELOGW("Copy in ctrl edges failed"); return); | ||||
} | } | ||||
bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op_desc) { | |||||
bool TransposeTransDataPass::TransDataCheckAccuracySupported(NodePtr &node) { | |||||
const OpDescPtr &op_desc = node->GetOpDesc(); | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | ||||
GELOGW("GELib not initialized"); | GELOGW("GELib not initialized"); | ||||
@@ -244,7 +246,7 @@ bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op | |||||
auto &kernel_name = it.opKernelLib; | auto &kernel_name = it.opKernelLib; | ||||
auto kernel_info_store = kernel_map.find(kernel_name); | auto kernel_info_store = kernel_map.find(kernel_name); | ||||
if (kernel_info_store != kernel_map.end()) { | if (kernel_info_store != kernel_map.end()) { | ||||
if (kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason, true)) { | |||||
if (kernel_info_store->second->CheckAccuracySupported(node, unsupported_reason, true)) { | |||||
return true; | return true; | ||||
} | } | ||||
} | } | ||||
@@ -26,9 +26,9 @@ class TransposeTransDataPass : public BaseNodePass { | |||||
private: | private: | ||||
Status CheckOneInAndOneOutDataAnchor(NodePtr &node) const; | Status CheckOneInAndOneOutDataAnchor(NodePtr &node) const; | ||||
Status RemoveTranspose(NodePtr &node); | Status RemoveTranspose(NodePtr &node); | ||||
bool FusionIfNeed(OpDescPtr &op_desc, OpDescPtr &transdata_op_desc); | |||||
bool FusionIfNeed(OpDescPtr &op_desc, NodePtr &node); | |||||
void CopyInputEdges(NodePtr &origin_node, NodePtr &new_node); | void CopyInputEdges(NodePtr &origin_node, NodePtr &new_node); | ||||
bool TransDataCheckAccuracySupported(const OpDescPtr &op_desc); | |||||
bool TransDataCheckAccuracySupported(NodePtr &node); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_TRANSPOSE_TRANSDATA_PASS_H_ | #endif // GE_GRAPH_PASSES_TRANSPOSE_TRANSDATA_PASS_H_ | ||||
@@ -690,6 +690,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/infershape_pass_unittest.cc" | "graph/passes/infershape_pass_unittest.cc" | ||||
"graph/passes/multi_batch_clone_pass_unittest.cc" | "graph/passes/multi_batch_clone_pass_unittest.cc" | ||||
"graph/passes/replace_with_empty_const_pass_unittest.cc" | "graph/passes/replace_with_empty_const_pass_unittest.cc" | ||||
"graph/passes/transpose_transdata_pass_unittest.cc" | |||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -53,26 +53,20 @@ TEST_F(UtestGeGenerator, test_build_single_op_offline) { | |||||
EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); | EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, "offline_"), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); | ||||
} | } | ||||
/* | |||||
TEST_F(UtestGeGenerator, test_build_single_op_online) { | TEST_F(UtestGeGenerator, test_build_single_op_online) { | ||||
GeTensorDesc tensor_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
TensorUtils::SetSize(tensor_desc, 512); | |||||
GeTensorDesc tensor_desc; | |||||
shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("Add", "add"); | shared_ptr<OpDesc> op_desc = make_shared<OpDesc>("Add", "add"); | ||||
EXPECT_EQ(op_desc->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
EXPECT_EQ(op_desc->AddInputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
EXPECT_EQ(op_desc->AddOutputDesc(tensor_desc), GRAPH_SUCCESS); | |||||
op_desc->AddInputDesc(tensor_desc); | |||||
op_desc->AddInputDesc(tensor_desc); | |||||
op_desc->AddOutputDesc(tensor_desc); | |||||
GeTensor tensor(tensor_desc); | GeTensor tensor(tensor_desc); | ||||
const vector<GeTensor> inputs = { tensor, tensor }; | const vector<GeTensor> inputs = { tensor, tensor }; | ||||
const vector<GeTensor> outputs = { tensor }; | const vector<GeTensor> outputs = { tensor }; | ||||
// not Initialize, impl is null. | |||||
GeGenerator generator; | GeGenerator generator; | ||||
generator.Initialize({}); | generator.Initialize({}); | ||||
ModelBufferData model_buffer; | ModelBufferData model_buffer; | ||||
EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_SYS, model_buffer), GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED); | |||||
EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_AIVECTOR, model_buffer), FAILED); | |||||
} | } | ||||
*/ | |||||
} // namespace ge | } // namespace ge |
@@ -0,0 +1,67 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <vector> | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/transpose_transdata_pass.h" | |||||
#include "graph_builder_utils.h" | |||||
#undef private | |||||
#undef protected | |||||
#include "graph/graph.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "common/types.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
namespace ge { | |||||
class UtestGraphPassesTransposeTransdataPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
static ComputeGraphPtr BuildGraphTransposeD() { | |||||
auto builder = ut::GraphBuilder("g1"); | |||||
auto transdata1 = builder.AddNode("transdata1", "TransData", 1, 1, FORMAT_NC1HWC0, DT_FLOAT, std::vector<int64_t>({1, 1, 224, 224, 16})); | |||||
transdata1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NHWC); | |||||
transdata1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 224, 224, 3}))); | |||||
auto transpose1 = builder.AddNode("transpose1", "TransposeD", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | |||||
transpose1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NHWC); | |||||
transpose1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 224, 224, 3}))); | |||||
auto transdata2 = builder.AddNode("transdata2", "TransData", 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({1, 3, 224, 224})); | |||||
transdata2->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); | |||||
transdata2->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape(std::vector<int64_t>({1, 1, 224, 224, 16}))); | |||||
builder.AddDataEdge(transdata1, 0, transpose1, 0); | |||||
builder.AddDataEdge(transpose1, 0, transdata2, 0); | |||||
return builder.GetGraph(); | |||||
} | |||||
TEST_F(UtestGraphPassesTransposeTransdataPass, test_run) { | |||||
auto compute_graph = BuildGraphTransposeD(); | |||||
compute_graph->SetSessionID(0); | |||||
auto transpose = compute_graph->FindNode("transpose1"); | |||||
TransposeTransDataPass pass; | |||||
EXPECT_EQ(pass.Run(transpose), SUCCESS); | |||||
} | |||||
} // namespace ge |