From: @zhengyuanhua Reviewed-by: @youui,@liujunzhu Signed-off-by: @youuitags/v1.3.0
| @@ -25,15 +25,15 @@ | |||||
| namespace ge { | namespace ge { | ||||
| Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | ||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| for (const auto &node : graph->GetAllNodes()) { | |||||
| if (node->GetType() == STREAMSWITCH) { | |||||
| auto sub_graph = node->GetOwnerComputeGraph(); | |||||
| if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { | |||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); | |||||
| if (graph->GetGraphUnknownFlag()) { | |||||
| for (const auto &node : graph->GetAllNodes()) { | |||||
| if (node->GetType() == STREAMSWITCH) { | |||||
| auto sub_graph = node->GetOwnerComputeGraph(); | |||||
| if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { | |||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| if (graph->GetGraphUnknownFlag()) { | |||||
| GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); | GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -715,6 +715,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
| "graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
| "graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||||
| ) | ) | ||||
| set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <cstdint> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #define private public | |||||
| #include "graph/passes/memcpy_addr_async_pass.h" | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #undef private | |||||
| namespace ge { | |||||
| class UtestMemcpyAddrAsyncPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(UtestMemcpyAddrAsyncPass, run) { | |||||
| ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default"); | |||||
| ge::OpDescPtr op = std::make_shared<ge::OpDesc>(); | |||||
| op->SetType(STREAMSWITCH); | |||||
| op->SetName("stream_switch"); | |||||
| op->AddOutputDesc(ge::GeTensorDesc()); | |||||
| ge::NodePtr node = graph->AddNode(op); | |||||
| graph->SetGraphUnknownFlag(true); | |||||
| MemcpyAddrAsyncPass pass; | |||||
| Status ret = pass.Run(graph); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||