diff --git a/cmake/external_libs/protobuf_shared.cmake b/cmake/external_libs/protobuf_shared.cmake index 6334c8a3..dfdb0606 100755 --- a/cmake/external_libs/protobuf_shared.cmake +++ b/cmake/external_libs/protobuf_shared.cmake @@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() if (GE_PB_PKG) - set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") + set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") else() if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") - set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") + set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") else() - set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") - set(MD5 "3d9e32700639618a4d2d342c99d4507a") + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") + set(MD5 "1a6274bc4a65b55a6fa70e264d796490") endif () endif() @@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/ set(INSTALL_BASE_DIR "") set(INSTALL_LIBRARY_DIR lib) -install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL +install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL DESTINATION ${INSTALL_LIBRARY_DIR}) install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL DESTINATION ${INSTALL_LIBRARY_DIR}) diff --git a/cmake/external_libs/protobuf_static.cmake b/cmake/external_libs/protobuf_static.cmake index 22f537cf..b8ff90bb 100755 --- a/cmake/external_libs/protobuf_static.cmake +++ b/cmake/external_libs/protobuf_static.cmake @@ -16,11 +16,11 @@ if(GE_PB_PKG) set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") else() if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") - set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") + set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") else() - set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") - set(MD5 "3d9e32700639618a4d2d342c99d4507a") + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") + set(MD5 "1a6274bc4a65b55a6fa70e264d796490") endif () endif() @@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) ExternalProject_Add(protobuf_static_build URL ${REQ_URL} - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} diff --git a/cmake/external_libs/protoc.cmake b/cmake/external_libs/protoc.cmake index 421f2632..f16f5e22 100755 --- a/cmake/external_libs/protoc.cmake +++ b/cmake/external_libs/protoc.cmake @@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR endif() if(GE_PB_PKG) - set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") + set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz") else() if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") - set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") + set(MD5 "f4489cb88922ad9c58cbe3308d59cee5") else() - set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") - set(MD5 "3d9e32700639618a4d2d342c99d4507a") + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") + set(MD5 "1a6274bc4a65b55a6fa70e264d796490") endif () endif() @@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") ExternalProject_Add(protoc_build URL ${REQ_URL} - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz - #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc /cmake BUILD_COMMAND $(MAKE) diff --git a/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc index 52e6cb9c..b2017e4d 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/graph/common/omg_util.cc @@ -274,21 +274,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { return false; } -/// -/// @brief Set Op _force_unknown_shape flag -/// @param [in] node -/// @param [in] force_unknown, set attribute if true -/// @param [in] group_index, condition group index of node. -/// @return -/// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { - if (!force_unknown) { - return; - } - - SetControlFlowGroup(node, group_index); -} - /// /// @brief Set Op _control_flow_group flag /// @param [in] node diff --git a/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h index 148e4102..edaafa45 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/graph/common/omg_util.h @@ -125,15 +125,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); /// bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); -/// -/// @brief Set Op _force_unknown_shape flag -/// @param [in] node -/// @param [in] force_unknown, set attribute if true -/// @param [in] group_index, condition group index of node. -/// @return -/// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); - /// /// @brief Set Op _control_flow_group flag /// @param [in] node diff --git a/ge/graph/optimize/graph_optimize.cc b/ge/graph/optimize/graph_optimize.cc index 835e257b..55f374eb 100644 --- a/ge/graph/optimize/graph_optimize.cc +++ b/ge/graph/optimize/graph_optimize.cc @@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) { GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); continue; } -#ifndef ONLY_COMPILE_OPEN_SRC GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); ret = (iter->second)->OptimizeAfterStage1(*compute_graph); -#endif if (ret != SUCCESS) { REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " "graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc index 055b2aa4..1db47498 100755 --- a/ge/graph/partition/dynamic_shape_partition.cc +++ b/ge/graph/partition/dynamic_shape_partition.cc @@ -364,6 +364,7 @@ static std::string ToString(const std::vector &clusters) { } void DynamicShapePartitioner::MergeClustersControlFlow() { + std::unordered_set all_merged_clusters; for (const auto &item : control_clusters_) { const auto &control_cluster = item.second; auto rit = control_cluster.rbegin(); @@ -373,17 +374,32 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { } const auto &cluster = *rit; + if (all_merged_clusters.count(cluster) > 0) { + continue; + } + + bool is_unknown_cluster = cluster->IsUnknownShape(); for (++rit; rit != control_cluster.rend(); ++rit) { const auto &cluster_from = *rit; + if (all_merged_clusters.count(cluster_from) > 0) { + continue; + } + auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), ToString(merged_clusters).c_str()); for (const auto &merged_cluster : merged_clusters) { + all_merged_clusters.emplace(merged_cluster); for (const auto &node : merged_cluster->Nodes()) { node_2_cluster_[node] = cluster; } } } + + if (!is_unknown_cluster && cluster->IsUnknownShape()) { + GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); + ordered_cluster_.push_back(cluster); + } } } @@ -703,7 +719,12 @@ void Cluster::Merge(ClusterPtr other) { if (other->min_ < min_) { min_ = other->min_; } -}; + + if (!IsUnknownShape() && other->IsUnknownShape()) { + type_ = UNKNOWN_SHAPE; + } +} + bool Cluster::TryMerge(ClusterPtr other) { std::queue forward_reached; forward_reached.push(other); diff --git a/ge/graph/partition/dynamic_shape_partition.h b/ge/graph/partition/dynamic_shape_partition.h index a17c4e4b..bd3b128f 100644 --- a/ge/graph/partition/dynamic_shape_partition.h +++ b/ge/graph/partition/dynamic_shape_partition.h @@ -161,7 +161,7 @@ class DynamicShapePartitioner { ge::ComputeGraphPtr root_graph_; // The original graph to partition std::unordered_map> node_2_cluster_; // Record nodes and the cluster it belongs to // V1 control flow cluster, need merge to one Graph. - std::unordered_map>> control_clusters_; + std::map>> control_clusters_; // topological sorted clusters, this field will change with the splitting. // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index 08b358ee..74babadc 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -132,39 +132,17 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: /// @return /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map> &switch_groups) { - std::function callback = [](const NodePtr &n) { - return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); - }; - - for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { - const auto &op_node1 = it1->first; - const auto &op_desc1 = op_node1->GetOpDesc(); - if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { + for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { + const auto &op_node = it->first; + const auto &op_desc = op_node->GetOpDesc(); + if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { continue; } - if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { - int64_t group_index = op_desc1->GetId(); - GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); - MarkForceUnknownShape(op_node1, true, group_index); - for (const auto &n : it1->second) { - MarkForceUnknownShape(n, true, group_index); - } - - for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { - const auto &op_node2 = it2->first; - const auto &op_desc2 = op_node2->GetOpDesc(); - if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { - continue; - } - - if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { - MarkForceUnknownShape(op_node2, true, group_index); - for (const auto &n : it2->second) { - MarkForceUnknownShape(n, true, group_index); - } - } - } + int64_t group_index = op_desc->GetId(); + SetControlFlowGroup(op_node, group_index); + for (const auto &n : it->second) { + SetControlFlowGroup(n, group_index); } } } diff --git a/ge/graph/passes/mark_graph_unknown_status_pass.cc b/ge/graph/passes/mark_graph_unknown_status_pass.cc index 2d7b179b..9e460fc7 100644 --- a/ge/graph/passes/mark_graph_unknown_status_pass.cc +++ b/ge/graph/passes/mark_graph_unknown_status_pass.cc @@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { } } + const auto &node = graph->GetParentNode(); + if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) { + GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), + "[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); + } + for (const auto &node : graph->GetDirectNode()) { GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc index 0b383911..dbcff620 100644 --- a/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); return FAILED, "[Check][Param] Param of pre node is nullptr."); int64_t group_index = -1; - bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - MarkForceUnknownShape(node, force_unknown, group_index); + (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); @@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); return FAILED; } - MarkForceUnknownShape(active_node, force_unknown, group_index); + SetControlFlowGroup(active_node, group_index); } return SUCCESS; diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index 67735b8b..fb8f8627 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -284,13 +284,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { /// @return void /// void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { + std::string node_type; for (const auto &switch_node : loop_group.switch_nodes) { SetControlFlowGroup(switch_node, group_index); for (const auto &node : switch_node->GetOutDataNodes()) { - std::string node_type; (void)GetOriginalType(node, node_type); if (kExitOpTypes.count(node_type) > 0) { SetControlFlowGroup(node, group_index); + } else { + // For: Switch -> Cast -> Exit + for (const auto &n : node->GetOutDataNodes()) { + (void)GetOriginalType(n, node_type); + if (kExitOpTypes.count(node_type) > 0) { + SetControlFlowGroup(n, group_index); + } + } } } } diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index e7743130..e4ab0111 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); int64_t group_index = -1; - bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - MarkForceUnknownShape(stream_switch, force_unknown, group_index); + (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); + SetControlFlowGroup(stream_switch, group_index); return stream_switch; } @@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { - std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; - std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; + const std::list &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; + const std::list &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; std::set same_cond_switch; same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); @@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) std::function callback = [&group_index](const NodePtr &n) { return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); }; - bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); - MarkForceUnknownShape(active_node, is_unknown_shape, group_index); + (void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); + SetControlFlowGroup(active_node, group_index); const std::string &cond_group = cond_node->GetName(); for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); - std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); + const std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); GE_IF_BOOL_EXEC(switch_list.empty(), continue); // select first stream_switch @@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) "[Add][Edge] between %s and %s failed.", cast_node->GetName().c_str(), stream_switch->GetName().c_str()); - MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); + SetControlFlowGroup(stream_switch, group_index); for (const NodePtr &node : switch_list) { GE_IF_BOOL_EXEC(node != stream_switch, { GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 313a2934..468c84e6 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -19,8 +19,9 @@ #include "framework/common/debug/log.h" #include "graph/compute_graph.h" #include "graph/utils/tensor_utils.h" -#include "hybrid_execution_context.h" -#include "subgraph_context.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/node_executor/task_context.h" #define INC_ITERATION_COUNT(iteration) \ do { \ @@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex this->op_desc_ = node_item.node->GetOpDesc(); } +Status NodeState::Init(int group, const shared_ptr &frame_state) { + GE_CHECK_NOTNULL(frame_state); + group_ = group; + frame_state_ = frame_state; + auto unique_task_context = TaskContext::Create(this, subgraph_context_); + GE_CHECK_NOTNULL(unique_task_context); + task_context_ = std::shared_ptr(unique_task_context.release()); + return SUCCESS; +} + Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { if (node_item_->IsMergeOp()) { GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); @@ -314,15 +325,54 @@ std::shared_ptr NodeState::GetTaskContext() { return task_context_; } +void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { + if (node_item_->root_data_.count(input_idx) > 0) { + GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); + root_tensor_values_[input_idx] = tensor; + } + + if (node_item_->enter_data_.count(input_idx) > 0) { + GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); + root_tensor_values_[input_idx] = tensor; + } +} + +void NodeState::UpdatePersistTensor(int input_idx) { + const auto it = root_tensor_values_.find(input_idx); + if (it == root_tensor_values_.end()) { + GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); + return; + } + + auto tensor = task_context_->MutableInput(input_idx); + if (tensor == nullptr) { + GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx); + return; + } + + *tensor = it->second; + GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx); +} + void NodeState::ResetContext(uint64_t iteration) { switch_index_ = -1; subgraph_context_->ResetContext(node_item_->node); - if (iteration == 0) { - data_scheduled_ = static_cast(node_item_->root_data_.size()); - ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); - } else { - data_scheduled_ = static_cast(node_item_->root_data_.size() + node_item_->enter_data_.size()); - ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size()); + auto unique_task_context = TaskContext::Create(this, subgraph_context_); + GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); + task_context_ = std::shared_ptr(unique_task_context.release()); + + data_scheduled_ = static_cast(node_item_->root_data_.size()); + ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); + for (auto item : node_item_->root_data_) { + UpdatePersistTensor(item.first); + } + + if (iteration > 0) { + data_scheduled_ += static_cast(node_item_->enter_data_.size()); + ctrl_scheduled_ += static_cast(node_item_->enter_ctrl_.size()); + for (auto item : node_item_->enter_data_) { + UpdatePersistTensor(item.first); + } } iteration_count_ = iteration; diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 9dd29846..85f9e4c3 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -100,6 +100,8 @@ struct NodeState { NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); ~NodeState() = default; + Status Init(int group, const shared_ptr &frame_state); + OpDesc *GetOpDesc() const { return op_desc_.get(); } @@ -129,6 +131,8 @@ struct NodeState { void RunStreamActive(); void RunNextIteration(); + void SavePersistTensor(int input_idx, const TensorValue &tensor); + Status NodeScheduled(const std::function &ready) const; void SetScheduleFuture(std::future &&future); @@ -150,18 +154,10 @@ struct NodeState { return merge_index_; } - void SetGroup(int group) { - group_ = group; - } - int GetGroup() const { return group_; } - void SetFrameState(const shared_ptr &frame_state) { - frame_state_ = frame_state; - } - const shared_ptr &GetKernelTask() const { return kernel_task_; } @@ -187,6 +183,7 @@ struct NodeState { void SetCtrlSchedule(const NodeState &node_state, const std::function &ready); void ResetContext(uint64_t iteration); void ScheduleContext(const NodeState &node_state); + void UpdatePersistTensor(int input_idx); const NodeItem *node_item_ = nullptr; std::shared_ptr kernel_task_ = nullptr; @@ -199,6 +196,7 @@ struct NodeState { std::future schedule_future_; std::shared_ptr frame_state_; + std::map root_tensor_values_; uint64_t active_count_ = 0; uint64_t iteration_count_ = 0; uint32_t ctrl_scheduled_ = 0; diff --git a/ge/hybrid/executor/subgraph_context.cc b/ge/hybrid/executor/subgraph_context.cc index b6763ffd..5e97a9a2 100644 --- a/ge/hybrid/executor/subgraph_context.cc +++ b/ge/hybrid/executor/subgraph_context.cc @@ -19,7 +19,7 @@ namespace ge { namespace hybrid { -SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context) +SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context) : graph_item_(graph_item), execution_context_(execution_context) { } @@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { return nullptr; } + return CreateNodeState(node_item); +} + +NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { GELOGD("[%s] lock for write", node_item->NodeName().c_str()); if (mmRWLockWRLock(&rw_lock_) != EN_OK) { REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); return nullptr; } + auto &node_state = node_states_[node_item]; - if (node_state == nullptr) { - const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); - node_state.reset(new(std::nothrow)NodeState(*node_item, this)); - node_state->SetFrameState(GetOrCreateFrameState(*node_item)); - node_state->SetGroup(group_); - (void)guard; - } + do { + if (node_state == nullptr) { + const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); + node_state.reset(new(std::nothrow)NodeState(*node_item, this)); + if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); + REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); + break; + } + (void)guard; + } + } while (0); + GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); if (mmWRLockUnLock(&rw_lock_) != EN_OK) { REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); diff --git a/ge/hybrid/executor/subgraph_context.h b/ge/hybrid/executor/subgraph_context.h index a43cd210..023be981 100644 --- a/ge/hybrid/executor/subgraph_context.h +++ b/ge/hybrid/executor/subgraph_context.h @@ -30,7 +30,7 @@ namespace ge { namespace hybrid { class SubgraphContext { public: - explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); + explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context); ~SubgraphContext(); Status Init(); @@ -51,10 +51,11 @@ class SubgraphContext { void NodeDone(const NodePtr &node); private: + NodeStatePtr CreateNodeState(const NodeItem *node_item); FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock friend class TaskContext; const GraphItem *graph_item_; - const GraphExecutionContext *execution_context_; + GraphExecutionContext *execution_context_; mmRWLock_t rw_lock_; std::vector all_inputs_; std::vector all_outputs_; diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 612e7565..7429acc5 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -175,16 +175,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vectorSetKernelTask(node_item->kernel_task); - known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); - GE_CHECK_NOTNULL(known_shape_task_context_); - node_state->SetTaskContext(known_shape_task_context_); - std::function callback; GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); - HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback), + HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback), "[%s] Failed to execute node [%s] for known subgraph.", graph_item_->GetName().c_str(), - known_shape_task_context_->GetNodeName()); + node_state->GetName().c_str()); GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); return SUCCESS; @@ -271,16 +267,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { } else { node_state->SetKernelTask(node_item.kernel_task); } - auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); - GE_CHECK_NOTNULL(unique_task_context); const auto &task = node_state->GetKernelTask(); if (task == nullptr) { GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); return INTERNAL_ERROR; } - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); return AfterPrepared(p_node_state); } @@ -480,19 +472,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta } else { node_state.SetKernelTask(node_item.kernel_task); } - auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); - GE_CHECK_NOTNULL(unique_task_context); const auto &task = node_state.GetKernelTask(); if (task == nullptr) { GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); return INTERNAL_ERROR; } - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state.SetTaskContext(shared_task_context); GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); - GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws + GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); return SUCCESS; } diff --git a/ge/hybrid/executor/subgraph_executor.h b/ge/hybrid/executor/subgraph_executor.h index 0f54e4ca..7e1c2d0b 100644 --- a/ge/hybrid/executor/subgraph_executor.h +++ b/ge/hybrid/executor/subgraph_executor.h @@ -127,7 +127,6 @@ class SubgraphExecutor { ThreadPool pre_run_pool_; BlockingQueue ready_queue_; std::unique_ptr shape_inference_engine_; - std::shared_ptr known_shape_task_context_; std::mutex mu_; // Guard for prepare_queues_. std::map> prepare_queues_; diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index b339e630..cef06fc6 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -398,12 +398,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { data_send_.emplace(node_item); node_item->data_recv_[this] = anchor_index; if (is_root_node_) { - node_item->root_data_.emplace(this); + node_item->root_data_[anchor_index] = this; } // If Enter feed Not Merge, take as root Node. if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { - node_item->enter_data_.emplace(this); - node_item->enter_inside_.emplace(anchor_index); + node_item->enter_data_[anchor_index] = this; } GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 8de15952..ec66f094 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -148,15 +148,14 @@ struct NodeItem { int64_t frame_index_ = -1; int64_t parent_frame_ = -1; std::set root_ctrl_; // Recv ctrl from root node - std::set root_data_; // Recv data from root node + std::map root_data_; // Recv data from root node std::set enter_ctrl_; // Recv ctrl from Enter node - std::set enter_data_; // Recv data from Enter node + std::map enter_data_; // Recv data from Enter node std::set data_send_; // Send data notify to std::map data_recv_; // Recv data notify from std::set ctrl_send_; // Send ctrl notify to std::set ctrl_recv_; // Recv ctrl notify from std::vector> switch_groups_; // Send ctrl notify to - std::set enter_inside_; // Enter feed loop inside Node, Not cross Merge. std::shared_ptr kernel_task; std::unique_ptr fused_subgraph; diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index 14eb1222..fe580c1e 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -52,9 +52,7 @@ void TaskContext::ReleaseWorkspace() { } } -std::unique_ptr TaskContext::Create(NodeState *node_state, - GraphExecutionContext *execution_context, - SubgraphContext *subgraph_context) { +std::unique_ptr TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) { const NodeItem &node_item = *node_state->GetNodeItem(); GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", node_item.NodeName().c_str(), @@ -75,7 +73,7 @@ std::unique_ptr TaskContext::Create(NodeState *node_state, } auto task_context = std::unique_ptr( - new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context)); + new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context)); if (task_context == nullptr) { REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str()); GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str()); @@ -85,7 +83,7 @@ std::unique_ptr TaskContext::Create(NodeState *node_state, task_context->node_item_ = &node_item; task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; - task_context->iteration_ = execution_context->iteration; + task_context->iteration_ = subgraph_context->execution_context_->iteration; return task_context; } @@ -460,6 +458,10 @@ Status TaskContext::PropagateOutputs() { subgraph_context_->all_inputs_[input_offset].SetName( node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); } + + auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); + GE_CHECK_NOTNULL(dst_node_state); + dst_node_state->SavePersistTensor(dst_input_idx, *tensor); } } (void)guard; @@ -489,11 +491,6 @@ void TaskContext::ReleaseInputsAndOutputs() { } void TaskContext::ReleaseInput(int index) { - if (node_item_->enter_inside_.count(index) > 0) { - GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); - return; - } - auto input_tensor = MutableInput(index); if (input_tensor != nullptr) { input_tensor->Destroy(); diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index ba4c62e6..c96e194e 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -36,9 +36,7 @@ class SubgraphContext; class TaskContext { public: - static std::unique_ptr Create(NodeState *node_state, - GraphExecutionContext *execution_context, - SubgraphContext *subgraph_context); + static std::unique_ptr Create(NodeState *node_state, SubgraphContext *subgraph_context); ~TaskContext(); diff --git a/metadef b/metadef index b27915cd..c6030152 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit b27915cd37919430a61953f8998b7acce4a60177 +Subproject commit c6030152c6dc05515115765babb5d64fde649df4 diff --git a/parser b/parser index e75eda62..155d3262 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit e75eda62de2b51a0bded5481ca81eb8fc7bf376e +Subproject commit 155d3262ba17f800094abb58b6a809b041cf0a74 diff --git a/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc b/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc index 5cf7569b..85328b27 100644 --- a/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc +++ b/tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc @@ -272,115 +272,3 @@ TEST_F(UtestGeAnchor, graph_utils_test) { EXPECT_EQ(GraphUtils::RemoveEdge(conv_node->GetOutDataAnchor(0), bn_node->GetInControlAnchor()), GRAPH_SUCCESS); EXPECT_EQ(GraphUtils::RemoveEdge(conv_node->GetOutDataAnchor(0), bn_node->GetInControlAnchor()), GRAPH_FAILED); } - -TEST_F(UtestGeAnchor, data_anchor_replace_peer) { - ComputeGraphPtr graph_ptr = std::make_shared("graph"); - OpDescPtr in_op_ptr = std::make_shared("in_op_1", "float"); - in_op_ptr->AddInputDesc("x1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - NodePtr node1 = graph_ptr->AddNode(in_op_ptr); - NodePtr node2 = graph_ptr->AddNode(in_op_ptr); - NodePtr node3 = graph_ptr->AddNode(in_op_ptr); - - OutDataAnchorPtr out_data_anchor = node1->GetOutDataAnchor(1); - InDataAnchorPtr in_data_anchor = node2->GetInDataAnchor(1); - EXPECT_EQ(out_data_anchor != nullptr, true); - EXPECT_EQ(in_data_anchor != nullptr, true); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(0)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(1)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(2)), GRAPH_SUCCESS); - - size_t out_idx = 0; - for (; out_idx < out_data_anchor->peer_anchors_.size(); out_idx++) { - if (out_data_anchor->peer_anchors_[out_idx].lock() == in_data_anchor) { - break; - } - } - EXPECT_EQ(out_idx, 1); - - size_t in_idx = 0; - for (; in_idx < in_data_anchor->peer_anchors_.size(); in_idx++) { - if (in_data_anchor->peer_anchors_[in_idx].lock() == out_data_anchor) { - break; - } - } - EXPECT_EQ(in_idx, 0); - - out_data_anchor->ReplacePeer(in_data_anchor, node3->GetInDataAnchor(1), node3->GetOutDataAnchor(1)); - - size_t out_idx1 = 0; - for (; out_idx1 < out_data_anchor->peer_anchors_.size(); out_idx1++) { - if (out_data_anchor->peer_anchors_[out_idx1].lock() == node3->GetInDataAnchor(1)) { - break; - } - } - EXPECT_EQ(out_idx1, out_idx); - - size_t in_idx1 = 0; - for (; in_idx1 < in_data_anchor->peer_anchors_.size(); in_idx1++) { - if (in_data_anchor->peer_anchors_[in_idx1].lock() == node3->GetOutDataAnchor(1)) { - break; - } - } - EXPECT_EQ(in_idx1, in_idx); -} - -TEST_F(UtestGeAnchor, graph_utils_insert_node) { - ComputeGraphPtr graph_ptr = std::make_shared("graph"); - OpDescPtr in_op_ptr = std::make_shared("in_op_1", "float"); - in_op_ptr->AddInputDesc("x1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddInputDesc("x3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - in_op_ptr->AddOutputDesc("y3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW)); - NodePtr node1 = graph_ptr->AddNode(in_op_ptr); - NodePtr node2 = graph_ptr->AddNode(in_op_ptr); - NodePtr node3 = graph_ptr->AddNode(in_op_ptr); - - OutDataAnchorPtr out_data_anchor = node1->GetOutDataAnchor(1); - InDataAnchorPtr in_data_anchor = node2->GetInDataAnchor(1); - EXPECT_EQ(out_data_anchor != nullptr, true); - EXPECT_EQ(in_data_anchor != nullptr, true); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(0)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(1)), GRAPH_SUCCESS); - EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(2)), GRAPH_SUCCESS); - - size_t out_idx = 0; - for (; out_idx < out_data_anchor->peer_anchors_.size(); out_idx++) { - if (out_data_anchor->peer_anchors_[out_idx].lock() == in_data_anchor) { - break; - } - } - EXPECT_EQ(out_idx, 1); - - size_t in_idx = 0; - for (; in_idx < in_data_anchor->peer_anchors_.size(); in_idx++) { - if (in_data_anchor->peer_anchors_[in_idx].lock() == out_data_anchor) { - break; - } - } - EXPECT_EQ(in_idx, 0); - - GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, in_data_anchor, node3); - - size_t out_idx1 = 0; - for (; out_idx1 < out_data_anchor->peer_anchors_.size(); out_idx1++) { - if (out_data_anchor->peer_anchors_[out_idx1].lock() == node3->GetInDataAnchor(0)) { - break; - } - } - EXPECT_EQ(out_idx1, out_idx); - - size_t in_idx1 = 0; - for (; in_idx1 < in_data_anchor->peer_anchors_.size(); in_idx1++) { - if (in_data_anchor->peer_anchors_[in_idx1].lock() == node3->GetOutDataAnchor(0)) { - break; - } - } - EXPECT_EQ(in_idx1, in_idx); -} diff --git a/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc b/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc index 0366446c..c91f68df 100644 --- a/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc +++ b/tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc @@ -30,6 +30,7 @@ #include "graph/model_serialize.h" #include "graph/detail/model_serialize_imp.h" +#include "graph/node_impl.h" #include "graph/ge_attr_value.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" @@ -1062,7 +1063,7 @@ TEST(UtestGeModelSerialize, test_model_serialize_imp_invalid_param) { auto graph = std::make_shared("test_graph"); auto node = graph->AddNode(std::make_shared()); - node->op_ = nullptr; + node->impl_->op_ = nullptr; ge::proto::ModelDef model_def; Model model; model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); diff --git a/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc b/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc index aa43ac99..838df735 100644 --- a/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc +++ b/tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc @@ -25,6 +25,7 @@ #include "graph/ge_attr_value.h" #include "graph/tensor.h" #include "graph/utils/tensor_utils.h" +#include "graph/ge_tensor_impl.h" #undef private #undef protected @@ -196,23 +197,6 @@ TEST_F(UtestGeTensor, test_shape_copy_move) { EXPECT_EQ(shape4.GetDimNum(), 3); } -TEST_F(UtestGeTensor, test_tensor_desc_invalid_null) { - GeTensorDesc tensor_desc(nullptr, nullptr); - EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); - EXPECT_EQ(tensor_desc.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensor_desc.MutableShape().shape_def_.GetProtoMsg(), nullptr); - - GeTensorDesc tensor_desc2; - EXPECT_EQ(tensor_desc2.GetDataType(), DT_FLOAT); - EXPECT_EQ(tensor_desc2.GetFormat(), FORMAT_ND); - - tensor_desc2.SetDataType(DT_DUAL_SUB_INT8); - EXPECT_EQ(tensor_desc2.GetDataType(), DT_DUAL_SUB_INT8); - - TensorUtils::SetWeightSize(tensor_desc, 100); - EXPECT_EQ(TensorUtils::GetWeightSize(tensor_desc), 0); -} - TEST_F(UtestGeTensor, test_tensor_invalid_null) { ProtoMsgOwner msg_owner; GeTensor tensor(msg_owner, nullptr); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 63579109..0d1ae079 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -121,6 +121,7 @@ set(GRAPH_SRC_FILES "${GE_CODE_DIR}/metadef/register/op_tiling.cpp" "${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc" "${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp" + "${GE_CODE_DIR}/metadef/register/op_tiling_registry_impl.cpp" ) set(PARSER_SRC_FILES diff --git a/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc b/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc index c8abadb5..da1abd0f 100644 --- a/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc +++ b/tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc @@ -20,9 +20,11 @@ #define protected public #include "graph/partition/dynamic_shape_partition.h" #include "compute_graph.h" +#include "graph/compute_graph_impl.h" #include "inc/framework/common/types.h" #include "utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" +#include "graph/common/omg_util.h" namespace ge { namespace { @@ -37,33 +39,33 @@ GeTensorDescPtr CreateTensorDesc(std::initializer_list shape, Format fo } class NodeBuilder { - public: - NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared(name, type); } - - NodeBuilder &AddInputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, - DataType data_type = DT_FLOAT) { - op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); - return *this; - } - - NodeBuilder &AddOutputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, - DataType data_type = DT_FLOAT) { - op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); - return *this; - } - - NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { - op_desc_->AddOutputDesc(tensor_desc->Clone()); - return *this; - } - - NodePtr Build(const ComputeGraphPtr &graph) { - NodePtr node = graph->AddNode(op_desc_); - return node; - } - - private: - OpDescPtr op_desc_; + public: + NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared(name, type); } + + NodeBuilder &AddInputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, + DataType data_type = DT_FLOAT) { + op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + NodeBuilder &AddOutputDesc(std::initializer_list shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, + DataType data_type = DT_FLOAT) { + op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); + return *this; + } + + NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { + op_desc_->AddOutputDesc(tensor_desc->Clone()); + return *this; + } + + NodePtr Build(const ComputeGraphPtr &graph) { + NodePtr node = graph->AddNode(op_desc_); + return node; + } + + private: + OpDescPtr op_desc_; }; } // namespace @@ -92,28 +94,137 @@ TEST_F(UtestDynamicShapePartition, single_op_scene_success) { EXPECT_EQ(partitioner.Partition(), SUCCESS); } +/******************************************************************************* + * | + * Merge1 + * Active / \ Active + * / \. + * / \. + * Merge2 \. + * Active/ \Active \. + * / \ \. + * Add Sub Relu + * | | | + * | | | + * Switch_f2 Switch_t2 | + * \ / | + * \ / | + * Less2 | + * | | + * | | + * Switch_f Switch_t + * | \ / | + * | Active | + * | | | + * | Less1 | + * | / \ | + * | / \ | + * Data Data + ******************************************************************************/ TEST_F(UtestDynamicShapePartition, merge_control_flow_group) { ComputeGraphPtr graph = std::make_shared("default"); AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, "session_graph_id"); - NodePtr data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); - NodePtr data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); - NodePtr merge = NodeBuilder("node2", MERGE).AddInputDesc({1}).AddInputDesc({1}) - .AddOutputDesc({1}).AddOutputDesc({}).Build(graph); - - GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2->GetOutDataAnchor(0), merge->GetInDataAnchor(1)); - - (void)AttrUtils::SetBool(data1->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - (void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); - (void)AttrUtils::SetBool(data2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - (void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); - (void)AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); - (void)AttrUtils::SetInt(merge->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); - - EXPECT_EQ(graph->sub_graph_.size(), 0); + auto data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + + auto less1 = NodeBuilder("less1", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active1 = NodeBuilder("active1", STREAMACTIVE).Build(graph); + auto switch_t = NodeBuilder("switch_t", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto switch_f = NodeBuilder("switch_f", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto const_01 = NodeBuilder("const_01", CONSTANT).AddOutputDesc({1}).Build(graph); + auto const_11 = NodeBuilder("const_11", CONSTANT).AddOutputDesc({1}).Build(graph); + + + auto less2 = NodeBuilder("less2", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active2 = NodeBuilder("active2", STREAMACTIVE).Build(graph); + auto switch_t2 = NodeBuilder("switch_t2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto switch_f2 = NodeBuilder("switch_f2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph); + auto const_02 = NodeBuilder("const_02", CONSTANT).AddOutputDesc({1}).Build(graph); + auto const_12 = NodeBuilder("const_12", CONSTANT).AddOutputDesc({1}).Build(graph); + + auto add2 = NodeBuilder("add2", ADD).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto sub2 = NodeBuilder("sub2", SUB).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto merge2 = NodeBuilder("merge2", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active_f2 = NodeBuilder("active_f2", STREAMACTIVE).Build(graph); + auto active_t2 = NodeBuilder("active_t2", STREAMACTIVE).Build(graph); + + auto relu1 = NodeBuilder("relu1", RELU).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto merge1 = NodeBuilder("merge1", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); + auto active_f1 = NodeBuilder("active_f1", STREAMACTIVE).Build(graph); + auto active_t1 = NodeBuilder("active_t1", STREAMACTIVE).Build(graph); + + auto output1 = NodeBuilder("noutput1", NETOUTPUT).AddInputDesc({1}).Build(graph); + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(0)); + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), less1->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0)); + GraphUtils::AddEdge(const_01->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1)); + GraphUtils::AddEdge(const_11->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1)); + GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor()); + GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor()); + GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor()); + + + GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less2->GetInDataAnchor(0)); + GraphUtils::AddEdge(less1->GetOutDataAnchor(0), less2->GetInDataAnchor(1)); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(0)); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(0)); + GraphUtils::AddEdge(const_02->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(1)); + GraphUtils::AddEdge(const_12->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(1)); + GraphUtils::AddEdge(less2->GetOutControlAnchor(), active2->GetInControlAnchor()); + GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t2->GetInControlAnchor()); + GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f2->GetInControlAnchor()); + + + GraphUtils::AddEdge(switch_f2->GetOutControlAnchor(), add2->GetInControlAnchor()); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), add2->GetInDataAnchor(0)); + GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(0)); + GraphUtils::AddEdge(add2->GetOutControlAnchor(), active_f2->GetInControlAnchor()); + GraphUtils::AddEdge(active_f2->GetOutControlAnchor(), merge2->GetInControlAnchor()); + + GraphUtils::AddEdge(switch_t2->GetOutControlAnchor(), sub2->GetInControlAnchor()); + GraphUtils::AddEdge(less2->GetOutDataAnchor(0), sub2->GetInDataAnchor(0)); + GraphUtils::AddEdge(sub2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1)); + GraphUtils::AddEdge(sub2->GetOutControlAnchor(), active_t2->GetInControlAnchor()); + GraphUtils::AddEdge(active_t2->GetOutControlAnchor(), merge2->GetInControlAnchor()); + + GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), less2->GetInControlAnchor()); + GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), relu1->GetInControlAnchor()); + + + GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge2->GetOutControlAnchor(), active_f1->GetInControlAnchor()); + GraphUtils::AddEdge(active_f1->GetOutControlAnchor(), merge1->GetInControlAnchor()); + + GraphUtils::AddEdge(data2->GetOutDataAnchor(0), relu1->GetInDataAnchor(1)); + GraphUtils::AddEdge(relu1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); + GraphUtils::AddEdge(relu1->GetOutControlAnchor(), active_t1->GetInControlAnchor()); + GraphUtils::AddEdge(active_t1->GetOutControlAnchor(), merge1->GetInControlAnchor()); + + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + + AttrUtils::SetBool(merge2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); + EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); + + SetControlFlowGroup(merge2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_f2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_t2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(active2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(active_t2, merge2->GetOpDesc()->GetId()); + SetControlFlowGroup(active_f2, merge2->GetOpDesc()->GetId()); + + SetControlFlowGroup(merge1, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_f, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(switch_t, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(active1, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(active_f1, merge1->GetOpDesc()->GetId()); + SetControlFlowGroup(active_t1, merge1->GetOpDesc()->GetId()); + + EXPECT_EQ(graph->impl_->sub_graph_.size(), 0); DynamicShapePartitioner partitioner(graph); EXPECT_EQ(partitioner.Partition(), SUCCESS); - EXPECT_EQ(graph->sub_graph_.size(), 1); + EXPECT_EQ(graph->impl_->sub_graph_.size(), 3); // input less1 uknown } } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc index 07022230..cc20d614 100644 --- a/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc +++ b/tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc @@ -83,18 +83,14 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) { execution_context.profiling_level = 1; SubgraphContext subgraph_context(nullptr, &execution_context); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); - auto shared_task_context = std::shared_ptr(task_context.release()); - node_state.SetTaskContext(shared_task_context); - - ExecutionEngine execution_engine; - ASSERT_TRUE(node_state.GetTaskContext() != nullptr); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); + ASSERT_TRUE(node_state->GetTaskContext() != nullptr); std::function callback; SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); - executor.InitCallback(&node_state, callback); - EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); + executor.InitCallback(node_state.get(), callback); + ExecutionEngine execution_engine; + EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); } TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { @@ -118,21 +114,18 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) { execution_context.model = &hybrid_model; SubgraphContext subgraph_context(nullptr, &execution_context); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); uint32_t task_id = 0; uint32_t stream_id = 1; std::string task_type = "rts"; uint32_t block_dim = 0; - task_context->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); - auto shared_task_context = std::shared_ptr(task_context.release()); - node_state.SetTaskContext(shared_task_context); + node_state->GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim); - ExecutionEngine execution_engine; - ASSERT_TRUE(node_state.GetTaskContext() != nullptr); + ASSERT_TRUE(node_state->GetTaskContext() != nullptr); std::function callback; SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context); - executor.InitCallback(&node_state, callback); - EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR); + executor.InitCallback(node_state.get(), callback); + ExecutionEngine execution_engine; + EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR); } diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 088aec50..4f14f628 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -40,6 +40,7 @@ #include "graph/types.h" #include "graph/utils/tensor_utils.h" #include "graph/testcase/ge_graph/graph_builder_utils.h" +#include "graph/op_desc_impl.h" #undef private #undef protected @@ -159,11 +160,9 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) { GraphExecutionContext execution_context; SubgraphContext subgraph_context(nullptr, &execution_context); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); - ASSERT_TRUE(task_context != nullptr); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); - ASSERT_EQ(aicore_task->UpdateTilingInfo(*task_context), SUCCESS); + ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS); } TEST_F(UtestGeHybrid, index_taskdefs_failed) { @@ -477,12 +476,14 @@ TEST_F(UtestGeHybrid, TestTaskContext) { node_item->output_start = 0; GraphExecutionContext execution_context; - SubgraphContext subgraph_context(nullptr, &execution_context); + GraphItem graph_item; + SubgraphContext subgraph_context(&graph_item, &execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); subgraph_context.all_inputs_.resize(2); subgraph_context.all_outputs_.resize(1); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); + auto task_context = node_state->GetTaskContext(); ASSERT_TRUE(task_context != nullptr); auto desc = task_context->MutableInputDesc(2); ASSERT_TRUE(desc == nullptr); @@ -522,12 +523,14 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) { node_item->output_start = 0; GraphExecutionContext execution_context; - SubgraphContext subgraph_context(nullptr, &execution_context); + GraphItem graph_item; + SubgraphContext subgraph_context(&graph_item, &execution_context); + ASSERT_EQ(subgraph_context.Init(), SUCCESS); subgraph_context.all_inputs_.resize(2); subgraph_context.all_outputs_.resize(1); - NodeState node_state(*node_item, &subgraph_context); - auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get()); + auto task_context = node_state->GetTaskContext(); int32_t buffer[1]; aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer)); @@ -737,7 +740,7 @@ TEST_F(UtestGeHybrid, TestParseDependencies) { std::vector deps; deps.push_back("Data"); auto op_desc = netoutput->GetOpDesc(); - op_desc->input_name_idx_["Data"] = 0; + op_desc->impl_->input_name_idx_["Data"] = 0; auto data_desc = data->GetOpDesc(); auto tensor = std::make_shared(); auto tensor_desc = data_desc->MutableInputDesc(0); diff --git a/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc index a7a407a4..e4d211f9 100644 --- a/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc @@ -97,11 +97,6 @@ TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; GeLocalNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); diff --git a/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc index afaf067e..8e6630f6 100644 --- a/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc @@ -94,18 +94,17 @@ TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) { tensor.SetData(data); ctx->SetTensor(1, 0, tensor.Clone()); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); vector addr_infos; shared_ptr task = MakeShared(); task->remote_index_ = {1, 0}; - ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); + ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID); Shape s2({1}); TensorDesc tensor_desc2(s2); Tensor tensor2(tensor_desc2); ctx->SetTensor(1, 0, tensor2.Clone()); - task->ExtractTensor(*unique_task_context, addr_infos); - ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID); + task->ExtractTensor(*node_state->GetTaskContext(), addr_infos); + ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID); RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id)); } @@ -140,11 +139,6 @@ TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - for (int i=0; i<4; ++i) { uint64_t value_0 = 512; TensorValue in_tensor0(&value_0, sizeof(value_0)); @@ -206,11 +200,6 @@ TEST_F(UtestHcclNodeExecutor, alltoallv_execute) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - for (int i=0; i<5; ++i) { uint64_t value_0 = 512; TensorValue in_tensor0(&value_0, sizeof(value_0)); diff --git a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc index 44b2f37f..109e5192 100644 --- a/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc +++ b/tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc @@ -96,11 +96,6 @@ TEST_F(UtestRtsNodeTask, test_stream_switch_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; uint64_t value_1 = 120; TensorValue in_tensor0(&value_0, sizeof(value_0)); @@ -153,11 +148,6 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); @@ -203,11 +193,6 @@ TEST_F(UtestRtsNodeTask, test_stream_merge_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; TensorValue in_tensor0(&value_0, sizeof(value_0)); subgraph_context.SetInput(*node_item, 0, in_tensor0); @@ -271,11 +256,6 @@ TEST_F(UtestRtsNodeTask, test_memcpy_async_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; TensorValue in_tensor0(&value_0, sizeof(value_0)); subgraph_context.SetInput(*node_item, 0, in_tensor0); @@ -328,11 +308,6 @@ TEST_F(UtestRtsNodeTask, test_pass_through_task) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - uint64_t value_0 = 110; TensorValue in_tensor0(&value_0, sizeof(value_0)); subgraph_context.SetInput(*node_item, 0, in_tensor0); @@ -384,11 +359,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_set) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); @@ -428,11 +398,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS); @@ -472,11 +437,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) { auto node_state = subgraph_context.GetOrCreateNodeState(node_item); ASSERT_NE(node_state, nullptr); - auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context); - ASSERT_NE(unique_task_context, nullptr); - auto shared_task_context = std::shared_ptr(unique_task_context.release()); - node_state->SetTaskContext(shared_task_context); - NodeTaskPtr task = nullptr; RtsNodeExecutor node_executor; ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);