Browse Source

fix streamswitch

tags/v1.3.0
yangwei 3 years ago
parent
commit
3f2e8be1dc
5 changed files with 47 additions and 9 deletions
  1. +6
    -4
      ge/graph/load/model_manager/davinci_model.cc
  2. +1
    -1
      ge/graph/load/model_manager/task_info/stream_switch_task_info.cc
  3. +38
    -2
      ge/graph/passes/memcpy_addr_async_pass.cc
  4. +2
    -0
      ge/graph/passes/memcpy_addr_async_pass.h
  5. +0
    -2
      tests/ut/ge/CMakeLists.txt

+ 6
- 4
ge/graph/load/model_manager/davinci_model.cc View File

@@ -883,6 +883,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {
continue; continue;
} }


// for dynamic shape with control flow
SetLabelForDynamic(node);
auto it = op_desc_handle.find(op_desc->GetType()); auto it = op_desc_handle.find(op_desc->GetType());
if (it != op_desc_handle.end()) { if (it != op_desc_handle.end()) {
if ((this->*it->second)(op_desc) != SUCCESS) { if ((this->*it->second)(op_desc) != SUCCESS) {
@@ -891,8 +893,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {
} }
continue; continue;
} }
// for dynamic shape with control flow
SetLabelForDynamic(node);

if (IsNoTaskAndDumpNeeded(op_desc)) { if (IsNoTaskAndDumpNeeded(op_desc)) {
GELOGD("node[%s] without task, and save op_desc and addr for dump", op_desc->GetName().c_str()); GELOGD("node[%s] without task, and save op_desc and addr for dump", op_desc->GetName().c_str());
const RuntimeParam &rts_param = GetRuntimeParam(); const RuntimeParam &rts_param = GetRuntimeParam();
@@ -936,11 +937,12 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {
} }


void DavinciModel::SetLabelForDynamic(const NodePtr &node) { void DavinciModel::SetLabelForDynamic(const NodePtr &node) {
if (known_node_ && node->GetOpDesc()->GetType() == LABELSWITCHBYINDEX) {
if (known_node_ && (node->GetType() == LABELSWITCHBYINDEX || node->GetType() == STREAMSWITCH)) {
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor != nullptr) { if (peer_out_data_anchor != nullptr) {
string tensor_name = node->GetName();
// name+index as the label of switch input
string tensor_name = node->GetName() + std::to_string(in_data_anchor->GetIdx());
auto peer_node = peer_out_data_anchor->GetOwnerNode(); auto peer_node = peer_out_data_anchor->GetOwnerNode();
(void)AttrUtils::SetStr(peer_node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); (void)AttrUtils::SetStr(peer_node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name);
(void)AttrUtils::SetInt(peer_node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, 0); (void)AttrUtils::SetInt(peer_node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, 0);


+ 1
- 1
ge/graph/load/model_manager/task_info/stream_switch_task_info.cc View File

@@ -148,7 +148,7 @@ Status StreamSwitchTaskInfo::CalculateArgs(const domi::TaskDef &task_def, Davinc
return FAILED; return FAILED;
} }
for (uint32_t i = 0; i < STREAM_SWITCH_INPUT_NUM; ++i) { for (uint32_t i = 0; i < STREAM_SWITCH_INPUT_NUM; ++i) {
string input_tensor_name = op_desc->GetInputNameByIndex(i);
string input_tensor_name = op_desc->GetName() + std::to_string(i);
int64_t fixed_addr_offset = davinci_model->GetFixedAddrsSize(input_tensor_name); int64_t fixed_addr_offset = davinci_model->GetFixedAddrsSize(input_tensor_name);
fixed_addr_offset_.emplace_back(fixed_addr_offset); fixed_addr_offset_.emplace_back(fixed_addr_offset);
auto tensor_desc = op_desc->GetInputDesc(i); auto tensor_desc = op_desc->GetInputDesc(i);


+ 38
- 2
ge/graph/passes/memcpy_addr_async_pass.cc View File

@@ -25,6 +25,14 @@
namespace ge { namespace ge {
Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);
for (const auto &node : graph->GetAllNodes()) {
if (node->GetType() == STREAMSWITCH) {
auto sub_graph = node->GetOwnerComputeGraph();
if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) {
GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph.");
}
}
}
if (graph->GetGraphUnknownFlag()) { if (graph->GetGraphUnknownFlag()) {
GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str());
return SUCCESS; return SUCCESS;
@@ -63,6 +71,28 @@ Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) {
return SUCCESS; return SUCCESS;
} }


Status MemcpyAddrAsyncPass::AddMemcpyAsyncNode(const NodePtr &node) {
GE_CHECK_NOTNULL(node);
GELOGI("Start add memcpyasync node in front of node %s", node->GetName().c_str());
known_sub_graph_ = true;
auto sub_graph = node->GetOwnerComputeGraph();
for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
auto memcpy_async_node = CreateMemcpyAddrAsyncNode(sub_graph, peer_out_anchor, node);
if (memcpy_async_node == nullptr) {
GELOGE(INTERNAL_ERROR, "Create memcpyasync node failed.");
return INTERNAL_ERROR;
}
Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_async_node);
if (ret != SUCCESS) {
GELOGE(ret, "Insert memcpyasync node failed.");
return ret;
}
}
return SUCCESS;
}

Status MemcpyAddrAsyncPass::AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node) { Status MemcpyAddrAsyncPass::AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node) {
GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str()); GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str());
for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
@@ -208,9 +238,15 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr
static uint32_t new_node_index = 0; static uint32_t new_node_index = 0;
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc();
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid.");
std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++);


OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC);
OpDescPtr op_desc = nullptr;
if (known_sub_graph_) { // insert memcpyasync node when known sub graph
string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(new_node_index++);
op_desc = MakeShared<OpDesc>(node_name, MEMCPYASYNC);
} else {
string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++);
op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC);
}
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);


if (op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) { if (op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) {


+ 2
- 0
ge/graph/passes/memcpy_addr_async_pass.h View File

@@ -27,6 +27,7 @@ class MemcpyAddrAsyncPass : public GraphPass {


private: private:
Status AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node); Status AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node);
Status AddMemcpyAsyncNode(const NodePtr &node);
void FindUserData(const NodePtr &node, uint32_t &parent_index); void FindUserData(const NodePtr &node, uint32_t &parent_index);
void FindUserDataForKnown(const NodePtr &parent_node, uint32_t &parent_index); void FindUserDataForKnown(const NodePtr &parent_node, uint32_t &parent_index);
void FindUserDataForNonDynamic(const ge::NodePtr &parent_node, uint32_t &parent_index); void FindUserDataForNonDynamic(const ge::NodePtr &parent_node, uint32_t &parent_index);
@@ -48,6 +49,7 @@ class MemcpyAddrAsyncPass : public GraphPass {
OutDataAnchorPtr peer_out_anchor_for_known_; OutDataAnchorPtr peer_out_anchor_for_known_;
InDataAnchorPtr in_anchor_for_known_; InDataAnchorPtr in_anchor_for_known_;
bool find_user_data_for_known_ = false; bool find_user_data_for_known_ = false;
bool known_sub_graph_ = false;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_MEMCPY_ADDR_ASYNC_PASS_H_ #endif // GE_GRAPH_PASSES_MEMCPY_ADDR_ASYNC_PASS_H_

+ 0
- 2
tests/ut/ge/CMakeLists.txt View File

@@ -303,7 +303,6 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" "${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc"
"${GE_CODE_DIR}/ge/ir_build/attr_options/weight_compress_option.cc" "${GE_CODE_DIR}/ge/ir_build/attr_options/weight_compress_option.cc"
"${GE_CODE_DIR}/ge/graph/build/label_allocator.cc" "${GE_CODE_DIR}/ge/graph/build/label_allocator.cc"
"${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc"
"${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc"
"${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc" "${GE_CODE_DIR}/ge/graph/partition/dynamic_shape_partition.cc"
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc"
@@ -430,7 +429,6 @@ set(DISTINCT_GRAPH_LOAD_SRC_FILES
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/memcpy_async_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/memcpy_async_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/stream_active_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/stream_active_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/stream_switch_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/end_graph_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/end_graph_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/model_exit_task_info.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/model_exit_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc" "${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/super_kernel/super_kernel.cc"


Loading…
Cancel
Save