zhaozhixuan 3 years ago
parent
commit
3a54dc6dd8
35 changed files with 361 additions and 414 deletions
  1. +6
    -6
      cmake/external_libs/protobuf_shared.cmake
  2. +4
    -6
      cmake/external_libs/protobuf_static.cmake
  3. +5
    -7
      cmake/external_libs/protoc.cmake
  4. +0
    -15
      ge/graph/common/omg_util.cc
  5. +0
    -9
      ge/graph/common/omg_util.h
  6. +0
    -2
      ge/graph/optimize/graph_optimize.cc
  7. +22
    -1
      ge/graph/partition/dynamic_shape_partition.cc
  8. +1
    -1
      ge/graph/partition/dynamic_shape_partition.h
  9. +8
    -30
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  10. +6
    -0
      ge/graph/passes/mark_graph_unknown_status_pass.cc
  11. +2
    -3
      ge/graph/passes/merge_to_stream_merge_pass.cc
  12. +9
    -1
      ge/graph/passes/next_iteration_pass.cc
  13. +8
    -8
      ge/graph/passes/switch_to_stream_switch_pass.cc
  14. +58
    -8
      ge/hybrid/executor/node_state.cc
  15. +6
    -8
      ge/hybrid/executor/node_state.h
  16. +19
    -8
      ge/hybrid/executor/subgraph_context.cc
  17. +3
    -2
      ge/hybrid/executor/subgraph_context.h
  18. +3
    -15
      ge/hybrid/executor/subgraph_executor.cc
  19. +0
    -1
      ge/hybrid/executor/subgraph_executor.h
  20. +2
    -3
      ge/hybrid/model/node_item.cc
  21. +2
    -3
      ge/hybrid/model/node_item.h
  22. +7
    -10
      ge/hybrid/node_executor/task_context.cc
  23. +1
    -3
      ge/hybrid/node_executor/task_context.h
  24. +1
    -1
      metadef
  25. +1
    -1
      parser
  26. +0
    -112
      tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc
  27. +2
    -1
      tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc
  28. +1
    -17
      tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc
  29. +1
    -0
      tests/ut/ge/CMakeLists.txt
  30. +155
    -44
      tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc
  31. +11
    -18
      tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc
  32. +14
    -11
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc
  33. +0
    -5
      tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc
  34. +3
    -14
      tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc
  35. +0
    -40
      tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc

+ 6
- 6
cmake/external_libs/protobuf_shared.cmake View File

@@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()
if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif ()
endif()

@@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR})
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR})


+ 4
- 6
cmake/external_libs/protobuf_static.cmake View File

@@ -16,11 +16,11 @@ if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif ()
endif()

@@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static)
ExternalProject_Add(protobuf_static_build
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}


+ 5
- 7
cmake/external_libs/protoc.cmake View File

@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif()

if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif ()
endif()

@@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake
BUILD_COMMAND $(MAKE)


+ 0
- 15
ge/graph/common/omg_util.cc View File

@@ -274,21 +274,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
return false;
}

///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) {
if (!force_unknown) {
return;
}

SetControlFlowGroup(node, group_index);
}

///
/// @brief Set Op _control_flow_group flag
/// @param [in] node


+ 0
- 9
ge/graph/common/omg_util.h View File

@@ -125,15 +125,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size);
///
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);

///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index);

///
/// @brief Set Op _control_flow_group flag
/// @param [in] node


+ 0
- 2
ge/graph/optimize/graph_optimize.cc View File

@@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) {
GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str());
continue;
}
#ifndef ONLY_COMPILE_OPEN_SRC
GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str());
ret = (iter->second)->OptimizeAfterStage1(*compute_graph);
#endif
if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, "
"graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str());


+ 22
- 1
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -364,6 +364,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
}

void DynamicShapePartitioner::MergeClustersControlFlow() {
std::unordered_set<ClusterPtr> all_merged_clusters;
for (const auto &item : control_clusters_) {
const auto &control_cluster = item.second;
auto rit = control_cluster.rbegin();
@@ -373,17 +374,32 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
}

const auto &cluster = *rit;
if (all_merged_clusters.count(cluster) > 0) {
continue;
}

bool is_unknown_cluster = cluster->IsUnknownShape();
for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit;
if (all_merged_clusters.count(cluster_from) > 0) {
continue;
}

auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
ToString(merged_clusters).c_str());
for (const auto &merged_cluster : merged_clusters) {
all_merged_clusters.emplace(merged_cluster);
for (const auto &node : merged_cluster->Nodes()) {
node_2_cluster_[node] = cluster;
}
}
}

if (!is_unknown_cluster && cluster->IsUnknownShape()) {
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str());
ordered_cluster_.push_back(cluster);
}
}
}

@@ -703,7 +719,12 @@ void Cluster::Merge(ClusterPtr other) {
if (other->min_ < min_) {
min_ = other->min_;
}
};

if (!IsUnknownShape() && other->IsUnknownShape()) {
type_ = UNKNOWN_SHAPE;
}
}

bool Cluster::TryMerge(ClusterPtr other) {
std::queue<ClusterPtr> forward_reached;
forward_reached.push(other);


+ 1
- 1
ge/graph/partition/dynamic_shape_partition.h View File

@@ -161,7 +161,7 @@ class DynamicShapePartitioner {
ge::ComputeGraphPtr root_graph_; // The original graph to partition
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to
// V1 control flow cluster, need merge to one Graph.
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
// topological sorted clusters, this field will change with the splitting.
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters


+ 8
- 30
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -132,39 +132,17 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
/// @return
///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP);
};

for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) {
const auto &op_node1 = it1->first;
const auto &op_desc1 = op_node1->GetOpDesc();
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) {
const auto &op_node = it->first;
const auto &op_desc = op_node->GetOpDesc();
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) {
int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
MarkForceUnknownShape(op_node1, true, group_index);
for (const auto &n : it1->second) {
MarkForceUnknownShape(n, true, group_index);
}

for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
MarkForceUnknownShape(op_node2, true, group_index);
for (const auto &n : it2->second) {
MarkForceUnknownShape(n, true, group_index);
}
}
}
int64_t group_index = op_desc->GetId();
SetControlFlowGroup(op_node, group_index);
for (const auto &n : it->second) {
SetControlFlowGroup(n, group_index);
}
}
}


+ 6
- 0
ge/graph/passes/mark_graph_unknown_status_pass.cc View File

@@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) {
}
}

const auto &node = graph->GetParentNode();
if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) {
GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape),
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str());
}

for (const auto &node : graph->GetDirectNode()) {
GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str());
(void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape);


+ 2
- 3
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
return FAILED, "[Check][Param] Param of pre node is nullptr.");
int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(node, force_unknown, group_index);
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str());
return FAILED;
}
MarkForceUnknownShape(active_node, force_unknown, group_index);
SetControlFlowGroup(active_node, group_index);
}

return SUCCESS;


+ 9
- 1
ge/graph/passes/next_iteration_pass.cc View File

@@ -284,13 +284,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
/// @return void
///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) {
std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type;
(void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(node, group_index);
} else {
// For: Switch -> Cast -> Exit
for (const auto &n : node->GetOutDataNodes()) {
(void)GetOriginalType(n, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(n, group_index);
}
}
}
}
}


+ 8
- 8
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());

int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(stream_switch, force_unknown, group_index);
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
SetControlFlowGroup(stream_switch, group_index);
return stream_switch;
}

@@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) {
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) {
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) {
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
std::set<NodePtr> same_cond_switch;
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end());
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());
@@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) {
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
};
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
MarkForceUnknownShape(active_node, is_unknown_shape, group_index);
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
SetControlFlowGroup(active_node, group_index);

const std::string &cond_group = cond_node->GetName();
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
GE_IF_BOOL_EXEC(switch_list.empty(), continue);

// select first stream_switch
@@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
"[Add][Edge] between %s and %s failed.",
cast_node->GetName().c_str(), stream_switch->GetName().c_str());

MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index);
SetControlFlowGroup(stream_switch, group_index);
for (const NodePtr &node : switch_list) {
GE_IF_BOOL_EXEC(node != stream_switch, {
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),


+ 58
- 8
ge/hybrid/executor/node_state.cc View File

@@ -19,8 +19,9 @@
#include "framework/common/debug/log.h"
#include "graph/compute_graph.h"
#include "graph/utils/tensor_utils.h"
#include "hybrid_execution_context.h"
#include "subgraph_context.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/task_context.h"

#define INC_ITERATION_COUNT(iteration) \
do { \
@@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex
this->op_desc_ = node_item.node->GetOpDesc();
}

Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) {
GE_CHECK_NOTNULL(frame_state);
group_ = group;
frame_state_ = frame_state;
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());
return SUCCESS;
}

Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
if (node_item_->IsMergeOp()) {
GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size());
@@ -314,15 +325,54 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_;
}

void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) {
if (node_item_->root_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
}

if (node_item_->enter_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
}
}

void NodeState::UpdatePersistTensor(int input_idx) {
const auto it = root_tensor_values_.find(input_idx);
if (it == root_tensor_values_.end()) {
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx);
return;
}

auto tensor = task_context_->MutableInput(input_idx);
if (tensor == nullptr) {
GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx);
return;
}

*tensor = it->second;
GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx);
}

void NodeState::ResetContext(uint64_t iteration) {
switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node);
if (iteration == 0) {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
} else {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size());
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());

data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
for (auto item : node_item_->root_data_) {
UpdatePersistTensor(item.first);
}

if (iteration > 0) {
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size());
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size());
for (auto item : node_item_->enter_data_) {
UpdatePersistTensor(item.first);
}
}

iteration_count_ = iteration;


+ 6
- 8
ge/hybrid/executor/node_state.h View File

@@ -100,6 +100,8 @@ struct NodeState {
NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context);
~NodeState() = default;

Status Init(int group, const shared_ptr<FrameState> &frame_state);

OpDesc *GetOpDesc() const {
return op_desc_.get();
}
@@ -129,6 +131,8 @@ struct NodeState {
void RunStreamActive();
void RunNextIteration();

void SavePersistTensor(int input_idx, const TensorValue &tensor);

Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;

void SetScheduleFuture(std::future<Status> &&future);
@@ -150,18 +154,10 @@ struct NodeState {
return merge_index_;
}

void SetGroup(int group) {
group_ = group;
}

int GetGroup() const {
return group_;
}

void SetFrameState(const shared_ptr<FrameState> &frame_state) {
frame_state_ = frame_state;
}

const shared_ptr<NodeTask> &GetKernelTask() const {
return kernel_task_;
}
@@ -187,6 +183,7 @@ struct NodeState {
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t iteration);
void ScheduleContext(const NodeState &node_state);
void UpdatePersistTensor(int input_idx);

const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr;
@@ -199,6 +196,7 @@ struct NodeState {

std::future<Status> schedule_future_;
std::shared_ptr<FrameState> frame_state_;
std::map<int, TensorValue> root_tensor_values_;
uint64_t active_count_ = 0;
uint64_t iteration_count_ = 0;
uint32_t ctrl_scheduled_ = 0;


+ 19
- 8
ge/hybrid/executor/subgraph_context.cc View File

@@ -19,7 +19,7 @@

namespace ge {
namespace hybrid {
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context)
SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context)
: graph_item_(graph_item), execution_context_(execution_context) {
}

@@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
return nullptr;
}

return CreateNodeState(node_item);
}

NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) {
GELOGD("[%s] lock for write", node_item->NodeName().c_str());
if (mmRWLockWRLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str());
GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str());
return nullptr;
}

auto &node_state = node_states_[node_item];
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
node_state->SetFrameState(GetOrCreateFrameState(*node_item));
node_state->SetGroup(group_);
(void)guard;
}
do {
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str());
REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str());
break;
}
(void)guard;
}
} while (0);

GELOGD("[%s] unlock for write", node_item->NodeName().c_str());
if (mmWRLockUnLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str());


+ 3
- 2
ge/hybrid/executor/subgraph_context.h View File

@@ -30,7 +30,7 @@ namespace ge {
namespace hybrid {
class SubgraphContext {
public:
explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context);
explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context);
~SubgraphContext();

Status Init();
@@ -51,10 +51,11 @@ class SubgraphContext {
void NodeDone(const NodePtr &node);

private:
NodeStatePtr CreateNodeState(const NodeItem *node_item);
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock
friend class TaskContext;
const GraphItem *graph_item_;
const GraphExecutionContext *execution_context_;
GraphExecutionContext *execution_context_;
mmRWLock_t rw_lock_;
std::vector<TensorValue> all_inputs_;
std::vector<TensorValue> all_outputs_;


+ 3
- 15
ge/hybrid/executor/subgraph_executor.cc View File

@@ -175,16 +175,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue
GE_CHECK_NOTNULL(node_state);
node_state->SetKernelTask(node_item->kernel_task);

known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(known_shape_task_context_);
node_state->SetTaskContext(known_shape_task_context_);

std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback));
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback),
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback),
"[%s] Failed to execute node [%s] for known subgraph.",
graph_item_->GetName().c_str(),
known_shape_task_context_->GetNodeName());
node_state->GetName().c_str());

GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str());
return SUCCESS;
@@ -271,16 +267,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) {
} else {
node_state->SetKernelTask(node_item.kernel_task);
}
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str());
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state));
return AfterPrepared(p_node_state);
}
@@ -480,19 +472,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else {
node_state.SetKernelTask(node_item.kernel_task);
}
auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state.GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str());
REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state.SetTaskContext(shared_task_context);
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context));
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start");
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end");
return SUCCESS;
}


+ 0
- 1
ge/hybrid/executor/subgraph_executor.h View File

@@ -127,7 +127,6 @@ class SubgraphExecutor {
ThreadPool pre_run_pool_;
BlockingQueue<NodeState *> ready_queue_;
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_;
std::shared_ptr<TaskContext> known_shape_task_context_;

std::mutex mu_; // Guard for prepare_queues_.
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_;


+ 2
- 3
ge/hybrid/model/node_item.cc View File

@@ -398,12 +398,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_send_.emplace(node_item);
node_item->data_recv_[this] = anchor_index;
if (is_root_node_) {
node_item->root_data_.emplace(this);
node_item->root_data_[anchor_index] = this;
}
// If Enter feed Not Merge, take as root Node.
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) {
node_item->enter_data_.emplace(this);
node_item->enter_inside_.emplace(anchor_index);
node_item->enter_data_[anchor_index] = this;
}
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
}


+ 2
- 3
ge/hybrid/model/node_item.h View File

@@ -148,15 +148,14 @@ struct NodeItem {
int64_t frame_index_ = -1;
int64_t parent_frame_ = -1;
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node
std::set<const NodeItem *> root_data_; // Recv data from root node
std::map<int, const NodeItem *> root_data_; // Recv data from root node
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node
std::set<const NodeItem *> enter_data_; // Recv data from Enter node
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node
std::set<const NodeItem *> data_send_; // Send data notify to
std::map<const NodeItem *, int> data_recv_; // Recv data notify from
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge.

std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph;


+ 7
- 10
ge/hybrid/node_executor/task_context.cc View File

@@ -52,9 +52,7 @@ void TaskContext::ReleaseWorkspace() {
}
}

std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
GraphExecutionContext *execution_context,
SubgraphContext *subgraph_context) {
std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, SubgraphContext *subgraph_context) {
const NodeItem &node_item = *node_state->GetNodeItem();
GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.",
node_item.NodeName().c_str(),
@@ -75,7 +73,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
}

auto task_context = std::unique_ptr<TaskContext>(
new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context));
new(std::nothrow)TaskContext(subgraph_context->execution_context_, node_state, subgraph_context));
if (task_context == nullptr) {
REPORT_CALL_ERROR("E19999", "Create TaskContext failed for [%s].", node_item.NodeName().c_str());
GELOGE(MEMALLOC_FAILED, "[Create][TaskContext] failed for [%s].", node_item.NodeName().c_str());
@@ -85,7 +83,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state,
task_context->node_item_ = &node_item;
task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start;
task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start;
task_context->iteration_ = execution_context->iteration;
task_context->iteration_ = subgraph_context->execution_context_->iteration;
return task_context;
}

@@ -460,6 +458,10 @@ Status TaskContext::PropagateOutputs() {
subgraph_context_->all_inputs_[input_offset].SetName(
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx));
}

auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SavePersistTensor(dst_input_idx, *tensor);
}
}
(void)guard;
@@ -489,11 +491,6 @@ void TaskContext::ReleaseInputsAndOutputs() {
}

void TaskContext::ReleaseInput(int index) {
if (node_item_->enter_inside_.count(index) > 0) {
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index);
return;
}

auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) {
input_tensor->Destroy();


+ 1
- 3
ge/hybrid/node_executor/task_context.h View File

@@ -36,9 +36,7 @@ class SubgraphContext;

class TaskContext {
public:
static std::unique_ptr<TaskContext> Create(NodeState *node_state,
GraphExecutionContext *execution_context,
SubgraphContext *subgraph_context);
static std::unique_ptr<TaskContext> Create(NodeState *node_state, SubgraphContext *subgraph_context);

~TaskContext();



+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit b27915cd37919430a61953f8998b7acce4a60177
Subproject commit c6030152c6dc05515115765babb5d64fde649df4

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit e75eda62de2b51a0bded5481ca81eb8fc7bf376e
Subproject commit 155d3262ba17f800094abb58b6a809b041cf0a74

+ 0
- 112
tests/ut/common/graph/testcase/ge_graph/ge_graph_anchor_unittest.cc View File

@@ -272,115 +272,3 @@ TEST_F(UtestGeAnchor, graph_utils_test) {
EXPECT_EQ(GraphUtils::RemoveEdge(conv_node->GetOutDataAnchor(0), bn_node->GetInControlAnchor()), GRAPH_SUCCESS);
EXPECT_EQ(GraphUtils::RemoveEdge(conv_node->GetOutDataAnchor(0), bn_node->GetInControlAnchor()), GRAPH_FAILED);
}

TEST_F(UtestGeAnchor, data_anchor_replace_peer) {
ComputeGraphPtr graph_ptr = std::make_shared<ComputeGraph>("graph");
OpDescPtr in_op_ptr = std::make_shared<OpDesc>("in_op_1", "float");
in_op_ptr->AddInputDesc("x1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddInputDesc("x2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddInputDesc("x3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddOutputDesc("y1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddOutputDesc("y2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddOutputDesc("y3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
NodePtr node1 = graph_ptr->AddNode(in_op_ptr);
NodePtr node2 = graph_ptr->AddNode(in_op_ptr);
NodePtr node3 = graph_ptr->AddNode(in_op_ptr);

OutDataAnchorPtr out_data_anchor = node1->GetOutDataAnchor(1);
InDataAnchorPtr in_data_anchor = node2->GetInDataAnchor(1);
EXPECT_EQ(out_data_anchor != nullptr, true);
EXPECT_EQ(in_data_anchor != nullptr, true);
EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(0)), GRAPH_SUCCESS);
EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(1)), GRAPH_SUCCESS);
EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(2)), GRAPH_SUCCESS);

size_t out_idx = 0;
for (; out_idx < out_data_anchor->peer_anchors_.size(); out_idx++) {
if (out_data_anchor->peer_anchors_[out_idx].lock() == in_data_anchor) {
break;
}
}
EXPECT_EQ(out_idx, 1);

size_t in_idx = 0;
for (; in_idx < in_data_anchor->peer_anchors_.size(); in_idx++) {
if (in_data_anchor->peer_anchors_[in_idx].lock() == out_data_anchor) {
break;
}
}
EXPECT_EQ(in_idx, 0);

out_data_anchor->ReplacePeer(in_data_anchor, node3->GetInDataAnchor(1), node3->GetOutDataAnchor(1));

size_t out_idx1 = 0;
for (; out_idx1 < out_data_anchor->peer_anchors_.size(); out_idx1++) {
if (out_data_anchor->peer_anchors_[out_idx1].lock() == node3->GetInDataAnchor(1)) {
break;
}
}
EXPECT_EQ(out_idx1, out_idx);

size_t in_idx1 = 0;
for (; in_idx1 < in_data_anchor->peer_anchors_.size(); in_idx1++) {
if (in_data_anchor->peer_anchors_[in_idx1].lock() == node3->GetOutDataAnchor(1)) {
break;
}
}
EXPECT_EQ(in_idx1, in_idx);
}

TEST_F(UtestGeAnchor, graph_utils_insert_node) {
ComputeGraphPtr graph_ptr = std::make_shared<ComputeGraph>("graph");
OpDescPtr in_op_ptr = std::make_shared<OpDesc>("in_op_1", "float");
in_op_ptr->AddInputDesc("x1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddInputDesc("x2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddInputDesc("x3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddOutputDesc("y1", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddOutputDesc("y2", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
in_op_ptr->AddOutputDesc("y3", GeTensorDesc(GeShape({1, 32, 8, 8}), FORMAT_NCHW));
NodePtr node1 = graph_ptr->AddNode(in_op_ptr);
NodePtr node2 = graph_ptr->AddNode(in_op_ptr);
NodePtr node3 = graph_ptr->AddNode(in_op_ptr);

OutDataAnchorPtr out_data_anchor = node1->GetOutDataAnchor(1);
InDataAnchorPtr in_data_anchor = node2->GetInDataAnchor(1);
EXPECT_EQ(out_data_anchor != nullptr, true);
EXPECT_EQ(in_data_anchor != nullptr, true);
EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(0)), GRAPH_SUCCESS);
EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(1)), GRAPH_SUCCESS);
EXPECT_EQ(node1->GetOutDataAnchor(1)->LinkTo(node2->GetInDataAnchor(2)), GRAPH_SUCCESS);

size_t out_idx = 0;
for (; out_idx < out_data_anchor->peer_anchors_.size(); out_idx++) {
if (out_data_anchor->peer_anchors_[out_idx].lock() == in_data_anchor) {
break;
}
}
EXPECT_EQ(out_idx, 1);

size_t in_idx = 0;
for (; in_idx < in_data_anchor->peer_anchors_.size(); in_idx++) {
if (in_data_anchor->peer_anchors_[in_idx].lock() == out_data_anchor) {
break;
}
}
EXPECT_EQ(in_idx, 0);

GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, in_data_anchor, node3);

size_t out_idx1 = 0;
for (; out_idx1 < out_data_anchor->peer_anchors_.size(); out_idx1++) {
if (out_data_anchor->peer_anchors_[out_idx1].lock() == node3->GetInDataAnchor(0)) {
break;
}
}
EXPECT_EQ(out_idx1, out_idx);

size_t in_idx1 = 0;
for (; in_idx1 < in_data_anchor->peer_anchors_.size(); in_idx1++) {
if (in_data_anchor->peer_anchors_[in_idx1].lock() == node3->GetOutDataAnchor(0)) {
break;
}
}
EXPECT_EQ(in_idx1, in_idx);
}

+ 2
- 1
tests/ut/common/graph/testcase/ge_graph/ge_model_serialize_unittest.cc View File

@@ -30,6 +30,7 @@
#include "graph/model_serialize.h"

#include "graph/detail/model_serialize_imp.h"
#include "graph/node_impl.h"
#include "graph/ge_attr_value.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
@@ -1062,7 +1063,7 @@ TEST(UtestGeModelSerialize, test_model_serialize_imp_invalid_param) {

auto graph = std::make_shared<ComputeGraph>("test_graph");
auto node = graph->AddNode(std::make_shared<OpDesc>());
node->op_ = nullptr;
node->impl_->op_ = nullptr;
ge::proto::ModelDef model_def;
Model model;
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));


+ 1
- 17
tests/ut/common/graph/testcase/ge_graph/ge_tensor_unittest.cc View File

@@ -25,6 +25,7 @@
#include "graph/ge_attr_value.h"
#include "graph/tensor.h"
#include "graph/utils/tensor_utils.h"
#include "graph/ge_tensor_impl.h"
#undef private
#undef protected

@@ -196,23 +197,6 @@ TEST_F(UtestGeTensor, test_shape_copy_move) {
EXPECT_EQ(shape4.GetDimNum(), 3);
}

TEST_F(UtestGeTensor, test_tensor_desc_invalid_null) {
GeTensorDesc tensor_desc(nullptr, nullptr);
EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED);
EXPECT_EQ(tensor_desc.GetFormat(), FORMAT_RESERVED);
EXPECT_EQ(tensor_desc.MutableShape().shape_def_.GetProtoMsg(), nullptr);

GeTensorDesc tensor_desc2;
EXPECT_EQ(tensor_desc2.GetDataType(), DT_FLOAT);
EXPECT_EQ(tensor_desc2.GetFormat(), FORMAT_ND);

tensor_desc2.SetDataType(DT_DUAL_SUB_INT8);
EXPECT_EQ(tensor_desc2.GetDataType(), DT_DUAL_SUB_INT8);

TensorUtils::SetWeightSize(tensor_desc, 100);
EXPECT_EQ(TensorUtils::GetWeightSize(tensor_desc), 0);
}

TEST_F(UtestGeTensor, test_tensor_invalid_null) {
ProtoMsgOwner msg_owner;
GeTensor tensor(msg_owner, nullptr);


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -121,6 +121,7 @@ set(GRAPH_SRC_FILES
"${GE_CODE_DIR}/metadef/register/op_tiling.cpp"
"${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc"
"${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp"
"${GE_CODE_DIR}/metadef/register/op_tiling_registry_impl.cpp"
)

set(PARSER_SRC_FILES


+ 155
- 44
tests/ut/ge/graph/partition/dynamic_shape_partition_unittest.cc View File

@@ -20,9 +20,11 @@
#define protected public
#include "graph/partition/dynamic_shape_partition.h"
#include "compute_graph.h"
#include "graph/compute_graph_impl.h"
#include "inc/framework/common/types.h"
#include "utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/common/omg_util.h"

namespace ge {
namespace {
@@ -37,33 +39,33 @@ GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format fo
}

class NodeBuilder {
public:
NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT) {
op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}
NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT) {
op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}
NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) {
op_desc_->AddOutputDesc(tensor_desc->Clone());
return *this;
}
NodePtr Build(const ComputeGraphPtr &graph) {
NodePtr node = graph->AddNode(op_desc_);
return node;
}
private:
OpDescPtr op_desc_;
public:
NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT) {
op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}
NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
DataType data_type = DT_FLOAT) {
op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}
NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) {
op_desc_->AddOutputDesc(tensor_desc->Clone());
return *this;
}
NodePtr Build(const ComputeGraphPtr &graph) {
NodePtr node = graph->AddNode(op_desc_);
return node;
}
private:
OpDescPtr op_desc_;
};
} // namespace

@@ -92,28 +94,137 @@ TEST_F(UtestDynamicShapePartition, single_op_scene_success) {
EXPECT_EQ(partitioner.Partition(), SUCCESS);
}

/*******************************************************************************
* |
* Merge1
* Active / \ Active
* / \.
* / \.
* Merge2 \.
* Active/ \Active \.
* / \ \.
* Add Sub Relu
* | | |
* | | |
* Switch_f2 Switch_t2 |
* \ / |
* \ / |
* Less2 |
* | |
* | |
* Switch_f Switch_t
* | \ / |
* | Active |
* | | |
* | Less1 |
* | / \ |
* | / \ |
* Data Data
******************************************************************************/
TEST_F(UtestDynamicShapePartition, merge_control_flow_group) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("default");
AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, "session_graph_id");

NodePtr data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
NodePtr data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
NodePtr merge = NodeBuilder("node2", MERGE).AddInputDesc({1}).AddInputDesc({1})
.AddOutputDesc({1}).AddOutputDesc({}).Build(graph);

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), merge->GetInDataAnchor(1));

(void)AttrUtils::SetBool(data1->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3);
(void)AttrUtils::SetBool(data2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);
(void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3);
(void)AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);
(void)AttrUtils::SetInt(merge->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3);

EXPECT_EQ(graph->sub_graph_.size(), 0);
auto data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);

auto less1 = NodeBuilder("less1", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto active1 = NodeBuilder("active1", STREAMACTIVE).Build(graph);
auto switch_t = NodeBuilder("switch_t", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph);
auto switch_f = NodeBuilder("switch_f", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph);
auto const_01 = NodeBuilder("const_01", CONSTANT).AddOutputDesc({1}).Build(graph);
auto const_11 = NodeBuilder("const_11", CONSTANT).AddOutputDesc({1}).Build(graph);


auto less2 = NodeBuilder("less2", LESS).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto active2 = NodeBuilder("active2", STREAMACTIVE).Build(graph);
auto switch_t2 = NodeBuilder("switch_t2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph);
auto switch_f2 = NodeBuilder("switch_f2", STREAMSWITCH).AddInputDesc({1}).AddInputDesc({1}).Build(graph);
auto const_02 = NodeBuilder("const_02", CONSTANT).AddOutputDesc({1}).Build(graph);
auto const_12 = NodeBuilder("const_12", CONSTANT).AddOutputDesc({1}).Build(graph);

auto add2 = NodeBuilder("add2", ADD).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto sub2 = NodeBuilder("sub2", SUB).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto merge2 = NodeBuilder("merge2", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto active_f2 = NodeBuilder("active_f2", STREAMACTIVE).Build(graph);
auto active_t2 = NodeBuilder("active_t2", STREAMACTIVE).Build(graph);

auto relu1 = NodeBuilder("relu1", RELU).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto merge1 = NodeBuilder("merge1", STREAMMERGE).AddInputDesc({1}).AddInputDesc({1}).AddOutputDesc({1}).Build(graph);
auto active_f1 = NodeBuilder("active_f1", STREAMACTIVE).Build(graph);
auto active_t1 = NodeBuilder("active_t1", STREAMACTIVE).Build(graph);

auto output1 = NodeBuilder("noutput1", NETOUTPUT).AddInputDesc({1}).Build(graph);

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0));
GraphUtils::AddEdge(const_01->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1));
GraphUtils::AddEdge(const_11->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutControlAnchor(), active1->GetInControlAnchor());
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_t->GetInControlAnchor());
GraphUtils::AddEdge(active1->GetOutControlAnchor(), switch_f->GetInControlAnchor());


GraphUtils::AddEdge(data1->GetOutDataAnchor(0), less2->GetInDataAnchor(0));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), less2->GetInDataAnchor(1));
GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(0));
GraphUtils::AddEdge(less2->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(0));
GraphUtils::AddEdge(const_02->GetOutDataAnchor(0), switch_t2->GetInDataAnchor(1));
GraphUtils::AddEdge(const_12->GetOutDataAnchor(0), switch_f2->GetInDataAnchor(1));
GraphUtils::AddEdge(less2->GetOutControlAnchor(), active2->GetInControlAnchor());
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t2->GetInControlAnchor());
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f2->GetInControlAnchor());


GraphUtils::AddEdge(switch_f2->GetOutControlAnchor(), add2->GetInControlAnchor());
GraphUtils::AddEdge(less2->GetOutDataAnchor(0), add2->GetInDataAnchor(0));
GraphUtils::AddEdge(add2->GetOutDataAnchor(0), merge2->GetInDataAnchor(0));
GraphUtils::AddEdge(add2->GetOutControlAnchor(), active_f2->GetInControlAnchor());
GraphUtils::AddEdge(active_f2->GetOutControlAnchor(), merge2->GetInControlAnchor());

GraphUtils::AddEdge(switch_t2->GetOutControlAnchor(), sub2->GetInControlAnchor());
GraphUtils::AddEdge(less2->GetOutDataAnchor(0), sub2->GetInDataAnchor(0));
GraphUtils::AddEdge(sub2->GetOutDataAnchor(0), merge2->GetInDataAnchor(1));
GraphUtils::AddEdge(sub2->GetOutControlAnchor(), active_t2->GetInControlAnchor());
GraphUtils::AddEdge(active_t2->GetOutControlAnchor(), merge2->GetInControlAnchor());

GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), less2->GetInControlAnchor());
GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), relu1->GetInControlAnchor());


GraphUtils::AddEdge(merge2->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge2->GetOutControlAnchor(), active_f1->GetInControlAnchor());
GraphUtils::AddEdge(active_f1->GetOutControlAnchor(), merge1->GetInControlAnchor());

GraphUtils::AddEdge(data2->GetOutDataAnchor(0), relu1->GetInDataAnchor(1));
GraphUtils::AddEdge(relu1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(relu1->GetOutControlAnchor(), active_t1->GetInControlAnchor());
GraphUtils::AddEdge(active_t1->GetOutControlAnchor(), merge1->GetInControlAnchor());

GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

AttrUtils::SetBool(merge2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true);
EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS);

SetControlFlowGroup(merge2, merge2->GetOpDesc()->GetId());
SetControlFlowGroup(switch_f2, merge2->GetOpDesc()->GetId());
SetControlFlowGroup(switch_t2, merge2->GetOpDesc()->GetId());
SetControlFlowGroup(active2, merge2->GetOpDesc()->GetId());
SetControlFlowGroup(active_t2, merge2->GetOpDesc()->GetId());
SetControlFlowGroup(active_f2, merge2->GetOpDesc()->GetId());

SetControlFlowGroup(merge1, merge1->GetOpDesc()->GetId());
SetControlFlowGroup(switch_f, merge1->GetOpDesc()->GetId());
SetControlFlowGroup(switch_t, merge1->GetOpDesc()->GetId());
SetControlFlowGroup(active1, merge1->GetOpDesc()->GetId());
SetControlFlowGroup(active_f1, merge1->GetOpDesc()->GetId());
SetControlFlowGroup(active_t1, merge1->GetOpDesc()->GetId());

EXPECT_EQ(graph->impl_->sub_graph_.size(), 0);
DynamicShapePartitioner partitioner(graph);
EXPECT_EQ(partitioner.Partition(), SUCCESS);
EXPECT_EQ(graph->sub_graph_.size(), 1);
EXPECT_EQ(graph->impl_->sub_graph_.size(), 3); // input less1 uknown
}
} // namespace ge

+ 11
- 18
tests/ut/ge/hybrid/executor/worker/execution_engine_unittest.cc View File

@@ -83,18 +83,14 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_kernel_task) {
execution_context.profiling_level = 1;
SubgraphContext subgraph_context(nullptr, &execution_context);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release());
node_state.SetTaskContext(shared_task_context);

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
ASSERT_TRUE(node_state->GetTaskContext() != nullptr);

std::function<void()> callback;
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context);
executor.InitCallback(&node_state, callback);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
executor.InitCallback(node_state.get(), callback);
ExecutionEngine execution_engine;
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
}

TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
@@ -118,21 +114,18 @@ TEST_F(UtestExecutionEngine, ExecuteAsync_without_callback_and_kernel_task) {
execution_context.model = &hybrid_model;
SubgraphContext subgraph_context(nullptr, &execution_context);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
uint32_t task_id = 0;
uint32_t stream_id = 1;
std::string task_type = "rts";
uint32_t block_dim = 0;
task_context->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim);
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release());
node_state.SetTaskContext(shared_task_context);
node_state->GetTaskContext()->SaveProfilingTaskDescInfo(task_id, stream_id, task_type, block_dim);

ExecutionEngine execution_engine;
ASSERT_TRUE(node_state.GetTaskContext() != nullptr);
ASSERT_TRUE(node_state->GetTaskContext() != nullptr);

std::function<void()> callback;
SubgraphExecutor executor(hybrid_model.GetRootGraphItem(), &execution_context);
executor.InitCallback(&node_state, callback);
EXPECT_EQ(execution_engine.ExecuteAsync(node_state, node_state.GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
executor.InitCallback(node_state.get(), callback);
ExecutionEngine execution_engine;
EXPECT_EQ(execution_engine.ExecuteAsync(*node_state, node_state->GetTaskContext(), execution_context, callback), INTERNAL_ERROR);
}

+ 14
- 11
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -40,6 +40,7 @@
#include "graph/types.h"
#include "graph/utils/tensor_utils.h"
#include "graph/testcase/ge_graph/graph_builder_utils.h"
#include "graph/op_desc_impl.h"
#undef private
#undef protected

@@ -159,11 +160,9 @@ TEST_F(UtestGeHybrid, task_update_tiling_info) {

GraphExecutionContext execution_context;
SubgraphContext subgraph_context(nullptr, &execution_context);
NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
ASSERT_TRUE(task_context != nullptr);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS);
ASSERT_EQ(aicore_task->UpdateTilingInfo(*task_context), SUCCESS);
ASSERT_EQ(aicore_task->UpdateTilingInfo(*node_state->GetTaskContext()), SUCCESS);
}

TEST_F(UtestGeHybrid, index_taskdefs_failed) {
@@ -477,12 +476,14 @@ TEST_F(UtestGeHybrid, TestTaskContext) {
node_item->output_start = 0;

GraphExecutionContext execution_context;
SubgraphContext subgraph_context(nullptr, &execution_context);
GraphItem graph_item;
SubgraphContext subgraph_context(&graph_item, &execution_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
subgraph_context.all_inputs_.resize(2);
subgraph_context.all_outputs_.resize(1);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
auto task_context = node_state->GetTaskContext();
ASSERT_TRUE(task_context != nullptr);
auto desc = task_context->MutableInputDesc(2);
ASSERT_TRUE(desc == nullptr);
@@ -522,12 +523,14 @@ TEST_F(UtestGeHybrid, hybrid_model_executor_update_args) {
node_item->output_start = 0;

GraphExecutionContext execution_context;
SubgraphContext subgraph_context(nullptr, &execution_context);
GraphItem graph_item;
SubgraphContext subgraph_context(&graph_item, &execution_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
subgraph_context.all_inputs_.resize(2);
subgraph_context.all_outputs_.resize(1);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
auto node_state = subgraph_context.GetOrCreateNodeState(node_item.get());
auto task_context = node_state->GetTaskContext();

int32_t buffer[1];
aicore_task->tiling_buffer_ = TensorBuffer::Create(buffer, sizeof(buffer));
@@ -737,7 +740,7 @@ TEST_F(UtestGeHybrid, TestParseDependencies) {
std::vector<std::string> deps;
deps.push_back("Data");
auto op_desc = netoutput->GetOpDesc();
op_desc->input_name_idx_["Data"] = 0;
op_desc->impl_->input_name_idx_["Data"] = 0;
auto data_desc = data->GetOpDesc();
auto tensor = std::make_shared<GeTensor>();
auto tensor_desc = data_desc->MutableInputDesc(0);


+ 0
- 5
tests/ut/ge/hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc View File

@@ -97,11 +97,6 @@ TEST_F(UtestGeLocalNodeExecutor, test_no_op_task) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
GeLocalNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);


+ 3
- 14
tests/ut/ge/hybrid/node_executor/hccl/hccl_node_executor_unittest.cc View File

@@ -94,18 +94,17 @@ TEST_F(UtestHcclNodeExecutor, test_rdmatask_extract_tensor) {
tensor.SetData(data);
ctx->SetTensor(1, 0, tensor.Clone());
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
vector<HcomRemoteAccessAddrInfo> addr_infos;
shared_ptr<RdmaNodeTask> task = MakeShared<RdmaNodeTask>();
task->remote_index_ = {1, 0};
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);
ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID);
Shape s2({1});
TensorDesc tensor_desc2(s2);
Tensor tensor2(tensor_desc2);
ctx->SetTensor(1, 0, tensor2.Clone());
task->ExtractTensor(*unique_task_context, addr_infos);
ASSERT_EQ(task->ExtractTensor(*unique_task_context, addr_infos), PARAM_INVALID);
task->ExtractTensor(*node_state->GetTaskContext(), addr_infos);
ASSERT_EQ(task->ExtractTensor(*node_state->GetTaskContext(), addr_infos), PARAM_INVALID);
RuntimeInferenceContext::DestroyContext(std::to_string(graph_context.context_id));
}
@@ -140,11 +139,6 @@ TEST_F(UtestHcclNodeExecutor, gatheralltoallv_execute) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
for (int i=0; i<4; ++i) {
uint64_t value_0 = 512;
TensorValue in_tensor0(&value_0, sizeof(value_0));
@@ -206,11 +200,6 @@ TEST_F(UtestHcclNodeExecutor, alltoallv_execute) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);
auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
for (int i=0; i<5; ++i) {
uint64_t value_0 = 512;
TensorValue in_tensor0(&value_0, sizeof(value_0));


+ 0
- 40
tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc View File

@@ -96,11 +96,6 @@ TEST_F(UtestRtsNodeTask, test_stream_switch_task) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
uint64_t value_1 = 120;
TensorValue in_tensor0(&value_0, sizeof(value_0));
@@ -153,11 +148,6 @@ TEST_F(UtestRtsNodeTask, test_stream_active_task) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
@@ -203,11 +193,6 @@ TEST_F(UtestRtsNodeTask, test_stream_merge_task) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);
@@ -271,11 +256,6 @@ TEST_F(UtestRtsNodeTask, test_memcpy_async_task) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);
@@ -328,11 +308,6 @@ TEST_F(UtestRtsNodeTask, test_pass_through_task) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);
@@ -384,11 +359,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_set) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
@@ -428,11 +398,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
@@ -472,11 +437,6 @@ TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) {
auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);


Loading…
Cancel
Save