From: @zhangxiaokun9 Reviewed-by: @ji_chen Signed-off-by: @ji_chentags/v1.3.0
| @@ -274,21 +274,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||
| return false; | |||
| } | |||
| /// | |||
| /// @brief Set Op _force_unknown_shape flag | |||
| /// @param [in] node | |||
| /// @param [in] force_unknown, set attribute if true | |||
| /// @param [in] group_index, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||
| if (!force_unknown) { | |||
| return; | |||
| } | |||
| SetControlFlowGroup(node, group_index); | |||
| } | |||
| /// | |||
| /// @brief Set Op _control_flow_group flag | |||
| /// @param [in] node | |||
| @@ -125,15 +125,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); | |||
| /// | |||
| bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||
| /// | |||
| /// @brief Set Op _force_unknown_shape flag | |||
| /// @param [in] node | |||
| /// @param [in] force_unknown, set attribute if true | |||
| /// @param [in] group_index, condition group index of node. | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||
| /// | |||
| /// @brief Set Op _control_flow_group flag | |||
| /// @param [in] node | |||
| @@ -364,6 +364,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||
| } | |||
| void DynamicShapePartitioner::MergeClustersControlFlow() { | |||
| std::unordered_set<ClusterPtr> all_merged_clusters; | |||
| for (const auto &item : control_clusters_) { | |||
| const auto &control_cluster = item.second; | |||
| auto rit = control_cluster.rbegin(); | |||
| @@ -373,17 +374,32 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||
| } | |||
| const auto &cluster = *rit; | |||
| if (all_merged_clusters.count(cluster) > 0) { | |||
| continue; | |||
| } | |||
| bool is_unknown_cluster = cluster->IsUnknownShape(); | |||
| for (++rit; rit != control_cluster.rend(); ++rit) { | |||
| const auto &cluster_from = *rit; | |||
| if (all_merged_clusters.count(cluster_from) > 0) { | |||
| continue; | |||
| } | |||
| auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | |||
| GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | |||
| ToString(merged_clusters).c_str()); | |||
| for (const auto &merged_cluster : merged_clusters) { | |||
| all_merged_clusters.emplace(merged_cluster); | |||
| for (const auto &node : merged_cluster->Nodes()) { | |||
| node_2_cluster_[node] = cluster; | |||
| } | |||
| } | |||
| } | |||
| if (!is_unknown_cluster && cluster->IsUnknownShape()) { | |||
| GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); | |||
| ordered_cluster_.push_back(cluster); | |||
| } | |||
| } | |||
| } | |||
| @@ -703,7 +719,12 @@ void Cluster::Merge(ClusterPtr other) { | |||
| if (other->min_ < min_) { | |||
| min_ = other->min_; | |||
| } | |||
| }; | |||
| if (!IsUnknownShape() && other->IsUnknownShape()) { | |||
| type_ = UNKNOWN_SHAPE; | |||
| } | |||
| } | |||
| bool Cluster::TryMerge(ClusterPtr other) { | |||
| std::queue<ClusterPtr> forward_reached; | |||
| forward_reached.push(other); | |||
| @@ -161,7 +161,7 @@ class DynamicShapePartitioner { | |||
| ge::ComputeGraphPtr root_graph_; // The original graph to partition | |||
| std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | |||
| // V1 control flow cluster, need merge to one Graph. | |||
| std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||
| std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||
| // topological sorted clusters, this field will change with the splitting. | |||
| // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | |||
| // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | |||
| @@ -132,39 +132,17 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
| /// @return | |||
| /// | |||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||
| }; | |||
| for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||
| const auto &op_node1 = it1->first; | |||
| const auto &op_desc1 = op_node1->GetOpDesc(); | |||
| if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { | |||
| const auto &op_node = it->first; | |||
| const auto &op_desc = op_node->GetOpDesc(); | |||
| if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| continue; | |||
| } | |||
| if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { | |||
| int64_t group_index = op_desc1->GetId(); | |||
| GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||
| MarkForceUnknownShape(op_node1, true, group_index); | |||
| for (const auto &n : it1->second) { | |||
| MarkForceUnknownShape(n, true, group_index); | |||
| } | |||
| for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||
| const auto &op_node2 = it2->first; | |||
| const auto &op_desc2 = op_node2->GetOpDesc(); | |||
| if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
| continue; | |||
| } | |||
| if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||
| MarkForceUnknownShape(op_node2, true, group_index); | |||
| for (const auto &n : it2->second) { | |||
| MarkForceUnknownShape(n, true, group_index); | |||
| } | |||
| } | |||
| } | |||
| int64_t group_index = op_desc->GetId(); | |||
| SetControlFlowGroup(op_node, group_index); | |||
| for (const auto &n : it->second) { | |||
| SetControlFlowGroup(n, group_index); | |||
| } | |||
| } | |||
| } | |||
| @@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { | |||
| } | |||
| } | |||
| const auto &node = graph->GetParentNode(); | |||
| if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) { | |||
| GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | |||
| "[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); | |||
| } | |||
| for (const auto &node : graph->GetDirectNode()) { | |||
| GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); | |||
| (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); | |||
| @@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | |||
| return FAILED, "[Check][Param] Param of pre node is nullptr."); | |||
| int64_t group_index = -1; | |||
| bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| MarkForceUnknownShape(node, force_unknown, group_index); | |||
| (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | |||
| @@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
| GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| MarkForceUnknownShape(active_node, force_unknown, group_index); | |||
| SetControlFlowGroup(active_node, group_index); | |||
| } | |||
| return SUCCESS; | |||
| @@ -284,13 +284,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||
| /// @return void | |||
| /// | |||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | |||
| std::string node_type; | |||
| for (const auto &switch_node : loop_group.switch_nodes) { | |||
| SetControlFlowGroup(switch_node, group_index); | |||
| for (const auto &node : switch_node->GetOutDataNodes()) { | |||
| std::string node_type; | |||
| (void)GetOriginalType(node, node_type); | |||
| if (kExitOpTypes.count(node_type) > 0) { | |||
| SetControlFlowGroup(node, group_index); | |||
| } else { | |||
| // For: Switch -> Cast -> Exit | |||
| for (const auto &n : node->GetOutDataNodes()) { | |||
| (void)GetOriginalType(n, node_type); | |||
| if (kExitOpTypes.count(node_type) > 0) { | |||
| SetControlFlowGroup(n, group_index); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||
| peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | |||
| int64_t group_index = -1; | |||
| bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||
| (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| SetControlFlowGroup(stream_switch, group_index); | |||
| return stream_switch; | |||
| } | |||
| @@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
| Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | |||
| for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | |||
| for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | |||
| std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
| std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
| const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
| const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
| std::set<NodePtr> same_cond_switch; | |||
| same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | |||
| same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | |||
| @@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
| std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | |||
| return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
| }; | |||
| bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
| MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||
| (void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
| SetControlFlowGroup(active_node, group_index); | |||
| const std::string &cond_group = cond_node->GetName(); | |||
| for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | |||
| bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | |||
| std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
| const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
| GE_IF_BOOL_EXEC(switch_list.empty(), continue); | |||
| // select first stream_switch | |||
| @@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
| "[Add][Edge] between %s and %s failed.", | |||
| cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | |||
| MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||
| SetControlFlowGroup(stream_switch, group_index); | |||
| for (const NodePtr &node : switch_list) { | |||
| GE_IF_BOOL_EXEC(node != stream_switch, { | |||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | |||
| @@ -19,8 +19,9 @@ | |||
| #include "framework/common/debug/log.h" | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| #include "hybrid_execution_context.h" | |||
| #include "subgraph_context.h" | |||
| #include "hybrid/executor/hybrid_execution_context.h" | |||
| #include "hybrid/executor/subgraph_context.h" | |||
| #include "hybrid/node_executor/task_context.h" | |||
| #define INC_ITERATION_COUNT(iteration) \ | |||
| do { \ | |||
| @@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex | |||
| this->op_desc_ = node_item.node->GetOpDesc(); | |||
| } | |||
| Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) { | |||
| GE_CHECK_NOTNULL(frame_state); | |||
| group_ = group; | |||
| frame_state_ = frame_state; | |||
| auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||
| GE_CHECK_NOTNULL(unique_task_context); | |||
| task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| return SUCCESS; | |||
| } | |||
| Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | |||
| if (node_item_->IsMergeOp()) { | |||
| GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); | |||
| @@ -314,15 +325,54 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
| return task_context_; | |||
| } | |||
| void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||
| if (node_item_->root_data_.count(input_idx) > 0) { | |||
| GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||
| root_tensor_values_[input_idx] = tensor; | |||
| } | |||
| if (node_item_->enter_data_.count(input_idx) > 0) { | |||
| GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | |||
| root_tensor_values_[input_idx] = tensor; | |||
| } | |||
| } | |||
| void NodeState::UpdatePersistTensor(int input_idx) { | |||
| const auto it = root_tensor_values_.find(input_idx); | |||
| if (it == root_tensor_values_.end()) { | |||
| GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | |||
| return; | |||
| } | |||
| auto tensor = task_context_->MutableInput(input_idx); | |||
| if (tensor == nullptr) { | |||
| GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx); | |||
| return; | |||
| } | |||
| *tensor = it->second; | |||
| GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx); | |||
| } | |||
| void NodeState::ResetContext(uint64_t iteration) { | |||
| switch_index_ = -1; | |||
| subgraph_context_->ResetContext(node_item_->node); | |||
| if (iteration == 0) { | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
| } else { | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); | |||
| auto unique_task_context = TaskContext::Create(this, subgraph_context_); | |||
| GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); | |||
| task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
| ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
| for (auto item : node_item_->root_data_) { | |||
| UpdatePersistTensor(item.first); | |||
| } | |||
| if (iteration > 0) { | |||
| data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||
| ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||
| for (auto item : node_item_->enter_data_) { | |||
| UpdatePersistTensor(item.first); | |||
| } | |||
| } | |||
| iteration_count_ = iteration; | |||
| @@ -100,6 +100,8 @@ struct NodeState { | |||
| NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | |||
| ~NodeState() = default; | |||
| Status Init(int group, const shared_ptr<FrameState> &frame_state); | |||
| OpDesc *GetOpDesc() const { | |||
| return op_desc_.get(); | |||
| } | |||
| @@ -129,6 +131,8 @@ struct NodeState { | |||
| void RunStreamActive(); | |||
| void RunNextIteration(); | |||
| void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||
| Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
| void SetScheduleFuture(std::future<Status> &&future); | |||
| @@ -150,18 +154,10 @@ struct NodeState { | |||
| return merge_index_; | |||
| } | |||
| void SetGroup(int group) { | |||
| group_ = group; | |||
| } | |||
| int GetGroup() const { | |||
| return group_; | |||
| } | |||
| void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||
| frame_state_ = frame_state; | |||
| } | |||
| const shared_ptr<NodeTask> &GetKernelTask() const { | |||
| return kernel_task_; | |||
| } | |||
| @@ -187,6 +183,7 @@ struct NodeState { | |||
| void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
| void ResetContext(uint64_t iteration); | |||
| void ScheduleContext(const NodeState &node_state); | |||
| void UpdatePersistTensor(int input_idx); | |||
| const NodeItem *node_item_ = nullptr; | |||
| std::shared_ptr<NodeTask> kernel_task_ = nullptr; | |||
| @@ -199,6 +196,7 @@ struct NodeState { | |||
| std::future<Status> schedule_future_; | |||
| std::shared_ptr<FrameState> frame_state_; | |||
| std::map<int, TensorValue> root_tensor_values_; | |||
| uint64_t active_count_ = 0; | |||
| uint64_t iteration_count_ = 0; | |||
| uint32_t ctrl_scheduled_ = 0; | |||
| @@ -19,7 +19,7 @@ | |||
| namespace ge { | |||
| namespace hybrid { | |||
| SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context) | |||
| SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context) | |||
| : graph_item_(graph_item), execution_context_(execution_context) { | |||
| } | |||
| @@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||
| return nullptr; | |||
| } | |||
| return CreateNodeState(node_item); | |||
| } | |||
| NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { | |||
| GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | |||
| if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | |||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | |||
| return nullptr; | |||
| } | |||
| auto &node_state = node_states_[node_item]; | |||
| if (node_state == nullptr) { | |||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||
| node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||
| node_state->SetGroup(group_); | |||
| (void)guard; | |||
| } | |||
| do { | |||
| if (node_state == nullptr) { | |||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||
| if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); | |||
| REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); | |||
| break; | |||
| } | |||
| (void)guard; | |||
| } | |||
| } while (0); | |||
| GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | |||
| if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | |||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | |||
| @@ -30,7 +30,7 @@ namespace ge { | |||
| namespace hybrid { | |||
| class SubgraphContext { | |||
| public: | |||
| explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); | |||
| explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context); | |||
| ~SubgraphContext(); | |||
| Status Init(); | |||
| @@ -51,10 +51,11 @@ class SubgraphContext { | |||
| void NodeDone(const NodePtr &node); | |||
| private: | |||
| NodeStatePtr CreateNodeState(const NodeItem *node_item); | |||
| FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | |||
| friend class TaskContext; | |||
| const GraphItem *graph_item_; | |||
| const GraphExecutionContext *execution_context_; | |||
| GraphExecutionContext *execution_context_; | |||
| mmRWLock_t rw_lock_; | |||
| std::vector<TensorValue> all_inputs_; | |||
| std::vector<TensorValue> all_outputs_; | |||
| @@ -175,16 +175,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||
| GE_CHECK_NOTNULL(node_state); | |||
| node_state->SetKernelTask(node_item->kernel_task); | |||
| known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||
| GE_CHECK_NOTNULL(known_shape_task_context_); | |||
| node_state->SetTaskContext(known_shape_task_context_); | |||
| std::function<void()> callback; | |||
| GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | |||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback), | |||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback), | |||
| "[%s] Failed to execute node [%s] for known subgraph.", | |||
| graph_item_->GetName().c_str(), | |||
| known_shape_task_context_->GetNodeName()); | |||
| node_state->GetName().c_str()); | |||
| GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); | |||
| return SUCCESS; | |||
| @@ -271,16 +267,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||
| } else { | |||
| node_state->SetKernelTask(node_item.kernel_task); | |||
| } | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||
| GE_CHECK_NOTNULL(unique_task_context); | |||
| const auto &task = node_state->GetKernelTask(); | |||
| if (task == nullptr) { | |||
| GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); | |||
| REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); | |||
| return AfterPrepared(p_node_state); | |||
| } | |||
| @@ -480,19 +472,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||
| } else { | |||
| node_state.SetKernelTask(node_item.kernel_task); | |||
| } | |||
| auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); | |||
| GE_CHECK_NOTNULL(unique_task_context); | |||
| const auto &task = node_state.GetKernelTask(); | |||
| if (task == nullptr) { | |||
| GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); | |||
| REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state.SetTaskContext(shared_task_context); | |||
| GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); | |||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); | |||
| GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws | |||
| GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws | |||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); | |||
| return SUCCESS; | |||
| } | |||
| @@ -125,7 +125,6 @@ class SubgraphExecutor { | |||
| ThreadPool pre_run_pool_; | |||
| BlockingQueue<NodeState *> ready_queue_; | |||
| std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | |||
| std::shared_ptr<TaskContext> known_shape_task_context_; | |||
| std::mutex mu_; // Guard for prepare_queues_. | |||
| std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; | |||
| @@ -398,12 +398,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||
| data_send_.emplace(node_item); | |||
| node_item->data_recv_[this] = anchor_index; | |||
| if (is_root_node_) { | |||
| node_item->root_data_.emplace(this); | |||
| node_item->root_data_[anchor_index] = this; | |||
| } | |||
| // If Enter feed Not Merge, take as root Node. | |||
| if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||
| node_item->enter_data_.emplace(this); | |||
| node_item->enter_inside_.emplace(anchor_index); | |||
| node_item->enter_data_[anchor_index] = this; | |||
| } | |||
| GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
| } | |||
| @@ -148,15 +148,14 @@ struct NodeItem { | |||
| int64_t frame_index_ = -1; | |||
| int64_t parent_frame_ = -1; | |||
| std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | |||
| std::set<const NodeItem *> root_data_; // Recv data from root node | |||
| std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||
| std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | |||
| std::set<const NodeItem *> enter_data_; // Recv data from Enter node | |||
| std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||
| std::set<const NodeItem *> data_send_; // Send data notify to | |||
| std::map<const NodeItem *, int> data_recv_; // Recv data notify from | |||
| std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | |||
| std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | |||
| std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||
| std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge. | |||
| std::shared_ptr<NodeTask> kernel_task; | |||
| std::unique_ptr<FusedSubgraph> fused_subgraph; | |||
| @@ -52,9 +52,7 @@ void TaskContext::ReleaseWorkspace() { | |||
| } | |||
| } | |||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||
| GraphExecutionContext *execution_context, | |||
| SubgraphContext *subgraph_context) { | |||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) { | |||
| const NodeItem &node_item = *node_state->GetNodeItem(); | |||
| GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | |||
| node_item.NodeName().c_str(), | |||
| @@ -75,7 +73,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||
| } | |||
| auto task_context = std::unique_ptr<TaskContext>( | |||
| new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context)); | |||
| new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context)); | |||
| if (task_context == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); | |||
| GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); | |||
| @@ -85,7 +83,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||
| task_context->node_item_ = &node_item; | |||
| task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; | |||
| task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; | |||
| task_context->iteration_ = execution_context->iteration; | |||
| task_context->iteration_ = subgraph_context->execution_context_->iteration; | |||
| return task_context; | |||
| } | |||
| @@ -460,6 +458,10 @@ Status TaskContext::PropagateOutputs() { | |||
| subgraph_context_->all_inputs_[input_offset].SetName( | |||
| node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | |||
| } | |||
| auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||
| GE_CHECK_NOTNULL(dst_node_state); | |||
| dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||
| } | |||
| } | |||
| (void)guard; | |||
| @@ -489,11 +491,6 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||
| } | |||
| void TaskContext::ReleaseInput(int index) { | |||
| if (node_item_->enter_inside_.count(index) > 0) { | |||
| GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||
| return; | |||
| } | |||
| auto input_tensor = MutableInput(index); | |||
| if (input_tensor != nullptr) { | |||
| input_tensor->Destroy(); | |||
| @@ -36,9 +36,7 @@ class SubgraphContext; | |||
| class TaskContext { | |||
| public: | |||
| static std::unique_ptr<TaskContext> Create(NodeState *node_state, | |||
| GraphExecutionContext *execution_context, | |||
| SubgraphContext *subgraph_context); | |||
| static std::unique_ptr<TaskContext> Create(NodeState *node_state, SubgraphContext *subgraph_context); | |||
| ~TaskContext(); | |||
| @@ -24,6 +24,7 @@ | |||
| #include "inc/framework/common/types.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/common/omg_util.h" | |||
| namespace ge { | |||
| namespace { | |||
| @@ -38,33 +39,33 @@ GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format fo | |||
| } | |||
| class NodeBuilder { | |||
| public: | |||
| NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); } | |||
| NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, | |||
| DataType data_type = DT_FLOAT) { | |||
| op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
| return *this; | |||
| } | |||
| NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, | |||
| DataType data_type = DT_FLOAT) { | |||
| op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
| return *this; | |||
| } | |||
| NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { | |||
| op_desc_->AddOutputDesc(tensor_desc->Clone()); | |||
| return *this; | |||
| } | |||
| NodePtr Build(const ComputeGraphPtr &graph) { | |||
| NodePtr node = graph->AddNode(op_desc_); | |||
| return node; | |||
| } | |||
| private: | |||
| OpDescPtr op_desc_; | |||
| public: | |||
| NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); } | |||
| NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, | |||
| DataType data_type = DT_FLOAT) { | |||
| op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
| return *this; | |||
| } | |||
| NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, | |||
| DataType data_type = DT_FLOAT) { | |||
| op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||
| return *this; | |||
| } | |||
| NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { | |||
| op_desc_->AddOutputDesc(tensor_desc->Clone()); | |||
| return *this; | |||
| } | |||
| NodePtr Build(const ComputeGraphPtr &graph) { | |||
| NodePtr node = graph->AddNode(op_desc_); | |||
| return node; | |||
| } | |||
| private: | |||
| OpDescPtr op_desc_; | |||
| }; | |||
| } // namespace | |||
| @@ -93,28 +94,137 @@ TEST_F(UtestDynamicShapePartition, single_op_scene_success) { | |||
| EXPECT_EQ(partitioner.Partition(), SUCCESS); | |||
| } | |||
| /******************************************************************************* | |||
| * | | |||
| * Merge1 | |||
| * Active / \ Active | |||
| * / \. | |||
| * / \. | |||
| * Merge2 \. | |||
| * Active/ \Active \. | |||
| * / \ \. | |||
| * Add Sub Relu | |||
| * | | | | |||
| * | | | | |||
| * Switch_f2 Switch_t2 | | |||
| * \ / | | |||
| * \ / | | |||
| * Less2 | | |||
| * | | | |||
| * | | | |||
| * Switch_f Switch_t | |||
| * | \ / | | |||
| * | Active | | |||
| * | | | | |||
| * | Less1 | | |||
| * | / \ | | |||
| * | / \ | | |||
| * Data Data | |||
| ******************************************************************************/ | |||
| TEST_F(UtestDynamicShapePartition, merge_control_flow_group) { | |||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("default"); | |||
| AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, "session_graph_id"); | |||
| NodePtr data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| NodePtr data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| NodePtr merge = NodeBuilder("node2", MERGE).AddInputDesc({1}).AddInputDesc({1}) | |||
| .AddOutputDesc({1}).AddOutputDesc({}).Build(graph); | |||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), merge->GetInDataAnchor(1)); | |||
| (void)AttrUtils::SetBool(data1->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||
| (void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||
| (void)AttrUtils::SetBool(data2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||
| (void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||
| (void)AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||
| (void)AttrUtils::SetInt(merge->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||
| auto data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto less1 = NodeBuilder("less1", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto active1 = NodeBuilder("active1", STREAMACTIVE).Build(graph); | |||
| auto switch_t = NodeBuilder("switch_t", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); | |||
| auto switch_f = NodeBuilder("switch_f", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); | |||
| auto const_01 = NodeBuilder("const_01", CONSTANT).AddOutputDesc({1}).Build(graph); | |||
| auto const_11 = NodeBuilder("const_11", CONSTANT).AddOutputDesc({1}).Build(graph); | |||
| auto less2 = NodeBuilder("less2", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto active2 = NodeBuilder("active2", STREAMACTIVE).Build(graph); | |||
| auto switch_t2 = NodeBuilder("switch_t2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); | |||
| auto switch_f2 = NodeBuilder("switch_f2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); | |||
| auto const_02 = NodeBuilder("const_02", CONSTANT).AddOutputDesc({1}).Build(graph); | |||
| auto const_12 = NodeBuilder("const_12", CONSTANT).AddOutputDesc({1}).Build(graph); | |||
| auto add2 = NodeBuilder("add2", ADD).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto sub2 = NodeBuilder("sub2", SUB).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto merge2 = NodeBuilder("merge2", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto active_f2 = NodeBuilder("active_f2", STREAMACTIVE).Build(graph); | |||
| auto active_t2 = NodeBuilder("active_t2", STREAMACTIVE).Build(graph); | |||
| auto relu1 = NodeBuilder("relu1", RELU).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto merge1 = NodeBuilder("merge1", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||
| auto active_f1 = NodeBuilder("active_f1", STREAMACTIVE).Build(graph); | |||
| auto active_t1 = NodeBuilder("active_t1", STREAMACTIVE).Build(graph); | |||
| auto output1 = NodeBuilder("noutput1", NETOUTPUT).AddInputDesc({1}).Build(graph); | |||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(const_01->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(const_11->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(less1->GetOutDataAnchor(0), less2->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(const_02->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(const_12->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(less2->GetOutControlAnchor(), active2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(switch_f2->GetOutControlAnchor(), add2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(less2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(add2->GetOutControlAnchor(), active_f2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active_f2->GetOutControlAnchor(), merge2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(switch_t2->GetOutControlAnchor(), sub2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(less2->GetOutDataAnchor(0), sub2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(sub2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(sub2->GetOutControlAnchor(), active_t2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active_t2->GetOutControlAnchor(), merge2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), less2->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), relu1->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(merge2->GetOutControlAnchor(), active_f1->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active_f1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), relu1->GetInDataAnchor(1)); | |||
| GraphUtils::AddEdge(relu1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(relu1->GetOutControlAnchor(), active_t1->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(active_t1->GetOutControlAnchor(), merge1->GetInControlAnchor()); | |||
| GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||
| AttrUtils::SetBool(merge2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||
| EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); | |||
| SetControlFlowGroup(merge2, merge2->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(switch_f2, merge2->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(switch_t2, merge2->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active2, merge2->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active_t2, merge2->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active_f2, merge2->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(merge1, merge1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(switch_f, merge1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(switch_t, merge1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active1, merge1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active_f1, merge1->GetOpDesc()->GetId()); | |||
| SetControlFlowGroup(active_t1, merge1->GetOpDesc()->GetId()); | |||
| EXPECT_EQ(graph->impl_->sub_graph_.size(), 0); | |||
| DynamicShapePartitioner partitioner(graph); | |||
| EXPECT_EQ(partitioner.Partition(), SUCCESS); | |||
| EXPECT_EQ(graph->impl_->sub_graph_.size(), 1); | |||
| EXPECT_EQ(graph->impl_->sub_graph_.size(), 3); // input less1 uknown | |||
| } | |||
| } // namespace ge | |||
| @@ -83,18 +83,14 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { | |||
| execution_context.profiling_level = 1; | |||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||
| NodeState node_state(*node_item, &subgraph_context); | |||
| auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release()); | |||
| node_state.SetTaskContext(shared_task_context); | |||
| ExecutionEngine execution_engine; | |||
| ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||
| ASSERT_TRUE(node_state->GetTaskContext() != nullptr); | |||
| std::function<void()> callback; | |||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | |||
| executor.InitCallback(&node_state, callback); | |||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||
| executor.InitCallback(node_state.get(), callback); | |||
| ExecutionEngine execution_engine; | |||
| EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||
| } | |||
| TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||
| @@ -118,21 +114,18 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { | |||
| execution_context.model = &hybrid_model; | |||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||
| NodeState node_state(*node_item, &subgraph_context); | |||
| auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||
| uint32_t task_id = 0; | |||
| uint32_t stream_id = 1; | |||
| std::string task_type = "rts"; | |||
| uint32_t block_dim = 0; | |||
| task_context->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release()); | |||
| node_state.SetTaskContext(shared_task_context); | |||
| node_state->GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); | |||
| ExecutionEngine execution_engine; | |||
| ASSERT_TRUE(node_state.GetTaskContext() != nullptr); | |||
| ASSERT_TRUE(node_state->GetTaskContext() != nullptr); | |||
| std::function<void()> callback; | |||
| SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); | |||
| executor.InitCallback(&node_state, callback); | |||
| EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||
| executor.InitCallback(node_state.get(), callback); | |||
| ExecutionEngine execution_engine; | |||
| EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); | |||
| } | |||
| @@ -160,11 +160,9 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { | |||
| GraphExecutionContext execution_context; | |||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||
| NodeState node_state(*node_item, &subgraph_context); | |||
| auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
| ASSERT_TRUE(task_context != nullptr); | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||
| ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); | |||
| ASSERT_EQ(aicore_task->UpdateTilingInfo(*task_context), SUCCESS); | |||
| ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS); | |||
| } | |||
| TEST_F(UtestGeHybrid, index_taskdefs_failed) { | |||
| @@ -477,12 +475,14 @@ TEST_F(UtestGeHybrid, TestTaskContext) { | |||
| node_item->output_start = 0; | |||
| GraphExecutionContext execution_context; | |||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||
| GraphItem graph_item; | |||
| SubgraphContext subgraph_context(&graph_item, &execution_context); | |||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||
| subgraph_context.all_inputs_.resize(2); | |||
| subgraph_context.all_outputs_.resize(1); | |||
| NodeState node_state(*node_item, &subgraph_context); | |||
| auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||
| auto task_context = node_state->GetTaskContext(); | |||
| ASSERT_TRUE(task_context != nullptr); | |||
| auto desc = task_context->MutableInputDesc(2); | |||
| ASSERT_TRUE(desc == nullptr); | |||
| @@ -522,12 +522,14 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) { | |||
| node_item->output_start = 0; | |||
| GraphExecutionContext execution_context; | |||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||
| GraphItem graph_item; | |||
| SubgraphContext subgraph_context(&graph_item, &execution_context); | |||
| ASSERT_EQ(subgraph_context.Init(), SUCCESS); | |||
| subgraph_context.all_inputs_.resize(2); | |||
| subgraph_context.all_outputs_.resize(1); | |||
| NodeState node_state(*node_item, &subgraph_context); | |||
| auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); | |||
| auto task_context = node_state->GetTaskContext(); | |||
| int32_t buffer[1]; | |||
| aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); | |||
| @@ -97,11 +97,6 @@ TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| NodeTaskPtr task = nullptr; | |||
| GeLocalNodeExecutor node_executor; | |||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
| @@ -94,18 +94,17 @@ TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { | |||
| tensor.SetData(data); | |||
| ctx->SetTensor(1, 0, tensor.Clone()); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| vector<HcomRemoteAccessAddrInfo> addr_infos; | |||
| shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>(); | |||
| task->remote_index_ = {1, 0}; | |||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||
| ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID); | |||
| Shape s2({1}); | |||
| TensorDesc tensor_desc2(s2); | |||
| Tensor tensor2(tensor_desc2); | |||
| ctx->SetTensor(1, 0, tensor2.Clone()); | |||
| task->ExtractTensor(*unique_task_context, addr_infos); | |||
| ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); | |||
| task->ExtractTensor(*node_state->GetTaskContext(), addr_infos); | |||
| ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID); | |||
| RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); | |||
| } | |||
| @@ -140,11 +139,6 @@ TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| for (int i=0; i<4; ++i) { | |||
| uint64_t value_0 = 512; | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| @@ -206,11 +200,6 @@ TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| for (int i=0; i<5; ++i) { | |||
| uint64_t value_0 = 512; | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| @@ -96,11 +96,6 @@ TEST_F(UtestRtsNodeTask, test_stream_switch_task) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| uint64_t value_0 = 110; | |||
| uint64_t value_1 = 120; | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| @@ -153,11 +148,6 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| NodeTaskPtr task = nullptr; | |||
| RtsNodeExecutor node_executor; | |||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
| @@ -203,11 +193,6 @@ TEST_F(UtestRtsNodeTask, test_stream_merge_task) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| uint64_t value_0 = 110; | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
| @@ -271,11 +256,6 @@ TEST_F(UtestRtsNodeTask, test_memcpy_async_task) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| uint64_t value_0 = 110; | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
| @@ -328,11 +308,6 @@ TEST_F(UtestRtsNodeTask, test_pass_through_task) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| uint64_t value_0 = 110; | |||
| TensorValue in_tensor0(&value_0, sizeof(value_0)); | |||
| subgraph_context.SetInput(*node_item, 0, in_tensor0); | |||
| @@ -384,11 +359,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_set) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| NodeTaskPtr task = nullptr; | |||
| RtsNodeExecutor node_executor; | |||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
| @@ -428,11 +398,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| NodeTaskPtr task = nullptr; | |||
| RtsNodeExecutor node_executor; | |||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||
| @@ -472,11 +437,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) { | |||
| auto node_state = subgraph_context.GetOrCreateNodeState(node_item); | |||
| ASSERT_NE(node_state, nullptr); | |||
| auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); | |||
| ASSERT_NE(unique_task_context, nullptr); | |||
| auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release()); | |||
| node_state->SetTaskContext(shared_task_context); | |||
| NodeTaskPtr task = nullptr; | |||
| RtsNodeExecutor node_executor; | |||
| ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); | |||