| @@ -26,12 +26,17 @@ | |||||
| namespace ge { | namespace ge { | ||||
| LabelAllocator::LabelAllocator(const ComputeGraphPtr &graph) : compute_graph_(graph) {} | LabelAllocator::LabelAllocator(const ComputeGraphPtr &graph) : compute_graph_(graph) {} | ||||
| Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { | |||||
| Status LabelAllocator::AssignFunctionalLabels() { | |||||
| if (compute_graph_ == nullptr) { | if (compute_graph_ == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "ComputeGraph not set, Assign labels failed."); | GELOGE(INTERNAL_ERROR, "ComputeGraph not set, Assign labels failed."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (compute_graph_->GetGraphUnknownFlag()) { | |||||
| GELOGD("Graph[%s] is unknown graph, skip label allocator.", compute_graph_->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| // Add label task for sub graph. | // Add label task for sub graph. | ||||
| GELOGI("AssignFunctionalLabels start: %s.", compute_graph_->GetName().c_str()); | GELOGI("AssignFunctionalLabels start: %s.", compute_graph_->GetName().c_str()); | ||||
| std::set<NodePtr> functional_nodes; | std::set<NodePtr> functional_nodes; | ||||
| @@ -42,7 +47,7 @@ Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { | |||||
| } | } | ||||
| // Add label for functional op. | // Add label for functional op. | ||||
| label_index = 0; | |||||
| uint32_t label_index = 0; | |||||
| for (auto node : functional_nodes) { | for (auto node : functional_nodes) { | ||||
| LabelMakerPtr maker = LabelMakerFactory::Instance().Create(node->GetType(), compute_graph_, node); | LabelMakerPtr maker = LabelMakerFactory::Instance().Create(node->GetType(), compute_graph_, node); | ||||
| if (maker == nullptr) { | if (maker == nullptr) { | ||||
| @@ -56,6 +61,7 @@ Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { | |||||
| } | } | ||||
| } | } | ||||
| (void)AttrUtils::SetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_index); | |||||
| GELOGI("AssignFunctionalLabels success."); | GELOGI("AssignFunctionalLabels success."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ class LabelAllocator { | |||||
| explicit LabelAllocator(const ComputeGraphPtr &graph); | explicit LabelAllocator(const ComputeGraphPtr &graph); | ||||
| ~LabelAllocator() = default; | ~LabelAllocator() = default; | ||||
| Status AssignFunctionalLabels(uint32_t &label_index); | |||||
| Status AssignFunctionalLabels(); | |||||
| private: | private: | ||||
| bool CollectFunctionalNode(ComputeGraphPtr &graph, std::set<NodePtr> &functional_nodes); | bool CollectFunctionalNode(ComputeGraphPtr &graph, std::set<NodePtr> &functional_nodes); | ||||
| @@ -348,7 +348,11 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr | |||||
| auto compute_graph = subgraph->subgraph_info.GetSubGraph(); | auto compute_graph = subgraph->subgraph_info.GetSubGraph(); | ||||
| for (NodePtr &node : compute_graph->GetDirectNode()) { | for (NodePtr &node : compute_graph->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| if (IsEngineSkip(*subgraph) && node->GetInNodes().empty()) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_RTS_LABEL_NODE)) { | |||||
| node->GetOpDesc()->SetStreamId(context.default_stream); | |||||
| GELOGD("Node %s of type %s in subgraph %s is assigned parent stream %ld (engine: %s).", node->GetName().c_str(), | |||||
| node->GetType().c_str(), subgraph->name.c_str(), context.default_stream, engine_name.c_str()); | |||||
| } else if (IsEngineSkip(*subgraph) && node->GetInNodes().empty()) { | |||||
| GELOGD("Node %s of type %s in subgraph %s doesn't need to assign a stream (engine: %s).", | GELOGD("Node %s of type %s in subgraph %s doesn't need to assign a stream (engine: %s).", | ||||
| node->GetName().c_str(), node->GetType().c_str(), subgraph->name.c_str(), engine_name.c_str()); | node->GetName().c_str(), node->GetType().c_str(), subgraph->name.c_str(), engine_name.c_str()); | ||||
| } else { | } else { | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
| #include "graph/buffer.h" | #include "graph/buffer.h" | ||||
| #include "graph/build/label_allocator.h" | |||||
| #include "graph/build/stream_allocator.h" | #include "graph/build/stream_allocator.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
| @@ -42,7 +41,6 @@ | |||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/passes/memcpy_addr_async_pass.h" | |||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #include "memory/memory_assigner.h" | #include "memory/memory_assigner.h" | ||||
| #include "omg/version.h" | #include "omg/version.h" | ||||
| @@ -692,25 +690,8 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { | |||||
| GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); | GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); | ||||
| // Assign functional op labels. | // Assign functional op labels. | ||||
| GE_TIMESTAMP_START(AssignFunctionalLabels); | |||||
| LabelAllocator label_allocator(compute_graph_); | |||||
| GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed."); | |||||
| GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); | |||||
| // Add memcpy_addr_async node. | |||||
| rtFeatureType_t feature_type = FEATURE_TYPE_MEMCPY; | |||||
| int32_t feature_info = MEMCPY_INFO_SUPPORT_ZEROCOPY; | |||||
| int64_t value = 0; | |||||
| rtError_t rt_ret = rtGetRtCapability(feature_type, feature_info, &value); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtGetRtCapability failed."); | |||||
| return RT_FAILED; | |||||
| } else { | |||||
| GE_TIMESTAMP_START(AddMemcpyAddrAsyncNode); | |||||
| MemcpyAddrAsyncPass memcpy_addr; | |||||
| GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph_), "Add memcpy_addr_async node failed."); | |||||
| GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); | |||||
| } | |||||
| label_num_ = 0; | |||||
| (void)AttrUtils::GetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_num_); | |||||
| GE_TIMESTAMP_START(AssignMemory); | GE_TIMESTAMP_START(AssignMemory); | ||||
| MemoryAssigner mem_assigner(compute_graph_); | MemoryAssigner mem_assigner(compute_graph_); | ||||
| @@ -23,75 +23,65 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| namespace { | |||||
| const int64_t kInvalidStreamId = -1; | |||||
| } // namespace | |||||
| namespace ge { | namespace ge { | ||||
| /** | /** | ||||
| * @ingroup ge | * @ingroup ge | ||||
| * @brief Set stream id for head node. | |||||
| * @brief Link node to graph head. | |||||
| * @param [in] graph: graph for add node. | * @param [in] graph: graph for add node. | ||||
| * @param [in] op_desc: OpDesc for set logical stream id. | |||||
| * @param [in] node: Node add to graph head. | |||||
| * @return: void | * @return: void | ||||
| */ | */ | ||||
| void LabelMaker::SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||||
| int64_t stream_id = kInvalidStreamId; | |||||
| const auto &node_list = graph->GetDirectNode(); | |||||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||||
| const auto &node = node_list.at(i); | |||||
| GE_CHECK_NOTNULL_EXEC(node, continue); | |||||
| void LabelMaker::LinkToGraphHead(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
| static const std::set<std::string> non_calc_types = { DATA, CONSTANT, CONSTANTOP, VARIABLE }; | |||||
| for (auto &n : graph->GetDirectNode()) { | |||||
| if (non_calc_types.count(n->GetType()) > 0) { | |||||
| continue; | |||||
| } | |||||
| stream_id = node->GetOpDesc()->GetStreamId(); | |||||
| if (stream_id != kInvalidStreamId) { | |||||
| break; | |||||
| const auto nodes = n->GetInDataNodes(); | |||||
| if (nodes.empty()) { | |||||
| continue; | |||||
| } | } | ||||
| } | |||||
| GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
| op_desc->SetStreamId(stream_id); | |||||
| } | |||||
| bool is_head_node = true; | |||||
| for (auto &in_node : nodes) { | |||||
| if (non_calc_types.count(in_node->GetType()) == 0) { | |||||
| is_head_node = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief Set stream id for tail node. | |||||
| * @param [in] graph: graph for add node. | |||||
| * @param [in] op_desc: OpDesc for set logical stream id. | |||||
| * @return: void | |||||
| */ | |||||
| void LabelMaker::SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||||
| int64_t stream_id = kInvalidStreamId; | |||||
| const auto &node_list = graph->GetDirectNode(); | |||||
| for (size_t i = node_list.size(); i > 0; --i) { | |||||
| const auto &node = node_list.at(i - 1); // i from list size, need shift 1. | |||||
| GE_CHECK_NOTNULL_EXEC(node, continue); | |||||
| if (!is_head_node) { | |||||
| continue; | |||||
| } | |||||
| stream_id = node->GetOpDesc()->GetStreamId(); | |||||
| if (stream_id != kInvalidStreamId) { | |||||
| break; | |||||
| if (GraphUtils::AddEdge(node->GetOutControlAnchor(), n->GetInControlAnchor()) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", node->GetName().c_str(), n->GetName().c_str()); | |||||
| } | } | ||||
| } | } | ||||
| GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
| op_desc->SetStreamId(stream_id); | |||||
| } | } | ||||
| /** | /** | ||||
| * @ingroup ge | * @ingroup ge | ||||
| * @brief Set stream id for parent node. | |||||
| * @brief Link node to graph tail. | |||||
| * @param [in] graph: graph for add node. | * @param [in] graph: graph for add node. | ||||
| * @param [in] op_desc: OpDesc for set logical stream id. | |||||
| * @param [in] node: Node add to graph tail. | |||||
| * @return: void | * @return: void | ||||
| */ | */ | ||||
| void LabelMaker::SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||||
| int64_t stream_id = kInvalidStreamId; | |||||
| const auto &node = graph->GetParentNode(); | |||||
| if (node != nullptr) { | |||||
| stream_id = node->GetOpDesc()->GetStreamId(); | |||||
| } | |||||
| void LabelMaker::LinkToGraphTail(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
| auto tail = graph->FindFirstNodeMatchType(NETOUTPUT); | |||||
| while (tail != nullptr) { | |||||
| auto nodes = tail->GetOutControlNodes(); | |||||
| if (!nodes.empty()) { | |||||
| tail = nodes.at(0); | |||||
| continue; | |||||
| } | |||||
| GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
| op_desc->SetStreamId(stream_id); | |||||
| if (GraphUtils::AddEdge(tail->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", tail->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| return; | |||||
| } | |||||
| } | } | ||||
| /** | /** | ||||
| @@ -112,7 +102,7 @@ NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::str | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE); | OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdOwner(graph, op_desc); | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
| GELOGI("StreamActive: Create node %s.", op_desc->GetName().c_str()); | GELOGI("StreamActive: Create node %s.", op_desc->GetName().c_str()); | ||||
| vector<uint32_t> active_streams; | vector<uint32_t> active_streams; | ||||
| @@ -122,6 +112,7 @@ NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::str | |||||
| NodePtr stream_active = graph->AddNodeFront(op_desc); | NodePtr stream_active = graph->AddNodeFront(op_desc); | ||||
| GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); | GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); | ||||
| LinkToGraphHead(graph, stream_active); | |||||
| return stream_active; | return stream_active; | ||||
| } | } | ||||
| @@ -146,7 +137,7 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdOwner(graph, op_desc); | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
| GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | ||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
| @@ -173,19 +164,9 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st | |||||
| NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | ||||
| GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | ||||
| const auto &node_list = graph->GetDirectNode(); | |||||
| auto it = node_list.end(); | |||||
| if (it == node_list.begin()) { | |||||
| GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| --it; | |||||
| const NodePtr &node = *it; | |||||
| GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdOwner(graph, op_desc); | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
| GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | ||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
| @@ -194,11 +175,7 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st | |||||
| GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | ||||
| // Link control edge to graph tail. | // Link control edge to graph tail. | ||||
| if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_set->GetInControlAnchor()) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "LabelSet: Add ctrl edge to %s failed.", node->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| LinkToGraphTail(graph, label_set); | |||||
| return label_set; | return label_set; | ||||
| } | } | ||||
| @@ -222,7 +199,7 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdOwner(graph, op_desc); | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
| GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | ||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
| @@ -246,32 +223,17 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s | |||||
| NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | ||||
| GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | ||||
| const auto &node_list = graph->GetDirectNode(); | |||||
| auto it = node_list.end(); | |||||
| if (it == node_list.begin()) { | |||||
| GELOGE(INTERNAL_ERROR, "LabelGoto: Graph %s node is empty.", graph->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| --it; | |||||
| const NodePtr &node = *it; | |||||
| GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdLeave(graph, op_desc); | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
| GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | ||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
| NodePtr label_goto = graph->AddNode(op_desc); | NodePtr label_goto = graph->AddNode(op_desc); | ||||
| GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr); | ||||
| SetStreamIdOwner(graph, op_desc); | |||||
| // Link control edge to graph tail. | // Link control edge to graph tail. | ||||
| if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_goto->GetInControlAnchor()) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "LabelGoto: Add ctrl edge to %s failed.", node->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| LinkToGraphTail(graph, label_goto); | |||||
| return label_goto; | return label_goto; | ||||
| } | } | ||||
| @@ -297,7 +259,7 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdOwner(graph, op_desc); | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
| GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | ||||
| if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | ||||
| @@ -332,19 +294,9 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: | |||||
| const std::vector<uint32_t> &labels) { | const std::vector<uint32_t> &labels) { | ||||
| GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | ||||
| const auto &node_list = graph->GetDirectNode(); | |||||
| auto it = node_list.end(); | |||||
| if (it == node_list.begin()) { | |||||
| GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Graph %s node is empty.", graph->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| --it; | |||||
| const NodePtr &node = *it; | |||||
| GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdOwner(graph, op_desc); | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
| GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | ||||
| if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | ||||
| @@ -361,11 +313,7 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: | |||||
| GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr); | ||||
| // Link control edge to graph tail. | // Link control edge to graph tail. | ||||
| if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_switch->GetInControlAnchor()) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", node->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| LinkToGraphTail(graph, label_switch); | |||||
| return label_switch; | return label_switch; | ||||
| } | } | ||||
| @@ -385,7 +333,6 @@ NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std: | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| op_desc->SetStreamId(kInvalidStreamId); | |||||
| GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | ||||
| if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | ||||
| @@ -60,9 +60,8 @@ class LabelMaker { | |||||
| ComputeGraphPtr parent_graph_; | ComputeGraphPtr parent_graph_; | ||||
| private: | private: | ||||
| void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||||
| void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||||
| void SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||||
| void LinkToGraphHead(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
| void LinkToGraphTail(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_LABEL_MAKER_H_ | #endif // GE_GRAPH_PASSES_LABEL_MAKER_H_ | ||||
| @@ -100,6 +100,8 @@ | |||||
| #include "graph/passes/subgraph_const_migration_pass.h" | #include "graph/passes/subgraph_const_migration_pass.h" | ||||
| #include "graph/passes/unused_args_clean_pass.h" | #include "graph/passes/unused_args_clean_pass.h" | ||||
| #include "graph/passes/global_step_insert_pass.h" | #include "graph/passes/global_step_insert_pass.h" | ||||
| #include "graph/passes/memcpy_addr_async_pass.h" | |||||
| #include "graph/build/label_allocator.h" | |||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/graph_util.h" | #include "graph/graph_util.h" | ||||
| @@ -634,6 +636,13 @@ Status GraphManager::PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node, | |||||
| GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", | GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", | ||||
| GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts, | GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts, | ||||
| compute_graph); | compute_graph); | ||||
| Status ret = compute_graph->TopologicalSorting(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Graph topological sort failed, ret:%d.", ret); | |||||
| return ret; | |||||
| } | |||||
| GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); | GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); | ||||
| GELOGI("PreRun:PreRunAfterOptimizeSubGraph success."); | GELOGI("PreRun:PreRunAfterOptimizeSubGraph success."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -2180,6 +2189,18 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // Assign functional op labels. | |||||
| GE_TIMESTAMP_START(AssignFunctionalLabels); | |||||
| LabelAllocator label_allocator(compute_graph); | |||||
| GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(), "Assign label failed."); | |||||
| GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); | |||||
| // Add memcpy addr asynchronous node. | |||||
| GE_TIMESTAMP_START(AddMemcpyAddrAsyncNode); | |||||
| MemcpyAddrAsyncPass memcpy_addr; | |||||
| GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed."); | |||||
| GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); | |||||
| // After while sub graph handle, mark all node rw type | // After while sub graph handle, mark all node rw type | ||||
| auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph); | auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph); | ||||
| if (result != SUCCESS) { | if (result != SUCCESS) { | ||||
| @@ -2190,11 +2211,6 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||||
| ChangeConstTypeWhenTraining(compute_graph); | ChangeConstTypeWhenTraining(compute_graph); | ||||
| ret = compute_graph->TopologicalSorting(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Graph topological sort failed, ret:%d.", ret); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("End optimize after merge sub graph."); | GELOGI("End optimize after merge sub graph."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -202,7 +202,7 @@ Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { | |||||
| GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); | GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | |||||
| base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | |||||
| mem_size = rdma_mem_size_; | mem_size = rdma_mem_size_; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ namespace ge { | |||||
| Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | ||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| auto node_type = NodeUtils::GetNodeType(*node); | auto node_type = NodeUtils::GetNodeType(*node); | ||||
| if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { | |||||
| if (node_type == SWITCH || node_type == SWITCHN) { | |||||
| GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | ||||
| const OpDescPtr op_desc = node->GetOpDesc(); | const OpDescPtr op_desc = node->GetOpDesc(); | ||||
| const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | ||||
| @@ -37,10 +37,15 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | |||||
| if (node_type == IDENTITY) { | if (node_type == IDENTITY) { | ||||
| GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | ||||
| AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | ||||
| continue; | |||||
| } | |||||
| if (node_type == REFMERGE || node_type == REFSWITCH) { | |||||
| GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); | |||||
| AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
| AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (node_type == MERGE || node_type == REFMERGE) { | |||||
| if (node_type == MERGE) { | |||||
| GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | ||||
| const OpDescPtr op_desc = node->GetOpDesc(); | const OpDescPtr op_desc = node->GetOpDesc(); | ||||
| const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | ||||
| @@ -25,6 +25,14 @@ | |||||
| namespace ge { | namespace ge { | ||||
| Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | ||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| int64_t value = 0; | |||||
| rtError_t rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMCPY, MEMCPY_INFO_SUPPORT_ZEROCOPY, &value); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtGetRtCapability failed, error=0x%x.", rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| for (auto &node : graph->GetAllNodes()) { | for (auto &node : graph->GetAllNodes()) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | ||||
| @@ -210,9 +218,18 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| int64_t stream_id = out_of_user_data->GetOpDesc()->GetStreamId(); | |||||
| op_desc->SetStreamId(stream_id); | |||||
| GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
| string stream_label; | |||||
| if (AttrUtils::GetStr(out_of_user_data->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { | |||||
| (void)AttrUtils::SetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label); | |||||
| GELOGD("Node %s set stream label: %s", op_desc->GetName().c_str(), stream_label.c_str()); | |||||
| } | |||||
| bool rts_label_node = false; | |||||
| if (AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_RTS_LABEL_NODE, rts_label_node)) { | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, rts_label_node); | |||||
| GELOGD("Node %s set rts label node attribute", op_desc->GetName().c_str()); | |||||
| } | |||||
| bool labeled_input = false; | bool labeled_input = false; | ||||
| (void)ge::AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, labeled_input); | (void)ge::AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, labeled_input); | ||||
| if (labeled_input) { | if (labeled_input) { | ||||
| @@ -79,6 +79,13 @@ Status MergePass::Run(NodePtr &node) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| auto in_node = in_data_nodes.at(0); | |||||
| if (IsMergeInputNeedOptimized(in_node)) { | |||||
| if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) { | |||||
| GELOGE(FAILED, "Isolate and delete node %s failed.", in_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return IsolateAndDeleteNode(node, merge_io_map); | return IsolateAndDeleteNode(node, merge_io_map); | ||||
| } | } | ||||
| default: { | default: { | ||||
| @@ -172,4 +179,27 @@ Status MergePass::CreateConstByValue(NodePtr &node, int value_index, OpDescPtr & | |||||
| GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed"); | GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| // node is not inserted by MergeInputMemcpyPass | |||||
| if ((node->GetType() != MEMCPYASYNC) && (node->GetType() != MEMCPYADDRASYNC)) { | |||||
| return false; | |||||
| } | |||||
| if (node->GetInDataNodes().size() != 1) { | |||||
| return false; | |||||
| } | |||||
| auto in_node = node->GetInDataNodes().at(0); | |||||
| if (in_node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| // in_node may be global_step var | |||||
| if ((in_node->GetType() == VARIABLE) || (in_node->GetType() == VARIABLEV2)) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -28,6 +28,7 @@ class MergePass : public BaseNodePass { | |||||
| bool IsNeedChangeIndexToConstant(NodePtr &node) const; | bool IsNeedChangeIndexToConstant(NodePtr &node) const; | ||||
| Status ChangeIndexToConstant(NodePtr &node, int &value_index); | Status ChangeIndexToConstant(NodePtr &node, int &value_index); | ||||
| Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); | Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); | ||||
| bool IsMergeInputNeedOptimized(NodePtr &node) const; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_MERGE_PASS_H_ | #endif // GE_GRAPH_PASSES_MERGE_PASS_H_ | ||||
| @@ -173,14 +173,17 @@ Status NextIterationPass::FindWhileGroups() { | |||||
| NodePtr next_node = nullptr; | NodePtr next_node = nullptr; | ||||
| if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { | if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Get NextIteration node failed."); | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s", | |||||
| out_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | ||||
| NodePtr switch_node = nullptr; | NodePtr switch_node = nullptr; | ||||
| if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { | if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Get Switch node failed."); | |||||
| GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s", | |||||
| out_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (switch_node == nullptr) { | if (switch_node == nullptr) { | ||||
| @@ -189,7 +192,9 @@ Status NextIterationPass::FindWhileGroups() { | |||||
| NodePtr loop_cond = nullptr; | NodePtr loop_cond = nullptr; | ||||
| if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { | if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Get LoopCond node failed."); | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s", | |||||
| switch_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (batch_iter.second->loop_cond == nullptr) { | if (batch_iter.second->loop_cond == nullptr) { | ||||
| @@ -117,6 +117,7 @@ | |||||
| #include "graph/passes/variable_op_pass.h" | #include "graph/passes/variable_op_pass.h" | ||||
| #include "graph/passes/variable_prepare_op_pass.h" | #include "graph/passes/variable_prepare_op_pass.h" | ||||
| #include "graph/passes/variable_ref_delete_op_pass.h" | #include "graph/passes/variable_ref_delete_op_pass.h" | ||||
| #include "graph/passes/mark_agnostic_pass.h" | |||||
| namespace ge { | namespace ge { | ||||
| @@ -1626,6 +1627,7 @@ Status GraphPrepare::PrepareOptimize() { | |||||
| try { | try { | ||||
| (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | ||||
| (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | ||||
| (void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass); | |||||
| } catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
| GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -40,8 +40,6 @@ using domi::AippOpParams; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const char *const kMbatchSwitchnName = "mbatch-switch-name"; | const char *const kMbatchSwitchnName = "mbatch-switch-name"; | ||||
| const int64_t kFormatAgnosticSwitch = 1; | |||||
| const int64_t kFormatDependInputIndex = 1; | |||||
| } // namespace | } // namespace | ||||
| static void ConvertShape2Nhwc(Format &format, vector<int64_t> &shape_vec) { | static void ConvertShape2Nhwc(Format &format, vector<int64_t> &shape_vec) { | ||||
| if ((format == FORMAT_NHWC) || (shape_vec.size() != static_cast<size_t>(NORMAL_TENSOR_SIZE))) { | if ((format == FORMAT_NHWC) || (shape_vec.size() != static_cast<size_t>(NORMAL_TENSOR_SIZE))) { | ||||
| @@ -269,23 +267,6 @@ Status InsertNewOpUtil::GetAippParams(const std::unique_ptr<domi::AippOpParams> | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status InsertNewOpUtil::AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node) { | |||||
| GE_CHECK_NOTNULL(aipp_node); | |||||
| auto next_nodes = aipp_node->GetOutDataNodes(); | |||||
| for (const auto next_node : next_nodes) { | |||||
| GE_CHECK_NOTNULL(next_node); | |||||
| auto op_desc = next_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (op_desc->GetType() == SWITCHN) { | |||||
| GELOGI("Find switchn node [%s] after aipp [%s]", op_desc->GetName().c_str(), aipp_node->GetName().c_str()); | |||||
| (void)AttrUtils::SetInt(op_desc, "_format_agnostic", kFormatAgnosticSwitch); | |||||
| (void)AttrUtils::SetListInt(op_desc, "_format_agnostic_except_input", | |||||
| std::vector<int64_t>({kFormatDependInputIndex})); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | ||||
| std::map<std::string, NodePtr> switchn_names_to_data; | std::map<std::string, NodePtr> switchn_names_to_data; | ||||
| std::set<NodePtr> updated_switchn; | std::set<NodePtr> updated_switchn; | ||||
| @@ -300,9 +281,6 @@ Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | |||||
| } | } | ||||
| if (node->GetType() == AIPP) { | if (node->GetType() == AIPP) { | ||||
| GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); | GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); | ||||
| // In dynamic batch/HW and dynamic aipp scend, switchn should be set format agnostic, otherwise transdata maybe | |||||
| // inserted between aipp and switchn which introduce performance and memory increase problem. | |||||
| GE_RETURN_IF_ERROR(AddFormatAgnosticAttrToSwitchn(node)); | |||||
| } | } | ||||
| if (node->GetType() == CASE && node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | if (node->GetType() == CASE && node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | ||||
| multbatch_case = node; | multbatch_case = node; | ||||
| @@ -68,7 +68,6 @@ class InsertNewOpUtil { | |||||
| void UpdateMultiBatchInputDims(const OpDescPtr &data_opdesc, Format &old_format); | void UpdateMultiBatchInputDims(const OpDescPtr &data_opdesc, Format &old_format); | ||||
| Status UpdatePrevNodeByAipp(NodePtr &node, std::set<NodePtr> &switchns); | Status UpdatePrevNodeByAipp(NodePtr &node, std::set<NodePtr> &switchns); | ||||
| Status UpdateDataBySwitchN(const NodePtr &switchn, const NodePtr &data); | Status UpdateDataBySwitchN(const NodePtr &switchn, const NodePtr &data); | ||||
| Status AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node); | |||||
| Status GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std::set<NodePtr>> &data_next_node_map); | Status GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std::set<NodePtr>> &data_next_node_map); | ||||
| Status GetAllAipps(const NodePtr &data_node, const NodePtr &node, std::vector<NodePtr> &aipps); | Status GetAllAipps(const NodePtr &data_node, const NodePtr &node, std::vector<NodePtr> &aipps); | ||||
| Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output); | Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output); | ||||
| @@ -45,16 +45,9 @@ NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { | |||||
| NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} | NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} | ||||
| void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | ||||
| void *try_reuse_addr = nullptr; | |||||
| size_t allocate_size = size; | size_t allocate_size = size; | ||||
| MemStorageType mem_type = HBM; | MemStorageType mem_type = HBM; | ||||
| if (attr != nullptr) { | if (attr != nullptr) { | ||||
| try_reuse_addr = attr->try_reuse_addr_; | |||||
| if (attr->padding_ != 0) { | |||||
| // padding up to multiple of attr->padding, and add extra attr->padding_ | |||||
| allocate_size = (size + 2 * attr->padding_ - 1) / attr->padding_ * attr->padding_; | |||||
| GELOGD("Padding size %ld by %d. final size = %zu.", size, attr->padding_, allocate_size); | |||||
| } | |||||
| mem_type = attr->mem_type_; | mem_type = attr->mem_type_; | ||||
| } | } | ||||
| @@ -69,6 +62,17 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | |||||
| } else if (mem_type == HOST_DDR) { | } else if (mem_type == HOST_DDR) { | ||||
| buffer = malloc(allocate_size); | buffer = malloc(allocate_size); | ||||
| } else { | } else { | ||||
| void *try_reuse_addr = nullptr; | |||||
| int padding = kDefaultPadding; | |||||
| if (attr != nullptr) { | |||||
| try_reuse_addr = attr->try_reuse_addr_; | |||||
| if (attr->padding_ > 0) { | |||||
| padding = attr->padding_; | |||||
| } | |||||
| } | |||||
| // padding up to multiple of padding, and add extra padding | |||||
| allocate_size = (size + 2 * padding - 1) / padding * padding; | |||||
| GELOGD("Padding size %ld by %d. final size = %zu.", size, padding, allocate_size); | |||||
| buffer = MemManager::Instance() | buffer = MemManager::Instance() | ||||
| .CachingInstance(RT_MEMORY_HBM) | .CachingInstance(RT_MEMORY_HBM) | ||||
| .Malloc(allocate_size, reinterpret_cast<uint8_t *>(try_reuse_addr), device_id_); | .Malloc(allocate_size, reinterpret_cast<uint8_t *>(try_reuse_addr), device_id_); | ||||
| @@ -120,11 +120,13 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| output_idx, | output_idx, | ||||
| output_tensor->GetSize()); | output_tensor->GetSize()); | ||||
| GE_CHK_RT_RET(rtMemcpy(host_buffer.data(), | |||||
| tensor_size, | |||||
| output_tensor->GetData(), | |||||
| tensor_size, | |||||
| RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| if (tensor_size > 0) { | |||||
| GE_CHK_RT_RET(rtMemcpy(host_buffer.data(), | |||||
| tensor_size, | |||||
| output_tensor->GetData(), | |||||
| tensor_size, | |||||
| RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| } | |||||
| tensor.SetData(std::move(host_buffer)); | tensor.SetData(std::move(host_buffer)); | ||||
| string session_id = std::to_string(context_->GetSessionId()); | string session_id = std::to_string(context_->GetSessionId()); | ||||
| RuntimeInferenceContext *runtime_infer_ctx = nullptr; | RuntimeInferenceContext *runtime_infer_ctx = nullptr; | ||||
| @@ -257,7 +257,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| } | } | ||||
| // cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
| if (node_item.node_type == IF || node_item.node_type == CASE) { | |||||
| if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | |||||
| const auto &in_anchor = ge_node->GetInDataAnchor(0); | const auto &in_anchor = ge_node->GetInDataAnchor(0); | ||||
| GE_CHECK_NOTNULL(in_anchor); | GE_CHECK_NOTNULL(in_anchor); | ||||
| const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); | const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); | ||||
| @@ -701,6 +701,9 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), | GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), | ||||
| "[%s] Failed to identify ref outputs.", | "[%s] Failed to identify ref outputs.", | ||||
| parent_node_item->NodeName().c_str()); | parent_node_item->NodeName().c_str()); | ||||
| GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), | |||||
| "[%s] Failed to identify same outputs.", | |||||
| parent_node_item->NodeName().c_str()); | |||||
| // if parent is function control op. need add a virtual partitioned call | // if parent is function control op. need add a virtual partitioned call | ||||
| if (parent_node_item->IsControlOp()) { | if (parent_node_item->IsControlOp()) { | ||||
| @@ -917,7 +920,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr | |||||
| auto parent_node = sub_graph.GetParentNode(); | auto parent_node = sub_graph.GetParentNode(); | ||||
| GE_CHECK_NOTNULL(parent_node); | GE_CHECK_NOTNULL(parent_node); | ||||
| auto op_type = parent_node->GetType(); | auto op_type = parent_node->GetType(); | ||||
| if (op_type == IF || op_type == CASE || op_type == WHILE) { | |||||
| if (IsControlOp(op_type)) { | |||||
| GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", | GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", | ||||
| sub_graph.GetName().c_str(), | sub_graph.GetName().c_str(), | ||||
| ge_model->GetModelTaskDefPtr()->task_size()); | ge_model->GetModelTaskDefPtr()->task_size()); | ||||
| @@ -1162,6 +1165,46 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { | |||||
| GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); | |||||
| auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | |||||
| if (net_output_node == nullptr) { | |||||
| GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto net_output_desc = net_output_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(net_output_desc); | |||||
| std::map<std::string, int> connected_inputs; | |||||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
| auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto src_node = out_data_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto op_desc = src_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||||
| auto it = connected_inputs.find(input_key); | |||||
| if (it == connected_inputs.end()) { | |||||
| connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); | |||||
| } else { | |||||
| GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||||
| in_data_anchor->GetIdx(), | |||||
| it->second, | |||||
| src_node->GetName().c_str(), | |||||
| out_data_anchor->GetIdx()); | |||||
| node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | ||||
| GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | ||||
| auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | ||||
| @@ -59,6 +59,7 @@ class HybridModelBuilder { | |||||
| Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
| Status LoadTasks(); | Status LoadTasks(); | ||||
| Status IdentifyVariableOutputs(NodeItem &node_item); | Status IdentifyVariableOutputs(NodeItem &node_item); | ||||
| Status IdentifySameInputs(NodeItem &node_item); | |||||
| Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
| Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
| Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | ||||
| @@ -28,6 +28,9 @@ namespace hybrid { | |||||
| namespace { | namespace { | ||||
| const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
| const char * const kNodeTypeRetVal = "_RetVal"; | const char * const kNodeTypeRetVal = "_RetVal"; | ||||
| std::set<std::string> kControlOpTypes { | |||||
| IF, STATELESSIF, CASE, WHILE, STATELESSWHILE | |||||
| }; | |||||
| Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
| uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
| @@ -102,6 +105,11 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool IsControlOp(const std::string &op_type) { | |||||
| return kControlOpTypes.count(op_type) > 0; | |||||
| } | |||||
| NodeItem::NodeItem(NodePtr node): node(std::move(node)) { | NodeItem::NodeItem(NodePtr node): node(std::move(node)) { | ||||
| this->op_desc = this->node->GetOpDesc().get(); | this->op_desc = this->node->GetOpDesc().get(); | ||||
| this->node_id = this->op_desc->GetId(); | this->node_id = this->op_desc->GetId(); | ||||
| @@ -153,8 +161,7 @@ Status NodeItem::Init() { | |||||
| } | } | ||||
| bool NodeItem::IsControlOp() const { | bool NodeItem::IsControlOp() const { | ||||
| auto op_type = op_desc->GetType(); | |||||
| return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR; | |||||
| return ge::hybrid::IsControlOp(op_desc->GetType()); | |||||
| } | } | ||||
| std::string NodeItem::DebugString() const { | std::string NodeItem::DebugString() const { | ||||
| @@ -36,6 +36,8 @@ struct FusedSubgraph { | |||||
| ComputeGraphPtr graph; | ComputeGraphPtr graph; | ||||
| }; | }; | ||||
| bool IsControlOp(const std::string &op_type); | |||||
| // for caching static information across execution | // for caching static information across execution | ||||
| struct NodeItem { | struct NodeItem { | ||||
| explicit NodeItem(NodePtr node); | explicit NodeItem(NodePtr node); | ||||
| @@ -83,6 +85,7 @@ struct NodeItem { | |||||
| const NodeExecutor *node_executor = nullptr; | const NodeExecutor *node_executor = nullptr; | ||||
| std::map<int, ge::NodePtr> ref_outputs; | std::map<int, ge::NodePtr> ref_outputs; | ||||
| std::map<int, int> reuse_inputs; | std::map<int, int> reuse_inputs; | ||||
| std::map<int, int> reuse_outputs; | |||||
| std::vector<bool> is_input_shape_static; | std::vector<bool> is_input_shape_static; | ||||
| bool is_output_shape_static = true; | bool is_output_shape_static = true; | ||||
| @@ -156,6 +156,13 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, | |||||
| Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] Start"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] Start"); | ||||
| if (IsNoOp(context)) { | |||||
| GELOGD("[%s] Skipping execution for op with empty outputs", context.GetNodeName()); | |||||
| auto ret = context.TryExecuteCallback(done_callback); | |||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] End"); | |||||
| return ret; | |||||
| } | |||||
| auto op_desc = context.GetNodeItem().op_desc; | auto op_desc = context.GetNodeItem().op_desc; | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); | GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); | ||||
| @@ -219,5 +226,18 @@ bool AiCoreNodeTask::IsSupportDynamicShape() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AiCoreNodeTask::IsNoOp(TaskContext &task_context) { | |||||
| for (int i = 0; i < task_context.NumOutputs(); ++i) { | |||||
| const auto &tensor_desc = task_context.MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| const auto &shape = tensor_desc->MutableShape(); | |||||
| if (shape.IsScalar() || shape.GetShapeSize() > 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -52,6 +52,7 @@ class AiCoreNodeTask : public NodeTask { | |||||
| Status UpdateArgs(TaskContext &context) override; | Status UpdateArgs(TaskContext &context) override; | ||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | ||||
| private: | private: | ||||
| static bool IsNoOp(TaskContext &task_context); | |||||
| std::vector<std::unique_ptr<AiCoreOpTask>> tasks_; | std::vector<std::unique_ptr<AiCoreOpTask>> tasks_; | ||||
| }; | }; | ||||
| @@ -404,11 +404,11 @@ Status ControlOpNodeExecutor::LoadTask(const HybridModel &model, | |||||
| unique_ptr<ControlOpNodeTask> node_task; | unique_ptr<ControlOpNodeTask> node_task; | ||||
| auto node_type = node->GetType(); | auto node_type = node->GetType(); | ||||
| if (node_type == IF) { | |||||
| if (node_type == IF || node_type == STATELESSIF) { | |||||
| node_task.reset(new(std::nothrow) IfOpNodeTask()); | node_task.reset(new(std::nothrow) IfOpNodeTask()); | ||||
| } else if (node_type == CASE) { | } else if (node_type == CASE) { | ||||
| node_task.reset(new(std::nothrow) CaseOpNodeTask()); | node_task.reset(new(std::nothrow) CaseOpNodeTask()); | ||||
| } else if (node_type == WHILE) { | |||||
| } else if (node_type == WHILE || node_type == STATELESSWHILE) { | |||||
| node_task.reset(new(std::nothrow) WhileOpNodeTask()); | node_task.reset(new(std::nothrow) WhileOpNodeTask()); | ||||
| } else { | } else { | ||||
| GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); | GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); | ||||
| @@ -189,13 +189,20 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| } | } | ||||
| GE_CHECK_NOTNULL(tv); | GE_CHECK_NOTNULL(tv); | ||||
| auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | ||||
| addr_infos.resize(dims.front()); | |||||
| for (auto idx = 0; idx < dims.front(); ++idx) { | |||||
| auto row_num = dims.front(); | |||||
| addr_infos.resize(row_num); | |||||
| auto device_len = tv->GetSize() / row_num; | |||||
| if (device_len <= 0 || device_len > data[kVarTableIdxLen]) { | |||||
| GELOGE(FAILED, "Local embedding length is out of range."); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto idx = 0; idx < row_num; ++idx) { | |||||
| FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | FMK_INT64_MULCHECK(idx, kVarTableRowCnt); | ||||
| auto line_idx = idx * kVarTableRowCnt; | auto line_idx = idx * kVarTableRowCnt; | ||||
| addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, | addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, | ||||
| data[line_idx + kVarTableIdxLen]}; | |||||
| local_addr += data[line_idx + kVarTableIdxLen]; | |||||
| device_len}; | |||||
| local_addr += device_len; | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node | |||||
| return ExecutorType::GE_LOCAL; | return ExecutorType::GE_LOCAL; | ||||
| } | } | ||||
| if (op_type == IF || op_type == CASE || op_type == WHILE) { | |||||
| if (IsControlOp(op_type)) { | |||||
| return ExecutorType::CONTROL_OP; | return ExecutorType::CONTROL_OP; | ||||
| } | } | ||||
| @@ -221,16 +221,22 @@ Status TaskContext::AllocateOutput(int index, | |||||
| GE_CHECK_NOTNULL(ref_tensor); | GE_CHECK_NOTNULL(ref_tensor); | ||||
| outputs_start_[index] = *ref_tensor; | outputs_start_[index] = *ref_tensor; | ||||
| } else { | } else { | ||||
| auto reuse_input = node_item_->reuse_inputs.find(index); | |||||
| if (reuse_input != node_item_->reuse_inputs.end()) { | |||||
| GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); | |||||
| outputs_start_[index] = inputs_start_[reuse_input->second]; | |||||
| auto reuse_output_it = node_item_->reuse_outputs.find(index); | |||||
| if (reuse_output_it != node_item_->reuse_outputs.end()) { | |||||
| GELOGD("[%s] reuse output [%d] with output [%d]", GetNodeName(), index, reuse_output_it->second); | |||||
| outputs_start_[index] = outputs_start_[reuse_output_it->second]; | |||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); | |||||
| GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", | |||||
| node_item_->NodeName().c_str(), | |||||
| index, | |||||
| outputs_start_[index].GetSize()); | |||||
| auto reuse_input = node_item_->reuse_inputs.find(index); | |||||
| if (reuse_input != node_item_->reuse_inputs.end()) { | |||||
| GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); | |||||
| outputs_start_[index] = inputs_start_[reuse_input->second]; | |||||
| } else { | |||||
| GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); | |||||
| GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", | |||||
| node_item_->NodeName().c_str(), | |||||
| index, | |||||
| outputs_start_[index].GetSize()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -201,6 +201,10 @@ DEFINE_string(op_compiler_cache_dir, "", "Optional; the path to cache operator c | |||||
| DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode"); | DEFINE_string(op_compiler_cache_mode, "", "Optional; choose the operator compiler cache mode"); | ||||
| DEFINE_string(mdl_bank_path, "", "Optional; model bank path"); | |||||
| DEFINE_string(op_bank_path, "", "Optional; op bank path"); | |||||
| class GFlagUtils { | class GFlagUtils { | ||||
| public: | public: | ||||
| /** | /** | ||||
| @@ -300,7 +304,11 @@ class GFlagUtils { | |||||
| " --save_original_model Control whether to output original model. E.g.: true: output original model\n" | " --save_original_model Control whether to output original model. E.g.: true: output original model\n" | ||||
| " --log Generate log with level. Support debug, info, warning, error, null\n" | " --log Generate log with level. Support debug, info, warning, error, null\n" | ||||
| " --dump_mode The switch of dump json with shape, to be used with mode 1. " | " --dump_mode The switch of dump json with shape, to be used with mode 1. " | ||||
| "0(default): disable; 1: enable."); | |||||
| "0(default): disable; 1: enable.\n" | |||||
| " --debug_dir Set the save path of operator compilation intermediate files. Default value: ./\n" | |||||
| " --op_compiler_cache_dir Set the save path of operator compilation cache files. Default value: ./\n" | |||||
| " --op_compiler_cache_mode Set the operator compilation cache mode." | |||||
| "Options are disable(default), enable and force(force to refresh the cache)"); | |||||
| gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); | gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true); | ||||
| // Using gflags to analyze input parameters | // Using gflags to analyze input parameters | ||||
| @@ -1013,6 +1021,8 @@ static void SetEnvForSingleOp(std::map<string, string> &options) { | |||||
| options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir); | options.emplace(ge::DEBUG_DIR, FLAGS_debug_dir); | ||||
| options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir); | options.emplace(ge::OP_COMPILER_CACHE_DIR, FLAGS_op_compiler_cache_dir); | ||||
| options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode); | options.emplace(ge::OP_COMPILER_CACHE_MODE, FLAGS_op_compiler_cache_mode); | ||||
| options.emplace(ge::MDL_BANK_PATH_FLAG, FLAGS_mdl_bank_path); | |||||
| options.emplace(ge::OP_BANK_PATH_FLAG, FLAGS_op_bank_path); | |||||
| } | } | ||||
| domi::Status GenerateSingleOp(const std::string& json_file_path) { | domi::Status GenerateSingleOp(const std::string& json_file_path) { | ||||
| @@ -1166,6 +1176,10 @@ domi::Status GenerateOmModel() { | |||||
| } | } | ||||
| options.insert(std::pair<string, string>(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); | options.insert(std::pair<string, string>(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); | ||||
| options.insert(std::pair<string, string>(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path)); | |||||
| options.insert(std::pair<string, string>(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path)); | |||||
| // set enable scope fusion passes | // set enable scope fusion passes | ||||
| SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); | SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes); | ||||
| // print atc option map | // print atc option map | ||||
| @@ -48,6 +48,8 @@ constexpr char const *kKeyShapeRange = "shape_range"; | |||||
| constexpr char const *kKeyValue = "value"; | constexpr char const *kKeyValue = "value"; | ||||
| constexpr char const *kKeyFormat = "format"; | constexpr char const *kKeyFormat = "format"; | ||||
| constexpr char const *kFileSuffix = ".om"; | constexpr char const *kFileSuffix = ".om"; | ||||
| constexpr char const *kKeyDynamicInput = "dynamic_input"; | |||||
| constexpr char const *kKeyDynamicOutput = "dynamic_output"; | |||||
| constexpr int kDumpJsonIndent = 2; | constexpr int kDumpJsonIndent = 2; | ||||
| constexpr int kShapeRangePairSize = 2; | constexpr int kShapeRangePairSize = 2; | ||||
| constexpr int kShapeRangeLow = 0; | constexpr int kShapeRangeLow = 0; | ||||
| @@ -124,6 +126,10 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { | |||||
| if (tensor_name != j.end()) { | if (tensor_name != j.end()) { | ||||
| desc.name = tensor_name->get<string>(); | desc.name = tensor_name->get<string>(); | ||||
| } | } | ||||
| auto dynamic_input_name = j.find(kKeyDynamicInput); | |||||
| if (dynamic_input_name != j.end()) { | |||||
| desc.dynamic_input_name = dynamic_input_name->get<string>(); | |||||
| } | |||||
| } | } | ||||
| void from_json(const Json &j, SingleOpAttr &attr) { | void from_json(const Json &j, SingleOpAttr &attr) { | ||||
| @@ -276,6 +282,23 @@ std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) { | |||||
| return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type)); | return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type)); | ||||
| } | } | ||||
| Status SingleOpParser::UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc) { | |||||
| std::map<std::string, int> dynamic_name_map; | |||||
| for (auto &tensor : desc) { | |||||
| if (tensor.dynamic_input_name.empty()) { | |||||
| continue; | |||||
| } | |||||
| if (dynamic_name_map.find(tensor.dynamic_input_name) == dynamic_name_map.end()) { | |||||
| dynamic_name_map[tensor.dynamic_input_name] = 0; | |||||
| } else { | |||||
| dynamic_name_map[tensor.dynamic_input_name]++; | |||||
| } | |||||
| tensor.name = tensor.dynamic_input_name + std::to_string(dynamic_name_map[tensor.dynamic_input_name]); | |||||
| } | |||||
| GELOGD("Update dynamic tensor name success!"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SingleOpParser::ConvertToBuildParam(int index, | Status SingleOpParser::ConvertToBuildParam(int index, | ||||
| const SingleOpDesc &single_op_desc, | const SingleOpDesc &single_op_desc, | ||||
| SingleOpBuildParam &build_param) { | SingleOpBuildParam &build_param) { | ||||
| @@ -471,6 +494,11 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<Si | |||||
| SingleOpDesc single_op_desc; | SingleOpDesc single_op_desc; | ||||
| GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str()); | GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str()); | ||||
| single_op_desc = single_op_json; | single_op_desc = single_op_json; | ||||
| if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "Update dynamic tensor name failed!"); | |||||
| return FAILED; | |||||
| } | |||||
| if (!Validate(single_op_desc)) { | if (!Validate(single_op_desc)) { | ||||
| GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str()); | GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -33,6 +33,7 @@ struct SingleOpTensorDesc { | |||||
| std::vector<std::vector<int64_t>> dim_ranges; | std::vector<std::vector<int64_t>> dim_ranges; | ||||
| ge::Format format = ge::FORMAT_RESERVED; | ge::Format format = ge::FORMAT_RESERVED; | ||||
| ge::DataType type = ge::DT_UNDEFINED; | ge::DataType type = ge::DT_UNDEFINED; | ||||
| std::string dynamic_input_name; | |||||
| }; | }; | ||||
| struct SingleOpAttr { | struct SingleOpAttr { | ||||
| @@ -70,6 +71,7 @@ class SingleOpParser { | |||||
| static bool Validate(const SingleOpDesc &op_desc); | static bool Validate(const SingleOpDesc &op_desc); | ||||
| static std::unique_ptr<OpDesc> CreateOpDesc(const std::string &op_type); | static std::unique_ptr<OpDesc> CreateOpDesc(const std::string &op_type); | ||||
| static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); | static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); | ||||
| static Status UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc); | |||||
| static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); | static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); | ||||
| static Status SetShapeRange(const std::string &op_name, | static Status SetShapeRange(const std::string &op_name, | ||||
| const SingleOpTensorDesc &tensor_desc, | const SingleOpTensorDesc &tensor_desc, | ||||
| @@ -245,6 +245,12 @@ const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||||
| // 0: close debug; 1: open TBE compiler; 2: open ccec compiler | // 0: close debug; 1: open TBE compiler; 2: open ccec compiler | ||||
| const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | ||||
| // Configure model bank path | |||||
| const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; | |||||
| // Configure op bank path | |||||
| const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; | |||||
| // Graph run mode | // Graph run mode | ||||
| enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
| @@ -315,13 +321,27 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c | |||||
| static const char *const DEBUG_DIR = ge::DEBUG_DIR; | static const char *const DEBUG_DIR = ge::DEBUG_DIR; | ||||
| static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | ||||
| static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | ||||
| static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str(); | |||||
| static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str(); | |||||
| // for interface: aclgrphBuildModel | // for interface: aclgrphBuildModel | ||||
| const std::set<std::string> ir_builder_suppported_options = { | |||||
| INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, | |||||
| DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, DYNAMIC_DIMS, | |||||
| INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||||
| AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, | |||||
| INPUT_FP16_NODES, LOG_LEVEL}; | |||||
| const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | |||||
| INPUT_SHAPE, | |||||
| OP_NAME_MAP, | |||||
| DYNAMIC_BATCH_SIZE, | |||||
| DYNAMIC_IMAGE_SIZE, | |||||
| DYNAMIC_DIMS, | |||||
| INSERT_OP_FILE, | |||||
| PRECISION_MODE, | |||||
| EXEC_DISABLE_REUSED_MEMORY, | |||||
| AUTO_TUNE_MODE, | |||||
| OUTPUT_TYPE, | |||||
| OUT_NODES, | |||||
| INPUT_FP16_NODES, | |||||
| LOG_LEVEL, | |||||
| DEBUG_DIR, | |||||
| OP_COMPILER_CACHE_DIR, | |||||
| OP_COMPILER_CACHE_MODE}; | |||||
| // for interface: aclgrphParse | // for interface: aclgrphParse | ||||
| const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | ||||
| @@ -336,7 +356,9 @@ const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | |||||
| OUT_NODES, | OUT_NODES, | ||||
| COMPRESS_WEIGHT_CONF, | COMPRESS_WEIGHT_CONF, | ||||
| ENABLE_SCOPE_FUSION_PASSES, | ENABLE_SCOPE_FUSION_PASSES, | ||||
| LOG_LEVEL}; | |||||
| LOG_LEVEL, | |||||
| MDL_BANK_PATH_FLAG, | |||||
| OP_BANK_PATH_FLAG}; | |||||
| // for interface: aclgrphBuildInitialize | // for interface: aclgrphBuildInitialize | ||||
| const std::set<std::string> global_options = {CORE_TYPE, | const std::set<std::string> global_options = {CORE_TYPE, | ||||