@@ -25,15 +25,15 @@ | |||||
namespace ge { | namespace ge { | ||||
Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | ||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
for (const auto &node : graph->GetAllNodes()) { | |||||
if (node->GetType() == STREAMSWITCH) { | |||||
auto sub_graph = node->GetOwnerComputeGraph(); | |||||
if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { | |||||
GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); | |||||
if (graph->GetGraphUnknownFlag()) { | |||||
for (const auto &node : graph->GetAllNodes()) { | |||||
if (node->GetType() == STREAMSWITCH) { | |||||
auto sub_graph = node->GetOwnerComputeGraph(); | |||||
if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { | |||||
GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | |||||
if (graph->GetGraphUnknownFlag()) { | |||||
GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); | GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -715,6 +715,7 @@ set(PASS_TEST_FILES | |||||
"graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
"graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
"graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
"graph/passes/memcpy_addr_async_unittest.cc" | |||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -0,0 +1,47 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <cstdint> | |||||
#include <memory> | |||||
#include <string> | |||||
#define private public | |||||
#include "graph/passes/memcpy_addr_async_pass.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "inc/pass_manager.h" | |||||
#undef private | |||||
namespace ge { | |||||
class UtestMemcpyAddrAsyncPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(UtestMemcpyAddrAsyncPass, run) { | |||||
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default"); | |||||
ge::OpDescPtr op = std::make_shared<ge::OpDesc>(); | |||||
op->SetType(STREAMSWITCH); | |||||
op->SetName("stream_switch"); | |||||
op->AddOutputDesc(ge::GeTensorDesc()); | |||||
ge::NodePtr node = graph->AddNode(op); | |||||
graph->SetGraphUnknownFlag(true); | |||||
MemcpyAddrAsyncPass pass; | |||||
Status ret = pass.Run(graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
} // namespace ge |