| @@ -387,6 +387,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { | |||||
| if (!in_cluster->IsUnknownShape()) { | if (!in_cluster->IsUnknownShape()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (!cluster->IsAdjoinNodes(in_cluster)) { | |||||
| continue; | |||||
| } | |||||
| auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | ||||
| GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | ||||
| ToString(merged_clusters).c_str()); | ToString(merged_clusters).c_str()); | ||||
| @@ -80,6 +80,10 @@ class DynamicShapePartitioner { | |||||
| Status BuildPartitionSubgraph(); | Status BuildPartitionSubgraph(); | ||||
| // Clear resource and break circular dependency | // Clear resource and break circular dependency | ||||
| void Clear(); | void Clear(); | ||||
| bool IsAdjoinNodes(const std::shared_ptr<Cluster> &other) const { | |||||
| const auto &out_clusters = other->out_clusters_; | |||||
| return std::find(out_clusters.begin(), out_clusters.end(), shared_from_this()) != out_clusters.end(); | |||||
| } | |||||
| private: | private: | ||||
| static thread_local size_t unique_id_; | static thread_local size_t unique_id_; | ||||
| @@ -537,6 +537,7 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
| Status SubgraphExecutor::ScheduleTasks(int group) { | Status SubgraphExecutor::ScheduleTasks(int group) { | ||||
| GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | ||||
| subgraph_context_->SetGroup(group); | |||||
| auto prepare_future = std::async(std::launch::async, [&]() -> Status { | auto prepare_future = std::async(std::launch::async, [&]() -> Status { | ||||
| GetContext().SetSessionId(context_->session_id); | GetContext().SetSessionId(context_->session_id); | ||||
| GetContext().SetContextId(context_->context_id); | GetContext().SetContextId(context_->context_id); | ||||
| @@ -401,6 +401,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
| if (is_root_node_) { | if (is_root_node_) { | ||||
| node_item->root_data_.emplace(this); | node_item->root_data_.emplace(this); | ||||
| } | } | ||||
| // If Enter feed Not Merge, take as root Node. | |||||
| if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||||
| node_item->root_data_.emplace(this); | |||||
| node_item->enter_inside_.emplace(anchor_index); | |||||
| } | |||||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
| } | } | ||||
| @@ -142,6 +142,7 @@ struct NodeItem { | |||||
| std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
| std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
| std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | ||||
| std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge. | |||||
| std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
| std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
| @@ -489,6 +489,11 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
| } | } | ||||
| void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
| if (node_item_->enter_inside_.count(index) > 0) { | |||||
| GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||||
| return; | |||||
| } | |||||
| auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
| if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
| input_tensor->Destroy(); | input_tensor->Destroy(); | ||||