diff --git a/ge/graph/passes/memcpy_addr_async_pass.cc b/ge/graph/passes/memcpy_addr_async_pass.cc index aff89f35..e8e4ebd8 100755 --- a/ge/graph/passes/memcpy_addr_async_pass.cc +++ b/ge/graph/passes/memcpy_addr_async_pass.cc @@ -25,15 +25,15 @@ namespace ge { Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); - for (const auto &node : graph->GetAllNodes()) { - if (node->GetType() == STREAMSWITCH) { - auto sub_graph = node->GetOwnerComputeGraph(); - if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { - GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); + if (graph->GetGraphUnknownFlag()) { + for (const auto &node : graph->GetAllNodes()) { + if (node->GetType() == STREAMSWITCH) { + auto sub_graph = node->GetOwnerComputeGraph(); + if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { + GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); + } } } - } - if (graph->GetGraphUnknownFlag()) { GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); return SUCCESS; } diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index afe2f2b9..e957d119 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -715,6 +715,7 @@ set(PASS_TEST_FILES "graph/passes/mark_node_unknown_shape_pass_unittest.cc" "graph/passes/reshape_recovery_pass_unittest.cc" "graph/passes/cast_remove_pass_unittest.cc" + "graph/passes/memcpy_addr_async_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/passes/memcpy_addr_async_unittest.cc b/tests/ut/ge/graph/passes/memcpy_addr_async_unittest.cc new file mode 100644 index 00000000..e5bc450e --- /dev/null +++ b/tests/ut/ge/graph/passes/memcpy_addr_async_unittest.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#define private public +#include "graph/passes/memcpy_addr_async_pass.h" +#include "common/ge_inner_error_codes.h" +#include "inc/pass_manager.h" +#undef private + +namespace ge { +class UtestMemcpyAddrAsyncPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +TEST_F(UtestMemcpyAddrAsyncPass, run) { + ge::ComputeGraphPtr graph = std::make_shared("default"); + ge::OpDescPtr op = std::make_shared(); + op->SetType(STREAMSWITCH); + op->SetName("stream_switch"); + op->AddOutputDesc(ge::GeTensorDesc()); + ge::NodePtr node = graph->AddNode(op); + graph->SetGraphUnknownFlag(true); + MemcpyAddrAsyncPass pass; + Status ret = pass.Run(graph); + EXPECT_EQ(ret, SUCCESS); +} +} // namespace ge