| @@ -391,6 +391,8 @@ set(TRAIN_SRC_LIST | |||||
| "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | ||||
| "hybrid/node_executor/hccl/hccl_node_executor.cc" | "hybrid/node_executor/hccl/hccl_node_executor.cc" | ||||
| "hybrid/node_executor/rts/rts_node_executor.cc" | "hybrid/node_executor/rts/rts_node_executor.cc" | ||||
| "hybrid/node_executor/rts/rts_node_task.cc" | |||||
| "hybrid/node_executor/rts/rts_task_factory.cc" | |||||
| "hybrid/node_executor/node_executor.cc" | "hybrid/node_executor/node_executor.cc" | ||||
| "hybrid/node_executor/task_context.cc" | "hybrid/node_executor/task_context.cc" | ||||
| "hybrid/hybrid_davinci_model.cc" | "hybrid/hybrid_davinci_model.cc" | ||||
| @@ -62,6 +62,10 @@ const uint32_t SWITCH_TRUE_OUTPUT = 1; | |||||
| const uint32_t SWITCH_DATA_INPUT = 0; | const uint32_t SWITCH_DATA_INPUT = 0; | ||||
| const uint32_t SWITCH_PRED_INPUT = 1; | const uint32_t SWITCH_PRED_INPUT = 1; | ||||
| // Merge | |||||
| const uint32_t MERGE_DATA_OUTPUT = 0; | |||||
| const uint32_t MERGE_INDEX_OUTPUT = 1; | |||||
| // FunctionOp | // FunctionOp | ||||
| const uint32_t IF_COND_INPUT = 0; | const uint32_t IF_COND_INPUT = 0; | ||||
| const uint32_t FOR_START_INPUT = 0; | const uint32_t FOR_START_INPUT = 0; | ||||
| @@ -110,6 +110,8 @@ set(SRC_LIST | |||||
| "../hybrid/node_executor/controlop/control_op_executor.cc" | "../hybrid/node_executor/controlop/control_op_executor.cc" | ||||
| "../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | "../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | ||||
| "../hybrid/node_executor/rts/rts_node_executor.cc" | "../hybrid/node_executor/rts/rts_node_executor.cc" | ||||
| "../hybrid/node_executor/rts/rts_node_task.cc" | |||||
| "../hybrid/node_executor/rts/rts_task_factory.cc" | |||||
| "../hybrid/node_executor/node_executor.cc" | "../hybrid/node_executor/node_executor.cc" | ||||
| "../hybrid/node_executor/task_context.cc" | "../hybrid/node_executor/task_context.cc" | ||||
| "../hybrid/hybrid_davinci_model.cc" | "../hybrid/hybrid_davinci_model.cc" | ||||
| @@ -16,9 +16,6 @@ | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include <algorithm> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| @@ -244,4 +241,41 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size) { | |||||
| output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize; | output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| /// | |||||
| /// @brief Check Is Unknown shape Tensor | |||||
| /// @param [in] tensor_desc | |||||
| /// @return true: Unknown / false: Known | |||||
| /// | |||||
| bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||||
| const static int kUnknowShape = -1; | |||||
| const static int kUnknowRank = -2; | |||||
| for (auto dim_size : tensor_desc.GetShape().GetDims()) { | |||||
| if (dim_size == kUnknowShape || dim_size == kUnknowRank) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| /// | |||||
| /// @brief Set Op _force_unknown_shape flag | |||||
| /// @param [in] node | |||||
| /// @param [in] force_unknown, set attribute if true | |||||
| /// @return | |||||
| /// | |||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) { | |||||
| GE_RT_VOID_CHECK_NOTNULL(node); | |||||
| if (!force_unknown) { | |||||
| return; | |||||
| } | |||||
| GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str()); | |||||
| if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { | |||||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | |||||
| node->GetName().c_str(), node->GetType().c_str()); | |||||
| GELOGE(FAILED, "Op: %s set %s failed", node->GetName().c_str(), ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str()); | |||||
| } | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -117,6 +117,21 @@ void AlignMemSize(int64_t &mem_size, int64_t align_size); | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status GetMemorySize(const NodePtr &node, int64_t &output_size); | Status GetMemorySize(const NodePtr &node, int64_t &output_size); | ||||
| /// | |||||
| /// @brief Check Is Unknown shape Tensor | |||||
| /// @param [in] tensor_desc | |||||
| /// @return true: Unknown / false: Known | |||||
| /// | |||||
| bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||||
| /// | |||||
| /// @brief Set Op _force_unknown_shape flag | |||||
| /// @param [in] node | |||||
| /// @param [in] force_unknown, set attribute if true | |||||
| /// @return | |||||
| /// | |||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown); | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | ||||
| @@ -168,7 +168,7 @@ Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { | |||||
| if (it == allocated_blocks_.end()) { | if (it == allocated_blocks_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Param ptr not allocated before, device_id:%u, check invalid", | REPORT_INNER_ERROR("E19999", "Param ptr not allocated before, device_id:%u, check invalid", | ||||
| device_id); | device_id); | ||||
| GELOGE(PARAM_INVALID, "Invalid memory pointer"); | |||||
| GELOGE(PARAM_INVALID, "Invalid memory pointer: %p", ptr); | |||||
| return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
| } | } | ||||
| Block *block = it->second; | Block *block = it->second; | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/common/omg_util.h" | |||||
| #define REQUIRE(cond, ...) \ | #define REQUIRE(cond, ...) \ | ||||
| do { \ | do { \ | ||||
| @@ -45,6 +46,11 @@ | |||||
| #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::set<std::string> kControlFlowOps{ | |||||
| STREAMACTIVE, STREAMSWITCH, STREAMMERGE, ENTER, REFENTER, LOOPCOND, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT | |||||
| }; | |||||
| } | |||||
| using Cluster = DynamicShapePartitioner::Cluster; | using Cluster = DynamicShapePartitioner::Cluster; | ||||
| using ClusterPtr = std::shared_ptr<Cluster>; | using ClusterPtr = std::shared_ptr<Cluster>; | ||||
| @@ -273,7 +279,7 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
| auto cluster = MakeShared<Cluster>(rank++, type, node, this); | auto cluster = MakeShared<Cluster>(rank++, type, node, this); | ||||
| REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | ||||
| node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
| if (cluster->IsUnknownShape()) { | |||||
| if (cluster->IsUnknownShape() && !cluster->IsControlFlow()) { | |||||
| ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
| } | } | ||||
| // Already sorted topologically, so access to the parent cluster is safe | // Already sorted topologically, so access to the parent cluster is safe | ||||
| @@ -347,7 +353,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||||
| void DynamicShapePartitioner::MergeClustersUnknownShape() { | void DynamicShapePartitioner::MergeClustersUnknownShape() { | ||||
| // Merge unknown shape clusters | // Merge unknown shape clusters | ||||
| for (const auto &cluster : ordered_cluster_) { | for (const auto &cluster : ordered_cluster_) { | ||||
| if (cluster->IsIndependent()) { | |||||
| if (cluster->IsIndependent() || cluster->IsControlFlow()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (const auto &in_cluster : cluster->Inputs()) { | for (const auto &in_cluster : cluster->Inputs()) { | ||||
| @@ -545,17 +551,6 @@ Status DynamicShapePartitioner::IsUnknownShapeGraph(ComputeGraphPtr graph, bool | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| bool DynamicShapePartitioner::IsUnknownShapeTensor(const GeTensorDesc &tensor) { | |||||
| const static int kUnknowShape = -1; | |||||
| const static int kUnknowRank = -2; | |||||
| for (auto dim_size : tensor.GetShape().GetDims()) { | |||||
| if (dim_size == kUnknowShape || dim_size == kUnknowRank) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::string Cluster::DebugString() const { | std::string Cluster::DebugString() const { | ||||
| std::stringstream ss; | std::stringstream ss; | ||||
| switch (type_) { | switch (type_) { | ||||
| @@ -612,6 +607,14 @@ bool Cluster::IsRefVariable() const { | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool Cluster::IsControlFlow() const { | |||||
| const auto &op_desc = nodes_[0]->GetOpDesc(); | |||||
| bool is_ctrl_flow = kControlFlowOps.count(op_desc->GetType()) > 0 && op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| GELOGD("[%s] %s rts control flow Op ", op_desc->GetName().c_str(), is_ctrl_flow ? "Is" : "Not"); | |||||
| return is_ctrl_flow; | |||||
| } | |||||
| void Cluster::AddInput(ClusterPtr in) { | void Cluster::AddInput(ClusterPtr in) { | ||||
| if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; | if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; | ||||
| in_clusters_.insert(in_clusters_.end(), in); | in_clusters_.insert(in_clusters_.end(), in); | ||||
| @@ -732,29 +735,33 @@ std::vector<ClusterPtr> Cluster::Outputs() const { return out_clusters_; }; | |||||
| std::vector<NodePtr> Cluster::Nodes() const { return nodes_; }; | std::vector<NodePtr> Cluster::Nodes() const { return nodes_; }; | ||||
| void Cluster::AddFrameInput(InDataAnchorPtr anchor) { | void Cluster::AddFrameInput(InDataAnchorPtr anchor) { | ||||
| inputs_index_[anchor] = inputs_.size(); | |||||
| inputs_.push_back(anchor); | |||||
| }; | |||||
| if (anchor != nullptr && anchor->GetPeerOutAnchor() != nullptr) { | |||||
| inputs_index_[anchor] = inputs_.size(); | |||||
| inputs_.push_back(anchor); | |||||
| } | |||||
| } | |||||
| void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) { | void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) { | ||||
| outputs_index_[anchor] = outputs_.size(); | |||||
| outputs_.push_back(anchor); | |||||
| }; | |||||
| if (anchor != nullptr) { | |||||
| outputs_index_[anchor] = outputs_.size(); | |||||
| outputs_.push_back(anchor); | |||||
| } | |||||
| } | |||||
| InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) { | InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) { | ||||
| return partition_node_->GetInDataAnchor(static_cast<int>(inputs_index_[anchor])); | return partition_node_->GetInDataAnchor(static_cast<int>(inputs_index_[anchor])); | ||||
| }; | |||||
| } | |||||
| OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) { | OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) { | ||||
| return partition_node_->GetOutDataAnchor(static_cast<int>(outputs_index_[anchor])); | return partition_node_->GetOutDataAnchor(static_cast<int>(outputs_index_[anchor])); | ||||
| }; | |||||
| } | |||||
| InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); }; | InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); }; | ||||
| OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | ||||
| Status Cluster::BuildFrame() { | Status Cluster::BuildFrame() { | ||||
| if (IsUnknownShape() || IsKnownShape() || IsInputNode()) { | |||||
| if ((IsUnknownShape() || IsKnownShape() || IsInputNode()) && !IsControlFlow()) { | |||||
| return BuildPartitionFrame(); | return BuildPartitionFrame(); | ||||
| } else { | } else { | ||||
| auto node = nodes_.front(); | auto node = nodes_.front(); | ||||
| @@ -889,7 +896,7 @@ Status Cluster::CombinePartitionFrame() { | |||||
| } | } | ||||
| Status Cluster::BuildPartitionSubgraph() { | Status Cluster::BuildPartitionSubgraph() { | ||||
| if (IsData() || IsNetOutput() || IsIndependent()) { | |||||
| if (IsData() || IsNetOutput() || IsIndependent() || IsControlFlow()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int64_t parent_node_index = 0; | int64_t parent_node_index = 0; | ||||
| @@ -47,6 +47,7 @@ class DynamicShapePartitioner { | |||||
| bool IsUnknownShape() const; | bool IsUnknownShape() const; | ||||
| bool IsIndependent() const; | bool IsIndependent() const; | ||||
| bool IsNetOutput() const; | bool IsNetOutput() const; | ||||
| bool IsControlFlow() const; | |||||
| std::vector<std::shared_ptr<Cluster>> Inputs() const; | std::vector<std::shared_ptr<Cluster>> Inputs() const; | ||||
| std::vector<std::shared_ptr<Cluster>> Outputs() const; | std::vector<std::shared_ptr<Cluster>> Outputs() const; | ||||
| bool IsInputNode() const; | bool IsInputNode() const; | ||||
| @@ -151,7 +152,6 @@ class DynamicShapePartitioner { | |||||
| Status CollectSpreadUnknownShapeNodes(NodePtr node); | Status CollectSpreadUnknownShapeNodes(NodePtr node); | ||||
| Status IsUnknownShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow); | Status IsUnknownShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow); | ||||
| Status IsUnknownShapeNode(ge::NodePtr node, bool &is_unknow); | Status IsUnknownShapeNode(ge::NodePtr node, bool &is_unknow); | ||||
| bool IsUnknownShapeTensor(const ge::GeTensorDesc &tensor); | |||||
| Status CtrlEdgeTransfer(); | Status CtrlEdgeTransfer(); | ||||
| ge::ComputeGraphPtr root_graph_; // The original graph to partition | ge::ComputeGraphPtr root_graph_; // The original graph to partition | ||||
| std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | ||||
| @@ -36,6 +36,7 @@ struct DuringPassNodeSets { | |||||
| std::unordered_set<NodePtr> nodes_re_pass; | std::unordered_set<NodePtr> nodes_re_pass; | ||||
| std::unordered_set<NodePtr> nodes_re_pass_immediately; | std::unordered_set<NodePtr> nodes_re_pass_immediately; | ||||
| std::unordered_set<NodePtr> nodes_last; | std::unordered_set<NodePtr> nodes_last; | ||||
| std::unordered_set<NodePtr> nodes_stopped; | |||||
| }; | }; | ||||
| void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | ||||
| @@ -56,11 +57,18 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i | |||||
| } | } | ||||
| void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | ||||
| std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { | |||||
| DuringPassNodeSets &during_pass_node_set) { | |||||
| std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen; | |||||
| const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last; | |||||
| const std::unordered_set<NodePtr> &nodes_stopped = during_pass_node_set.nodes_stopped; | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (nodes_stopped.count(node) > 0) { | |||||
| GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| if (nodes_last.count(node) != 0) { | if (nodes_last.count(node) != 0) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -73,7 +81,7 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n | |||||
| } | } | ||||
| void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | ||||
| std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
| std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
| std::unordered_set<NodePtr> &nodes_re_pass) { | std::unordered_set<NodePtr> &nodes_re_pass) { | ||||
| for (const auto &node_to_re_pass : nodes_to_re_pass) { | for (const auto &node_to_re_pass : nodes_to_re_pass) { | ||||
| if (node_to_re_pass == nullptr) { | if (node_to_re_pass == nullptr) { | ||||
| @@ -113,15 +121,24 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo | |||||
| return result; | return result; | ||||
| } | } | ||||
| auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
| const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | ||||
| during_pass_node_set.nodes_re_pass); | during_pass_node_set.nodes_re_pass); | ||||
| auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
| const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | ||||
| during_pass_node_set.nodes_re_pass_immediately); | during_pass_node_set.nodes_re_pass_immediately); | ||||
| auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| for (const auto &node : name_to_pass.second->GetNodesStopped()) { | |||||
| GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), name_to_pass.first.c_str()); | |||||
| during_pass_node_set.nodes_stopped.emplace(node); | |||||
| } | |||||
| for (const auto &node : name_to_pass.second->GetNodesRestored()) { | |||||
| GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), name_to_pass.first.c_str()); | |||||
| during_pass_node_set.nodes_stopped.erase(node); | |||||
| } | |||||
| const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | ||||
| if (nodes_deleted_by_pass.count(node) > 0) { | if (nodes_deleted_by_pass.count(node) > 0) { | ||||
| GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | ||||
| @@ -222,8 +239,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); | |||||
| const auto all_out_nodes = node->GetOutNodes(); | |||||
| auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", | GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", | ||||
| @@ -258,6 +274,8 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| nodes.push_front(node); | nodes.push_front(node); | ||||
| } | } | ||||
| during_pass_node_set.nodes_re_pass_immediately.clear(); | during_pass_node_set.nodes_re_pass_immediately.clear(); | ||||
| AddNextIterNodes(all_out_nodes, nodes, during_pass_node_set); | |||||
| } | } | ||||
| for (auto &node : during_pass_node_set.nodes_last) { | for (auto &node : during_pass_node_set.nodes_last) { | ||||
| @@ -51,11 +51,15 @@ class BaseNodePass { | |||||
| virtual ~BaseNodePass() = default; | virtual ~BaseNodePass() = default; | ||||
| std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
| std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
| std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesStopped() { return nodes_stopped_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesRestored() { return nodes_restored_; } | |||||
| void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | ||||
| @@ -65,6 +69,8 @@ class BaseNodePass { | |||||
| nodes_need_re_pass_.clear(); | nodes_need_re_pass_.clear(); | ||||
| nodes_deleted_.clear(); | nodes_deleted_.clear(); | ||||
| nodes_need_re_pass_immediately_.clear(); | nodes_need_re_pass_immediately_.clear(); | ||||
| nodes_stopped_.clear(); | |||||
| nodes_restored_.clear(); | |||||
| } | } | ||||
| protected: | protected: | ||||
| @@ -80,7 +86,7 @@ class BaseNodePass { | |||||
| /// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
| void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
| /// | /// | ||||
| /// Add a node to be optimized immediately again. If you add a new node to the graph, or | /// Add a node to be optimized immediately again. If you add a new node to the graph, or | ||||
| @@ -88,13 +94,13 @@ class BaseNodePass { | |||||
| /// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
| void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
| /// | /// | ||||
| /// Add a node and it's input/output data nodes to be optimized again. | /// Add a node and it's input/output data nodes to be optimized again. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddRePassNodesWithInOut(NodePtr &node) { | |||||
| void AddRePassNodesWithInOut(const NodePtr &node) { | |||||
| AddRePassNode(node); | AddRePassNode(node); | ||||
| auto out_nodes = node->GetOutNodes(); | auto out_nodes = node->GetOutNodes(); | ||||
| for (auto &out_node : out_nodes) { | for (auto &out_node : out_nodes) { | ||||
| @@ -116,12 +122,34 @@ class BaseNodePass { | |||||
| /// | /// | ||||
| void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | ||||
| /// | |||||
| /// If you stop a node from the graph, especially following node. The remain | |||||
| /// iterate passes will stop process on the stopped node(if it can be | |||||
| /// reached by edge connections) till the last one. Obviously it is a waste of | |||||
| /// time. You can add the stopped nodes by calling this function, to stop the | |||||
| /// next iterations. | |||||
| /// @param node | |||||
| /// | |||||
| void AddNodeStopped(const NodePtr &node) { nodes_stopped_.insert(node); } | |||||
| /// | |||||
| /// If you restore a node from the graph, especially following node. The remain | |||||
| /// iterate passes will continue process on the stopped node(if it can be | |||||
| /// reached by edge connections) till the last one. | |||||
| /// You can add the restored nodes by calling this function, to restore the | |||||
| /// next iterations. | |||||
| /// @param node | |||||
| /// | |||||
| void AddNodeRestored(const NodePtr &node) { nodes_restored_.insert(node); } | |||||
| bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | ||||
| private: | private: | ||||
| std::unordered_set<NodePtr> nodes_need_re_pass_; | std::unordered_set<NodePtr> nodes_need_re_pass_; | ||||
| std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | ||||
| std::unordered_set<NodePtr> nodes_deleted_; | std::unordered_set<NodePtr> nodes_deleted_; | ||||
| std::unordered_set<NodePtr> nodes_stopped_; | |||||
| std::unordered_set<NodePtr> nodes_restored_; | |||||
| std::map<NodePassOption, std::string> options_; | std::map<NodePassOption, std::string> options_; | ||||
| }; | }; | ||||
| @@ -17,11 +17,11 @@ | |||||
| #include "graph/passes/infershape_pass.h" | #include "graph/passes/infershape_pass.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "analyzer/analyzer.h" | #include "analyzer/analyzer.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
| #include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
| @@ -94,8 +94,10 @@ Status InferShapePass::Run(NodePtr &node) { | |||||
| GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); | GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); | ||||
| return GE_GRAPH_INFERSHAPE_FAILED; | return GE_GRAPH_INFERSHAPE_FAILED; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node)); | |||||
| bool need_repass = false; | bool need_repass = false; | ||||
| auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "_need_infer_again", need_repass); | |||||
| auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass); | |||||
| if (has_attr) { | if (has_attr) { | ||||
| if (!OptionExists(kOptimizeAfterSubGraph)) { | if (!OptionExists(kOptimizeAfterSubGraph)) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -105,9 +107,57 @@ Status InferShapePass::Run(NodePtr &node) { | |||||
| GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | GELOGD("Node %s need repass immediately.", node->GetName().c_str()); | ||||
| } else { | } else { | ||||
| // clear attr on while | // clear attr on while | ||||
| node->GetOpDesc()->DelAttr("_need_infer_again"); | |||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { | |||||
| for (auto &n : node->GetOutDataNodes()) { | |||||
| GE_CHECK_NOTNULL(n); | |||||
| if (re_pass_types.count(n->GetType()) > 0) { | |||||
| AddImmediateRePassNode(n); | |||||
| (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); | |||||
| GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| }; | |||||
| const auto ExProcNode = [&](const std::set<std::string> &proc_types, | |||||
| const std::function<void(InferShapePass *, NodePtr)> &proc_func, | |||||
| const std::string &info) { | |||||
| for (auto &n : node->GetOutDataNodes()) { | |||||
| GE_CHECK_NOTNULL(n); | |||||
| if (proc_types.count(n->GetType()) > 0) { | |||||
| proc_func(this, n); | |||||
| GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| }; | |||||
| if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { | |||||
| return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge | |||||
| } | |||||
| if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
| return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch | |||||
| } | |||||
| } | |||||
| if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeRestored, "need restore"); // Restore Exit | |||||
| } else { | |||||
| return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeStopped, "need stop"); // Stop Exit | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -30,6 +30,9 @@ class InferShapePass : public BaseNodePass { | |||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| Status Run(ge::NodePtr &node) override; | Status Run(ge::NodePtr &node) override; | ||||
| private: | |||||
| Status RePassLoopNode(const NodePtr &node); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | #endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_ | ||||
| @@ -15,23 +15,36 @@ | |||||
| */ | */ | ||||
| #include "graph/passes/merge_input_memcpy_pass.h" | #include "graph/passes/merge_input_memcpy_pass.h" | ||||
| #include <queue> | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::set<std::string> kLoopMergeInputs{ | |||||
| ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION | |||||
| }; | |||||
| } | |||||
| Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | ||||
| GELOGD("MergeInputMemcpyPass Enter"); | GELOGD("MergeInputMemcpyPass Enter"); | ||||
| std::unordered_map<NodePtr, std::vector<NodePtr>> switch_groups; | |||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| std::string type; | std::string type; | ||||
| GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); | ||||
| if ((type != MERGE) && (type != REFMERGE)) { | if ((type != MERGE) && (type != REFMERGE)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), | GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), | ||||
| "Merge add memcpy node failed."); | "Merge add memcpy node failed."); | ||||
| CollectSwitchGroup(node, switch_groups); | |||||
| } | } | ||||
| MarkUnknownForSwitch(switch_groups); | |||||
| GELOGD("MergeInputMemcpyPass Leave"); | GELOGD("MergeInputMemcpyPass Leave"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -101,4 +114,94 @@ NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph | |||||
| return graph->AddNode(op_desc); | return graph->AddNode(op_desc); | ||||
| } | } | ||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] merge node | |||||
| /// @param [out] switch_groups | |||||
| /// @return | |||||
| /// | |||||
| void MergeInputMemcpyPass::CollectSwitchGroup(const NodePtr &node, | |||||
| std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| const auto &src_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (src_out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string node_type; | |||||
| GetOriginalType(src_out_anchor->GetOwnerNode(), node_type); | |||||
| if (kLoopMergeInputs.count(node_type) > 0) { | |||||
| return; | |||||
| } | |||||
| } | |||||
| // Switch --> {Switch --> Merge} --> Merge | |||||
| std::queue<std::pair<NodePtr, uint32_t>> search_queue; | |||||
| search_queue.push({node, 0}); | |||||
| std::vector<NodePtr> &switch_group = switch_groups[node]; | |||||
| while (!search_queue.empty()) { | |||||
| const auto dst_node = search_queue.front().first; | |||||
| const auto dst_span = search_queue.front().second; | |||||
| search_queue.pop(); | |||||
| // Switch --> Identity --> Constant | |||||
| for (const auto &in_ctrl_node : dst_node->GetInControlNodes()) { | |||||
| if (in_ctrl_node->GetType() == IDENTITY) { | |||||
| GELOGD("Travel node: %s, In control: %s, span is: %u", | |||||
| dst_node->GetName().c_str(), in_ctrl_node->GetName().c_str(), dst_span); | |||||
| search_queue.push({in_ctrl_node, dst_span}); | |||||
| } | |||||
| } | |||||
| for (const auto &in_data_node : dst_node->GetInDataNodes()) { | |||||
| std::string node_type; | |||||
| GetOriginalType(in_data_node, node_type); | |||||
| GELOGD("Travel node: %s, %s node: %s, span is: %u", | |||||
| dst_node->GetName().c_str(), node_type.c_str(), in_data_node->GetName().c_str(), dst_span); | |||||
| if (node_type == SWITCH || node_type == REFSWITCH) { | |||||
| if (dst_span > 0) { | |||||
| search_queue.push({in_data_node, dst_span - 1}); | |||||
| } else { | |||||
| switch_group.emplace_back(in_data_node); | |||||
| } | |||||
| } else if (node_type == MERGE || node_type == REFMERGE) { | |||||
| search_queue.push({in_data_node, dst_span + 1}); | |||||
| } else { | |||||
| search_queue.push({in_data_node, dst_span}); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { | |||||
| GELOGI("Mark [%s] as for unknown shape, switch groups: %zu", node->GetName().c_str(), switch_groups.size()); | |||||
| MarkForceUnknownShape(node, true); | |||||
| for (const auto &n : switch_group) { | |||||
| MarkForceUnknownShape(n, true); | |||||
| } | |||||
| } | |||||
| } | |||||
| void MergeInputMemcpyPass::MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| }; | |||||
| for (const auto &item : switch_groups) { | |||||
| const auto &node = item.first; | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { | |||||
| continue; | |||||
| } | |||||
| const std::vector<NodePtr> &switch_group = item.second; | |||||
| if (std::any_of(switch_group.begin(), switch_group.end(), callback)) { | |||||
| GELOGI("Mark [%s] as force unknown shape, switch nodes: %zu", node->GetName().c_str(), switch_group.size()); | |||||
| MarkForceUnknownShape(node, true); | |||||
| for (const auto &n : switch_group) { | |||||
| MarkForceUnknownShape(n, true); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -44,6 +44,21 @@ class MergeInputMemcpyPass : public GraphPass { | |||||
| /// | /// | ||||
| NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | ||||
| const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | ||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] merge node | |||||
| /// @param [out] switch_groups | |||||
| /// @return | |||||
| /// | |||||
| void CollectSwitchGroup(const NodePtr &node, std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups); | |||||
| /// | |||||
| /// @brief Mark force unknown shape for Switch node | |||||
| /// @param [in] switch_groups | |||||
| /// @return | |||||
| /// | |||||
| void MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ | #endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ | ||||
| @@ -69,51 +69,9 @@ Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { | |||||
| Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node) { | Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node) { | ||||
| OpDescPtr merge_op_desc = merge_node->GetOpDesc(); | OpDescPtr merge_op_desc = merge_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(merge_op_desc); | GE_CHECK_NOTNULL(merge_op_desc); | ||||
| merge_op_desc->SetType(STREAMMERGE); | |||||
| const std::string &node_name = merge_node->GetName(); | |||||
| GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str()); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name, STREAMMERGE); | |||||
| if (op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "New GeTensor failed"); | |||||
| GELOGE(FAILED, "Create op_desc failed, StreamMerge:%s.", node_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| for (const InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS, | |||||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return FAILED, "Create StreamMerge op: add input desc failed."); | |||||
| } | |||||
| for (const OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS, | |||||
| REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return FAILED, "Create StreamMerge op: add output desc failed."); | |||||
| } | |||||
| NodePtr stream_merge = graph->AddNode(op_desc); | |||||
| GE_CHK_BOOL_EXEC(stream_merge != nullptr, | |||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
| graph->GetName().c_str()); | |||||
| return FAILED, "Insert StreamMerge node failed."); | |||||
| GE_CHK_STATUS_RET(MoveEdges(merge_node, stream_merge), "Move edges failed."); | |||||
| bypass_nodes_.insert(merge_node); | |||||
| if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { | |||||
| std::string next_iteration_name; | |||||
| GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name), | |||||
| REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", | |||||
| ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
| merge_op_desc->GetName().c_str(), merge_op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Get ATTR_NAME_NEXT_ITERATION failed"); | |||||
| return INTERNAL_ERROR); | |||||
| GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); | |||||
| } | |||||
| return AddActiveNodes(graph, stream_merge); | |||||
| return AddActiveNodes(graph, merge_node); | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -126,6 +84,8 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, | GE_CHK_BOOL_EXEC(node != nullptr, | ||||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | ||||
| return FAILED, "Param of pre node is null."); | return FAILED, "Param of pre node is null."); | ||||
| bool force_unknown = node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| MarkForceUnknownShape(node, force_unknown); | |||||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
| @@ -142,6 +102,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
| GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); | GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| MarkForceUnknownShape(active_node, force_unknown); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -140,6 +140,7 @@ Status NextIterationPass::FindWhileGroups() { | |||||
| GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); | GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| loop_group_iter.second->switch_nodes.emplace_back(switch_node); | |||||
| if (loop_group_iter.second->loop_cond == nullptr) { | if (loop_group_iter.second->loop_cond == nullptr) { | ||||
| loop_group_iter.second->loop_cond = loop_cond; | loop_group_iter.second->loop_cond = loop_cond; | ||||
| } else if (loop_group_iter.second->loop_cond != loop_cond) { | } else if (loop_group_iter.second->loop_cond != loop_cond) { | ||||
| @@ -181,6 +182,12 @@ bool NextIterationPass::VerifyWhileGroup() { | |||||
| frame_name.c_str()); | frame_name.c_str()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| // Mark loop as unknown shape If any merge has unknown shape output. | |||||
| const auto &op_desc = pair_iter.first->GetOpDesc(); | |||||
| if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { | |||||
| loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break. | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -194,6 +201,7 @@ bool NextIterationPass::VerifyWhileGroup() { | |||||
| /// | /// | ||||
| Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | ||||
| for (const auto &loop_cond_iter : loop_group_map_) { | for (const auto &loop_cond_iter : loop_group_map_) { | ||||
| const LoopCondGroup &loop_group = *loop_cond_iter.second; | |||||
| const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | ||||
| GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | ||||
| @@ -215,6 +223,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| enter_active->GetName().c_str()); | enter_active->GetName().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape); | |||||
| } | } | ||||
| for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | ||||
| @@ -243,6 +252,9 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| MarkForceUnknownShape(next_node, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape); | |||||
| } | } | ||||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | ||||
| @@ -250,6 +262,18 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | |||||
| for (const auto &switch_node : loop_group.switch_nodes) { | |||||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||||
| for (const auto &exit_node : switch_node->GetOutDataNodes()) { | |||||
| if (exit_node->GetType() == EXIT || exit_node->GetType() == REFEXIT) { | |||||
| MarkForceUnknownShape(exit_node, loop_group.is_unknown_shape); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -20,10 +20,11 @@ | |||||
| #include "inc/graph_pass.h" | #include "inc/graph_pass.h" | ||||
| struct LoopCondGroup { | struct LoopCondGroup { | ||||
| LoopCondGroup() : loop_cond(nullptr) {} | |||||
| ge::NodePtr loop_cond; // LoopCond node | ge::NodePtr loop_cond; // LoopCond node | ||||
| std::vector<ge::NodePtr> enter_nodes; // Enter nodes | std::vector<ge::NodePtr> enter_nodes; // Enter nodes | ||||
| std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration> | std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration> | ||||
| std::vector<ge::NodePtr> switch_nodes; // Switch nodes | |||||
| bool is_unknown_shape{false}; | |||||
| }; | }; | ||||
| using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>; | using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>; | ||||
| @@ -369,6 +369,7 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
| GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), | ||||
| "StreamSwitch node add cond edge failed."); | "StreamSwitch node add cond edge failed."); | ||||
| MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)); | |||||
| return stream_switch; | return stream_switch; | ||||
| } | } | ||||
| @@ -487,6 +488,12 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| }; | |||||
| bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||||
| MarkForceUnknownShape(active_node, is_unknown_shape); | |||||
| const std::string &cond_group = cond_node->GetName(); | const std::string &cond_group = cond_node->GetName(); | ||||
| for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | ||||
| bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | ||||
| @@ -515,6 +522,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), | ||||
| "Cast add data edge failed."); | "Cast add data edge failed."); | ||||
| MarkForceUnknownShape(stream_switch, is_unknown_shape); | |||||
| for (const NodePtr &node : switch_list) { | for (const NodePtr &node : switch_list) { | ||||
| GE_IF_BOOL_EXEC(node != stream_switch, { | GE_IF_BOOL_EXEC(node != stream_switch, { | ||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <cstddef> | #include <cstddef> | ||||
| #include <memory> | #include <memory> | ||||
| #include "memory/memory_api.h" | #include "memory/memory_api.h" | ||||
| #include "framework/common/util.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -84,6 +85,12 @@ class TensorValue { | |||||
| size_t GetSize() const; | size_t GetSize() const; | ||||
| template<typename T> | |||||
| Status CopyScalarValueToHost(T &value) const { | |||||
| GE_CHECK_GE(this->GetSize(), sizeof(value)); | |||||
| return rtMemcpy(&value, sizeof(value), this->GetData(), sizeof(value), RT_MEMCPY_DEVICE_TO_HOST); | |||||
| } | |||||
| private: | private: | ||||
| std::shared_ptr<TensorBuffer> buffer_; | std::shared_ptr<TensorBuffer> buffer_; | ||||
| std::string name_; | std::string name_; | ||||
| @@ -28,6 +28,8 @@ const int32_t kModelAbortNormalNew = 507024; | |||||
| std::atomic_ulong context_id_gen {}; | std::atomic_ulong context_id_gen {}; | ||||
| } // namespace | } // namespace | ||||
| long GraphExecutionContext::profiling_level = 0; | |||||
| GraphExecutionContext::GraphExecutionContext() { | GraphExecutionContext::GraphExecutionContext() { | ||||
| context_id = context_id_gen++; | context_id = context_id_gen++; | ||||
| } | } | ||||
| @@ -73,7 +73,7 @@ struct GraphExecutionContext { | |||||
| ExceptionDumper exception_dumper; | ExceptionDumper exception_dumper; | ||||
| std::vector<std::shared_ptr<ge::DavinciModel>> davinci_model; | std::vector<std::shared_ptr<ge::DavinciModel>> davinci_model; | ||||
| std::atomic_bool is_eos_{false}; | std::atomic_bool is_eos_{false}; | ||||
| long profiling_level = 0; | |||||
| static long profiling_level; | |||||
| long iteration = 0; | long iteration = 0; | ||||
| void *global_step = nullptr; | void *global_step = nullptr; | ||||
| @@ -82,17 +82,18 @@ struct GraphExecutionContext { | |||||
| mutable std::mutex mu; | mutable std::mutex mu; | ||||
| }; | }; | ||||
| #define RECORD_PROFILING_EVENT(context, evt_type, fmt, category, node_name, ...) \ | |||||
| do { \ | |||||
| if ((context != nullptr) && (context)->profiler != nullptr) { \ | |||||
| if (node_name != nullptr) { \ | |||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s@%ld] [%s] " fmt, \ | |||||
| GeLog::GetTid(), node_name, context->iteration, category, \ | |||||
| ##__VA_ARGS__); \ | |||||
| } else { \ | |||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \ | |||||
| }\ | |||||
| } \ | |||||
| #define RECORD_PROFILING_EVENT(context, evt_type, fmt, category, node_name, ...) \ | |||||
| do { \ | |||||
| if (ge::hybrid::GraphExecutionContext::profiling_level > 0) { \ | |||||
| if ((context != nullptr) && (context)->profiler != nullptr) { \ | |||||
| if (node_name != nullptr) { \ | |||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s@%ld] [%s] " fmt, \ | |||||
| GeLog::GetTid(), node_name, context->iteration, category, ##__VA_ARGS__); \ | |||||
| } else { \ | |||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \ | |||||
| } \ | |||||
| } \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| #define RECORD_MODEL_EXECUTION_EVENT(context, fmt, ...) \ | #define RECORD_MODEL_EXECUTION_EVENT(context, fmt, ...) \ | ||||
| @@ -155,9 +155,9 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| context_.dump_properties = DumpManager::GetInstance().GetDumpProperties(context_.session_id); | context_.dump_properties = DumpManager::GetInstance().GetDumpProperties(context_.session_id); | ||||
| const char *profiling_level = std::getenv(kEnvProfilingLevel); | const char *profiling_level = std::getenv(kEnvProfilingLevel); | ||||
| if (profiling_level != nullptr) { | if (profiling_level != nullptr) { | ||||
| context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase); | |||||
| GELOGD("Got profiling level = %ld", context_.profiling_level); | |||||
| if (context_.profiling_level > 0) { | |||||
| GraphExecutionContext::profiling_level = std::strtol(profiling_level, nullptr, kIntBase); | |||||
| GELOGD("Got profiling level = %ld", GraphExecutionContext::profiling_level); | |||||
| if (GraphExecutionContext::profiling_level > 0) { | |||||
| context_.profiler.reset(new(std::nothrow)HybridProfiler()); | context_.profiler.reset(new(std::nothrow)HybridProfiler()); | ||||
| GE_CHECK_NOTNULL(context_.profiler); | GE_CHECK_NOTNULL(context_.profiler); | ||||
| } | } | ||||
| @@ -187,9 +187,9 @@ void StageExecutor::Reset() { | |||||
| Status HybridModelPipelineExecutor::Init() { | Status HybridModelPipelineExecutor::Init() { | ||||
| const char *profiling_level = std::getenv(kEnvProfilingLevel); | const char *profiling_level = std::getenv(kEnvProfilingLevel); | ||||
| if (profiling_level != nullptr) { | if (profiling_level != nullptr) { | ||||
| context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase); | |||||
| GELOGD("Got profiling level = %ld", context_.profiling_level); | |||||
| if (context_.profiling_level > 0) { | |||||
| GraphExecutionContext::profiling_level = std::strtol(profiling_level, nullptr, kIntBase); | |||||
| GELOGD("Got profiling level = %ld", GraphExecutionContext::profiling_level); | |||||
| if (GraphExecutionContext::profiling_level > 0) { | |||||
| context_.profiler.reset(new (std::nothrow) HybridProfiler()); | context_.profiler.reset(new (std::nothrow) HybridProfiler()); | ||||
| GE_CHECK_NOTNULL(context_.profiler); | GE_CHECK_NOTNULL(context_.profiler); | ||||
| } | } | ||||
| @@ -210,7 +210,6 @@ Status HybridModelPipelineExecutor::InitStageExecutors() { | |||||
| if (context_.profiler != nullptr) { | if (context_.profiler != nullptr) { | ||||
| // will call unique_ptr::release later | // will call unique_ptr::release later | ||||
| stage_executor->context_.profiler.reset(context_.profiler.get()); | stage_executor->context_.profiler.reset(context_.profiler.get()); | ||||
| stage_executor->context_.profiling_level = context_.profiling_level; | |||||
| } | } | ||||
| stage_executors_.emplace_back(std::move(stage_executor)); | stage_executors_.emplace_back(std::move(stage_executor)); | ||||
| @@ -36,6 +36,16 @@ bool NodeDoneManager::Cond::Await() { | |||||
| return is_released_; | return is_released_; | ||||
| } | } | ||||
| void NodeDoneManager::Cond::Reset() { | |||||
| std::unique_lock<std::mutex> lk(cond_mu_); | |||||
| if (!is_released_ && !is_cancelled_) { | |||||
| GELOGW("Called before done, released: %d, cancelled: %d", is_released_, is_cancelled_); | |||||
| } | |||||
| is_released_ = false; | |||||
| is_cancelled_ = false; | |||||
| } | |||||
| void NodeDoneManager::Cond::Release() { | void NodeDoneManager::Cond::Release() { | ||||
| std::unique_lock<std::mutex> lk(cond_mu_); | std::unique_lock<std::mutex> lk(cond_mu_); | ||||
| is_released_ = true; | is_released_ = true; | ||||
| @@ -103,5 +113,13 @@ bool NodeDoneManager::Await(const NodePtr &node) { | |||||
| GELOGD("[%s] Await ended. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); | GELOGD("[%s] Await ended. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| void NodeDoneManager::Reset(const NodePtr &node) { | |||||
| auto sub = GetSubject(node); | |||||
| if (sub != nullptr) { | |||||
| sub->Reset(); | |||||
| GELOGD("[%s] Node reset.", node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -31,6 +31,8 @@ class NodeDoneManager { | |||||
| bool Await(const NodePtr &node); | bool Await(const NodePtr &node); | ||||
| void Reset(const NodePtr &node); | |||||
| void Destroy(); | void Destroy(); | ||||
| private: | private: | ||||
| @@ -40,6 +42,7 @@ class NodeDoneManager { | |||||
| void Release(); | void Release(); | ||||
| void Cancel(); | void Cancel(); | ||||
| bool Await(); | bool Await(); | ||||
| void Reset(); | |||||
| private: | private: | ||||
| std::mutex cond_mu_; | std::mutex cond_mu_; | ||||
| std::condition_variable cv_; | std::condition_variable cv_; | ||||
| @@ -30,6 +30,10 @@ constexpr auto kWaitInternal = 5; | |||||
| constexpr auto kMaxWaitTimes = 120; | constexpr auto kMaxWaitTimes = 120; | ||||
| } | } | ||||
| ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(node_item) { | ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(node_item) { | ||||
| InitShapeState(); | |||||
| } | |||||
| void ShapeInferenceState::InitShapeState() { | |||||
| this->num_pending_shapes_ = node_item.num_inputs - node_item.num_static_input_shapes; | this->num_pending_shapes_ = node_item.num_inputs - node_item.num_static_input_shapes; | ||||
| GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", | GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| @@ -135,19 +139,22 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < input_tensor_desc.size(); ++i) { | |||||
| auto dst_tensor_desc = node_item.op_desc->MutableInputDesc(i); | |||||
| if (dst_tensor_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| { | |||||
| const auto &guard = node_item.MutexGuard("AwaitShapesReady"); | |||||
| for (size_t i = 0; i < input_tensor_desc.size(); ++i) { | |||||
| auto dst_tensor_desc = node_item.MutableInputDesc(i); | |||||
| if (dst_tensor_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto &tensor_desc = input_tensor_desc[i]; | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(tensor_desc, tensor_size); | |||||
| auto &tensor_desc = input_tensor_desc[i]; | |||||
| int64_t tensor_size = -1; | |||||
| (void)TensorUtils::GetSize(tensor_desc, tensor_size); | |||||
| dst_tensor_desc->SetShape(tensor_desc.MutableShape()); | |||||
| dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*dst_tensor_desc, tensor_size); | |||||
| dst_tensor_desc->SetShape(tensor_desc.MutableShape()); | |||||
| dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); | |||||
| (void)TensorUtils::SetSize(*dst_tensor_desc, tensor_size); | |||||
| } | |||||
| } | } | ||||
| for (auto &p : shape_futures) { | for (auto &p : shape_futures) { | ||||
| @@ -159,8 +166,6 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | GE_CHECK_NOTNULL(src_tensor_desc); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | ||||
| auto input_desc = node_item.MutableInputDesc(idx); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| int64_t tensor_size = -1; | int64_t tensor_size = -1; | ||||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | ||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | ||||
| @@ -169,6 +174,9 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| src_tensor_desc->GetShape().ToString().c_str(), | src_tensor_desc->GetShape().ToString().c_str(), | ||||
| src_tensor_desc->GetOriginShape().ToString().c_str(), | src_tensor_desc->GetOriginShape().ToString().c_str(), | ||||
| tensor_size); | tensor_size); | ||||
| const auto &guard = node_item.MutexGuard("AwaitShapesReady"); | |||||
| auto input_desc = node_item.MutableInputDesc(idx); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| input_desc->SetShape(src_tensor_desc->GetShape()); | input_desc->SetShape(src_tensor_desc->GetShape()); | ||||
| input_desc->SetOriginShape(src_tensor_desc->GetOriginShape()); | input_desc->SetOriginShape(src_tensor_desc->GetOriginShape()); | ||||
| (void) TensorUtils::SetSize(*input_desc, tensor_size); | (void) TensorUtils::SetSize(*input_desc, tensor_size); | ||||
| @@ -207,6 +215,11 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex | |||||
| } | } | ||||
| Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | ||||
| if (node_item_->IsMergeOp()) { | |||||
| GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); | |||||
| return SUCCESS; | |||||
| } | |||||
| for (auto &src_node : node_item_->dependents_for_execution) { | for (auto &src_node : node_item_->dependents_for_execution) { | ||||
| GELOGD("[%s] Start to wait for data dependent node: [%s]", | GELOGD("[%s] Start to wait for data dependent node: [%s]", | ||||
| node_item_->NodeName().c_str(), | node_item_->NodeName().c_str(), | ||||
| @@ -225,7 +238,7 @@ Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | |||||
| node_item_->NodeName().c_str(), | node_item_->NodeName().c_str(), | ||||
| "[AwaitNodeDone] [%s] End", | "[AwaitNodeDone] [%s] End", | ||||
| src_node->GetName().c_str()); | src_node->GetName().c_str()); | ||||
| GELOGD("[%s] Done waiting node.", src_node->GetName().c_str()); | |||||
| GELOGD("[%s] Done waiting node: [%s]", node_item_->NodeName().c_str(), src_node->GetName().c_str()); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -255,6 +268,125 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
| return task_context_; | return task_context_; | ||||
| } | } | ||||
| void NodeState::ResetContext(int group) { | |||||
| SetGroup(group); | |||||
| if (loop_count_ == 0) { | |||||
| ++loop_count_; | |||||
| return; | |||||
| } | |||||
| ++loop_count_; | |||||
| if (loop_count_ == UINT64_MAX) { | |||||
| loop_count_ = 1; | |||||
| } | |||||
| switch_index_ = -1; | |||||
| const auto &guard = node_item_->MutexGuard("ResetContext"); | |||||
| shape_inference_state_.InitShapeState(); | |||||
| subgraph_context_->ResetContext(node_item_->node); | |||||
| GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | |||||
| } | |||||
| void NodeState::ResetSchedule() { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||||
| GELOGD("[%s] set schedule for root nodes, data: %u, ctrl: %u", GetName().c_str(), data_scheduled_, ctrl_scheduled_); | |||||
| } | |||||
| Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | |||||
| // Schedule data output. | |||||
| for (const auto &node : node_item_->data_send_) { | |||||
| const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | |||||
| GE_CHECK_NOTNULL(dst_node_state); | |||||
| dst_node_state->SetDataSchedule(node_item_, ready); | |||||
| } | |||||
| // Schedule ctrl output. | |||||
| for (const auto &node : node_item_->ctrl_send_) { | |||||
| const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | |||||
| GE_CHECK_NOTNULL(dst_node_state); | |||||
| dst_node_state->SetCtrlSchedule(node_item_, ready); | |||||
| } | |||||
| // Schedule switch group. | |||||
| if (switch_index_ >= 0 && static_cast<uint32_t>(switch_index_) < node_item_->switch_groups_.size()) { | |||||
| GELOGI("After [%s] scheduled, switch index: %d", GetName().c_str(), switch_index_); | |||||
| for (const auto &node : node_item_->switch_groups_[switch_index_]) { | |||||
| const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | |||||
| GE_CHECK_NOTNULL(dst_node_state); | |||||
| dst_node_state->SetCtrlSchedule(node_item_, ready); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool NodeState::IsScheduleReady() const { | |||||
| GELOGD("[%s] data[input: %zu, scheduled: %u], ctrl[input: %zu, scheduled: %u]", GetName().c_str(), | |||||
| node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
| if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { | |||||
| return false; | |||||
| } | |||||
| if (node_item_->IsMergeOp()) { | |||||
| return data_scheduled_ > 0; | |||||
| } | |||||
| // Exit may feed loop times... | |||||
| return data_scheduled_ >= node_item_->data_recv_.size(); | |||||
| } | |||||
| void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||||
| GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||||
| node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
| node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| ++data_scheduled_; | |||||
| if (node_item_->IsMergeOp()) { | |||||
| const auto it = node_item_->data_recv_.find(node_item); | |||||
| if (it != node_item_->data_recv_.end()) { | |||||
| merge_index_ = it->second; | |||||
| (void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); | |||||
| GELOGD("[%s] scheduled, [%s] set merge index: %d", node_item->node_name.c_str(), GetName().c_str(), it->second); | |||||
| } else { | |||||
| GELOGW("[%s] scheduled, [%s] not followed", node_item->node_name.c_str(), GetName().c_str()); | |||||
| } | |||||
| } | |||||
| if (IsScheduleReady()) { | |||||
| ready(node_item_); | |||||
| } | |||||
| } | |||||
| void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||||
| GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||||
| node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
| node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| ++ctrl_scheduled_; | |||||
| if (IsScheduleReady()) { | |||||
| ready(node_item_); | |||||
| } | |||||
| } | |||||
| void NodeState::SetScheduleFuture(std::future<Status> &&future) { | |||||
| schedule_future_ = std::move(future); | |||||
| } | |||||
| Status NodeState::WaitForScheduleDone() { | |||||
| if (schedule_future_.valid()) { | |||||
| GELOGD("[%s] Start to wait for schedule future.", GetName().c_str()); | |||||
| GE_CHK_STATUS_RET(schedule_future_.get(), "[Check][Status][%s] wait thread failed", GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | ||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | ||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_->GetNodeItem()->node), "cancelled"); | HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_->GetNodeItem()->node), "cancelled"); | ||||
| @@ -20,6 +20,8 @@ | |||||
| #include <condition_variable> | #include <condition_variable> | ||||
| #include <future> | #include <future> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include "common/blocking_queue.h" | |||||
| #include "external/ge/ge_api_error_codes.h" | #include "external/ge/ge_api_error_codes.h" | ||||
| #include "hybrid/model/node_item.h" | #include "hybrid/model/node_item.h" | ||||
| #include "node_done_manager.h" | #include "node_done_manager.h" | ||||
| @@ -32,6 +34,8 @@ class SubgraphContext; | |||||
| class TaskContext; | class TaskContext; | ||||
| struct NodeState; | struct NodeState; | ||||
| using NodeStatePtr = std::shared_ptr<NodeState>; | |||||
| class ShapeFuture { | class ShapeFuture { | ||||
| public: | public: | ||||
| ShapeFuture(NodeState *src_node, uint32_t src_index, SubgraphContext *subgraph_context); | ShapeFuture(NodeState *src_node, uint32_t src_index, SubgraphContext *subgraph_context); | ||||
| @@ -48,6 +52,8 @@ class ShapeFuture { | |||||
| struct ShapeInferenceState { | struct ShapeInferenceState { | ||||
| explicit ShapeInferenceState(const NodeItem &node_item); | explicit ShapeInferenceState(const NodeItem &node_item); | ||||
| void InitShapeState(); | |||||
| Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc); | Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc); | ||||
| void UpdateInputShapeFuture(int idx, ShapeFuture &&future); | void UpdateInputShapeFuture(int idx, ShapeFuture &&future); | ||||
| @@ -100,6 +106,43 @@ struct NodeState { | |||||
| Status UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape); | Status UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape); | ||||
| inline bool IsShapeDependence() const { | |||||
| return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | |||||
| } | |||||
| void ResetContext(int group); | |||||
| void ResetSchedule(); | |||||
| Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||||
| void SetScheduleFuture(std::future<Status> &&future); | |||||
| Status WaitForScheduleDone(); | |||||
| void SetSwitchIndex(int index) { | |||||
| switch_index_ = index; | |||||
| } | |||||
| int GetSwitchIndex() const { | |||||
| return switch_index_; | |||||
| } | |||||
| void SetMergeIndex(int index) { | |||||
| merge_index_ = index; | |||||
| } | |||||
| int GetMergeIndex() const { | |||||
| return merge_index_; | |||||
| } | |||||
| void SetGroup(int group) { | |||||
| group_ = group; | |||||
| } | |||||
| int GetGroup() const { | |||||
| return group_; | |||||
| } | |||||
| const shared_ptr<NodeTask> &GetKernelTask() const { | const shared_ptr<NodeTask> &GetKernelTask() const { | ||||
| return kernel_task_; | return kernel_task_; | ||||
| } | } | ||||
| @@ -120,6 +163,10 @@ struct NodeState { | |||||
| std::shared_ptr<TaskContext> GetTaskContext(); | std::shared_ptr<TaskContext> GetTaskContext(); | ||||
| private: | private: | ||||
| bool IsScheduleReady() const; | |||||
| void SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||||
| void SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||||
| const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
| std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
| std::future<Status> prepare_future_; | std::future<Status> prepare_future_; | ||||
| @@ -128,9 +175,15 @@ struct NodeState { | |||||
| SubgraphContext *subgraph_context_; | SubgraphContext *subgraph_context_; | ||||
| std::shared_ptr<TaskContext> task_context_ = nullptr; | std::shared_ptr<TaskContext> task_context_ = nullptr; | ||||
| std::mutex mu_; | std::mutex mu_; | ||||
| }; | |||||
| using NodeStatePtr = std::shared_ptr<NodeState>; | |||||
| std::future<Status> schedule_future_; | |||||
| uint64_t loop_count_ = 0; | |||||
| uint32_t ctrl_scheduled_ = 0; | |||||
| uint32_t data_scheduled_ = 0; | |||||
| int merge_index_ = -1; // Use for Execute (Reset after Executed). | |||||
| int switch_index_ = -1; // Use for Schedule (Reset after Prepared). | |||||
| int group_ = -1; | |||||
| }; | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -37,10 +37,15 @@ Status SubgraphContext::Init() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void SubgraphContext::ResetContext(const NodePtr &node) { | |||||
| node_done_manager_.Reset(node); | |||||
| } | |||||
| NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | ||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| auto &node_state = node_states_[node_item]; | auto &node_state = node_states_[node_item]; | ||||
| if (node_state == nullptr) { | if (node_state == nullptr) { | ||||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | ||||
| } | } | ||||
| @@ -34,6 +34,7 @@ class SubgraphContext { | |||||
| ~SubgraphContext() = default; | ~SubgraphContext() = default; | ||||
| Status Init(); | Status Init(); | ||||
| void ResetContext(const NodePtr &node); | |||||
| NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | ||||
| void OnError(Status error); | void OnError(Status error); | ||||
| @@ -178,7 +178,9 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||||
| known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | ||||
| GE_CHECK_NOTNULL(known_shape_task_context_); | GE_CHECK_NOTNULL(known_shape_task_context_); | ||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_), | |||||
| std::function<void()> callback; | |||||
| GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | |||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback), | |||||
| "[%s] Failed to execute node [%s] for known subgraph.", | "[%s] Failed to execute node [%s] for known subgraph.", | ||||
| graph_item_->GetName().c_str(), | graph_item_->GetName().c_str(), | ||||
| known_shape_task_context_->GetNodeName()); | known_shape_task_context_->GetNodeName()); | ||||
| @@ -206,76 +208,256 @@ Status SubgraphExecutor::ExecuteAsync(TaskContext &task_context) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| BlockingQueue<const NodeItem *> &SubgraphExecutor::GetPrepareQueue(int group) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| return prepare_queues_[group]; | |||||
| } | |||||
| Status SubgraphExecutor::NodeEnqueue(NodeState *node_state) { | |||||
| if (!ready_queue_.Push(node_state)) { | |||||
| if (context_->is_eos_) { | |||||
| GELOGD("Got end of sequence"); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "[Check][State][%s] Error occurs while launching tasks. quit from preparing nodes.", | |||||
| graph_item_->GetName().c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "[%s] Error occurs while launching tasks. quit from preparing nodes.", | |||||
| graph_item_->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_state->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||||
| GELOGD("[%s] Start to prepare node [%s].", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); | |||||
| // for while op | |||||
| if (force_infer_shape_ && !node_item.is_dynamic) { | |||||
| GELOGD("[%s] Force infer shape is set, updating node to dynamic.", node_item.NodeName().c_str()); | |||||
| auto &mutable_node_item = const_cast<NodeItem &>(node_item); | |||||
| mutable_node_item.SetToDynamic(); | |||||
| } | |||||
| auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); | |||||
| GE_CHECK_NOTNULL(node_state); | |||||
| node_state->ResetContext(group); | |||||
| auto p_node_state = node_state.get(); | |||||
| if (node_item.node_type == NETOUTPUT) { | |||||
| GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | |||||
| return AfterPrepared(p_node_state); | |||||
| } | |||||
| // only do shape inference and compilation for nodes with dynamic shapes. | |||||
| if (node_item.is_dynamic) { | |||||
| auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { | |||||
| GetContext().SetSessionId(context_->session_id); | |||||
| GetContext().SetContextId(context_->context_id); | |||||
| GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); | |||||
| GE_CHK_STATUS_RET_NOLOG(PrepareForExecution(context_, *p_node_state)); | |||||
| return AfterPrepared(p_node_state); | |||||
| }); | |||||
| p_node_state->SetPrepareFuture(std::move(prepare_future)); | |||||
| return NodeEnqueue(p_node_state); | |||||
| } else { | |||||
| GELOGD("[%s] Skipping shape inference and compilation for node with static shape.", | |||||
| node_item.NodeName().c_str()); | |||||
| if (node_item.kernel_task == nullptr) { | |||||
| GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str()); | |||||
| GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_), | |||||
| "[Invoke][Compile] failed for [%s].", p_node_state->GetName().c_str()); | |||||
| } else { | |||||
| node_state->SetKernelTask(node_item.kernel_task); | |||||
| } | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| const auto &task = node_state->GetKernelTask(); | |||||
| if (task == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | |||||
| return AfterPrepared(p_node_state); | |||||
| } | |||||
| } | |||||
| Status SubgraphExecutor::PrepareNodes(int group) { | Status SubgraphExecutor::PrepareNodes(int group) { | ||||
| GELOGD("[%s] Start to prepare nodes. group = %d", | |||||
| graph_item_->GetName().c_str(), | |||||
| group); | |||||
| auto &all_nodes = graph_item_->GetAllNodes(group); | |||||
| for (auto all_node : all_nodes) { | |||||
| auto &node_item = *all_node; | |||||
| // for while op | |||||
| if (force_infer_shape_ && !node_item.is_dynamic) { | |||||
| GELOGD("[%s] Force infer shape is set, updating node to dynamic.", node_item.NodeName().c_str()); | |||||
| auto &mutable_node_item = const_cast<NodeItem &>(node_item); | |||||
| mutable_node_item.SetToDynamic(); | |||||
| const size_t node_size = graph_item_->GetNodeSize(group); | |||||
| GELOGD("[%s] Start to prepare nodes. group = %d, size = %zu", graph_item_->GetName().c_str(), group, node_size); | |||||
| if (!graph_item_->HasCtrlFlowOp()) { | |||||
| for (const auto &node_item : graph_item_->GetAllNodes(group)) { | |||||
| RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] Start"); | |||||
| GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str()); | |||||
| RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End"); | |||||
| } | } | ||||
| GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGD("[%s] Start to prepare node [%s].", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); | |||||
| auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); | |||||
| GE_CHECK_NOTNULL(node_state); | |||||
| auto p_node_state = node_state.get(); | |||||
| if (node_item.node_type != NETOUTPUT) { | |||||
| // only do shape inference and compilation for nodes with dynamic shapes. | |||||
| if (node_item.is_dynamic) { | |||||
| auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { | |||||
| GetContext().SetSessionId(context_->session_id); | |||||
| GetContext().SetContextId(context_->context_id); | |||||
| GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); | |||||
| return PrepareForExecution(context_, *p_node_state); | |||||
| }); | |||||
| p_node_state->SetPrepareFuture(std::move(prepare_future)); | |||||
| } else { | |||||
| GELOGD("[%s] Skipping shape inference and compilation for node with static shape.", | |||||
| node_item.NodeName().c_str()); | |||||
| if (node_item.kernel_task == nullptr) { | |||||
| GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str()); | |||||
| GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_), | |||||
| "[Invoke][Compile] failed for [%s].", p_node_state->GetName().c_str()); | |||||
| } else { | |||||
| node_state->SetKernelTask(node_item.kernel_task); | |||||
| } | |||||
| auto unique_task_context = | |||||
| TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| const auto &task = node_state->GetKernelTask(); | |||||
| if (task == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| // Initialize the ready queue | |||||
| size_t node_count = 0; | |||||
| bool node_complete = false; | |||||
| for (const auto &node_item : graph_item_->GetRootNodes(group)) { | |||||
| RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] Start"); | |||||
| GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str()); | |||||
| RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End"); | |||||
| node_complete = node_item->NodeType() == NETOUTPUT; | |||||
| node_count++; | |||||
| } | |||||
| GELOGD("[%s] Done preparing root nodes.", graph_item_->GetName().c_str()); | |||||
| BlockingQueue<const NodeItem *> &prepare_queue = GetPrepareQueue(group); | |||||
| while (((group != -1) && (node_count < node_size)) || ((group == -1) && !node_complete)) { | |||||
| const NodeItem *node_item = nullptr; | |||||
| if (!prepare_queue.Pop(node_item)) { | |||||
| if (context_->is_eos_) { | |||||
| GELOGD("[%s] Got end of sequence.", graph_item_->GetName().c_str()); | |||||
| break; | |||||
| } | |||||
| if (context_->GetStatus() != SUCCESS) { | |||||
| GELOGD("[%s] Graph execution Got failed.", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | } | ||||
| GELOGE(INTERNAL_ERROR, "[%s] failed to pop node.", graph_item_->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | } | ||||
| if (!ready_queue_.Push(p_node_state)) { | |||||
| if (node_item == nullptr) { | |||||
| GELOGD("[%s] Got EOF from queue.", graph_item_->GetName().c_str()); | |||||
| break; | |||||
| } | |||||
| RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] Start"); | |||||
| GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str()); | |||||
| RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End"); | |||||
| node_complete = node_item->NodeType() == NETOUTPUT; | |||||
| node_count++; | |||||
| } | |||||
| GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { | |||||
| GELOGD("Graph[%s] After [%s] scheduled, data size: %zu, ctrl size: %zu, switch index: %d, merge index: %d", | |||||
| graph_item_->GetName().c_str(), node_state->GetName().c_str(), | |||||
| node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(), | |||||
| node_state->GetSwitchIndex(), node_state->GetMergeIndex()); | |||||
| auto future = pre_run_pool_.commit([this, node_state]() -> Status { | |||||
| RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start"); | |||||
| std::function<void(const NodeItem *)> callback = [&](const NodeItem *node_item) { | |||||
| const auto &node_name = node_item->node_name; | |||||
| int group = (node_state->GetGroup() != -1) ? node_item->group : -1; | |||||
| GELOGI("After [%s] scheduled, [%s] is ready for prepare.", node_state->GetName().c_str(), node_name.c_str()); | |||||
| BlockingQueue<const NodeItem *> &prepare_queue = GetPrepareQueue(group); | |||||
| if (!prepare_queue.Push(node_item)) { | |||||
| if (!context_->is_eos_) { | |||||
| GELOGE(INTERNAL_ERROR, "[Check][State][%s] error occurs when push to queue.", graph_item_->GetName().c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "[%s] error occurs when push to queue.", graph_item_->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| }; | |||||
| GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); | |||||
| node_state->ResetSchedule(); | |||||
| RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); | |||||
| return SUCCESS; | |||||
| }); | |||||
| node_state->SetScheduleFuture(std::move(future)); | |||||
| if (schedule_queue_.Push(node_state)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| if (context_->is_eos_) { | |||||
| GELOGD("[%s] Got end of sequence", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "[Check][State][%s] error occurs when push to queue.", graph_item_->GetName().c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "[%s] error occurs when push to queue.", graph_item_->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| Status SubgraphExecutor::AfterPrepared(NodeState *node_state) { | |||||
| if (!graph_item_->HasCtrlFlowOp()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| if (node_state->IsShapeDependence()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| // Not control flow node, propagate state. | |||||
| return NodeScheduled(node_state); | |||||
| } | |||||
| void SubgraphExecutor::AfterExecuted(NodeState *node_state) { | |||||
| if (!node_state->IsShapeDependence()) { | |||||
| return; | |||||
| } | |||||
| // For control flow node, propagate state. | |||||
| auto error = NodeScheduled(node_state); | |||||
| if (error != SUCCESS) { | |||||
| auto task_context = node_state->GetTaskContext(); | |||||
| task_context->OnError(error); | |||||
| } | |||||
| } | |||||
| void SubgraphExecutor::OnNodeDone(NodeState *node_state) { | |||||
| auto task_context = node_state->GetTaskContext(); | |||||
| NodeDoneCallback cb(context_, task_context); | |||||
| auto error = cb.OnNodeDone(); | |||||
| if (error != SUCCESS) { | |||||
| task_context->OnError(error); | |||||
| } | |||||
| if (node_state->IsShapeDependence() && graph_item_->HasCtrlFlowOp()) { | |||||
| AfterExecuted(node_state); | |||||
| } | |||||
| } | |||||
| Status SubgraphExecutor::InitCallback(NodeState *node_state, std::function<void()> &callback) { | |||||
| auto task_context = node_state->GetTaskContext(); | |||||
| GE_CHECK_NOTNULL(task_context); | |||||
| if (task_context->NeedCallback()) { | |||||
| callback = std::bind(&SubgraphExecutor::OnNodeDone, this, node_state); | |||||
| } else if (node_state->IsShapeDependence() && graph_item_->HasCtrlFlowOp()) { | |||||
| callback = std::bind(&SubgraphExecutor::AfterExecuted, this, node_state); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status SubgraphExecutor::ScheduleNodes() { | |||||
| GELOGD("[%s] Start to schedule nodes.", graph_item_->GetName().c_str()); | |||||
| while (true) { | |||||
| NodeState *node_state = nullptr; | |||||
| if (!schedule_queue_.Pop(node_state)) { | |||||
| if (context_->is_eos_) { | if (context_->is_eos_) { | ||||
| GELOGD("Got end of sequence"); | |||||
| GELOGD("[%s] Got end of sequence.", graph_item_->GetName().c_str()); | |||||
| break; | |||||
| } | |||||
| if (context_->GetStatus() != SUCCESS) { | |||||
| GELOGD("[%s] Graph execution Got failed.", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GELOGE(INTERNAL_ERROR, "[Check][State][%s] Error occurs while launching tasks. quit from preparing nodes.", | |||||
| graph_item_->GetName().c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "[%s] Error occurs while launching tasks. quit from preparing nodes.", | |||||
| graph_item_->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[%s] failed to pop node.", graph_item_->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); | |||||
| if (node_state == nullptr) { | |||||
| GELOGD("[%s] Got EOF from queue.", graph_item_->GetName().c_str()); | |||||
| break; | |||||
| } | |||||
| GE_CHK_STATUS_RET_NOLOG(node_state->WaitForScheduleDone()); | |||||
| } | } | ||||
| GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str()); | |||||
| GELOGD("[%s] Done schedule nodes successfully.", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -341,7 +523,10 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
| auto shared_task_context = node_state->GetTaskContext(); | auto shared_task_context = node_state->GetTaskContext(); | ||||
| GE_CHECK_NOTNULL(shared_task_context); | GE_CHECK_NOTNULL(shared_task_context); | ||||
| shared_task_context->SetForceInferShape(force_infer_shape_); | shared_task_context->SetForceInferShape(force_infer_shape_); | ||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_), | |||||
| std::function<void()> callback; | |||||
| GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state, callback)); | |||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_, callback), | |||||
| "[Invoke][ExecuteAsync] failed for [%s].", node_state->GetName().c_str()); | "[Invoke][ExecuteAsync] failed for [%s].", node_state->GetName().c_str()); | ||||
| GELOGD("[%s] Done executing node successfully.", node_state->GetName().c_str()); | GELOGD("[%s] Done executing node successfully.", node_state->GetName().c_str()); | ||||
| } | } | ||||
| @@ -354,22 +539,38 @@ Status SubgraphExecutor::ScheduleTasks(int group) { | |||||
| GetContext().SetContextId(context_->context_id); | GetContext().SetContextId(context_->context_id); | ||||
| auto ret = PrepareNodes(group); | auto ret = PrepareNodes(group); | ||||
| ready_queue_.Push(nullptr); | ready_queue_.Push(nullptr); | ||||
| schedule_queue_.Push(nullptr); | |||||
| for (auto &item : prepare_queues_) { | |||||
| item.second.Push(nullptr); | |||||
| } | |||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| auto schedule_future = std::async(std::launch::async, [&]() -> Status { | |||||
| return ScheduleNodes(); | |||||
| }); | |||||
| GELOGD("[%s] Start to execute subgraph.", graph_item_->GetName().c_str()); | GELOGD("[%s] Start to execute subgraph.", graph_item_->GetName().c_str()); | ||||
| auto ret = LaunchTasks(); | auto ret = LaunchTasks(); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| subgraph_context_->OnError(ret); | subgraph_context_->OnError(ret); | ||||
| context_->SetErrorCode(ret); | context_->SetErrorCode(ret); | ||||
| ready_queue_.Stop(); | ready_queue_.Stop(); | ||||
| schedule_queue_.Stop(); | |||||
| for (auto &item : prepare_queues_) { | |||||
| item.second.Stop(); | |||||
| } | |||||
| prepare_future.wait(); | prepare_future.wait(); | ||||
| schedule_future.wait(); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(prepare_future.get(), "[Invoke][get] [%s] Error occurred in task preparation.", | GE_CHK_STATUS_RET(prepare_future.get(), "[Invoke][get] [%s] Error occurred in task preparation.", | ||||
| graph_item_->GetName().c_str()); | graph_item_->GetName().c_str()); | ||||
| GE_CHK_STATUS_RET(schedule_future.get(), "[Invoke][get] [%s] Error occurred in task preparation.", | |||||
| graph_item_->GetName().c_str()); | |||||
| GELOGD("[%s] Done launching all tasks successfully.", graph_item_->GetName().c_str()); | GELOGD("[%s] Done launching all tasks successfully.", graph_item_->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -105,6 +105,18 @@ class SubgraphExecutor { | |||||
| Status PrepareNodes(int group = -1); | Status PrepareNodes(int group = -1); | ||||
| Status LaunchTasks(); | Status LaunchTasks(); | ||||
| Status SetOutputsToParentNode(TaskContext &task_context); | Status SetOutputsToParentNode(TaskContext &task_context); | ||||
| Status InitCallback(NodeState *node_state, std::function<void()> &callback); | |||||
| Status NodeEnqueue(NodeState *node_state); | |||||
| Status PrepareNode(const NodeItem &node_item, int group); | |||||
| BlockingQueue<const NodeItem *> &GetPrepareQueue(int group); | |||||
| Status ScheduleNodes(); | |||||
| Status NodeScheduled(NodeState *node_state); | |||||
| Status AfterPrepared(NodeState *node_state); | |||||
| void AfterExecuted(NodeState *node_state); | |||||
| void OnNodeDone(NodeState *node_state); | |||||
| const GraphItem *graph_item_; | const GraphItem *graph_item_; | ||||
| GraphExecutionContext *context_; | GraphExecutionContext *context_; | ||||
| @@ -114,6 +126,10 @@ class SubgraphExecutor { | |||||
| BlockingQueue<NodeState *> ready_queue_; | BlockingQueue<NodeState *> ready_queue_; | ||||
| std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | ||||
| std::shared_ptr<TaskContext> known_shape_task_context_; | std::shared_ptr<TaskContext> known_shape_task_context_; | ||||
| std::mutex mu_; // Guard for prepare_queues_. | |||||
| std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; | |||||
| BlockingQueue<NodeState *> schedule_queue_; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include "hybrid/executor//worker//shape_inference_engine.h" | #include "hybrid/executor//worker//shape_inference_engine.h" | ||||
| #include "common/dump/dump_op.h" | |||||
| #include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -62,22 +61,6 @@ Status LogOutputs(const NodeItem &node_item, const TaskContext &task_context) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| class NodeDoneCallback { | |||||
| public: | |||||
| NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr<TaskContext> task_context); | |||||
| ~NodeDoneCallback() = default; | |||||
| Status OnNodeDone(); | |||||
| private: | |||||
| Status PrepareConstInputs(const NodeItem &node_item); | |||||
| Status DumpDynamicNode(); | |||||
| Status ProfilingReport(); | |||||
| Status SaveDumpOpInfo(); | |||||
| Status GetTaskDescInfo(const NodePtr node, const HybridModel *model, | |||||
| std::vector<TaskDescInfo> &task_desc_info); | |||||
| GraphExecutionContext *graph_context_; | |||||
| std::shared_ptr<TaskContext> context_; | |||||
| DumpOp dump_op_; | |||||
| }; | |||||
| NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, | NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, | ||||
| std::shared_ptr<TaskContext> task_context) | std::shared_ptr<TaskContext> task_context) | ||||
| @@ -334,6 +317,7 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); | GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); | ||||
| if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | ||||
| // update output tensor sizes | // update output tensor sizes | ||||
| const auto &guard = node_item.MutexGuard("OnNodeDone"); | |||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | ||||
| GE_CHK_STATUS_RET_NOLOG(context_->GetNodeState()->GetShapeInferenceState().UpdateOutputDesc()); | GE_CHK_STATUS_RET_NOLOG(context_->GetNodeState()->GetShapeInferenceState().UpdateOutputDesc()); | ||||
| } | } | ||||
| @@ -361,31 +345,15 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| Status ExecutionEngine::ExecuteAsync(NodeState &node_state, | Status ExecutionEngine::ExecuteAsync(NodeState &node_state, | ||||
| const std::shared_ptr<TaskContext> &task_context, | const std::shared_ptr<TaskContext> &task_context, | ||||
| GraphExecutionContext &execution_context) { | |||||
| GraphExecutionContext &execution_context, | |||||
| const std::function<void()> &callback) { | |||||
| GELOGI("[%s] Node is ready for execution", task_context->GetNodeName()); | GELOGI("[%s] Node is ready for execution", task_context->GetNodeName()); | ||||
| RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); | RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); | ||||
| std::function<void()> callback = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(InitCallback(task_context, execution_context, callback)); | |||||
| GE_CHK_STATUS_RET_NOLOG(DoExecuteAsync(node_state, *task_context, execution_context, callback)); | GE_CHK_STATUS_RET_NOLOG(DoExecuteAsync(node_state, *task_context, execution_context, callback)); | ||||
| GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_state.GetNodeItem(), *task_context, execution_context)); | GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_state.GetNodeItem(), *task_context, execution_context)); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ExecutionEngine::InitCallback(const std::shared_ptr<TaskContext> &task_context, | |||||
| GraphExecutionContext &execution_context, std::function<void()> &callback) { | |||||
| if (task_context->NeedCallback()) { | |||||
| auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context)); | |||||
| GE_CHECK_NOTNULL(cb); | |||||
| callback = [task_context, cb]() { | |||||
| auto ret = cb->OnNodeDone(); | |||||
| if (ret != SUCCESS) { | |||||
| task_context->OnError(ret); | |||||
| } | |||||
| }; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | ||||
| TaskContext &task_context, | TaskContext &task_context, | ||||
| GraphExecutionContext &context, | GraphExecutionContext &context, | ||||
| @@ -423,7 +391,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | |||||
| node_state.GetName().c_str()); | node_state.GetName().c_str()); | ||||
| RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ValidateInputTensors] End"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ValidateInputTensors] End"); | ||||
| if (context.profiling_level > 0) { | |||||
| if (GraphExecutionContext::profiling_level > 0) { | |||||
| auto *ctx = &context; | auto *ctx = &context; | ||||
| const string &name = node_state.GetName(); | const string &name = node_state.GetName(); | ||||
| (void)task_context.RegisterCallback([ctx, name]() { | (void)task_context.RegisterCallback([ctx, name]() { | ||||
| @@ -19,14 +19,33 @@ | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hybrid/node_executor/task_context.h" | #include "hybrid/node_executor/task_context.h" | ||||
| #include "common/dump/dump_op.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| class NodeDoneCallback { | |||||
| public: | |||||
| NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr<TaskContext> task_context); | |||||
| ~NodeDoneCallback() = default; | |||||
| Status OnNodeDone(); | |||||
| private: | |||||
| Status PrepareConstInputs(const NodeItem &node_item); | |||||
| Status DumpDynamicNode(); | |||||
| Status ProfilingReport(); | |||||
| Status SaveDumpOpInfo(); | |||||
| Status GetTaskDescInfo(const NodePtr node, const HybridModel *model, | |||||
| std::vector<TaskDescInfo> &task_desc_info); | |||||
| GraphExecutionContext *graph_context_; | |||||
| std::shared_ptr<TaskContext> context_; | |||||
| DumpOp dump_op_; | |||||
| }; | |||||
| class ExecutionEngine { | class ExecutionEngine { | ||||
| public: | public: | ||||
| static Status ExecuteAsync(NodeState &node_state, | static Status ExecuteAsync(NodeState &node_state, | ||||
| const std::shared_ptr<TaskContext> &task_context, | const std::shared_ptr<TaskContext> &task_context, | ||||
| GraphExecutionContext &execution_context); | |||||
| GraphExecutionContext &execution_context, | |||||
| const std::function<void()> &callback); | |||||
| private: | private: | ||||
| static Status ValidateInputTensors(const NodeState &node_state, const TaskContext &task_context); | static Status ValidateInputTensors(const NodeState &node_state, const TaskContext &task_context); | ||||
| @@ -35,8 +54,6 @@ class ExecutionEngine { | |||||
| TaskContext &task_context, | TaskContext &task_context, | ||||
| GraphExecutionContext &context, | GraphExecutionContext &context, | ||||
| const std::function<void()> &callback); | const std::function<void()> &callback); | ||||
| static Status InitCallback(const std::shared_ptr<TaskContext> &task_context, | |||||
| GraphExecutionContext &execution_context, std::function<void()> &callback); | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -45,6 +45,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| const auto &guard = node_item.MutexGuard("InferShape"); | |||||
| if (node_item.fused_subgraph != nullptr) { | if (node_item.fused_subgraph != nullptr) { | ||||
| GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph)); | GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph)); | ||||
| GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item)); | GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item)); | ||||
| @@ -123,8 +124,9 @@ Status ShapeInferenceEngine::PropagateOutputShapes(NodeState &node_state) { | |||||
| node_item.shape_inference_type); | node_item.shape_inference_type); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start"); | ||||
| // propagate each output | // propagate each output | ||||
| const auto &guard = node_item.MutexGuard("PropagateOutputShapes"); | |||||
| for (int i = 0; i < node_item.num_outputs; ++i) { | for (int i = 0; i < node_item.num_outputs; ++i) { | ||||
| auto output_desc = node_item.op_desc->MutableOutputDesc(i); | |||||
| auto output_desc = node_item.MutableOutputDesc(i); | |||||
| auto &output_nodes = node_item.outputs[i]; | auto &output_nodes = node_item.outputs[i]; | ||||
| // propagate output to all sub-inputs | // propagate output to all sub-inputs | ||||
| @@ -43,6 +43,27 @@ const vector<NodeItem *> &GraphItem::GetAllNodes(int group) const { | |||||
| return grouped_node_items_[group]; | return grouped_node_items_[group]; | ||||
| } | } | ||||
| const vector<NodeItem *> &GraphItem::GetRootNodes(int group) const { | |||||
| if (group == -1) { | |||||
| return root_items_; | |||||
| } | |||||
| if (static_cast<uint32_t>(group) >= grouped_root_items_.size()) { | |||||
| static vector<NodeItem *> empty_nodes; | |||||
| return empty_nodes; | |||||
| } | |||||
| return grouped_root_items_[group]; | |||||
| } | |||||
| size_t GraphItem::GetNodeSize(int group) const { | |||||
| if (group == -1) { | |||||
| return node_items_.size(); | |||||
| } | |||||
| return (static_cast<uint32_t>(group) < grouped_node_items_.size()) ? grouped_node_items_[group].size() : 0; | |||||
| } | |||||
| const vector<const NodeItem *> &GraphItem::GetInputNodes() const { | const vector<const NodeItem *> &GraphItem::GetInputNodes() const { | ||||
| return input_nodes_; | return input_nodes_; | ||||
| } | } | ||||
| @@ -88,10 +109,12 @@ const vector<std::pair<const NodeItem *, int>> &GraphItem::GetOutputEdges() cons | |||||
| return output_edges_; | return output_edges_; | ||||
| } | } | ||||
| Status GraphItem::GroupNodes() { | |||||
| Status GraphItem::GroupNodes(const std::vector<NodeItem *> &node_items, | |||||
| std::vector<std::vector<NodeItem *>> &grouped_node_items) const { | |||||
| int curr_group = 0; | |||||
| int last_group = INT32_MIN; | int last_group = INT32_MIN; | ||||
| std::set<int> seen_groups; | std::set<int> seen_groups; | ||||
| for (auto node : node_items_) { | |||||
| for (auto node : node_items) { | |||||
| int group = node->group; | int group = node->group; | ||||
| if (group != last_group) { | if (group != last_group) { | ||||
| if (seen_groups.find(group) != seen_groups.end()) { | if (seen_groups.find(group) != seen_groups.end()) { | ||||
| @@ -101,15 +124,23 @@ Status GraphItem::GroupNodes() { | |||||
| } else { | } else { | ||||
| last_group = group; | last_group = group; | ||||
| seen_groups.insert(group); | seen_groups.insert(group); | ||||
| grouped_node_items_.emplace_back(std::vector<NodeItem *>()); | |||||
| curr_group = static_cast<int>(grouped_node_items.size()); | |||||
| grouped_node_items.emplace_back(std::vector<NodeItem *>()); | |||||
| } | } | ||||
| } | } | ||||
| GELOGD("Adding node [%s] to group %d", node->NodeName().c_str(), group); | |||||
| grouped_node_items_.back().emplace_back(node); | |||||
| node->group = curr_group; | |||||
| GELOGD("Adding node [%s] to group %d", node->NodeName().c_str(), node->group); | |||||
| grouped_node_items.back().emplace_back(node); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphItem::GroupNodes() { | |||||
| GE_CHK_STATUS_RET_NOLOG(GroupNodes(node_items_, grouped_node_items_)); | |||||
| GE_CHK_STATUS_RET_NOLOG(GroupNodes(root_items_, grouped_root_items_)); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -29,6 +29,7 @@ class GraphItem { | |||||
| Status GroupNodes(); | Status GroupNodes(); | ||||
| const vector<NodeItem *> &GetAllNodes() const; | const vector<NodeItem *> &GetAllNodes() const; | ||||
| const vector<NodeItem *> &GetAllNodes(int group) const; | const vector<NodeItem *> &GetAllNodes(int group) const; | ||||
| const vector<NodeItem *> &GetRootNodes(int group) const; | |||||
| const vector<const NodeItem *> &GetInputNodes() const; | const vector<const NodeItem *> &GetInputNodes() const; | ||||
| Status GetOutputDescList(std::vector<ConstGeTensorDescPtr> &output_desc_list) const; | Status GetOutputDescList(std::vector<ConstGeTensorDescPtr> &output_desc_list) const; | ||||
| const vector<std::pair<const NodeItem *, int>> &GetOutputEdges() const; | const vector<std::pair<const NodeItem *, int>> &GetOutputEdges() const; | ||||
| @@ -40,6 +41,12 @@ class GraphItem { | |||||
| return total_outputs_; | return total_outputs_; | ||||
| } | } | ||||
| size_t GetNodeSize(int group) const; | |||||
| bool HasCtrlFlowOp() const { | |||||
| return has_ctrl_flow_op_; | |||||
| } | |||||
| const std::string& GetName() const { | const std::string& GetName() const { | ||||
| return name_; | return name_; | ||||
| } | } | ||||
| @@ -60,9 +67,14 @@ class GraphItem { | |||||
| private: | private: | ||||
| friend class HybridModelBuilder; | friend class HybridModelBuilder; | ||||
| Status GroupNodes(const std::vector<NodeItem *> &node_items, | |||||
| std::vector<std::vector<NodeItem *>> &grouped_node_items) const; | |||||
| std::string name_; | std::string name_; | ||||
| std::vector<NodeItem *> node_items_; | std::vector<NodeItem *> node_items_; | ||||
| std::vector<std::vector<NodeItem *>> grouped_node_items_; | std::vector<std::vector<NodeItem *>> grouped_node_items_; | ||||
| std::vector<NodeItem *> root_items_; | |||||
| std::vector<std::vector<NodeItem *>> grouped_root_items_; | |||||
| std::vector<const NodeItem *> input_nodes_; | std::vector<const NodeItem *> input_nodes_; | ||||
| const NodeItem *output_node_ = nullptr; | const NodeItem *output_node_ = nullptr; | ||||
| // <src_node, out_index> | // <src_node, out_index> | ||||
| @@ -71,6 +83,7 @@ class GraphItem { | |||||
| int total_outputs_ = 0; | int total_outputs_ = 0; | ||||
| bool is_dynamic_ = true; | bool is_dynamic_ = true; | ||||
| bool has_ctrl_flow_op_ = false; | |||||
| std::vector<int> input_index_mapping_; | std::vector<int> input_index_mapping_; | ||||
| std::vector<int> output_index_mapping_; | std::vector<int> output_index_mapping_; | ||||
| }; | }; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "hybrid/model/hybrid_model_builder.h" | #include "hybrid/model/hybrid_model_builder.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/build/memory/var_mem_assign_util.h" | #include "graph/build/memory/var_mem_assign_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| @@ -42,6 +43,11 @@ const uint64_t kProfilingFpStartLogid = 1U; | |||||
| const uint64_t kProfilingBpEndLogid = 2U; | const uint64_t kProfilingBpEndLogid = 2U; | ||||
| const uint64_t kProfilingIterEndLogid = 65535U; | const uint64_t kProfilingIterEndLogid = 65535U; | ||||
| const int kBytes = 8; | const int kBytes = 8; | ||||
| const int kDecimal = 10; | |||||
| const uint8_t kStreamActiveIdx = 0; | |||||
| const uint8_t kStreamActiveNum = 1; | |||||
| const uint8_t kStreamSwitchIdx = 1; | |||||
| const uint8_t kStreamSwitchNum = 2; | |||||
| const uint32_t kStringHeadElems = 2; | const uint32_t kStringHeadElems = 2; | ||||
| const char *const kOwnerGraphIsUnknown = "OwnerGraphIsUnknown"; | const char *const kOwnerGraphIsUnknown = "OwnerGraphIsUnknown"; | ||||
| const char *const kProfilingGraph = "ProfilingGraph"; | const char *const kProfilingGraph = "ProfilingGraph"; | ||||
| @@ -213,6 +219,7 @@ Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_ite | |||||
| "[Invoke][GetCanonicalInputIndex] failed, dst_node:[%s].", dst_node->GetName().c_str()); | "[Invoke][GetCanonicalInputIndex] failed, dst_node:[%s].", dst_node->GetName().c_str()); | ||||
| node_item.outputs[i].emplace_back(canonical_index, dst_node_item); | node_item.outputs[i].emplace_back(canonical_index, dst_node_item); | ||||
| node_item.SetDataSend(dst_node_item, dst_in_anchor->GetIdx()); | |||||
| } | } | ||||
| } | } | ||||
| @@ -300,8 +307,9 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| } | } | ||||
| auto src_node = peer_anchor->GetOwnerNode(); | auto src_node = peer_anchor->GetOwnerNode(); | ||||
| GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
| auto src_node_item = MutableNodeItem(src_node); | |||||
| GE_CHECK_NOTNULL(src_node_item); | |||||
| NodeItem *src_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item), | |||||
| "[%s] failed to get or create node item", src_node->GetName().c_str()); | |||||
| if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) { | if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) { | ||||
| GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d", | GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d", | ||||
| @@ -323,15 +331,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| } | } | ||||
| } | } | ||||
| for (const auto &src_node : ge_node->GetInControlNodes()) { | |||||
| auto src_node_item = MutableNodeItem(src_node); | |||||
| if ((src_node_item != nullptr) && (is_hccl_op || src_node_item->IsHcclOp())) { | |||||
| GELOGD("[%s](%s) Add input control dependent node [%s](%s)", | |||||
| ge_node->GetName().c_str(), | |||||
| ge_node->GetType().c_str(), | |||||
| src_node->GetName().c_str(), | |||||
| src_node->GetType().c_str()); | |||||
| dependent_for_execution.emplace(src_node); | |||||
| if (node_item.node_type == NETOUTPUT) { | |||||
| for (const auto &src_node : ge_node->GetInControlNodes()) { | |||||
| auto src_node_item = MutableNodeItem(src_node); | |||||
| if ((src_node_item != nullptr) && src_node_item->IsHcclOp()) { | |||||
| GELOGD("[%s](%s) Add input control dependent node [%s](%s)", | |||||
| ge_node->GetName().c_str(), | |||||
| ge_node->GetType().c_str(), | |||||
| src_node->GetName().c_str(), | |||||
| src_node->GetType().c_str()); | |||||
| dependent_for_execution.emplace(src_node); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -794,6 +804,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| } | } | ||||
| hybrid_model_.root_graph_ = root_graph; | hybrid_model_.root_graph_ = root_graph; | ||||
| GE_CHK_STATUS_RET(RelinkNextIteration(), "[%s] Relink NextIteration failed", GetGraphName()); | |||||
| // Reset node id by topological order across all subgraphs | // Reset node id by topological order across all subgraphs | ||||
| int64_t index = 0; | int64_t index = 0; | ||||
| for (const auto &node : root_graph->GetAllNodes()) { | for (const auto &node : root_graph->GetAllNodes()) { | ||||
| @@ -839,7 +850,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| parent_node_item->NodeName().c_str()); | parent_node_item->NodeName().c_str()); | ||||
| // if parent is function control op. need add a virtual partitioned call | // if parent is function control op. need add a virtual partitioned call | ||||
| if (parent_node_item->IsControlOp()) { | |||||
| if (parent_node_item->IsControlFlowV2Op()) { | |||||
| GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), | GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), | ||||
| "[Invoke][LoadKnownShapedSubgraph]Failed to load function control op subgraph [%s]", | "[Invoke][LoadKnownShapedSubgraph]Failed to load function control op subgraph [%s]", | ||||
| sub_graph->GetName().c_str()); | sub_graph->GetName().c_str()); | ||||
| @@ -1169,7 +1180,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr | |||||
| auto parent_node = sub_graph.GetParentNode(); | auto parent_node = sub_graph.GetParentNode(); | ||||
| GE_CHECK_NOTNULL(parent_node); | GE_CHECK_NOTNULL(parent_node); | ||||
| auto op_type = parent_node->GetType(); | auto op_type = parent_node->GetType(); | ||||
| if (IsControlOp(op_type)) { | |||||
| if (IsControlFlowV2Op(op_type)) { | |||||
| GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", | GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", | ||||
| sub_graph.GetName().c_str(), | sub_graph.GetName().c_str(), | ||||
| ge_model->GetModelTaskDefPtr()->task_size()); | ge_model->GetModelTaskDefPtr()->task_size()); | ||||
| @@ -1325,6 +1336,10 @@ Status HybridModelBuilder::IndexSpecialNodes() { | |||||
| } | } | ||||
| } else if (op_type == CONSTANTOP) { | } else if (op_type == CONSTANTOP) { | ||||
| constant_op_nodes_.emplace(node->GetName(), node); | constant_op_nodes_.emplace(node->GetName(), node); | ||||
| } else if (op_type == STREAMMERGE) { | |||||
| stream_merge_op_nodes_.emplace(node->GetName(), node); | |||||
| } else if (op_type == NEXTITERATION || op_type == REFNEXTITERATION) { | |||||
| next_iteration_op_nodes_.emplace(node->GetName(), node); | |||||
| } else if (op_type == DATA && node->GetOwnerComputeGraph() != root_graph) { | } else if (op_type == DATA && node->GetOwnerComputeGraph() != root_graph) { | ||||
| NodePtr src_node; | NodePtr src_node; | ||||
| int peer_out_index = -1; | int peer_out_index = -1; | ||||
| @@ -1825,7 +1840,7 @@ Status HybridModelBuilder::GenerateEndProfilingTask(const OpDescPtr &op_desc, ve | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node) { | |||||
| Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node, uint32_t &prev_num) { | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| const OpDescPtr &op_desc = node->GetOpDesc(); | const OpDescPtr &op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| @@ -1871,7 +1886,7 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons | |||||
| if (!node_task_map.empty()) { | if (!node_task_map.empty()) { | ||||
| for (const auto &node_task : node_task_map) { | for (const auto &node_task : node_task_map) { | ||||
| NodePtr profiling_node = node_task.first; | NodePtr profiling_node = node_task.first; | ||||
| vector<domi::TaskDef> task_def_lists = node_task.second; | |||||
| const vector<domi::TaskDef> &task_def_lists = node_task.second; | |||||
| for (const auto &task_def : task_def_lists) { | for (const auto &task_def : task_def_lists) { | ||||
| hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | ||||
| } | } | ||||
| @@ -1886,6 +1901,7 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons | |||||
| node_item->input_start = 0; | node_item->input_start = 0; | ||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| graph_item.node_items_.emplace_back(node_item); | graph_item.node_items_.emplace_back(node_item); | ||||
| ++prev_num; | |||||
| } | } | ||||
| } else { | } else { | ||||
| GELOGD("No need to create profiling node before."); | GELOGD("No need to create profiling node before."); | ||||
| @@ -1894,7 +1910,7 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node) { | |||||
| Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node, uint32_t &post_num) { | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| const OpDescPtr &op_desc = node->GetOpDesc(); | const OpDescPtr &op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| @@ -1952,7 +1968,7 @@ Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const | |||||
| if (!node_task_map.empty()) { | if (!node_task_map.empty()) { | ||||
| for (const auto &node_task : node_task_map) { | for (const auto &node_task : node_task_map) { | ||||
| NodePtr profiling_node = node_task.first; | NodePtr profiling_node = node_task.first; | ||||
| vector<domi::TaskDef> task_def_lists = node_task.second; | |||||
| const vector<domi::TaskDef> &task_def_lists = node_task.second; | |||||
| for (const auto &task_def : task_def_lists) { | for (const auto &task_def : task_def_lists) { | ||||
| hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | ||||
| } | } | ||||
| @@ -1967,6 +1983,7 @@ Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const | |||||
| node_item->input_start = 0; | node_item->input_start = 0; | ||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| graph_item.node_items_.emplace_back(node_item); | graph_item.node_items_.emplace_back(node_item); | ||||
| ++post_num; | |||||
| } | } | ||||
| } else { | } else { | ||||
| GELOGD("No need to create profiling node after."); | GELOGD("No need to create profiling node after."); | ||||
| @@ -1986,20 +2003,23 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||||
| int input_start = 0; | int input_start = 0; | ||||
| int output_start = 0; | int output_start = 0; | ||||
| std::vector<NodeItem *> data_nodes; | std::vector<NodeItem *> data_nodes; | ||||
| std::map<size_t, std::pair<uint32_t, uint32_t>> profiling_nodes; | |||||
| for (auto &node : graph.GetDirectNode()) { | for (auto &node : graph.GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| const auto &op_type = node->GetType(); | const auto &op_type = node->GetType(); | ||||
| if (op_type == NOOP) { | |||||
| GELOGD("[%s] Skip NoOp", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| NodeItem *node_item = nullptr; | NodeItem *node_item = nullptr; | ||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | ||||
| GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item)); | |||||
| if (node->GetInAllNodes().empty()) { | |||||
| graph_item->root_items_.emplace_back(node_item); | |||||
| GELOGD("[%s] add to root node list", node->GetName().c_str()); | |||||
| } | |||||
| node_item->input_start = input_start; | node_item->input_start = input_start; | ||||
| node_item->output_start = output_start; | node_item->output_start = output_start; | ||||
| input_start += node_item->num_inputs; | input_start += node_item->num_inputs; | ||||
| @@ -2011,9 +2031,16 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||||
| graph_item->output_node_ = node_item; | graph_item->output_node_ = node_item; | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item, is_root_graph)); | GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item, is_root_graph)); | ||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeBefore(*graph_item, node)); | |||||
| uint32_t prev_num = 0; | |||||
| uint32_t post_num = 0; | |||||
| GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeBefore(*graph_item, node, prev_num)); | |||||
| size_t node_index = graph_item->node_items_.size(); | |||||
| graph_item->node_items_.emplace_back(node_item); | graph_item->node_items_.emplace_back(node_item); | ||||
| GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeAfter(*graph_item, node)); | |||||
| GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeAfter(*graph_item, node, post_num)); | |||||
| if (prev_num > 0 || post_num > 0) { | |||||
| profiling_nodes[node_index] = { prev_num, post_num }; | |||||
| } | |||||
| // parse var outputs | // parse var outputs | ||||
| GE_CHK_STATUS_RET_NOLOG(ParseVarOutputs(*node_item)); | GE_CHK_STATUS_RET_NOLOG(ParseVarOutputs(*node_item)); | ||||
| GELOGD("NodeItem created: %s", node_item->DebugString().c_str()); | GELOGD("NodeItem created: %s", node_item->DebugString().c_str()); | ||||
| @@ -2022,6 +2049,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||||
| graph_item->total_inputs_ = input_start; | graph_item->total_inputs_ = input_start; | ||||
| graph_item->total_outputs_ = output_start; | graph_item->total_outputs_ = output_start; | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildInputMapping(*graph_item, data_nodes, is_root_graph)); | GE_CHK_STATUS_RET_NOLOG(BuildInputMapping(*graph_item, data_nodes, is_root_graph)); | ||||
| GE_CHK_STATUS_RET_NOLOG(BuildProfilingControl(*graph_item, profiling_nodes)); | |||||
| if (is_root_graph) { | if (is_root_graph) { | ||||
| graph_item->SetName("Root-Graph"); | graph_item->SetName("Root-Graph"); | ||||
| GELOGD("Done loading dynamic subgraph: [%s]", graph_item->GetName().c_str()); | GELOGD("Done loading dynamic subgraph: [%s]", graph_item->GetName().c_str()); | ||||
| @@ -2271,5 +2299,299 @@ Status HybridModelBuilder::Convert2HostTensor(const NodePtr &node, int node_id, | |||||
| hybrid_model_.host_tensors_[node_id].emplace_back(output_idx, std::move(tensor)); | hybrid_model_.host_tensors_[node_id].emplace_back(output_idx, std::move(tensor)); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::RelinkNextIteration() { | |||||
| for (const auto &item : stream_merge_op_nodes_) { | |||||
| const auto &merge = item.second; | |||||
| std::string node_name; | |||||
| if (!AttrUtils::GetStr(merge->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, node_name)) { | |||||
| GELOGD("[%s] no attribute[%s], not in while loop", merge->GetName().c_str(), ATTR_NAME_NEXT_ITERATION.c_str()); | |||||
| continue; | |||||
| } | |||||
| const auto it = next_iteration_op_nodes_.find(node_name); | |||||
| if (it == next_iteration_op_nodes_.end()) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] expect NextIteration[%s] not found", merge->GetName().c_str(), node_name.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| const auto &iteration = it->second; | |||||
| if (GraphUtils::AddEdge(iteration->GetOutDataAnchor(0), merge->GetInDataAnchor(1)) != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] -> [%s] Add edge failed", node_name.c_str(), merge->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| stream_merge_op_nodes_.clear(); | |||||
| next_iteration_op_nodes_.clear(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::BuildProfilingControl(GraphItem &graph_item, | |||||
| const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes) { | |||||
| const auto node_size = graph_item.node_items_.size(); | |||||
| for (const auto &item : nodes) { | |||||
| const auto node_index = item.first; | |||||
| GE_CHK_BOOL_RET_STATUS(node_index < node_size, FAILED, "node index invalid"); | |||||
| const auto &node_item = graph_item.node_items_[node_index]; | |||||
| if (item.second.first > 0) { | |||||
| const auto prev_num = item.second.first; | |||||
| if (node_index == prev_num) { | |||||
| // Profiling Before root node. | |||||
| for (uint32_t i = 1; i <= prev_num; ++i) { | |||||
| GE_CHK_BOOL_RET_STATUS(node_index - i < node_size, FAILED, "prev index invalid"); | |||||
| const auto &curr_item = graph_item.node_items_[node_index - i]; | |||||
| graph_item.root_items_.emplace(graph_item.root_items_.begin(), curr_item); | |||||
| } | |||||
| } else { | |||||
| GE_CHK_BOOL_RET_STATUS((node_index - prev_num) - 1 < node_size, FAILED, "prev index invalid"); | |||||
| const auto &prev_item = graph_item.node_items_[(node_index - prev_num) - 1]; | |||||
| for (uint32_t i = 1; i <= prev_num; ++i) { | |||||
| GE_CHK_BOOL_RET_STATUS(node_index - i < node_size, FAILED, "prev index invalid"); | |||||
| const auto &curr_item = graph_item.node_items_[node_index - i]; | |||||
| prev_item->SetCtrlSend(curr_item, UINT32_MAX); | |||||
| curr_item->SetCtrlSend(node_item, UINT32_MAX); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (item.second.second > 0) { | |||||
| const auto post_num = item.second.second; | |||||
| if (node_size == node_index + post_num + 1) { | |||||
| // Profiling After last node. | |||||
| for (uint32_t i = 1; i <= post_num; ++i) { | |||||
| GE_CHK_BOOL_RET_STATUS(node_index + i < node_size, FAILED, "post index invalid"); | |||||
| const auto &curr_item = graph_item.node_items_[node_index + i]; | |||||
| node_item->SetCtrlSend(curr_item, UINT32_MAX); | |||||
| } | |||||
| } else { | |||||
| GE_CHK_BOOL_RET_STATUS((node_index + post_num) + 1 < node_size, FAILED, "post index invalid"); | |||||
| const auto &post_item = graph_item.node_items_[(node_index + post_num) + 1]; | |||||
| for (uint32_t i = 1; i <= post_num; ++i) { | |||||
| GE_CHK_BOOL_RET_STATUS(node_index + i < node_size, FAILED, "post index invalid"); | |||||
| const auto &curr_item = graph_item.node_items_[node_index + i]; | |||||
| node_item->SetCtrlSend(curr_item, UINT32_MAX); | |||||
| curr_item->SetCtrlSend(post_item, UINT32_MAX); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) { | |||||
| GELOGD("Build control flow for node %s", node->GetName().c_str()); | |||||
| using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>; | |||||
| static const std::map<std::string, GroupBuilder> control_flow{ | |||||
| { STREAMACTIVE, &HybridModelBuilder::CreateStreamActiveGroup }, | |||||
| { STREAMSWITCH, &HybridModelBuilder::CreateStreamSwitchGroup }, | |||||
| { STREAMSWITCHN, &HybridModelBuilder::CreateStreamSwitchNGroup }, | |||||
| { NEXTITERATION, &HybridModelBuilder::CreateNextIterationGroup }, | |||||
| { REFNEXTITERATION, &HybridModelBuilder::CreateNextIterationGroup }, | |||||
| { SWITCH, &HybridModelBuilder::CreateSwitchGroup }, | |||||
| { REFSWITCH, &HybridModelBuilder::CreateSwitchGroup }, | |||||
| { LABELSET, &HybridModelBuilder::CreateLabelSetGroup }, | |||||
| { LABELGOTO, &HybridModelBuilder::CreateLabelGotoGroup }, | |||||
| { LABELGOTOEX, &HybridModelBuilder::CreateLabelGotoGroup }, | |||||
| { LABELSWITCH, &HybridModelBuilder::CreateLabelSwitchGroup }, | |||||
| { LABELSWITCHBYINDEX, &HybridModelBuilder::CreateLabelSwitchGroup } | |||||
| }; | |||||
| Status ret = SUCCESS; | |||||
| auto it = control_flow.find(node_item->node_type); | |||||
| if (it == control_flow.end()) { | |||||
| ret = CreateNormalNodeGroup(node, node_item); | |||||
| } else { | |||||
| graph_item.has_ctrl_flow_op_ = true; | |||||
| ret = it->second(this, node, node_item); | |||||
| } | |||||
| GELOGD("Node: %s, control by: %zu, control for: %zu, switch group: %zu", node->GetName().c_str(), | |||||
| node_item->ctrl_recv_.size(), node_item->ctrl_send_.size(), node_item->switch_groups_.size()); | |||||
| return ret; | |||||
| } | |||||
| Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| const auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
| const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
| "[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
| node_item->SetCtrlSend(dst_node_item, UINT32_MAX); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != STREAMACTIVE) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| node_item->switch_groups_.resize(kStreamActiveNum); | |||||
| const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
| const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| if (dst_node->GetType() == STREAMMERGE) { | |||||
| GELOGI("[%s] skip control node: %s", node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
| "[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
| node_item->SetCtrlSend(dst_node_item, kStreamActiveIdx); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != STREAMSWITCH) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| // Consider as two groups, group[0] set empty for false, group[1] for true. | |||||
| node_item->switch_groups_.resize(kStreamSwitchNum); | |||||
| const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
| const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
| "[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
| node_item->SetCtrlSend(dst_node_item, kStreamSwitchIdx); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != STREAMSWITCHN) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| uint32_t batch_num = 0; | |||||
| if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BATCH_NUM, batch_num)) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_NUM failed", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (batch_num == 0) { | |||||
| GELOGW("[%s] Got empty branch for SwitchN, Please check.", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| node_item->switch_groups_.resize(batch_num); | |||||
| const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
| const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| std::string batch_label; | |||||
| if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| std::string::size_type pos = batch_label.rfind("_"); | |||||
| if (pos == std::string::npos) { | |||||
| GELOGW("[%s] Separator not found in batch label: %s.", node->GetName().c_str(), batch_label.c_str()); | |||||
| continue; | |||||
| } | |||||
| ++pos; // Skip Separator | |||||
| uint64_t batch_index = std::strtoul(batch_label.data() + pos, nullptr, kDecimal); | |||||
| if (batch_index >= batch_num) { | |||||
| GELOGW("batch label: %s, batch index: %lu great than batch num: %u", batch_label.c_str(), batch_index, batch_num); | |||||
| continue; | |||||
| } | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
| "[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
| node_item->SetCtrlSend(dst_node_item, batch_index); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::CreateNextIterationGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != NEXTITERATION && node_item->node_type != REFNEXTITERATION) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != SWITCH && node_item->node_type != REFSWITCH) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
| for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
| const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
| "[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
| node_item->SetCtrlSend(dst_node_item, UINT32_MAX); | |||||
| } | |||||
| // Group switch flow by out put data. | |||||
| node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); | |||||
| for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | |||||
| const auto &out_anchor = node->GetOutDataAnchor(i); | |||||
| for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
| "[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
| node_item->SetCtrlSend(dst_node_item, i); // take switch data as ctrl. | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::CreateLabelSetGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != LABELSET) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGE(UNSUPPORTED, "[%s] Not implemented.", node->GetName().c_str()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| Status HybridModelBuilder::CreateLabelGotoGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != LABELGOTO && node_item->node_type != LABELGOTOEX) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGE(UNSUPPORTED, "[%s] Not implemented.", node->GetName().c_str()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| Status HybridModelBuilder::CreateLabelSwitchGroup(const NodePtr &node, NodeItem *node_item) { | |||||
| if (node_item->node_type != LABELSWITCH && node_item->node_type != LABELSWITCHBYINDEX) { | |||||
| GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGE(UNSUPPORTED, "[%s] Not implemented.", node->GetName().c_str()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -85,8 +85,8 @@ class HybridModelBuilder { | |||||
| Status LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item); | Status LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item); | ||||
| Status RecoverGraphUnknownFlag(); | Status RecoverGraphUnknownFlag(); | ||||
| Status CheckAicpuOpList(); | Status CheckAicpuOpList(); | ||||
| Status CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node); | |||||
| Status CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node); | |||||
| Status CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node, uint32_t &prev_num); | |||||
| Status CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node, uint32_t &post_num); | |||||
| Status GenerateFpProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list); | Status GenerateFpProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list); | ||||
| Status GenerateBpProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list); | Status GenerateBpProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list); | ||||
| Status GenerateEndProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list); | Status GenerateEndProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list); | ||||
| @@ -94,6 +94,20 @@ class HybridModelBuilder { | |||||
| Status OptimizeDependenciesForConstantInputs(); | Status OptimizeDependenciesForConstantInputs(); | ||||
| Status Convert2HostTensor(const NodePtr &node, int node_id, uint32_t output_idx); | Status Convert2HostTensor(const NodePtr &node, int node_id, uint32_t output_idx); | ||||
| Status RelinkNextIteration(); | |||||
| Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | |||||
| Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateNextIterationGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateSwitchGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateLabelSetGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateLabelGotoGroup(const NodePtr &node, NodeItem *node_item); | |||||
| Status CreateLabelSwitchGroup(const NodePtr &node, NodeItem *node_item); | |||||
| const char* GetGraphName() const { | const char* GetGraphName() const { | ||||
| return hybrid_model_.model_name_.c_str(); | return hybrid_model_.model_name_.c_str(); | ||||
| } | } | ||||
| @@ -104,6 +118,8 @@ class HybridModelBuilder { | |||||
| GeRootModelPtr ge_root_model_; | GeRootModelPtr ge_root_model_; | ||||
| std::map<std::string, GeModelPtr> subgraph_models_; | std::map<std::string, GeModelPtr> subgraph_models_; | ||||
| std::map<std::string, NodePtr> constant_op_nodes_; | std::map<std::string, NodePtr> constant_op_nodes_; | ||||
| std::map<std::string, NodePtr> stream_merge_op_nodes_; | |||||
| std::map<std::string, NodePtr> next_iteration_op_nodes_; | |||||
| std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | ||||
| std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_; | ||||
| @@ -29,10 +29,19 @@ namespace hybrid { | |||||
| namespace { | namespace { | ||||
| const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
| const char *const kNodeTypeRetVal = "_RetVal"; | const char *const kNodeTypeRetVal = "_RetVal"; | ||||
| std::set<std::string> kControlOpTypes{ | |||||
| const std::set<std::string> kControlOpTypes{ | |||||
| IF, STATELESSIF, CASE, WHILE, STATELESSWHILE | IF, STATELESSIF, CASE, WHILE, STATELESSWHILE | ||||
| }; | }; | ||||
| const std::set<std::string> kControlFlowOpTypes{ | |||||
| STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX, | |||||
| NEXTITERATION, REFNEXTITERATION | |||||
| }; | |||||
| const std::set<std::string> kMergeOpTypes{ | |||||
| MERGE, REFMERGE, STREAMMERGE | |||||
| }; | |||||
| Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
| uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
| if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | ||||
| @@ -107,7 +116,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool IsControlOp(const std::string &op_type) { | |||||
| bool IsControlFlowV2Op(const std::string &op_type) { | |||||
| return kControlOpTypes.count(op_type) > 0; | return kControlOpTypes.count(op_type) > 0; | ||||
| } | } | ||||
| @@ -226,7 +235,7 @@ Status NodeItem::ResolveStaticInputsAndOutputs() { | |||||
| } | } | ||||
| void NodeItem::ResolveUnknownShapeType() { | void NodeItem::ResolveUnknownShapeType() { | ||||
| if (IsControlOp() || node_type == PARTITIONEDCALL) { | |||||
| if (IsControlFlowV2Op() || (is_dynamic && node_type == PARTITIONEDCALL)) { | |||||
| shape_inference_type = DEPEND_COMPUTE; | shape_inference_type = DEPEND_COMPUTE; | ||||
| } else { | } else { | ||||
| int32_t unknown_shape_type_val = 0; | int32_t unknown_shape_type_val = 0; | ||||
| @@ -236,6 +245,10 @@ void NodeItem::ResolveUnknownShapeType() { | |||||
| } | } | ||||
| Status NodeItem::Init() { | Status NodeItem::Init() { | ||||
| is_ctrl_flow_v2_op_ = ge::hybrid::IsControlFlowV2Op(node_type); | |||||
| is_ctrl_flow_op_ = kControlFlowOpTypes.count(node_type) > 0; | |||||
| is_merge_op_ = kMergeOpTypes.count(node_type) > 0; | |||||
| is_root_node_ = node->GetInAllNodes().empty(); | |||||
| GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); | GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); | ||||
| GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); | GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); | ||||
| ResolveUnknownShapeType(); | ResolveUnknownShapeType(); | ||||
| @@ -244,14 +257,12 @@ Status NodeItem::Init() { | |||||
| GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), | GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), | ||||
| "[Invoke][ParseFusedSubgraph][%s] Failed to parse fused subgraph", node_name.c_str()); | "[Invoke][ParseFusedSubgraph][%s] Failed to parse fused subgraph", node_name.c_str()); | ||||
| } | } | ||||
| copy_mu_ = MakeShared<std::mutex>(); | |||||
| GE_CHECK_NOTNULL(copy_mu_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| bool NodeItem::IsControlOp() const { | |||||
| return ge::hybrid::IsControlOp(op_desc->GetType()); | |||||
| } | |||||
| bool NodeItem::IsHcclOp() const { | bool NodeItem::IsHcclOp() const { | ||||
| return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL; | return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL; | ||||
| } | } | ||||
| @@ -383,5 +394,45 @@ bool NodeItem::IsInputShapeStatic(int index) const { | |||||
| return is_input_shape_static_[index]; | return is_input_shape_static_[index]; | ||||
| } | } | ||||
| void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
| data_send_.emplace(node_item); | |||||
| node_item->data_recv_[this] = anchor_index; | |||||
| if (is_root_node_) { | |||||
| node_item->root_data_.emplace(this); | |||||
| } | |||||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||||
| } | |||||
| void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
| if (switch_index < switch_groups_.size()) { | |||||
| std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index]; | |||||
| switch_group.emplace_back(node_item); | |||||
| } else { | |||||
| ctrl_send_.insert(node_item); | |||||
| } | |||||
| node_item->ctrl_recv_.emplace(this); | |||||
| if (is_root_node_) { | |||||
| node_item->root_ctrl_.emplace(this); | |||||
| } | |||||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||||
| } | |||||
| OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { | |||||
| if (mu_ != nullptr) { | |||||
| GELOGD("lock for %s", name_.c_str()); | |||||
| mu_->lock(); | |||||
| } | |||||
| } | |||||
| OptionalMutexGuard::~OptionalMutexGuard() { | |||||
| if (mu_ != nullptr) { | |||||
| GELOGD("unlock for %s", name_.c_str()); | |||||
| mu_->unlock(); | |||||
| mu_ = nullptr; | |||||
| } | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -37,7 +37,16 @@ struct FusedSubgraph { | |||||
| ComputeGraphPtr graph; | ComputeGraphPtr graph; | ||||
| }; | }; | ||||
| bool IsControlOp(const std::string &op_type); | |||||
| bool IsControlFlowV2Op(const std::string &op_type); | |||||
| class OptionalMutexGuard { | |||||
| public: | |||||
| OptionalMutexGuard(std::mutex *mutex, const string &name); | |||||
| ~OptionalMutexGuard(); | |||||
| private: | |||||
| std::mutex *mu_{nullptr}; | |||||
| std::string name_; | |||||
| }; | |||||
| // for caching static information across execution | // for caching static information across execution | ||||
| struct NodeItem { | struct NodeItem { | ||||
| @@ -70,12 +79,29 @@ struct NodeItem { | |||||
| Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | ||||
| bool IsControlOp() const; | |||||
| bool IsControlFlowV2Op() const { | |||||
| return is_ctrl_flow_v2_op_; | |||||
| } | |||||
| bool IsControlFlowOp() const { | |||||
| return is_ctrl_flow_op_; | |||||
| } | |||||
| bool IsMergeOp() const { | |||||
| return is_merge_op_; | |||||
| } | |||||
| bool IsHcclOp() const; | bool IsHcclOp() const; | ||||
| void SetToDynamic(); | void SetToDynamic(); | ||||
| void SetDataSend(NodeItem *node_item, int anchor_index); | |||||
| void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); | |||||
| OptionalMutexGuard MutexGuard(const std::string &name) const { | |||||
| return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); | |||||
| } | |||||
| std::string DebugString() const; | std::string DebugString() const; | ||||
| NodePtr node; | NodePtr node; | ||||
| @@ -99,7 +125,20 @@ struct NodeItem { | |||||
| std::set<int> to_const_output_id_list; | std::set<int> to_const_output_id_list; | ||||
| // src_output_id, dst_anchor_id, dst_node | // src_output_id, dst_anchor_id, dst_node | ||||
| vector<vector<pair<int, NodeItem *>>> outputs; | |||||
| std::vector<std::vector<std::pair<int, NodeItem *>>> outputs; | |||||
| // for linked drive | |||||
| bool is_root_node_ = false; | |||||
| bool is_ctrl_flow_v2_op_ = false; | |||||
| bool is_ctrl_flow_op_ = false; | |||||
| bool is_merge_op_ = false; | |||||
| std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | |||||
| std::set<const NodeItem *> root_data_; // Recv data from root node | |||||
| std::set<const NodeItem *> data_send_; // Send data notify to | |||||
| std::map<const NodeItem *, int> data_recv_; // Recv data notify from | |||||
| std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | |||||
| std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | |||||
| std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
| std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
| std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
| @@ -122,6 +161,7 @@ struct NodeItem { | |||||
| std::vector<bool> is_input_shape_static_; | std::vector<bool> is_input_shape_static_; | ||||
| std::vector<uint32_t> input_desc_indices_; | std::vector<uint32_t> input_desc_indices_; | ||||
| std::shared_ptr<std::mutex> copy_mu_; | |||||
| mutable std::mutex mu_; | mutable std::mutex mu_; | ||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| @@ -32,7 +32,7 @@ REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGR | |||||
| Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] Start"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] Start"); | ||||
| GELOGD("[%s] KnownNodeTask::ExecuteAsync in.", context.GetNodeName()); | |||||
| GELOGD("[%s] KnownNodeTask::ExecuteAsync in, model id: %u.", context.GetNodeName(), davinci_model_->Id()); | |||||
| if (davinci_model_->GetTaskList().empty()) { | if (davinci_model_->GetTaskList().empty()) { | ||||
| GELOGW("KnownNodeExecutor::ExecuteAsync davinci model has no taskinfo."); | GELOGW("KnownNodeExecutor::ExecuteAsync davinci model has no taskinfo."); | ||||
| @@ -62,7 +62,7 @@ Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> d | |||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End"); | ||||
| GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); | GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); | ||||
| GELOGD("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName()); | |||||
| GELOGD("[%s] KnownNodeTask::ExecuteAsync success, model id: %u.", context.GetNodeName(), davinci_model_->Id()); | |||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -22,18 +22,6 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::CONTROL_OP, ControlOpNodeExecutor); | REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::CONTROL_OP, ControlOpNodeExecutor); | ||||
| namespace { | |||||
| template<typename T> | |||||
| Status CopyScalarValueToHost(const TensorValue &tensor, T &value) { | |||||
| GE_CHECK_GE(tensor.GetSize(), sizeof(value)); | |||||
| GE_CHK_RT_RET(rtMemcpy(&value, | |||||
| sizeof(value), | |||||
| tensor.GetData(), | |||||
| sizeof(value), | |||||
| RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph, | Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph, | ||||
| TaskContext &task_context, | TaskContext &task_context, | ||||
| @@ -60,12 +48,12 @@ Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph, | |||||
| Status ControlOpNodeTask::ToBool(const TensorValue &tensor, DataType data_type, bool &value) { | Status ControlOpNodeTask::ToBool(const TensorValue &tensor, DataType data_type, bool &value) { | ||||
| switch (data_type) { | switch (data_type) { | ||||
| #define CASE(DT, T) \ | |||||
| case (DT): { \ | |||||
| T val{}; \ | |||||
| GE_CHK_STATUS_RET(CopyScalarValueToHost(tensor, val)); \ | |||||
| value = val != 0; \ | |||||
| break; \ | |||||
| #define CASE(DT, T) \ | |||||
| case (DT): { \ | |||||
| T val{}; \ | |||||
| GE_CHK_STATUS_RET(tensor.CopyScalarValueToHost(val)); \ | |||||
| value = val != 0; \ | |||||
| break; \ | |||||
| } | } | ||||
| // DT_STRING was handled in CondPass | // DT_STRING was handled in CondPass | ||||
| CASE(DT_FLOAT, float) | CASE(DT_FLOAT, float) | ||||
| @@ -77,7 +65,7 @@ Status ControlOpNodeTask::ToBool(const TensorValue &tensor, DataType data_type, | |||||
| CASE(DT_INT64, int64_t) | CASE(DT_INT64, int64_t) | ||||
| #undef CASE | #undef CASE | ||||
| case DT_BOOL: | case DT_BOOL: | ||||
| GE_CHK_STATUS_RET(CopyScalarValueToHost(tensor, value)); | |||||
| GE_CHK_STATUS_RET(tensor.CopyScalarValueToHost(value)); | |||||
| break; | break; | ||||
| default: | default: | ||||
| GELOGE(UNSUPPORTED, "Data type %s is not support by cond.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGE(UNSUPPORTED, "Data type %s is not support by cond.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| @@ -182,7 +170,7 @@ Status CaseOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::func | |||||
| auto branch_tensor = task_context.GetInput(kCaseBranchIndex); | auto branch_tensor = task_context.GetInput(kCaseBranchIndex); | ||||
| GE_CHECK_NOTNULL(branch_tensor); | GE_CHECK_NOTNULL(branch_tensor); | ||||
| int32_t branch_index = 0; | int32_t branch_index = 0; | ||||
| GE_CHK_STATUS_RET(CopyScalarValueToHost(*branch_tensor, branch_index)); | |||||
| GE_CHK_STATUS_RET(branch_tensor->CopyScalarValueToHost(branch_index)); | |||||
| const GraphItem *subgraph = SelectBranch(branch_index); | const GraphItem *subgraph = SelectBranch(branch_index); | ||||
| GELOGI("[%s] Taking subgraph [%s] by branch = [%d]", | GELOGI("[%s] Taking subgraph [%s] by branch = [%d]", | ||||
| task_context.GetNodeName(), | task_context.GetNodeName(), | ||||
| @@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node | |||||
| return ExecutorType::GE_LOCAL; | return ExecutorType::GE_LOCAL; | ||||
| } | } | ||||
| if (IsControlOp(op_type)) { | |||||
| if (IsControlFlowV2Op(op_type)) { | |||||
| return ExecutorType::CONTROL_OP; | return ExecutorType::CONTROL_OP; | ||||
| } | } | ||||
| @@ -27,6 +27,8 @@ const uint32_t MEMORY_ALIGN_RATIO = 2; | |||||
| const uint32_t MEMORY_ALIGN_SIZE = 32; | const uint32_t MEMORY_ALIGN_SIZE = 32; | ||||
| namespace hybrid { | namespace hybrid { | ||||
| class HybridModel; | class HybridModel; | ||||
| using NodeTaskPtr = std::shared_ptr<NodeTask>; | |||||
| // Base class of Node Task | // Base class of Node Task | ||||
| class NodeTask { | class NodeTask { | ||||
| public: | public: | ||||
| @@ -14,7 +14,9 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "rts_node_executor.h" | |||||
| #include "hybrid/node_executor/rts/rts_node_executor.h" | |||||
| #include "hybrid/node_executor/rts/rts_task_factory.h" | |||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "common/types.h" | #include "common/types.h" | ||||
| @@ -26,6 +28,11 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::RTS, RtsNodeExecutor); | REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::RTS, RtsNodeExecutor); | ||||
| REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); | |||||
| Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | ||||
| auto input_desc = context.MutableInputDesc(index); | auto input_desc = context.MutableInputDesc(index); | ||||
| GE_CHECK_NOTNULL(input_desc); | GE_CHECK_NOTNULL(input_desc); | ||||
| @@ -77,10 +84,6 @@ Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void() | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status IdentityNodeTask::UpdateArgs(TaskContext &context) { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status IdentityNNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status IdentityNNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
| GELOGD("[%s] Start to execute.", context.GetNodeName()); | GELOGD("[%s] Start to execute.", context.GetNodeName()); | ||||
| for (int i = 0; i < context.NumInputs(); ++i) { | for (int i = 0; i < context.NumInputs(); ++i) { | ||||
| @@ -95,7 +98,15 @@ Status IdentityNNodeTask::ExecuteAsync(TaskContext &context, std::function<void( | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ProfilingTraceNodeTask::UpdateArgs(TaskContext &context) { | |||||
| Status ProfilingTraceNodeTask::Init(const HybridModel &model, const NodePtr &node) { | |||||
| auto *task_defs = model.GetTaskDefs(node); | |||||
| if (task_defs == nullptr || task_defs->empty()) { | |||||
| GELOGE(INTERNAL_ERROR, "Profiling node has no task to execute."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| task_defs_ = *task_defs; | |||||
| GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -116,32 +127,21 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| }; | |||||
| } | |||||
| Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GELOGD("[%s] Load for local task.", node->GetName().c_str()); | |||||
| auto op_type = node->GetType(); | auto op_type = node->GetType(); | ||||
| if (op_type == IDENTITY) { | |||||
| task = MakeShared<IdentityNodeTask>(); | |||||
| } else if (op_type == IDENTITYN) { | |||||
| task = MakeShared<IdentityNNodeTask>(); | |||||
| } else if (op_type == READVARIABLEOP) { | |||||
| task = MakeShared<ReadVariableOpNodeTask>(); | |||||
| } else if (op_type == PROFILINGTRAININGTRACE) { | |||||
| auto *task_defs = model.GetTaskDefs(node); | |||||
| if (task_defs == nullptr || task_defs->empty()) { | |||||
| GELOGE(INTERNAL_ERROR, "Profiling node has no task to execute."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| task = MakeShared<ProfilingTraceNodeTask>(*task_defs); | |||||
| } else { | |||||
| task = RtsTaskFactory::GetInstance().Create(op_type); | |||||
| if (task == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str()); | GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(task); | |||||
| return SUCCESS; | |||||
| RtsNodeTask *rts_task = dynamic_cast<RtsNodeTask *>(task.get()); | |||||
| GE_CHECK_NOTNULL(rts_task); | |||||
| return rts_task->Init(model, node); | |||||
| } | } | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -18,13 +18,12 @@ | |||||
| #define GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_EXECUTOR_H_ | #define GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_EXECUTOR_H_ | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include "proto/task.pb.h" | |||||
| #include "hybrid/node_executor/rts/rts_node_task.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| class IdentityNodeTask : public NodeTask { | |||||
| class IdentityNodeTask : public RtsNodeTask { | |||||
| public: | public: | ||||
| Status UpdateArgs(TaskContext &context) override; | |||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | ||||
| protected: | protected: | ||||
| @@ -41,12 +40,10 @@ class ReadVariableOpNodeTask : public IdentityNodeTask { | |||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | ||||
| }; | }; | ||||
| class ProfilingTraceNodeTask : public NodeTask { | |||||
| class ProfilingTraceNodeTask : public RtsNodeTask { | |||||
| public: | public: | ||||
| explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {} | |||||
| ~ProfilingTraceNodeTask() override = default; | |||||
| Status Init(const HybridModel &model, const NodePtr &node) override; | |||||
| Status UpdateArgs(TaskContext &context) override; | |||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | ||||
| private: | private: | ||||
| @@ -0,0 +1,240 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "hybrid/node_executor/rts/rts_node_task.h" | |||||
| #include "hybrid/node_executor/rts/rts_task_factory.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| namespace { | |||||
| constexpr uint8_t kSwitchPredIndex = 0; | |||||
| constexpr uint8_t kSwitchCompIndex = 1; | |||||
| const static std::map<rtCondition_t, std::function<bool(int64_t, int64_t)>> kCompHandle = { | |||||
| {RT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value == comp_value; }}, | |||||
| {RT_NOT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value != comp_value; }}, | |||||
| {RT_GREATER, [](int64_t pred_value, int64_t comp_value) { return pred_value > comp_value; }}, | |||||
| {RT_GREATER_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value >= comp_value; }}, | |||||
| {RT_LESS, [](int64_t pred_value, int64_t comp_value) { return pred_value < comp_value; }}, | |||||
| {RT_LESS_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value <= comp_value; }}, | |||||
| }; | |||||
| } | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(LOOPCOND, PassThroughNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(NEXTITERATION, PassThroughNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(REFNEXTITERATION, PassThroughNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(EXIT, PassThroughNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(REFEXIT, PassThroughNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(LABELSET, LabelSetNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(LABELGOTO, LabelGotoNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(LABELGOTOEX, LabelGotoNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(LABELSWITCH, LabelSwitchNodeTask); | |||||
| REGISTER_RTS_TASK_CREATOR(LABELSWITCHBYINDEX, LabelSwitchNodeTask); | |||||
| Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value) { | |||||
| auto tensor_value = task_context.GetInput(index); | |||||
| GE_CHECK_NOTNULL(tensor_value); | |||||
| auto tensor_desc = task_context.MutableInputDesc(index); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| auto data_type = tensor_desc->GetDataType(); | |||||
| switch (data_type) { | |||||
| #define CASE_TYPE(DT, VT) \ | |||||
| case (DT): { \ | |||||
| VT data_val{}; \ | |||||
| GE_CHK_STATUS_RET(tensor_value->CopyScalarValueToHost(data_val)); \ | |||||
| value = static_cast<int64_t>(data_val); \ | |||||
| break; \ | |||||
| } | |||||
| // Just accept index data type. | |||||
| CASE_TYPE(DT_INT32, int32_t) | |||||
| CASE_TYPE(DT_INT64, int64_t) | |||||
| #undef CASE_TYPE | |||||
| default: { | |||||
| GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
| const auto &node_state = task_context.GetNodeState(); | |||||
| node_state->SetSwitchIndex(0); | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGI("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StreamSwitchNodeTask::Init(const HybridModel &model, const NodePtr &node) { | |||||
| uint32_t value = 0; | |||||
| if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, value)) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Get %s failed.", node->GetName().c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| rtCondition_t cond = static_cast<rtCondition_t>(value); | |||||
| const auto it = kCompHandle.find(cond); | |||||
| if (it == kCompHandle.end()) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Get Condition: %u handle failed.", node->GetName().c_str(), value); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| comp_func_ = it->second; | |||||
| GELOGD("[%s] Done initialization successfully, condition is %u.", node->GetName().c_str(), value); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StreamSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
| GE_CHECK_NOTNULL(comp_func_); | |||||
| int64_t pred_value = 0; | |||||
| GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchPredIndex, pred_value)); | |||||
| int64_t comp_value = 0; | |||||
| GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchCompIndex, comp_value)); | |||||
| bool switch_idx = comp_func_(pred_value, comp_value); | |||||
| auto node_state = task_context.GetNodeState(); | |||||
| node_state->SetSwitchIndex(static_cast<int>(switch_idx)); | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGI("[%s] Done executing successfully, pred value: %ld, comp value: %ld, switch index: %d.", | |||||
| task_context.GetNodeName(), pred_value, comp_value, static_cast<int>(switch_idx)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| int index = task_context.GetNodeState()->GetMergeIndex(); | |||||
| GELOGD("[%s] Start to execute, merge index: %d.", task_context.GetNodeName(), index); | |||||
| if (index < 0 || index >= task_context.NumInputs()) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Invalid merge param, inputs num: %d, merge index: %d.", | |||||
| task_context.GetNodeName(), task_context.NumInputs(), index); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| const auto in_x = task_context.MutableInput(index); // x | |||||
| GE_CHECK_NOTNULL(in_x); | |||||
| task_context.SetOutput(MERGE_DATA_OUTPUT, *in_x); // y | |||||
| const auto out_y = task_context.MutableOutput(MERGE_INDEX_OUTPUT); // value_index | |||||
| GE_CHECK_NOTNULL(out_y); | |||||
| if (out_y->GetSize() > 0) { | |||||
| GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), &index, sizeof(index), | |||||
| RT_MEMCPY_HOST_TO_DEVICE_EX, task_context.GetStream())); | |||||
| } | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| task_context.GetNodeState()->SetMergeIndex(-1); // Invalidate for loop. | |||||
| GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
| const auto in_x = task_context.GetInput(0); // x | |||||
| GE_CHECK_NOTNULL(in_x); | |||||
| const auto out_y = task_context.MutableOutput(0); // value_index | |||||
| GE_CHECK_NOTNULL(out_y); | |||||
| GELOGD("[%s] input size: %zu, output size: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize()); | |||||
| if (in_x->GetSize() > 0 && out_y->GetSize() > 0) { | |||||
| GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), in_x->GetData(), in_x->GetSize(), | |||||
| RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); | |||||
| } else { | |||||
| GELOGW("[%s] invalid copy size, src: %zu, dst: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize()); | |||||
| } | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
| const auto in_x = task_context.GetInput(0); // x | |||||
| GE_CHECK_NOTNULL(in_x); | |||||
| task_context.SetOutput(0, *in_x); // y | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status LabelSetNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| Status LabelGotoNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| Status LabelSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
| return UNSUPPORTED; | |||||
| } | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_TASK_H_ | |||||
| #define GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_TASK_H_ | |||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| #include "proto/task.pb.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| class RtsNodeTask : public NodeTask { | |||||
| public: | |||||
| Status Init(TaskContext &task_context) override { | |||||
| return SUCCESS; | |||||
| } | |||||
| virtual Status Init(const HybridModel &model, const NodePtr &node) { | |||||
| GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status UpdateArgs(TaskContext &task_context) override { | |||||
| GELOGD("[%s] Done update args successfully.", task_context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value); | |||||
| }; | |||||
| class StreamActiveNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| class StreamSwitchNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status Init(const HybridModel &model, const NodePtr &node) override; | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| private: | |||||
| std::function<bool(int64_t, int64_t)> comp_func_{nullptr}; | |||||
| }; | |||||
| class StreamMergeNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| class MemcpyAsyncNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| class PassThroughNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| class LabelSetNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| class LabelGotoNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| class LabelSwitchNodeTask : public RtsNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| #endif // GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_TASK_H_ | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "hybrid/node_executor/rts/rts_task_factory.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| NodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const { | |||||
| auto it = creators_.find(task_type); | |||||
| if (it == creators_.end()) { | |||||
| GELOGW("Cannot find task type %s in inner map.", task_type.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| return it->second(); | |||||
| } | |||||
| void RtsTaskFactory::RegisterCreator(const std::string &task_type, const RtsTaskCreatorFun &creator) { | |||||
| if (creator == nullptr) { | |||||
| GELOGW("Register %s creator is null", task_type.c_str()); | |||||
| return; | |||||
| } | |||||
| auto it = creators_.find(task_type); | |||||
| if (it != creators_.end()) { | |||||
| GELOGW("Task %s creator already exist", task_type.c_str()); | |||||
| return; | |||||
| } | |||||
| creators_[task_type] = creator; | |||||
| } | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | |||||
| #define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | |||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| using RtsTaskCreatorFun = std::function<NodeTaskPtr()>; | |||||
| class RtsTaskFactory { | |||||
| public: | |||||
| static RtsTaskFactory &GetInstance() { | |||||
| static RtsTaskFactory instance; | |||||
| return instance; | |||||
| } | |||||
| NodeTaskPtr Create(const std::string &task_type) const; | |||||
| class RtsTaskRegistrar { | |||||
| public: | |||||
| RtsTaskRegistrar(const std::string &task_type, const RtsTaskCreatorFun &creator) { | |||||
| RtsTaskFactory::GetInstance().RegisterCreator(task_type, creator); | |||||
| } | |||||
| ~RtsTaskRegistrar() = default; | |||||
| }; | |||||
| private: | |||||
| RtsTaskFactory() = default; | |||||
| ~RtsTaskFactory() = default; | |||||
| /** | |||||
| * Register build of executor | |||||
| * @param executor_type type of executor | |||||
| * @param builder build function | |||||
| */ | |||||
| void RegisterCreator(const std::string &task_type, const RtsTaskCreatorFun &creator); | |||||
| std::map<std::string, RtsTaskCreatorFun> creators_; | |||||
| }; | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| #define REGISTER_RTS_TASK_CREATOR(task_type, task_clazz) \ | |||||
| REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | |||||
| #define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | |||||
| RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> NodeTaskPtr { return MakeShared<clazz>(); }) | |||||
| #endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | |||||
| @@ -418,13 +418,14 @@ Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr | |||||
| return MEMALLOC_FAILED; | return MEMALLOC_FAILED; | ||||
| } | } | ||||
| GELOGD("Allocating workspace of size = %zu successfully", size); | |||||
| GELOGD("[%s] Allocating workspace of size = %zu successfully", node_item_->NodeName().c_str(), size); | |||||
| workspaces_.emplace_back(*buffer); | workspaces_.emplace_back(*buffer); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TaskContext::PropagateOutputs() { | Status TaskContext::PropagateOutputs() { | ||||
| // propagate outputs | // propagate outputs | ||||
| const auto &guard = node_item_->MutexGuard("PropagateOutputs"); | |||||
| for (int i = 0; i < NumOutputs(); ++i) { | for (int i = 0; i < NumOutputs(); ++i) { | ||||
| auto tensor = MutableOutput(i); | auto tensor = MutableOutput(i); | ||||
| GE_CHECK_NOTNULL(tensor); | GE_CHECK_NOTNULL(tensor); | ||||
| @@ -561,7 +562,7 @@ const DumpProperties &TaskContext::GetDumpProperties() const { | |||||
| } | } | ||||
| bool TaskContext::NeedCallback() { | bool TaskContext::NeedCallback() { | ||||
| return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0 || | |||||
| return node_item_->has_observer || IsDumpEnabled() || GraphExecutionContext::profiling_level > 0 || | |||||
| !execution_context_->model->IsSingleOp(); | !execution_context_->model->IsSingleOp(); | ||||
| } | } | ||||
| @@ -54,6 +54,10 @@ GE_FUNC_VISIBILITY extern const uint32_t SWITCH_TRUE_OUTPUT; | |||||
| GE_FUNC_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | GE_FUNC_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | ||||
| GE_FUNC_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | GE_FUNC_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | ||||
| // Merge | |||||
| GE_FUNC_VISIBILITY extern const uint32_t MERGE_DATA_OUTPUT; | |||||
| GE_FUNC_VISIBILITY extern const uint32_t MERGE_INDEX_OUTPUT; | |||||
| // FunctionOp | // FunctionOp | ||||
| GE_FUNC_VISIBILITY extern const uint32_t IF_COND_INPUT; | GE_FUNC_VISIBILITY extern const uint32_t IF_COND_INPUT; | ||||
| GE_FUNC_VISIBILITY extern const uint32_t FOR_START_INPUT; | GE_FUNC_VISIBILITY extern const uint32_t FOR_START_INPUT; | ||||
| @@ -129,7 +133,7 @@ class GE_FUNC_VISIBILITY OpUtils { | |||||
| /// @param [out] output Data pointer after conversion. The format is HWCK | /// @param [out] output Data pointer after conversion. The format is HWCK | ||||
| /// | /// | ||||
| static void TransDataKCHW2HWCK(const void *input, int64_t K, int64_t C, int64_t H, int64_t W, void *output); | static void TransDataKCHW2HWCK(const void *input, int64_t K, int64_t C, int64_t H, int64_t W, void *output); | ||||
| static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node); | static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node); | ||||
| static vector<ConstGeTensorPtr> GetWeights(ge::ConstNodePtr node); | static vector<ConstGeTensorPtr> GetWeights(ge::ConstNodePtr node); | ||||
| static vector<GeTensorPtr> MutableWeights(const ge::Node &node); | static vector<GeTensorPtr> MutableWeights(const ge::Node &node); | ||||
| @@ -48,6 +48,14 @@ int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| std::string ErrorManager::GetErrorMessage() { | |||||
| return std::string(); | |||||
| } | |||||
| std::string ErrorManager::GetWarningMessage() { | |||||
| return std::string(); | |||||
| } | |||||
| int ErrorManager::ReportInterErrMessage(std::string error_code, const std::string &error_msg) { | int ErrorManager::ReportInterErrMessage(std::string error_code, const std::string &error_msg) { | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -99,7 +107,7 @@ int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) { | |||||
| const std::string &ErrorManager::GetLogHeader() { return error_context_.log_header; } | const std::string &ErrorManager::GetLogHeader() { return error_context_.log_header; } | ||||
| struct error_message::Context &ErrorManager::GetErrorManagerContext() { | struct error_message::Context &ErrorManager::GetErrorManagerContext() { | ||||
| struct error_message::Context error_context; | |||||
| static struct error_message::Context error_context; | |||||
| return error_context; | return error_context; | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ | |||||
| #cmake_minimum_required(VERSION 2.8) | #cmake_minimum_required(VERSION 2.8) | ||||
| project(STUB_MMPA) | |||||
| project(runtime_stub) | |||||
| file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | ||||
| "src/runtime_stub.cc" | "src/runtime_stub.cc" | ||||
| @@ -26,7 +26,13 @@ include_directories(${GE_CODE_DIR}/inc/framework) | |||||
| add_library(runtime_stub SHARED ${SRCS}) | add_library(runtime_stub SHARED ${SRCS}) | ||||
| target_compile_options(runtime_stub PRIVATE | |||||
| -g | |||||
| ) | |||||
| target_link_libraries(runtime_stub PRIVATE | target_link_libraries(runtime_stub PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| c_sec | c_sec | ||||
| ) | ) | ||||
| target_include_directories(runtime_stub INTERFACE ${CMAKE_CURRENT_LIST_DIR}/src) | |||||
| @@ -17,6 +17,9 @@ | |||||
| #include <cce/dnn.h> | #include <cce/dnn.h> | ||||
| #include <securec.h> | #include <securec.h> | ||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| #define EVENT_LENTH 10 | #define EVENT_LENTH 10 | ||||
| rtError_t rtCtxSetCurrent(rtContext_t ctx) { return RT_ERROR_NONE; } | rtError_t rtCtxSetCurrent(rtContext_t ctx) { return RT_ERROR_NONE; } | ||||
| @@ -96,15 +99,16 @@ rtError_t rtSetDevice(int32_t device) { return RT_ERROR_NONE; } | |||||
| rtError_t rtStreamSynchronize(rtStream_t stream) { return RT_ERROR_NONE; } | rtError_t rtStreamSynchronize(rtStream_t stream) { return RT_ERROR_NONE; } | ||||
| rtError_t rtMemcpy(void *dst, uint64_t dest_max, const void *src, uint64_t count, rtMemcpyKind_t kind) { | rtError_t rtMemcpy(void *dst, uint64_t dest_max, const void *src, uint64_t count, rtMemcpyKind_t kind) { | ||||
| #ifdef OTQT_UT | |||||
| if (dest_max == 12 && count == 12) { // UTEST_kernelinfo_manager.all_success special treatment | |||||
| if (dst != nullptr && src != nullptr) { | |||||
| memcpy_s(dst, dest_max, src, count); | memcpy_s(dst, dest_max, src, count); | ||||
| } | } | ||||
| #endif | |||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtMemcpyAsync(void *dst, uint64_t dest_max, const void *src, uint64_t count, rtMemcpyKind_t kind, | rtError_t rtMemcpyAsync(void *dst, uint64_t dest_max, const void *src, uint64_t count, rtMemcpyKind_t kind, | ||||
| rtStream_t stream) { | rtStream_t stream) { | ||||
| if (dst != nullptr && src != nullptr) { | |||||
| memcpy_s(dst, dest_max, src, count); | |||||
| } | |||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| @@ -125,9 +129,6 @@ rtError_t rtEventElapsedTime(float *time, rtEvent_t start, rtEvent_t end) { | |||||
| *time = 10.0f; | *time = 10.0f; | ||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtFunctionRegister(void *bin_handle, const void *stub_func, const char *stub_name, const void *dev_func) { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| rtError_t rtFunctionRegister(void *bin_handle, const void *stub_func, const char *stub_name, const void *dev_func, | rtError_t rtFunctionRegister(void *bin_handle, const void *stub_func, const char *stub_name, const void *dev_func, | ||||
| uint32_t func_mode) { | uint32_t func_mode) { | ||||
| @@ -156,7 +157,7 @@ rtError_t rtConfigureCall(uint32_t num_blocks, rtSmDesc_t *sm_desc, rtStream_t s | |||||
| rtError_t rtSetProfDir(char *prof_dir) { return RT_ERROR_NONE; } | rtError_t rtSetProfDir(char *prof_dir) { return RT_ERROR_NONE; } | ||||
| rtError_t rtSetProfDirEx(char *prof_dir, char *address, char *job_ctx) { return RT_ERROR_NONE; } | |||||
| rtError_t rtSetProfDirEx(const char *profDir, const char *address, const char *jobCtx) { return RT_ERROR_NONE; } | |||||
| rtError_t rtAiCoreMemorySizes(rtAiCoreMemorySize_t *aicore_memory_size) { return RT_ERROR_NONE; } | rtError_t rtAiCoreMemorySizes(rtAiCoreMemorySize_t *aicore_memory_size) { return RT_ERROR_NONE; } | ||||
| @@ -218,9 +219,8 @@ rtError_t rtGetFunctionByName(const char *stub_name, void **stub_func) { | |||||
| *(char **)stub_func = "func"; | *(char **)stub_func = "func"; | ||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtGetAddrByFun(const void *stubFunc, void **addr) | |||||
| { | |||||
| *(char**)addr = "dev_func"; | |||||
| rtError_t rtGetAddrByFun(const void *stubFunc, void **addr) { | |||||
| *(char **)addr = "dev_func"; | |||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtQueryFunctionRegistered(const char *stub_name) { return RT_ERROR_NONE; } | rtError_t rtQueryFunctionRegistered(const char *stub_name) { return RT_ERROR_NONE; } | ||||
| @@ -244,7 +244,9 @@ rtError_t rtEndGraphEx(rtModel_t model, rtStream_t stream, uint32_t flags) | |||||
| { | { | ||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtProfilerStop(void) { return RT_ERROR_NONE; } | |||||
| rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| rtError_t rtSetDvfsProfile(DvfsProfileMode mode) { return RT_ERROR_NONE; } | rtError_t rtSetDvfsProfile(DvfsProfileMode mode) { return RT_ERROR_NONE; } | ||||
| @@ -256,7 +258,9 @@ rtError_t rtCtxDestroy(rtContext_t ctx) { return RT_ERROR_NONE; } | |||||
| rtError_t rtProfilerInit(const char *prof_dir, const char *address, const char *job_ctx) { return RT_ERROR_NONE; } | rtError_t rtProfilerInit(const char *prof_dir, const char *address, const char *job_ctx) { return RT_ERROR_NONE; } | ||||
| rtError_t rtProfilerStart(void) { return RT_ERROR_NONE; } | |||||
| rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| rtError_t rtLabelCreate(rtLabel_t *label) { | rtError_t rtLabelCreate(rtLabel_t *label) { | ||||
| *label = new uint64_t; | *label = new uint64_t; | ||||
| @@ -305,7 +309,9 @@ rtError_t rtLabelGotoEx(rtLabel_t label, rtStream_t stream) { | |||||
| } | } | ||||
| rtError_t rtInvalidCache(uint64_t base, uint32_t len) { return RT_ERROR_NONE; } | |||||
| rtError_t rtInvalidCache(void *base, size_t len) { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| rtError_t rtModelLoadComplete(rtModel_t model) { return RT_ERROR_NONE; } | rtError_t rtModelLoadComplete(rtModel_t model) { return RT_ERROR_NONE; } | ||||
| @@ -314,7 +320,9 @@ rtError_t rtStreamCreateWithFlags(rtStream_t *stream, int32_t priority, uint32_t | |||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtFlushCache(uint64_t base, uint32_t len) { return RT_ERROR_NONE; } | |||||
| rtError_t rtFlushCache(void *base, size_t len) { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream_) { return RT_ERROR_NONE; } | rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream_) { return RT_ERROR_NONE; } | ||||
| @@ -445,4 +453,7 @@ rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void | |||||
| rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { | rtError_t rtDebugUnRegisterForStream(rtStream_t stream) { | ||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | |||||
| } | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "toolchain/slog.h" | #include "toolchain/slog.h" | ||||
| #include "toolchain/plog.h" | |||||
| #include <stdarg.h> | #include <stdarg.h> | ||||
| #include <stdio.h> | #include <stdio.h> | ||||
| @@ -46,3 +47,22 @@ int CheckLogLevel(int moduleId, int logLevel) | |||||
| { | { | ||||
| return 1; | return 1; | ||||
| } | } | ||||
| /** | |||||
| * @ingroup plog | |||||
| * @brief DlogReportInitialize: init log in service process before all device setting. | |||||
| * @return: 0: SUCCEED, others: FAILED | |||||
| */ | |||||
| int DlogReportInitialize() { | |||||
| return 0; | |||||
| } | |||||
| /** | |||||
| * @ingroup plog | |||||
| * @brief DlogReportFinalize: release log resource in service process after all device reset. | |||||
| * @return: 0: SUCCEED, others: FAILED | |||||
| */ | |||||
| int DlogReportFinalize() { | |||||
| return 0; | |||||
| } | |||||
| @@ -166,7 +166,7 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/common/dump/dump_properties.cc" | "${GE_CODE_DIR}/ge/common/dump/dump_properties.cc" | ||||
| "${GE_CODE_DIR}/ge/common/helper/model_helper.cc" | "${GE_CODE_DIR}/ge/common/helper/model_helper.cc" | ||||
| "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" | "${GE_CODE_DIR}/ge/common/dump/dump_manager.cc" | ||||
| "${GE_CODE_DIR}/ge/common/dump/exception_dumper.cc" | |||||
| "${GE_CODE_DIR}/ge/common/dump/exception_dumper.cc" | |||||
| "${GE_CODE_DIR}/ge/common/dump/opdebug_register.cc" | "${GE_CODE_DIR}/ge/common/dump/opdebug_register.cc" | ||||
| "${GE_CODE_DIR}/ge/common/dump/dump_op.cc" | "${GE_CODE_DIR}/ge/common/dump/dump_op.cc" | ||||
| "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" | "${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc" | ||||
| @@ -512,8 +512,8 @@ set(GRAPH_PASS_COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc" | ||||
| @@ -621,6 +621,8 @@ set(SINGLE_OP_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | ||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/hccl/hccl_node_executor.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/hccl/hccl_node_executor.cc" | ||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/rts/rts_node_executor.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/rts/rts_node_executor.cc" | ||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/rts/rts_node_task.cc" | |||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/rts/rts_task_factory.cc" | |||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/node_executor.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/node_executor.cc" | ||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/task_context.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/task_context.cc" | ||||
| "${GE_CODE_DIR}/ge/hybrid/hybrid_davinci_model.cc" | "${GE_CODE_DIR}/ge/hybrid/hybrid_davinci_model.cc" | ||||
| @@ -707,8 +709,8 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/transpose_transdata_pass_unittest.cc" | "graph/passes/transpose_transdata_pass_unittest.cc" | ||||
| "graph/passes/parallel_group_pass_unittest.cc" | "graph/passes/parallel_group_pass_unittest.cc" | ||||
| "graph/passes/buffer_pool_memory_pass_unittest.cc" | "graph/passes/buffer_pool_memory_pass_unittest.cc" | ||||
| "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | |||||
| "graph/passes/reshape_recovery_pass_unittest.cc" | |||||
| "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | |||||
| "graph/passes/reshape_recovery_pass_unittest.cc" | |||||
| "graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
| ) | ) | ||||
| @@ -751,12 +753,12 @@ set(KERNEL_TEST_FILES | |||||
| set(MULTI_PARTS_TEST_FILES | set(MULTI_PARTS_TEST_FILES | ||||
| "graph_ir/ge_operator_factory_unittest.cc" | "graph_ir/ge_operator_factory_unittest.cc" | ||||
| "graph_ir/ge_ir_build_unittest.cc" | |||||
| "graph_ir/ge_ir_build_unittest.cc" | |||||
| "graph/transop_util_unittest.cc" | "graph/transop_util_unittest.cc" | ||||
| "common/datatype_transfer_unittest.cc" | "common/datatype_transfer_unittest.cc" | ||||
| "common/dump_manager_unittest.cc" | "common/dump_manager_unittest.cc" | ||||
| "common/dump_op_unittest.cc" | "common/dump_op_unittest.cc" | ||||
| "common/dump_exception_unittest.cc" | |||||
| "common/dump_exception_unittest.cc" | |||||
| "common/opdebug_register_unittest.cc" | "common/opdebug_register_unittest.cc" | ||||
| "common/format_transfer_unittest.cc" | "common/format_transfer_unittest.cc" | ||||
| "common/format_transfer_transpose_unittest.cc" | "common/format_transfer_transpose_unittest.cc" | ||||
| @@ -775,7 +777,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "common/format_transfer_fracz_nhwc_unittest.cc" | "common/format_transfer_fracz_nhwc_unittest.cc" | ||||
| "common/format_transfer_fracz_hwcn_unittest.cc" | "common/format_transfer_fracz_hwcn_unittest.cc" | ||||
| "common/ge_format_util_unittest.cc" | "common/ge_format_util_unittest.cc" | ||||
| "common/ge_auth_file_saver_unittest.cc" | |||||
| "common/ge_auth_file_saver_unittest.cc" | |||||
| "graph/variable_accelerate_ctrl_unittest.cc" | "graph/variable_accelerate_ctrl_unittest.cc" | ||||
| "graph/build/logical_stream_allocator_unittest.cc" | "graph/build/logical_stream_allocator_unittest.cc" | ||||
| "graph/build/model_builder_unittest.cc" | "graph/build/model_builder_unittest.cc" | ||||
| @@ -804,7 +806,7 @@ set(SINGLE_OP_TEST_FILES | |||||
| "single_op/single_op_manager_unittest.cc" | "single_op/single_op_manager_unittest.cc" | ||||
| "single_op/stream_resource_unittest.cc" | "single_op/stream_resource_unittest.cc" | ||||
| "single_op/single_op_task_unittest.cc" | "single_op/single_op_task_unittest.cc" | ||||
| "single_op/single_op_unittest.cc" | |||||
| "single_op/single_op_unittest.cc" | |||||
| ) | ) | ||||
| set(PROFILING_MNG_TEST_FILES | set(PROFILING_MNG_TEST_FILES | ||||
| @@ -814,7 +816,9 @@ set(PROFILING_MNG_TEST_FILES | |||||
| set(HYBRID_TEST_FILES | set(HYBRID_TEST_FILES | ||||
| "hybrid/ge_hybrid_unittest.cc" | "hybrid/ge_hybrid_unittest.cc" | ||||
| "hybrid/known_node_executor_unittest.cc" | "hybrid/known_node_executor_unittest.cc" | ||||
| "hybrid/executor/worker/execution_engine_unittest.cc" | |||||
| "hybrid/executor/worker/execution_engine_unittest.cc" | |||||
| "hybrid/model/hybrid_model_builder_unittest.cc" | |||||
| "hybrid/node_executor/rts/rts_node_task_unittest.cc" | |||||
| ) | ) | ||||
| set(OTHERS_TEST_FILES | set(OTHERS_TEST_FILES | ||||
| @@ -333,8 +333,8 @@ TEST_F(UtestDavinciModel, init_unknown) { | |||||
| TEST_F(UtestDavinciModel, Init_variable_op) { | TEST_F(UtestDavinciModel, Init_variable_op) { | ||||
| DavinciModel model(0, g_local_call_back); | DavinciModel model(0, g_local_call_back); | ||||
| model.ge_model_ = make_shared<GeModel>(); | model.ge_model_ = make_shared<GeModel>(); | ||||
| model.runtime_param_.mem_base = (uint8_t *)0x08000000; | |||||
| model.runtime_param_.mem_size = 5120000; | |||||
| model.runtime_param_.mem_size = 51200; | |||||
| model.runtime_param_.mem_base = (uint8_t *)malloc(model.runtime_param_.mem_size); | |||||
| ComputeGraphPtr graph = make_shared<ComputeGraph>("default"); | ComputeGraphPtr graph = make_shared<ComputeGraph>("default"); | ||||
| GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | ||||
| @@ -365,6 +365,8 @@ TEST_F(UtestDavinciModel, Init_variable_op) { | |||||
| EXPECT_EQ(model.CopyOutputData(1, output_data, RT_MEMCPY_DEVICE_TO_HOST), SUCCESS); | EXPECT_EQ(model.CopyOutputData(1, output_data, RT_MEMCPY_DEVICE_TO_HOST), SUCCESS); | ||||
| EXPECT_EQ(model.ReturnResult(1, false, true, &output_data), INTERNAL_ERROR); | EXPECT_EQ(model.ReturnResult(1, false, true, &output_data), INTERNAL_ERROR); | ||||
| free(model.runtime_param_.mem_base); | |||||
| model.runtime_param_.mem_base = nullptr; | |||||
| } | } | ||||
| TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ1) { | TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ1) { | ||||
| @@ -20,9 +20,8 @@ | |||||
| #define private public | #define private public | ||||
| #include "graph/passes/infershape_pass.h" | #include "graph/passes/infershape_pass.h" | ||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph/operator.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/operator_factory.h" | #include "graph/operator_factory.h" | ||||
| #include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
| #include "graph_builder_utils.h" | #include "graph_builder_utils.h" | ||||
| @@ -36,6 +35,40 @@ class UtestGraphInfershapePass : public testing::Test { | |||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
| TensorUtils::SetSize(tensor, 512); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(1024); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(1024); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| op_desc->SetWorkspace({}); | |||||
| op_desc->SetWorkspaceBytes({}); | |||||
| op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
| const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; | |||||
| op_desc->AddInferFunc(stub_func); | |||||
| op_desc->AddInferFormatFunc(stub_func); | |||||
| op_desc->AddVerifierFunc(stub_func); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | TEST_F(UtestGraphInfershapePass, infershape_pass_failed) { | ||||
| GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16); | ||||
| string type = "AddN"; | string type = "AddN"; | ||||
| @@ -62,4 +95,67 @@ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) { | |||||
| EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); | EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) { | |||||
| /******************************************************************************* | |||||
| * Exit Identify | |||||
| * \ / \. | |||||
| * \ / \. | |||||
| * Switch Add | |||||
| * / | | | |||||
| * / | | | |||||
| * / | | | |||||
| * LoopCond | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Less | | | |||||
| * \ | NextIteration | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Merge <---------| | |||||
| * | | |||||
| * | | |||||
| * Enter | |||||
| ******************************************************************************/ | |||||
| auto graph = std::make_shared<ComputeGraph>("test_infer_shape"); | |||||
| auto data1 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
| auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||||
| auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2); | |||||
| auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
| auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | |||||
| auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2); | |||||
| auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1); | |||||
| auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
| auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | |||||
| auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | |||||
| auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
| GEPass ge_passes(graph); | |||||
| NamesToPass names_to_passes; | |||||
| InferShapePass infer_shape_pass; | |||||
| names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); | |||||
| EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -114,9 +114,9 @@ void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, | |||||
| /// Normal graph | /// Normal graph | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -188,10 +188,10 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() { | |||||
| /// Normal graph with multi buffer pool | /// Normal graph with multi buffer pool | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// (pool0) (pool1) (pool0) (pool0) (pool1) | /// (pool0) (pool1) (pool0) (pool0) (pool1) | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -265,9 +265,9 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() { | |||||
| /// SerialGraph: Buffer pool size only can contain one prefetch node | /// SerialGraph: Buffer pool size only can contain one prefetch node | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -345,7 +345,7 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() { | |||||
| /// GraphWithMultiPrefetch: Calc node with more prefetch node | /// GraphWithMultiPrefetch: Calc node with more prefetch node | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 | ||||
| /// \ / \ / \ / | /// \ / \ / \ / | ||||
| /// \ / \ / \ / | /// \ / \ / \ / | ||||
| @@ -426,9 +426,9 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() { | |||||
| /// Subgraph1: Subgraph2: | /// Subgraph1: Subgraph2: | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out | /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -540,9 +540,9 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() { | |||||
| /// Subgraph1: Subgraph2: | /// Subgraph1: Subgraph2: | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out | /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -651,10 +651,10 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() { | |||||
| /// batch_label_128 | /// batch_label_128 | ||||
| /// | /// | ||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | ||||
| /// / / / / / / \ | |||||
| /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \ | |||||
| /// const1 switch_false / / / / / \ | |||||
| /// \ / / / / / / \ | |||||
| /// / / / / / / \. | |||||
| /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \. | |||||
| /// const1 switch_false / / / / / \. | |||||
| /// \ / / / / / / \. | |||||
| /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output | /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output | ||||
| /// / \ \ \ \ \ \ / | /// / \ \ \ \ \ \ / | ||||
| /// const2 switch_true \ \ \ \ \ / | /// const2 switch_true \ \ \ \ \ / | ||||
| @@ -809,7 +809,7 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() { | |||||
| /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| @@ -892,7 +892,7 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() { | |||||
| /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ / \ / \ / \ / \ | |||||
| /// \ / \ / \ / \ / \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| @@ -54,9 +54,9 @@ class BufferPoolGraphBuilder { | |||||
| /// Normal graph | /// Normal graph | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -72,10 +72,10 @@ class BufferPoolGraphBuilder { | |||||
| /// Normal graph with multi buffer pool | /// Normal graph with multi buffer pool | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// (pool0) (pool1) (pool0) (pool0) (pool1) | /// (pool0) (pool1) (pool0) (pool0) (pool1) | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -92,9 +92,9 @@ class BufferPoolGraphBuilder { | |||||
| /// SerialGraph: Buffer pool size only can contain one prefetch node | /// SerialGraph: Buffer pool size only can contain one prefetch node | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -116,7 +116,7 @@ class BufferPoolGraphBuilder { | |||||
| /// GraphWithMultiPrefetch: Calc node with more prefetch node | /// GraphWithMultiPrefetch: Calc node with more prefetch node | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 | ||||
| /// \ / \ / \ / | /// \ / \ / \ / | ||||
| /// \ / \ / \ / | /// \ / \ / \ / | ||||
| @@ -144,9 +144,9 @@ class BufferPoolGraphBuilder { | |||||
| /// Subgraph1: Subgraph2: | /// Subgraph1: Subgraph2: | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out | /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -168,9 +168,9 @@ class BufferPoolGraphBuilder { | |||||
| /// Subgraph1: Subgraph2: | /// Subgraph1: Subgraph2: | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out | /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out | ||||
| /// | /// | ||||
| /// | /// | ||||
| @@ -189,10 +189,10 @@ class BufferPoolGraphBuilder { | |||||
| /// batch_label_128 | /// batch_label_128 | ||||
| /// | /// | ||||
| /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- | ||||
| /// / / / / / / \ | |||||
| /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \ | |||||
| /// const1 switch_false / / / / / \ | |||||
| /// \ / / / / / / \ | |||||
| /// / / / / / / \. | |||||
| /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \. | |||||
| /// const1 switch_false / / / / / \. | |||||
| /// \ / / / / / / \. | |||||
| /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output | /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output | ||||
| /// / \ \ \ \ \ \ / | /// / \ \ \ \ \ \ / | ||||
| /// const2 switch_true \ \ \ \ \ / | /// const2 switch_true \ \ \ \ \ / | ||||
| @@ -215,7 +215,7 @@ class BufferPoolGraphBuilder { | |||||
| /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ \ \ \ \ | |||||
| /// \ \ \ \ \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| @@ -238,7 +238,7 @@ class BufferPoolGraphBuilder { | |||||
| /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | /// GraphWithMultiOutputPrefetch: Prefetch has more than one output | ||||
| /// | /// | ||||
| /// w1 w2 w3 w4 w5 | /// w1 w2 w3 w4 w5 | ||||
| /// \ / \ / \ / \ / \ | |||||
| /// \ / \ / \ / \ / \. | |||||
| /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| /// / \ / \ / \ / \ / | /// / \ / \ / \ / \ / | ||||
| @@ -288,7 +288,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { | |||||
| HybridModel *model_ptr = &model; | HybridModel *model_ptr = &model; | ||||
| uint32_t device_id = 0; | uint32_t device_id = 0; | ||||
| rtStream_t stream; | |||||
| rtStream_t stream = nullptr; | |||||
| HybridModelExecutor executor(model_ptr, device_id, stream); | HybridModelExecutor executor(model_ptr, device_id, stream); | ||||
| executor.Init(); | executor.Init(); | ||||
| } | } | ||||
| @@ -644,17 +644,28 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { | |||||
| std::unique_ptr<NodeItem> node_item_1; | std::unique_ptr<NodeItem> node_item_1; | ||||
| NodeItem::Create(node_1, node_item_1); | NodeItem::Create(node_1, node_item_1); | ||||
| node_item_1->node_id = 1; | node_item_1->node_id = 1; | ||||
| node->GetOutControlAnchor()->LinkTo(node_1->GetInControlAnchor()); | node->GetOutControlAnchor()->LinkTo(node_1->GetInControlAnchor()); | ||||
| OpDescPtr op_desc_2 = CreateOpDesc("net_output", NETOUTPUT); | |||||
| auto node_2 = compute_graph->AddNode(op_desc_2); | |||||
| std::unique_ptr<NodeItem> node_item_2; | |||||
| NodeItem::Create(node_2, node_item_2); | |||||
| node_item_2->node_id = 2; | |||||
| node_1->GetOutControlAnchor()->LinkTo(node_2->GetInControlAnchor()); | |||||
| GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | ||||
| HybridModel model(root_model); | HybridModel model(root_model); | ||||
| model.root_graph_ = compute_graph; | model.root_graph_ = compute_graph; | ||||
| model.node_items_.emplace(node, std::move(node_item)); | model.node_items_.emplace(node, std::move(node_item)); | ||||
| model.node_items_.emplace(node_1, std::move(node_item_1)); | |||||
| model.node_items_.emplace(node_2, std::move(node_item_2)); | |||||
| HybridModelBuilder builder(model); | HybridModelBuilder builder(model); | ||||
| std::vector<std::string> deps; | std::vector<std::string> deps; | ||||
| ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); | |||||
| ASSERT_TRUE(model.GetNodeItem(node)->has_observer); | |||||
| ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | |||||
| ASSERT_EQ(builder.ParseDependentInputNodes(*model.node_items_[node_1], deps), SUCCESS); | |||||
| ASSERT_EQ(builder.ParseDependentInputNodes(*model.node_items_[node_2], deps), SUCCESS); | |||||
| ASSERT_FALSE(model.GetNodeItem(node)->has_observer); | |||||
| ASSERT_TRUE(model.GetNodeItem(node_1)->has_observer); | |||||
| ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 0); | |||||
| ASSERT_EQ(model.node_items_[node_2]->dependents_for_execution.size(), 1); | |||||
| } | } | ||||
| @@ -0,0 +1,233 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <gmock/gmock.h> | |||||
| #include <vector> | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "hybrid/model/hybrid_model_builder.h" | |||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| using namespace hybrid; | |||||
| class UtestHybridModelBuilder : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() { } | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
| TensorUtils::SetSize(tensor, 512); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(1024); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(1024); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| op_desc->SetWorkspace({}); | |||||
| op_desc->SetWorkspaceBytes({}); | |||||
| op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
| /******************************************************************************* | |||||
| * Exit Identify | |||||
| * \ / \. | |||||
| * \ / \. | |||||
| * Switch Add | |||||
| * / | | | |||||
| * / | | | |||||
| * / | | | |||||
| * LoopCond | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Less | | | |||||
| * \ | NextIteration | |||||
| * \ | | | |||||
| * \ | | | |||||
| * Merge <---------| | |||||
| * | | |||||
| * | | |||||
| * Enter | |||||
| ******************************************************************************/ | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| GeModelPtr ge_sub_model = make_shared<GeModel>(); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1); | |||||
| auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | |||||
| auto less1 = CreateNode(*graph, "less", LESS, 2, 1); | |||||
| less1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine"); | |||||
| auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1); | |||||
| auto switch_t = CreateNode(*graph, "switch_t", STREAMSWITCH, 2, 0); | |||||
| auto switch_f = CreateNode(*graph, "switch_f", STREAMSWITCH, 2, 0); | |||||
| auto ident1 = CreateNode(*graph, "identity", IDENTITY, 2, 1); | |||||
| auto add1 = CreateNode(*graph, "add", ADD, 2, 1); | |||||
| add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine"); | |||||
| auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1); | |||||
| auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1); | |||||
| auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1); | |||||
| auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0); | |||||
| auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0); | |||||
| auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0); | |||||
| auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
| GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value0->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), exit1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), ident1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), ident1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | |||||
| GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
| AttrUtils::SetStr(merge1->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next1->GetName()); | |||||
| AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | |||||
| AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | |||||
| AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | |||||
| AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | |||||
| // Build -> IndexSpecialNodes --> stream_merge_op_nodes_ | |||||
| // Build -> LoadGraph -> RelinkNextIteration | |||||
| // Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend | |||||
| // Build -> LoadGraph -> LoadDynamicSubgraph --> BuildControlFlowGroup --> NodeItem::SetCtrlSend | |||||
| auto &engine_mapping = NodeExecutorManager::GetInstance().engine_mapping_; | |||||
| engine_mapping.emplace("AIcoreEngine", NodeExecutorManager::ExecutorType::AICORE); | |||||
| engine_mapping.emplace("DNN_VM_GE_LOCAL_OP_STORE", NodeExecutorManager::ExecutorType::GE_LOCAL); | |||||
| engine_mapping.emplace("aicpu_tf_kernel", NodeExecutorManager::ExecutorType::AICPU_TF); | |||||
| engine_mapping.emplace("aicpu_ascend_kernel", NodeExecutorManager::ExecutorType::AICPU_TF); | |||||
| engine_mapping.emplace("ops_kernel_info_hccl", NodeExecutorManager::ExecutorType::HCCL); | |||||
| engine_mapping.emplace("DNN_VM_RTS_OP_STORE", NodeExecutorManager::ExecutorType::RTS); | |||||
| engine_mapping.emplace("DNN_VM_HOST_CPU_OP_STORE", NodeExecutorManager::ExecutorType::HOST_CPU); | |||||
| auto &task_executor = NodeExecutorManager::GetInstance().executors_; | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::AICORE, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::GE_LOCAL, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::AICPU_TF, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::HCCL, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||||
| task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor())); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
| ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS); | |||||
| engine_mapping.clear(); | |||||
| task_executor.clear(); | |||||
| } | |||||
| TEST_F(UtestHybridModelBuilder, create_called_invalid) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
| auto node = CreateNode(*graph, "node", PARTITIONEDCALL, 1, 1); | |||||
| NodeItem node_item(node); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateStreamActiveGroup(node, &node_item), INTERNAL_ERROR); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchGroup(node, &node_item), INTERNAL_ERROR); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateNextIterationGroup(node, &node_item), INTERNAL_ERROR); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(node, &node_item), INTERNAL_ERROR); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateSwitchGroup(node, &node_item), INTERNAL_ERROR); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateLabelSetGroup(node, &node_item), INTERNAL_ERROR); | |||||
| node_item.node_type = LABELSET; | |||||
| ASSERT_EQ(hybrid_model_builder.CreateLabelSetGroup(node, &node_item), UNSUPPORTED); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateLabelGotoGroup(node, &node_item), INTERNAL_ERROR); | |||||
| node_item.node_type = LABELGOTO; | |||||
| ASSERT_EQ(hybrid_model_builder.CreateLabelGotoGroup(node, &node_item), UNSUPPORTED); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateLabelSwitchGroup(node, &node_item), INTERNAL_ERROR); | |||||
| node_item.node_type = LABELSWITCH; | |||||
| ASSERT_EQ(hybrid_model_builder.CreateLabelSwitchGroup(node, &node_item), UNSUPPORTED); | |||||
| } | |||||
| TEST_F(UtestHybridModelBuilder, stream_switch_n_group) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
| auto switch_n = CreateNode(*graph, "switch_n", STREAMSWITCHN, 1, 0); | |||||
| NodeItem node_item(switch_n); | |||||
| // no batch_num | |||||
| ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(switch_n, &node_item), INTERNAL_ERROR); | |||||
| uint32_t batch_num = 0; | |||||
| AttrUtils::SetInt(switch_n->GetOpDesc(), ATTR_NAME_BATCH_NUM, batch_num); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(switch_n, &node_item), SUCCESS); | |||||
| batch_num = 3; | |||||
| AttrUtils::SetInt(switch_n->GetOpDesc(), ATTR_NAME_BATCH_NUM, batch_num); | |||||
| ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(switch_n, &node_item), SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,484 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <gmock/gmock.h> | |||||
| #include <vector> | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "hybrid/executor/subgraph_context.h" | |||||
| #include "hybrid/node_executor/rts/rts_node_executor.h" | |||||
| #include "model/ge_root_model.h" | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| using namespace hybrid; | |||||
| class UtestRtsNodeTask : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() { } | |||||
| }; | |||||
| static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| op_desc->SetStreamId(0); | |||||
| static int32_t index = 0; | |||||
| op_desc->SetId(index++); | |||||
| GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
| TensorUtils::SetSize(tensor, 64); | |||||
| vector<int64_t> input_offset; | |||||
| for (int i = 0; i < in_num; i++) { | |||||
| op_desc->AddInputDesc(tensor); | |||||
| input_offset.emplace_back(i * 64); | |||||
| } | |||||
| op_desc->SetInputOffset(input_offset); | |||||
| vector<int64_t> output_offset; | |||||
| for (int i = 0; i < out_num; i++) { | |||||
| op_desc->AddOutputDesc(tensor); | |||||
| output_offset.emplace_back(in_num * 64 + i * 64); | |||||
| } | |||||
| op_desc->SetOutputOffset(output_offset); | |||||
| op_desc->SetWorkspace({}); | |||||
| op_desc->SetWorkspaceBytes({}); | |||||
| op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
| return graph.AddNode(op_desc); | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_stream_switch_task) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "switch", STREAMSWITCH, 2, 0); | |||||
| ASSERT_TRUE(AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 0)); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 2; | |||||
| graph_item.total_outputs_ = 2; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| uint64_t value_0 = 110; | |||||
| uint64_t value_1 = 120; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| TensorValue in_tensor1(&value_1, sizeof(value_1)); | |||||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
| subgraph_context.SetInput(*node_item, 1, in_tensor1); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| ASSERT_EQ(node_state->GetSwitchIndex(), 0); // not equal, active 0 | |||||
| uint64_t value_2 = 110; | |||||
| TensorValue in_tensor2(&value_2, sizeof(value_2)); | |||||
| subgraph_context.SetInput(*node_item, 1, in_tensor2); | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| ASSERT_EQ(node_state->GetSwitchIndex(), 1); // equal, active 1 | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_stream_active_task) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "active", STREAMACTIVE, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| ASSERT_EQ(node_state->GetSwitchIndex(), 0); | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_stream_merge_task) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 2; | |||||
| graph_item.total_outputs_ = 2; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| uint64_t value_0 = 110; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
| uint64_t value_1 = 220; | |||||
| TensorValue in_tensor1(&value_1, sizeof(value_1)); | |||||
| subgraph_context.SetInput(*node_item, 1, in_tensor1); | |||||
| uint64_t value_2 = 123; | |||||
| TensorValue out_tensor0(&value_2, sizeof(value_2)); | |||||
| subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||||
| uint64_t value_3 = 223; | |||||
| TensorValue out_tensor1(&value_3, sizeof(value_3)); | |||||
| subgraph_context.SetOutput(*node_item, 1, out_tensor1); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| node_state->SetMergeIndex(1); | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| ASSERT_EQ(node_state->GetSwitchIndex(), -1); | |||||
| uint64_t value_4 = 323; | |||||
| ASSERT_EQ(node_state->GetTaskContext()->GetOutput(0)->CopyScalarValueToHost(value_4), SUCCESS); | |||||
| ASSERT_EQ(value_4, value_1); | |||||
| uint64_t value_5 = 423; | |||||
| ASSERT_EQ(node_state->GetTaskContext()->GetOutput(1)->CopyScalarValueToHost(value_5), SUCCESS); | |||||
| ASSERT_EQ(value_5, 1); | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_memcpy_async_task) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "memcpy", MEMCPYASYNC, 1, 1); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 1; | |||||
| graph_item.total_outputs_ = 1; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| uint64_t value_0 = 110; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
| uint64_t value_1 = 123; | |||||
| TensorValue out_tensor0(&value_1, sizeof(value_1)); | |||||
| subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| uint64_t value_4 = 323; | |||||
| ASSERT_EQ(node_state->GetTaskContext()->GetOutput(0)->CopyScalarValueToHost(value_4), SUCCESS); | |||||
| ASSERT_EQ(value_4, value_0); | |||||
| ASSERT_EQ(value_1, value_0); | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_pass_through_task) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "enter", ENTER, 1, 1); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 0; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 1; | |||||
| graph_item.total_outputs_ = 1; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| uint64_t value_0 = 110; | |||||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||||
| uint64_t value_1 = 123; | |||||
| TensorValue out_tensor0(&value_1, sizeof(value_1)); | |||||
| subgraph_context.SetOutput(*node_item, 0, out_tensor0); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS); | |||||
| uint64_t value_4 = 323; | |||||
| ASSERT_EQ(node_state->GetTaskContext()->GetOutput(0)->CopyScalarValueToHost(value_4), SUCCESS); | |||||
| ASSERT_EQ(value_4, value_0); | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_unsupport_label_set) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "labelset", LABELSET, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 2; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 2; | |||||
| graph_item.total_outputs_ = 2; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), UNSUPPORTED); | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "labelgoto", LABELGOTO, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 2; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 2; | |||||
| graph_item.total_outputs_ = 2; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), UNSUPPORTED); | |||||
| } | |||||
| TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| GeModelPtr ge_sub_model = std::make_shared<GeModel>(); | |||||
| GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph); | |||||
| ge_root_model->SetModelName("test_name"); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| NodePtr node = CreateNode(*graph, "labelswitch", LABELSWITCH, 0, 0); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS); | |||||
| NodeItem *node_item = new_node.get(); | |||||
| hybrid_model.node_items_[node] = std::move(new_node); | |||||
| node_item->input_start = 0; | |||||
| node_item->output_start = 2; | |||||
| GraphItem graph_item; | |||||
| graph_item.node_items_.emplace_back(node_item); | |||||
| graph_item.total_inputs_ = 2; | |||||
| graph_item.total_outputs_ = 2; | |||||
| GraphExecutionContext graph_context; | |||||
| SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||||
| graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager()); | |||||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||||
| ASSERT_NE(node_state, nullptr); | |||||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||||
| ASSERT_NE(unique_task_context, nullptr); | |||||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||||
| node_state->SetTaskContext(shared_task_context); | |||||
| NodeTaskPtr task = nullptr; | |||||
| RtsNodeExecutor node_executor; | |||||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||||
| ASSERT_NE(task, nullptr); | |||||
| std::function<void()> done = []() {}; | |||||
| ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), UNSUPPORTED); | |||||
| } | |||||
| } // namespace ge | |||||