Browse Source

GE local task for executor

tags/v1.3.0
zhangxiaokun 3 years ago
parent
commit
4a8338d2df
64 changed files with 2870 additions and 390 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +4
    -0
      ge/common/op/ge_op_utils.cc
  3. +2
    -0
      ge/executor/CMakeLists.txt
  4. +37
    -3
      ge/graph/common/omg_util.cc
  5. +15
    -0
      ge/graph/common/omg_util.h
  6. +1
    -1
      ge/graph/manager/graph_caching_allocator.cc
  7. +30
    -23
      ge/graph/partition/dynamic_shape_partition.cc
  8. +1
    -1
      ge/graph/partition/dynamic_shape_partition.h
  9. +25
    -7
      ge/graph/passes/base_pass.cc
  10. +34
    -6
      ge/graph/passes/base_pass.h
  11. +53
    -3
      ge/graph/passes/infershape_pass.cc
  12. +3
    -0
      ge/graph/passes/infershape_pass.h
  13. +103
    -0
      ge/graph/passes/merge_input_memcpy_pass.cc
  14. +15
    -0
      ge/graph/passes/merge_input_memcpy_pass.h
  15. +5
    -44
      ge/graph/passes/merge_to_stream_merge_pass.cc
  16. +24
    -0
      ge/graph/passes/next_iteration_pass.cc
  17. +2
    -1
      ge/graph/passes/next_iteration_pass.h
  18. +8
    -0
      ge/graph/passes/switch_to_stream_switch_pass.cc
  19. +7
    -0
      ge/hybrid/common/tensor_value.h
  20. +2
    -0
      ge/hybrid/executor/hybrid_execution_context.cc
  21. +13
    -12
      ge/hybrid/executor/hybrid_execution_context.h
  22. +3
    -3
      ge/hybrid/executor/hybrid_model_executor.cc
  23. +3
    -4
      ge/hybrid/executor/hybrid_model_pipeline_executor.cc
  24. +18
    -0
      ge/hybrid/executor/node_done_manager.cc
  25. +3
    -0
      ge/hybrid/executor/node_done_manager.h
  26. +146
    -14
      ge/hybrid/executor/node_state.cc
  27. +55
    -2
      ge/hybrid/executor/node_state.h
  28. +5
    -0
      ge/hybrid/executor/subgraph_context.cc
  29. +1
    -0
      ge/hybrid/executor/subgraph_context.h
  30. +259
    -58
      ge/hybrid/executor/subgraph_executor.cc
  31. +16
    -0
      ge/hybrid/executor/subgraph_executor.h
  32. +4
    -36
      ge/hybrid/executor/worker/execution_engine.cc
  33. +20
    -3
      ge/hybrid/executor/worker/execution_engine.h
  34. +3
    -1
      ge/hybrid/executor/worker/shape_inference_engine.cc
  35. +36
    -5
      ge/hybrid/model/graph_item.cc
  36. +13
    -0
      ge/hybrid/model/graph_item.h
  37. +345
    -23
      ge/hybrid/model/hybrid_model_builder.cc
  38. +18
    -2
      ge/hybrid/model/hybrid_model_builder.h
  39. +58
    -7
      ge/hybrid/model/node_item.cc
  40. +43
    -3
      ge/hybrid/model/node_item.h
  41. +2
    -2
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc
  42. +8
    -20
      ge/hybrid/node_executor/controlop/control_op_executor.cc
  43. +1
    -1
      ge/hybrid/node_executor/node_executor.cc
  44. +2
    -0
      ge/hybrid/node_executor/node_executor.h
  45. +24
    -24
      ge/hybrid/node_executor/rts/rts_node_executor.cc
  46. +4
    -7
      ge/hybrid/node_executor/rts/rts_node_executor.h
  47. +240
    -0
      ge/hybrid/node_executor/rts/rts_node_task.cc
  48. +89
    -0
      ge/hybrid/node_executor/rts/rts_node_task.h
  49. +46
    -0
      ge/hybrid/node_executor/rts/rts_task_factory.cc
  50. +65
    -0
      ge/hybrid/node_executor/rts/rts_task_factory.h
  51. +3
    -2
      ge/hybrid/node_executor/task_context.cc
  52. +5
    -1
      inc/framework/common/op/ge_op_utils.h
  53. +9
    -1
      tests/depends/error_manager/src/error_manager_stub.cc
  54. +7
    -1
      tests/depends/runtime/CMakeLists.txt
  55. +26
    -15
      tests/depends/runtime/src/runtime_stub.cc
  56. +20
    -0
      tests/depends/slog/src/slog_stub.cc
  57. +14
    -10
      tests/ut/ge/CMakeLists.txt
  58. +4
    -2
      tests/ut/ge/graph/load/davinci_model_unittest.cc
  59. +99
    -3
      tests/ut/ge/graph/passes/infershape_pass_unittest.cc
  60. +17
    -17
      tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc
  61. +17
    -17
      tests/ut/ge/graph/utils/buffer_pool_graph_builder.h
  62. +16
    -5
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc
  63. +233
    -0
      tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc
  64. +484
    -0
      tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc

+ 2
- 0
ge/CMakeLists.txt View File

@@ -391,6 +391,8 @@ set(TRAIN_SRC_LIST
"hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc"
"hybrid/node_executor/hccl/hccl_node_executor.cc"
"hybrid/node_executor/rts/rts_node_executor.cc"
"hybrid/node_executor/rts/rts_node_task.cc"
"hybrid/node_executor/rts/rts_task_factory.cc"
"hybrid/node_executor/node_executor.cc"
"hybrid/node_executor/task_context.cc"
"hybrid/hybrid_davinci_model.cc"


+ 4
- 0
ge/common/op/ge_op_utils.cc View File

@@ -62,6 +62,10 @@ const uint32_t SWITCH_TRUE_OUTPUT = 1;
const uint32_t SWITCH_DATA_INPUT = 0;
const uint32_t SWITCH_PRED_INPUT = 1;

// Merge
const uint32_t MERGE_DATA_OUTPUT = 0;
const uint32_t MERGE_INDEX_OUTPUT = 1;

// FunctionOp
const uint32_t IF_COND_INPUT = 0;
const uint32_t FOR_START_INPUT = 0;


+ 2
- 0
ge/executor/CMakeLists.txt View File

@@ -110,6 +110,8 @@ set(SRC_LIST
"../hybrid/node_executor/controlop/control_op_executor.cc"
"../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc"
"../hybrid/node_executor/rts/rts_node_executor.cc"
"../hybrid/node_executor/rts/rts_node_task.cc"
"../hybrid/node_executor/rts/rts_task_factory.cc"
"../hybrid/node_executor/node_executor.cc"
"../hybrid/node_executor/task_context.cc"
"../hybrid/hybrid_davinci_model.cc"


+ 37
- 3
ge/graph/common/omg_util.cc View File

@@ -16,9 +16,6 @@

#include "graph/common/omg_util.h"

#include <algorithm>

#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
@@ -244,4 +241,41 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size) {
output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize;
return SUCCESS;
}

///
/// @brief Check Is Unknown shape Tensor
/// @param [in] tensor_desc
/// @return true: Unknown / false: Known
///
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
const static int kUnknowShape = -1;
const static int kUnknowRank = -2;
for (auto dim_size : tensor_desc.GetShape().GetDims()) {
if (dim_size == kUnknowShape || dim_size == kUnknowRank) {
return true;
}
}

return false;
}

///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) {
GE_RT_VOID_CHECK_NOTNULL(node);
if (!force_unknown) {
return;
}

GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str());
if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(),
node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "Op: %s set %s failed", node->GetName().c_str(), ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str());
}
}
} // namespace ge

+ 15
- 0
ge/graph/common/omg_util.h View File

@@ -117,6 +117,21 @@ void AlignMemSize(int64_t &mem_size, int64_t align_size);
/// @return Status
///
Status GetMemorySize(const NodePtr &node, int64_t &output_size);

///
/// @brief Check Is Unknown shape Tensor
/// @param [in] tensor_desc
/// @return true: Unknown / false: Known
///
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);

///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown);
} // namespace ge

#endif // GE_GRAPH_COMMON_OMG_UTIL_H_

+ 1
- 1
ge/graph/manager/graph_caching_allocator.cc View File

@@ -168,7 +168,7 @@ Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) {
if (it == allocated_blocks_.end()) {
REPORT_INNER_ERROR("E19999", "Param ptr not allocated before, device_id:%u, check invalid",
device_id);
GELOGE(PARAM_INVALID, "Invalid memory pointer");
GELOGE(PARAM_INVALID, "Invalid memory pointer: %p", ptr);
return ge::PARAM_INVALID;
}
Block *block = it->second;


+ 30
- 23
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -31,6 +31,7 @@
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/common/omg_util.h"

#define REQUIRE(cond, ...) \
do { \
@@ -45,6 +46,11 @@
#define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__)

namespace ge {
namespace {
const std::set<std::string> kControlFlowOps{
STREAMACTIVE, STREAMSWITCH, STREAMMERGE, ENTER, REFENTER, LOOPCOND, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT
};
}
using Cluster = DynamicShapePartitioner::Cluster;
using ClusterPtr = std::shared_ptr<Cluster>;

@@ -273,7 +279,7 @@ Status DynamicShapePartitioner::InitClusters() {
auto cluster = MakeShared<Cluster>(rank++, type, node, this);
REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster.");
node_2_cluster_[node] = cluster;
if (cluster->IsUnknownShape()) {
if (cluster->IsUnknownShape() && !cluster->IsControlFlow()) {
ordered_cluster_.push_back(cluster);
}
// Already sorted topologically, so access to the parent cluster is safe
@@ -347,7 +353,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
void DynamicShapePartitioner::MergeClustersUnknownShape() {
// Merge unknown shape clusters
for (const auto &cluster : ordered_cluster_) {
if (cluster->IsIndependent()) {
if (cluster->IsIndependent() || cluster->IsControlFlow()) {
continue;
}
for (const auto &in_cluster : cluster->Inputs()) {
@@ -545,17 +551,6 @@ Status DynamicShapePartitioner::IsUnknownShapeGraph(ComputeGraphPtr graph, bool
return SUCCESS;
}

bool DynamicShapePartitioner::IsUnknownShapeTensor(const GeTensorDesc &tensor) {
const static int kUnknowShape = -1;
const static int kUnknowRank = -2;
for (auto dim_size : tensor.GetShape().GetDims()) {
if (dim_size == kUnknowShape || dim_size == kUnknowRank) {
return true;
}
}
return false;
}

std::string Cluster::DebugString() const {
std::stringstream ss;
switch (type_) {
@@ -612,6 +607,14 @@ bool Cluster::IsRefVariable() const {
}
return false;
}

bool Cluster::IsControlFlow() const {
const auto &op_desc = nodes_[0]->GetOpDesc();
bool is_ctrl_flow = kControlFlowOps.count(op_desc->GetType()) > 0 && op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
GELOGD("[%s] %s rts control flow Op ", op_desc->GetName().c_str(), is_ctrl_flow ? "Is" : "Not");
return is_ctrl_flow;
}

void Cluster::AddInput(ClusterPtr in) {
if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return;
in_clusters_.insert(in_clusters_.end(), in);
@@ -732,29 +735,33 @@ std::vector<ClusterPtr> Cluster::Outputs() const { return out_clusters_; };
std::vector<NodePtr> Cluster::Nodes() const { return nodes_; };

void Cluster::AddFrameInput(InDataAnchorPtr anchor) {
inputs_index_[anchor] = inputs_.size();
inputs_.push_back(anchor);
};
if (anchor != nullptr && anchor->GetPeerOutAnchor() != nullptr) {
inputs_index_[anchor] = inputs_.size();
inputs_.push_back(anchor);
}
}

void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) {
outputs_index_[anchor] = outputs_.size();
outputs_.push_back(anchor);
};
if (anchor != nullptr) {
outputs_index_[anchor] = outputs_.size();
outputs_.push_back(anchor);
}
}

InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) {
return partition_node_->GetInDataAnchor(static_cast<int>(inputs_index_[anchor]));
};
}

OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) {
return partition_node_->GetOutDataAnchor(static_cast<int>(outputs_index_[anchor]));
};
}

InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); };

OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); };

Status Cluster::BuildFrame() {
if (IsUnknownShape() || IsKnownShape() || IsInputNode()) {
if ((IsUnknownShape() || IsKnownShape() || IsInputNode()) && !IsControlFlow()) {
return BuildPartitionFrame();
} else {
auto node = nodes_.front();
@@ -889,7 +896,7 @@ Status Cluster::CombinePartitionFrame() {
}

Status Cluster::BuildPartitionSubgraph() {
if (IsData() || IsNetOutput() || IsIndependent()) {
if (IsData() || IsNetOutput() || IsIndependent() || IsControlFlow()) {
return SUCCESS;
}
int64_t parent_node_index = 0;


+ 1
- 1
ge/graph/partition/dynamic_shape_partition.h View File

@@ -47,6 +47,7 @@ class DynamicShapePartitioner {
bool IsUnknownShape() const;
bool IsIndependent() const;
bool IsNetOutput() const;
bool IsControlFlow() const;
std::vector<std::shared_ptr<Cluster>> Inputs() const;
std::vector<std::shared_ptr<Cluster>> Outputs() const;
bool IsInputNode() const;
@@ -151,7 +152,6 @@ class DynamicShapePartitioner {
Status CollectSpreadUnknownShapeNodes(NodePtr node);
Status IsUnknownShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow);
Status IsUnknownShapeNode(ge::NodePtr node, bool &is_unknow);
bool IsUnknownShapeTensor(const ge::GeTensorDesc &tensor);
Status CtrlEdgeTransfer();
ge::ComputeGraphPtr root_graph_; // The original graph to partition
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to


+ 25
- 7
ge/graph/passes/base_pass.cc View File

@@ -36,6 +36,7 @@ struct DuringPassNodeSets {
std::unordered_set<NodePtr> nodes_re_pass;
std::unordered_set<NodePtr> nodes_re_pass_immediately;
std::unordered_set<NodePtr> nodes_last;
std::unordered_set<NodePtr> nodes_stopped;
};

void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes,
@@ -56,11 +57,18 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i
}

void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) {
DuringPassNodeSets &during_pass_node_set) {
std::unordered_set<Node *> &nodes_seen = during_pass_node_set.nodes_seen;
const std::unordered_set<NodePtr> &nodes_last = during_pass_node_set.nodes_last;
const std::unordered_set<NodePtr> &nodes_stopped = during_pass_node_set.nodes_stopped;
for (auto &node : nodes) {
if (node == nullptr) {
continue;
}
if (nodes_stopped.count(node) > 0) {
GELOGD("The node %s was stopped by pass, skip it.", node->GetName().c_str());
continue;
}
if (nodes_last.count(node) != 0) {
continue;
}
@@ -73,7 +81,7 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n
}

void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass,
std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass,
std::unordered_set<NodePtr> &nodes_re_pass) {
for (const auto &node_to_re_pass : nodes_to_re_pass) {
if (node_to_re_pass == nullptr) {
@@ -113,15 +121,24 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo
return result;
}

auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass,
during_pass_node_set.nodes_re_pass);

auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately();
PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately,
during_pass_node_set.nodes_re_pass_immediately);

auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
for (const auto &node : name_to_pass.second->GetNodesStopped()) {
GELOGD("The node %s was stopped by pass %s", node->GetName().c_str(), name_to_pass.first.c_str());
during_pass_node_set.nodes_stopped.emplace(node);
}
for (const auto &node : name_to_pass.second->GetNodesRestored()) {
GELOGD("The node %s was restored by pass %s", node->GetName().c_str(), name_to_pass.first.c_str());
during_pass_node_set.nodes_stopped.erase(node);
}

const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted();
during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end());
if (nodes_deleted_by_pass.count(node) > 0) {
GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(),
@@ -222,8 +239,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
continue;
}

AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last);

const auto all_out_nodes = node->GetOutNodes();
auto ret = RunPasses(node, names_to_passes, during_pass_node_set);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u",
@@ -258,6 +274,8 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
nodes.push_front(node);
}
during_pass_node_set.nodes_re_pass_immediately.clear();

AddNextIterNodes(all_out_nodes, nodes, during_pass_node_set);
}

for (auto &node : during_pass_node_set.nodes_last) {


+ 34
- 6
ge/graph/passes/base_pass.h View File

@@ -51,11 +51,15 @@ class BaseNodePass {

virtual ~BaseNodePass() = default;

std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; }
const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; }

std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; }
const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; }

std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; }
const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; }

const std::unordered_set<NodePtr> &GetNodesStopped() { return nodes_stopped_; }

const std::unordered_set<NodePtr> &GetNodesRestored() { return nodes_restored_; }

void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; }

@@ -65,6 +69,8 @@ class BaseNodePass {
nodes_need_re_pass_.clear();
nodes_deleted_.clear();
nodes_need_re_pass_immediately_.clear();
nodes_stopped_.clear();
nodes_restored_.clear();
}

protected:
@@ -80,7 +86,7 @@ class BaseNodePass {
/// optimized by other passes, call this function.
/// @param node
///
void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); }
void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); }

///
/// Add a node to be optimized immediately again. If you add a new node to the graph, or
@@ -88,13 +94,13 @@ class BaseNodePass {
/// optimized by other passes, call this function.
/// @param node
///
void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); }
void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); }

///
/// Add a node and it's input/output data nodes to be optimized again.
/// @param node
///
void AddRePassNodesWithInOut(NodePtr &node) {
void AddRePassNodesWithInOut(const NodePtr &node) {
AddRePassNode(node);
auto out_nodes = node->GetOutNodes();
for (auto &out_node : out_nodes) {
@@ -116,12 +122,34 @@ class BaseNodePass {
///
void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); }

///
/// If you stop a node from the graph, especially following node. The remain
/// iterate passes will stop process on the stopped node(if it can be
/// reached by edge connections) till the last one. Obviously it is a waste of
/// time. You can add the stopped nodes by calling this function, to stop the
/// next iterations.
/// @param node
///
void AddNodeStopped(const NodePtr &node) { nodes_stopped_.insert(node); }

///
/// If you restore a node from the graph, especially following node. The remain
/// iterate passes will continue process on the stopped node(if it can be
/// reached by edge connections) till the last one.
/// You can add the restored nodes by calling this function, to restore the
/// next iterations.
/// @param node
///
void AddNodeRestored(const NodePtr &node) { nodes_restored_.insert(node); }

bool OptionExists(NodePassOption option) { return options_.count(option) > 0; }

private:
std::unordered_set<NodePtr> nodes_need_re_pass_;
std::unordered_set<NodePtr> nodes_need_re_pass_immediately_;
std::unordered_set<NodePtr> nodes_deleted_;
std::unordered_set<NodePtr> nodes_stopped_;
std::unordered_set<NodePtr> nodes_restored_;
std::map<NodePassOption, std::string> options_;
};



+ 53
- 3
ge/graph/passes/infershape_pass.cc View File

@@ -17,11 +17,11 @@
#include "graph/passes/infershape_pass.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "analyzer/analyzer.h"
#include "framework/common/util.h"
#include "graph/shape_refiner.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "utils/tensor_utils.h"
#include "utils/type_utils.h"

@@ -94,8 +94,10 @@ Status InferShapePass::Run(NodePtr &node) {
GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str());
return GE_GRAPH_INFERSHAPE_FAILED;
}

GE_CHK_STATUS_RET_NOLOG(RePassLoopNode(node));
bool need_repass = false;
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "_need_infer_again", need_repass);
auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_repass);
if (has_attr) {
if (!OptionExists(kOptimizeAfterSubGraph)) {
return SUCCESS;
@@ -105,9 +107,57 @@ Status InferShapePass::Run(NodePtr &node) {
GELOGD("Node %s need repass immediately.", node->GetName().c_str());
} else {
// clear attr on while
node->GetOpDesc()->DelAttr("_need_infer_again");
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
}
}
return SUCCESS;
}

Status InferShapePass::RePassLoopNode(const NodePtr &node) {
const auto RePassNode = [&](const std::set<std::string> &re_pass_types) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
if (re_pass_types.count(n->GetType()) > 0) {
AddImmediateRePassNode(n);
(void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false);
GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str());
}
}
return SUCCESS;
};

const auto ExProcNode = [&](const std::set<std::string> &proc_types,
const std::function<void(InferShapePass *, NodePtr)> &proc_func,
const std::string &info) {
for (auto &n : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(n);
if (proc_types.count(n->GetType()) > 0) {
proc_func(this, n);
GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str());
}
}
return SUCCESS;
};

if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) {
return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge
}

if (node->GetType() == MERGE || node->GetType() == REFMERGE) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
return RePassNode({SWITCH, REFSWITCH}); // Re-Pass Switch
}
}

if (node->GetType() == SWITCH || node->GetType() == REFSWITCH) {
if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) {
node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN);
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeRestored, "need restore"); // Restore Exit
} else {
return ExProcNode({EXIT, REFEXIT}, &InferShapePass::AddNodeStopped, "need stop"); // Stop Exit
}
}

return SUCCESS;
}
} // namespace ge

+ 3
- 0
ge/graph/passes/infershape_pass.h View File

@@ -30,6 +30,9 @@ class InferShapePass : public BaseNodePass {
/// @author
///
Status Run(ge::NodePtr &node) override;

private:
Status RePassLoopNode(const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_INFERSHAPE_PASS_H_

+ 103
- 0
ge/graph/passes/merge_input_memcpy_pass.cc View File

@@ -15,23 +15,36 @@
*/

#include "graph/passes/merge_input_memcpy_pass.h"

#include <queue>

#include "common/ge/ge_util.h"
#include "ge/ge_api_types.h"
#include "graph/common/omg_util.h"

namespace ge {
namespace {
const std::set<std::string> kLoopMergeInputs{
ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION
};
}
Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) {
GELOGD("MergeInputMemcpyPass Enter");
std::unordered_map<NodePtr, std::vector<NodePtr>> switch_groups;
for (const auto &node : graph->GetDirectNode()) {
std::string type;
GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed.");
if ((type != MERGE) && (type != REFMERGE)) {
continue;
}

GE_CHECK_NOTNULL(node->GetOpDesc());
GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)),
"Merge add memcpy node failed.");
CollectSwitchGroup(node, switch_groups);
}

MarkUnknownForSwitch(switch_groups);
GELOGD("MergeInputMemcpyPass Leave");
return SUCCESS;
}
@@ -101,4 +114,94 @@ NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph

return graph->AddNode(op_desc);
}

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch_groups
/// @return
///
void MergeInputMemcpyPass::CollectSwitchGroup(const NodePtr &node,
std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) {
const auto &op_desc = node->GetOpDesc();
for (const auto &in_anchor : node->GetAllInDataAnchors()) {
const auto &src_out_anchor = in_anchor->GetPeerOutAnchor();
if (src_out_anchor == nullptr) {
continue;
}

std::string node_type;
GetOriginalType(src_out_anchor->GetOwnerNode(), node_type);
if (kLoopMergeInputs.count(node_type) > 0) {
return;
}
}

// Switch --> {Switch --> Merge} --> Merge
std::queue<std::pair<NodePtr, uint32_t>> search_queue;
search_queue.push({node, 0});
std::vector<NodePtr> &switch_group = switch_groups[node];
while (!search_queue.empty()) {
const auto dst_node = search_queue.front().first;
const auto dst_span = search_queue.front().second;
search_queue.pop();

// Switch --> Identity --> Constant
for (const auto &in_ctrl_node : dst_node->GetInControlNodes()) {
if (in_ctrl_node->GetType() == IDENTITY) {
GELOGD("Travel node: %s, In control: %s, span is: %u",
dst_node->GetName().c_str(), in_ctrl_node->GetName().c_str(), dst_span);
search_queue.push({in_ctrl_node, dst_span});
}
}

for (const auto &in_data_node : dst_node->GetInDataNodes()) {
std::string node_type;
GetOriginalType(in_data_node, node_type);
GELOGD("Travel node: %s, %s node: %s, span is: %u",
dst_node->GetName().c_str(), node_type.c_str(), in_data_node->GetName().c_str(), dst_span);
if (node_type == SWITCH || node_type == REFSWITCH) {
if (dst_span > 0) {
search_queue.push({in_data_node, dst_span - 1});
} else {
switch_group.emplace_back(in_data_node);
}
} else if (node_type == MERGE || node_type == REFMERGE) {
search_queue.push({in_data_node, dst_span + 1});
} else {
search_queue.push({in_data_node, dst_span});
}
}
}

if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) {
GELOGI("Mark [%s] as for unknown shape, switch groups: %zu", node->GetName().c_str(), switch_groups.size());
MarkForceUnknownShape(node, true);
for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
}
}
}

void MergeInputMemcpyPass::MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
};

for (const auto &item : switch_groups) {
const auto &node = item.first;
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) {
continue;
}

const std::vector<NodePtr> &switch_group = item.second;
if (std::any_of(switch_group.begin(), switch_group.end(), callback)) {
GELOGI("Mark [%s] as force unknown shape, switch nodes: %zu", node->GetName().c_str(), switch_group.size());
MarkForceUnknownShape(node, true);
for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
}
}
}
}
} // namespace ge

+ 15
- 0
ge/graph/passes/merge_input_memcpy_pass.h View File

@@ -44,6 +44,21 @@ class MergeInputMemcpyPass : public GraphPass {
///
NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name,
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag);

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch_groups
/// @return
///
void CollectSwitchGroup(const NodePtr &node, std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups);

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] switch_groups
/// @return
///
void MarkUnknownForSwitch(const std::unordered_map<NodePtr, std::vector<NodePtr>> &switch_groups);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_

+ 5
- 44
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -69,51 +69,9 @@ Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) {
Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node) {
OpDescPtr merge_op_desc = merge_node->GetOpDesc();
GE_CHECK_NOTNULL(merge_op_desc);
merge_op_desc->SetType(STREAMMERGE);

const std::string &node_name = merge_node->GetName();
GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str());
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, STREAMMERGE);
if (op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "New GeTensor failed");
GELOGE(FAILED, "Create op_desc failed, StreamMerge:%s.", node_name.c_str());
return FAILED;
}

for (const InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) {
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return FAILED, "Create StreamMerge op: add input desc failed.");
}

for (const OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) {
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return FAILED, "Create StreamMerge op: add output desc failed.");
}

NodePtr stream_merge = graph->AddNode(op_desc);
GE_CHK_BOOL_EXEC(stream_merge != nullptr,
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(),
graph->GetName().c_str());
return FAILED, "Insert StreamMerge node failed.");
GE_CHK_STATUS_RET(MoveEdges(merge_node, stream_merge), "Move edges failed.");
bypass_nodes_.insert(merge_node);

if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) {
std::string next_iteration_name;
GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name),
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed",
ATTR_NAME_NEXT_ITERATION.c_str(),
merge_op_desc->GetName().c_str(), merge_op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "Get ATTR_NAME_NEXT_ITERATION failed");
return INTERNAL_ERROR);
GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed");
}

return AddActiveNodes(graph, stream_merge);
return AddActiveNodes(graph, merge_node);
}

///
@@ -126,6 +84,8 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GE_CHK_BOOL_EXEC(node != nullptr,
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
return FAILED, "Param of pre node is null.");
bool force_unknown = node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
MarkForceUnknownShape(node, force_unknown);
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
@@ -142,6 +102,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str());
return FAILED;
}
MarkForceUnknownShape(active_node, force_unknown);
}

return SUCCESS;


+ 24
- 0
ge/graph/passes/next_iteration_pass.cc View File

@@ -140,6 +140,7 @@ Status NextIterationPass::FindWhileGroups() {
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str());
return INTERNAL_ERROR;
}
loop_group_iter.second->switch_nodes.emplace_back(switch_node);
if (loop_group_iter.second->loop_cond == nullptr) {
loop_group_iter.second->loop_cond = loop_cond;
} else if (loop_group_iter.second->loop_cond != loop_cond) {
@@ -181,6 +182,12 @@ bool NextIterationPass::VerifyWhileGroup() {
frame_name.c_str());
return false;
}

// Mark loop as unknown shape If any merge has unknown shape output.
const auto &op_desc = pair_iter.first->GetOpDesc();
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) {
loop_group_iter.second->is_unknown_shape = true; // under check loop, cannot break.
}
}
}

@@ -194,6 +201,7 @@ bool NextIterationPass::VerifyWhileGroup() {
///
Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
for (const auto &loop_cond_iter : loop_group_map_) {
const LoopCondGroup &loop_group = *loop_cond_iter.second;
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());

@@ -215,6 +223,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
enter_active->GetName().c_str());
return INTERNAL_ERROR;
}
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape);
}

for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
@@ -243,6 +252,9 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
return INTERNAL_ERROR;
}

MarkForceUnknownShape(next_node, loop_group.is_unknown_shape);
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape);
}

if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
@@ -250,6 +262,18 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
return INTERNAL_ERROR;
}

MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape);
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape);
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape);
for (const auto &switch_node : loop_group.switch_nodes) {
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape);
for (const auto &exit_node : switch_node->GetOutDataNodes()) {
if (exit_node->GetType() == EXIT || exit_node->GetType() == REFEXIT) {
MarkForceUnknownShape(exit_node, loop_group.is_unknown_shape);
}
}
}
}

return SUCCESS;


+ 2
- 1
ge/graph/passes/next_iteration_pass.h View File

@@ -20,10 +20,11 @@
#include "inc/graph_pass.h"

struct LoopCondGroup {
LoopCondGroup() : loop_cond(nullptr) {}
ge::NodePtr loop_cond; // LoopCond node
std::vector<ge::NodePtr> enter_nodes; // Enter nodes
std::vector<std::pair<ge::NodePtr, ge::NodePtr>> merge_next_pairs; // <Merge, NextIteration>
std::vector<ge::NodePtr> switch_nodes; // Switch nodes
bool is_unknown_shape{false};
};
using LoopCondGroupPtr = std::shared_ptr<LoopCondGroup>;



+ 8
- 0
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -369,6 +369,7 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)),
"StreamSwitch node add cond edge failed.");

MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE));
return stream_switch;
}

@@ -487,6 +488,12 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
return FAILED;
}

std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
};
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
MarkForceUnknownShape(active_node, is_unknown_shape);

const std::string &cond_group = cond_node->GetName();
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
@@ -515,6 +522,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)),
"Cast add data edge failed.");

MarkForceUnknownShape(stream_switch, is_unknown_shape);
for (const NodePtr &node : switch_list) {
GE_IF_BOOL_EXEC(node != stream_switch, {
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),


+ 7
- 0
ge/hybrid/common/tensor_value.h View File

@@ -21,6 +21,7 @@
#include <cstddef>
#include <memory>
#include "memory/memory_api.h"
#include "framework/common/util.h"

namespace ge {
namespace hybrid {
@@ -84,6 +85,12 @@ class TensorValue {

size_t GetSize() const;

template<typename T>
Status CopyScalarValueToHost(T &value) const {
GE_CHECK_GE(this->GetSize(), sizeof(value));
return rtMemcpy(&value, sizeof(value), this->GetData(), sizeof(value), RT_MEMCPY_DEVICE_TO_HOST);
}

private:
std::shared_ptr<TensorBuffer> buffer_;
std::string name_;


+ 2
- 0
ge/hybrid/executor/hybrid_execution_context.cc View File

@@ -28,6 +28,8 @@ const int32_t kModelAbortNormalNew = 507024;
std::atomic_ulong context_id_gen {};
} // namespace

long GraphExecutionContext::profiling_level = 0;

GraphExecutionContext::GraphExecutionContext() {
context_id = context_id_gen++;
}


+ 13
- 12
ge/hybrid/executor/hybrid_execution_context.h View File

@@ -73,7 +73,7 @@ struct GraphExecutionContext {
ExceptionDumper exception_dumper;
std::vector<std::shared_ptr<ge::DavinciModel>> davinci_model;
std::atomic_bool is_eos_{false};
long profiling_level = 0;
static long profiling_level;
long iteration = 0;
void *global_step = nullptr;

@@ -82,17 +82,18 @@ struct GraphExecutionContext {
mutable std::mutex mu;
};

#define RECORD_PROFILING_EVENT(context, evt_type, fmt, category, node_name, ...) \
do { \
if ((context != nullptr) && (context)->profiler != nullptr) { \
if (node_name != nullptr) { \
context->profiler->RecordEvent(evt_type, "tid:%lu [%s@%ld] [%s] " fmt, \
GeLog::GetTid(), node_name, context->iteration, category, \
##__VA_ARGS__); \
} else { \
context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \
}\
} \
#define RECORD_PROFILING_EVENT(context, evt_type, fmt, category, node_name, ...) \
do { \
if (ge::hybrid::GraphExecutionContext::profiling_level > 0) { \
if ((context != nullptr) && (context)->profiler != nullptr) { \
if (node_name != nullptr) { \
context->profiler->RecordEvent(evt_type, "tid:%lu [%s@%ld] [%s] " fmt, \
GeLog::GetTid(), node_name, context->iteration, category, ##__VA_ARGS__); \
} else { \
context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \
} \
} \
} \
} while (0)

#define RECORD_MODEL_EXECUTION_EVENT(context, fmt, ...) \


+ 3
- 3
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -155,9 +155,9 @@ Status HybridModelExecutor::InitExecutionContext() {
context_.dump_properties = DumpManager::GetInstance().GetDumpProperties(context_.session_id);
const char *profiling_level = std::getenv(kEnvProfilingLevel);
if (profiling_level != nullptr) {
context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase);
GELOGD("Got profiling level = %ld", context_.profiling_level);
if (context_.profiling_level > 0) {
GraphExecutionContext::profiling_level = std::strtol(profiling_level, nullptr, kIntBase);
GELOGD("Got profiling level = %ld", GraphExecutionContext::profiling_level);
if (GraphExecutionContext::profiling_level > 0) {
context_.profiler.reset(new(std::nothrow)HybridProfiler());
GE_CHECK_NOTNULL(context_.profiler);
}


+ 3
- 4
ge/hybrid/executor/hybrid_model_pipeline_executor.cc View File

@@ -187,9 +187,9 @@ void StageExecutor::Reset() {
Status HybridModelPipelineExecutor::Init() {
const char *profiling_level = std::getenv(kEnvProfilingLevel);
if (profiling_level != nullptr) {
context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase);
GELOGD("Got profiling level = %ld", context_.profiling_level);
if (context_.profiling_level > 0) {
GraphExecutionContext::profiling_level = std::strtol(profiling_level, nullptr, kIntBase);
GELOGD("Got profiling level = %ld", GraphExecutionContext::profiling_level);
if (GraphExecutionContext::profiling_level > 0) {
context_.profiler.reset(new (std::nothrow) HybridProfiler());
GE_CHECK_NOTNULL(context_.profiler);
}
@@ -210,7 +210,6 @@ Status HybridModelPipelineExecutor::InitStageExecutors() {
if (context_.profiler != nullptr) {
// will call unique_ptr::release later
stage_executor->context_.profiler.reset(context_.profiler.get());
stage_executor->context_.profiling_level = context_.profiling_level;
}

stage_executors_.emplace_back(std::move(stage_executor));


+ 18
- 0
ge/hybrid/executor/node_done_manager.cc View File

@@ -36,6 +36,16 @@ bool NodeDoneManager::Cond::Await() {
return is_released_;
}

void NodeDoneManager::Cond::Reset() {
std::unique_lock<std::mutex> lk(cond_mu_);
if (!is_released_ && !is_cancelled_) {
GELOGW("Called before done, released: %d, cancelled: %d", is_released_, is_cancelled_);
}

is_released_ = false;
is_cancelled_ = false;
}

void NodeDoneManager::Cond::Release() {
std::unique_lock<std::mutex> lk(cond_mu_);
is_released_ = true;
@@ -103,5 +113,13 @@ bool NodeDoneManager::Await(const NodePtr &node) {
GELOGD("[%s] Await ended. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false");
return ret;
}

void NodeDoneManager::Reset(const NodePtr &node) {
auto sub = GetSubject(node);
if (sub != nullptr) {
sub->Reset();
GELOGD("[%s] Node reset.", node->GetName().c_str());
}
}
} // namespace hybrid
} // namespace ge

+ 3
- 0
ge/hybrid/executor/node_done_manager.h View File

@@ -31,6 +31,8 @@ class NodeDoneManager {

bool Await(const NodePtr &node);

void Reset(const NodePtr &node);

void Destroy();

private:
@@ -40,6 +42,7 @@ class NodeDoneManager {
void Release();
void Cancel();
bool Await();
void Reset();
private:
std::mutex cond_mu_;
std::condition_variable cv_;


+ 146
- 14
ge/hybrid/executor/node_state.cc View File

@@ -30,6 +30,10 @@ constexpr auto kWaitInternal = 5;
constexpr auto kMaxWaitTimes = 120;
}
ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(node_item) {
InitShapeState();
}

void ShapeInferenceState::InitShapeState() {
this->num_pending_shapes_ = node_item.num_inputs - node_item.num_static_input_shapes;
GELOGD("[%s] ShapeInferenceState created, pending shape count = %d",
node_item.NodeName().c_str(),
@@ -135,19 +139,22 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex
}
}

for (size_t i = 0; i < input_tensor_desc.size(); ++i) {
auto dst_tensor_desc = node_item.op_desc->MutableInputDesc(i);
if (dst_tensor_desc == nullptr) {
continue;
}
{
const auto &guard = node_item.MutexGuard("AwaitShapesReady");
for (size_t i = 0; i < input_tensor_desc.size(); ++i) {
auto dst_tensor_desc = node_item.MutableInputDesc(i);
if (dst_tensor_desc == nullptr) {
continue;
}

auto &tensor_desc = input_tensor_desc[i];
int64_t tensor_size = -1;
(void) TensorUtils::GetSize(tensor_desc, tensor_size);
auto &tensor_desc = input_tensor_desc[i];
int64_t tensor_size = -1;
(void)TensorUtils::GetSize(tensor_desc, tensor_size);

dst_tensor_desc->SetShape(tensor_desc.MutableShape());
dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape());
(void) TensorUtils::SetSize(*dst_tensor_desc, tensor_size);
dst_tensor_desc->SetShape(tensor_desc.MutableShape());
dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape());
(void)TensorUtils::SetSize(*dst_tensor_desc, tensor_size);
}
}

for (auto &p : shape_futures) {
@@ -159,8 +166,6 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex
GE_CHECK_NOTNULL(src_tensor_desc);
RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx);

auto input_desc = node_item.MutableInputDesc(idx);
GE_CHECK_NOTNULL(input_desc);
int64_t tensor_size = -1;
(void) TensorUtils::GetSize(*src_tensor_desc, tensor_size);
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu",
@@ -169,6 +174,9 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex
src_tensor_desc->GetShape().ToString().c_str(),
src_tensor_desc->GetOriginShape().ToString().c_str(),
tensor_size);
const auto &guard = node_item.MutexGuard("AwaitShapesReady");
auto input_desc = node_item.MutableInputDesc(idx);
GE_CHECK_NOTNULL(input_desc);
input_desc->SetShape(src_tensor_desc->GetShape());
input_desc->SetOriginShape(src_tensor_desc->GetOriginShape());
(void) TensorUtils::SetSize(*input_desc, tensor_size);
@@ -207,6 +215,11 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex
}

Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
if (node_item_->IsMergeOp()) {
GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size());
return SUCCESS;
}

for (auto &src_node : node_item_->dependents_for_execution) {
GELOGD("[%s] Start to wait for data dependent node: [%s]",
node_item_->NodeName().c_str(),
@@ -225,7 +238,7 @@ Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
node_item_->NodeName().c_str(),
"[AwaitNodeDone] [%s] End",
src_node->GetName().c_str());
GELOGD("[%s] Done waiting node.", src_node->GetName().c_str());
GELOGD("[%s] Done waiting node: [%s]", node_item_->NodeName().c_str(), src_node->GetName().c_str());
}

return SUCCESS;
@@ -255,6 +268,125 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_;
}

void NodeState::ResetContext(int group) {
SetGroup(group);
if (loop_count_ == 0) {
++loop_count_;
return;
}

++loop_count_;
if (loop_count_ == UINT64_MAX) {
loop_count_ = 1;
}

switch_index_ = -1;
const auto &guard = node_item_->MutexGuard("ResetContext");
shape_inference_state_.InitShapeState();
subgraph_context_->ResetContext(node_item_->node);
GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_);
}

void NodeState::ResetSchedule() {
std::lock_guard<std::mutex> lk(mu_);
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
GELOGD("[%s] set schedule for root nodes, data: %u, ctrl: %u", GetName().c_str(), data_scheduled_, ctrl_scheduled_);
}

Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const {
// Schedule data output.
for (const auto &node : node_item_->data_send_) {
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SetDataSchedule(node_item_, ready);
}

// Schedule ctrl output.
for (const auto &node : node_item_->ctrl_send_) {
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SetCtrlSchedule(node_item_, ready);
}

// Schedule switch group.
if (switch_index_ >= 0 && static_cast<uint32_t>(switch_index_) < node_item_->switch_groups_.size()) {
GELOGI("After [%s] scheduled, switch index: %d", GetName().c_str(), switch_index_);
for (const auto &node : node_item_->switch_groups_[switch_index_]) {
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node);
GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SetCtrlSchedule(node_item_, ready);
}
}

return SUCCESS;
}

bool NodeState::IsScheduleReady() const {
GELOGD("[%s] data[input: %zu, scheduled: %u], ctrl[input: %zu, scheduled: %u]", GetName().c_str(),
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), ctrl_scheduled_);
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) {
return false;
}

if (node_item_->IsMergeOp()) {
return data_scheduled_ > 0;
}

// Exit may feed loop times...
return data_scheduled_ >= node_item_->data_recv_.size();
}

void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u",
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), ctrl_scheduled_);

std::lock_guard<std::mutex> lk(mu_);
++data_scheduled_;

if (node_item_->IsMergeOp()) {
const auto it = node_item_->data_recv_.find(node_item);
if (it != node_item_->data_recv_.end()) {
merge_index_ = it->second;
(void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second);
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_item->node_name.c_str(), GetName().c_str(), it->second);
} else {
GELOGW("[%s] scheduled, [%s] not followed", node_item->node_name.c_str(), GetName().c_str());
}
}

if (IsScheduleReady()) {
ready(node_item_);
}
}

void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u",
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), ctrl_scheduled_);

std::lock_guard<std::mutex> lk(mu_);
++ctrl_scheduled_;

if (IsScheduleReady()) {
ready(node_item_);
}
}

void NodeState::SetScheduleFuture(std::future<Status> &&future) {
schedule_future_ = std::move(future);
}

Status NodeState::WaitForScheduleDone() {
if (schedule_future_.valid()) {
GELOGD("[%s] Start to wait for schedule future.", GetName().c_str());
GE_CHK_STATUS_RET(schedule_future_.get(), "[Check][Status][%s] wait thread failed", GetName().c_str());
}

return SUCCESS;
}

Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) {
GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str());
HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_->GetNodeItem()->node), "cancelled");


+ 55
- 2
ge/hybrid/executor/node_state.h View File

@@ -20,6 +20,8 @@
#include <condition_variable>
#include <future>
#include <mutex>

#include "common/blocking_queue.h"
#include "external/ge/ge_api_error_codes.h"
#include "hybrid/model/node_item.h"
#include "node_done_manager.h"
@@ -32,6 +34,8 @@ class SubgraphContext;
class TaskContext;
struct NodeState;

using NodeStatePtr = std::shared_ptr<NodeState>;

class ShapeFuture {
public:
ShapeFuture(NodeState *src_node, uint32_t src_index, SubgraphContext *subgraph_context);
@@ -48,6 +52,8 @@ class ShapeFuture {
struct ShapeInferenceState {
explicit ShapeInferenceState(const NodeItem &node_item);

void InitShapeState();

Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc);

void UpdateInputShapeFuture(int idx, ShapeFuture &&future);
@@ -100,6 +106,43 @@ struct NodeState {

Status UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape);

inline bool IsShapeDependence() const {
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE;
}

void ResetContext(int group);

void ResetSchedule();

Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;

void SetScheduleFuture(std::future<Status> &&future);
Status WaitForScheduleDone();

void SetSwitchIndex(int index) {
switch_index_ = index;
}

int GetSwitchIndex() const {
return switch_index_;
}

void SetMergeIndex(int index) {
merge_index_ = index;
}

int GetMergeIndex() const {
return merge_index_;
}

void SetGroup(int group) {
group_ = group;
}

int GetGroup() const {
return group_;
}

const shared_ptr<NodeTask> &GetKernelTask() const {
return kernel_task_;
}
@@ -120,6 +163,10 @@ struct NodeState {
std::shared_ptr<TaskContext> GetTaskContext();

private:
bool IsScheduleReady() const;
void SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready);
void SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready);

const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr;
std::future<Status> prepare_future_;
@@ -128,9 +175,15 @@ struct NodeState {
SubgraphContext *subgraph_context_;
std::shared_ptr<TaskContext> task_context_ = nullptr;
std::mutex mu_;
};

using NodeStatePtr = std::shared_ptr<NodeState>;
std::future<Status> schedule_future_;
uint64_t loop_count_ = 0;
uint32_t ctrl_scheduled_ = 0;
uint32_t data_scheduled_ = 0;
int merge_index_ = -1; // Use for Execute (Reset after Executed).
int switch_index_ = -1; // Use for Schedule (Reset after Prepared).
int group_ = -1;
};
} // namespace hybrid
} // namespace ge



+ 5
- 0
ge/hybrid/executor/subgraph_context.cc View File

@@ -37,10 +37,15 @@ Status SubgraphContext::Init() {
return SUCCESS;
}

void SubgraphContext::ResetContext(const NodePtr &node) {
node_done_manager_.Reset(node);
}

NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
std::lock_guard<std::mutex> lk(mu_);
auto &node_state = node_states_[node_item];
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
}



+ 1
- 0
ge/hybrid/executor/subgraph_context.h View File

@@ -34,6 +34,7 @@ class SubgraphContext {
~SubgraphContext() = default;

Status Init();
void ResetContext(const NodePtr &node);
NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item);

void OnError(Status error);


+ 259
- 58
ge/hybrid/executor/subgraph_executor.cc View File

@@ -178,7 +178,9 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue
known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(known_shape_task_context_);

HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_),
std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback));
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback),
"[%s] Failed to execute node [%s] for known subgraph.",
graph_item_->GetName().c_str(),
known_shape_task_context_->GetNodeName());
@@ -206,76 +208,256 @@ Status SubgraphExecutor::ExecuteAsync(TaskContext &task_context) {
return SUCCESS;
}

BlockingQueue<const NodeItem *> &SubgraphExecutor::GetPrepareQueue(int group) {
std::lock_guard<std::mutex> lk(mu_);
return prepare_queues_[group];
}

Status SubgraphExecutor::NodeEnqueue(NodeState *node_state) {
if (!ready_queue_.Push(node_state)) {
if (context_->is_eos_) {
GELOGD("Got end of sequence");
return SUCCESS;
}
GELOGE(INTERNAL_ERROR, "[Check][State][%s] Error occurs while launching tasks. quit from preparing nodes.",
graph_item_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "[%s] Error occurs while launching tasks. quit from preparing nodes.",
graph_item_->GetName().c_str());
return INTERNAL_ERROR;
}

GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_state->GetName().c_str());
return SUCCESS;
}

Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) {
GELOGD("[%s] Start to prepare node [%s].", graph_item_->GetName().c_str(), node_item.NodeName().c_str());
// for while op
if (force_infer_shape_ && !node_item.is_dynamic) {
GELOGD("[%s] Force infer shape is set, updating node to dynamic.", node_item.NodeName().c_str());
auto &mutable_node_item = const_cast<NodeItem &>(node_item);
mutable_node_item.SetToDynamic();
}

auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item);
GE_CHECK_NOTNULL(node_state);
node_state->ResetContext(group);
auto p_node_state = node_state.get();

if (node_item.node_type == NETOUTPUT) {
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state));
return AfterPrepared(p_node_state);
}

// only do shape inference and compilation for nodes with dynamic shapes.
if (node_item.is_dynamic) {
auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status {
GetContext().SetSessionId(context_->session_id);
GetContext().SetContextId(context_->context_id);
GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state));
GE_CHK_STATUS_RET_NOLOG(PrepareForExecution(context_, *p_node_state));
return AfterPrepared(p_node_state);
});

p_node_state->SetPrepareFuture(std::move(prepare_future));
return NodeEnqueue(p_node_state);
} else {
GELOGD("[%s] Skipping shape inference and compilation for node with static shape.",
node_item.NodeName().c_str());
if (node_item.kernel_task == nullptr) {
GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str());
GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_),
"[Invoke][Compile] failed for [%s].", p_node_state->GetName().c_str());
} else {
node_state->SetKernelTask(node_item.kernel_task);
}
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str());
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state));
return AfterPrepared(p_node_state);
}
}

Status SubgraphExecutor::PrepareNodes(int group) {
GELOGD("[%s] Start to prepare nodes. group = %d",
graph_item_->GetName().c_str(),
group);
auto &all_nodes = graph_item_->GetAllNodes(group);
for (auto all_node : all_nodes) {
auto &node_item = *all_node;
// for while op
if (force_infer_shape_ && !node_item.is_dynamic) {
GELOGD("[%s] Force infer shape is set, updating node to dynamic.", node_item.NodeName().c_str());
auto &mutable_node_item = const_cast<NodeItem &>(node_item);
mutable_node_item.SetToDynamic();
const size_t node_size = graph_item_->GetNodeSize(group);
GELOGD("[%s] Start to prepare nodes. group = %d, size = %zu", graph_item_->GetName().c_str(), group, node_size);
if (!graph_item_->HasCtrlFlowOp()) {
for (const auto &node_item : graph_item_->GetAllNodes(group)) {
RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] Start");
GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str());
RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End");
}
GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str());
return SUCCESS;
}

GELOGD("[%s] Start to prepare node [%s].", graph_item_->GetName().c_str(), node_item.NodeName().c_str());
auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item);
GE_CHECK_NOTNULL(node_state);
auto p_node_state = node_state.get();

if (node_item.node_type != NETOUTPUT) {
// only do shape inference and compilation for nodes with dynamic shapes.
if (node_item.is_dynamic) {
auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status {
GetContext().SetSessionId(context_->session_id);
GetContext().SetContextId(context_->context_id);
GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state));
return PrepareForExecution(context_, *p_node_state);
});

p_node_state->SetPrepareFuture(std::move(prepare_future));
} else {
GELOGD("[%s] Skipping shape inference and compilation for node with static shape.",
node_item.NodeName().c_str());
if (node_item.kernel_task == nullptr) {
GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str());
GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_),
"[Invoke][Compile] failed for [%s].", p_node_state->GetName().c_str());
} else {
node_state->SetKernelTask(node_item.kernel_task);
}
auto unique_task_context =
TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask();
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str());
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR;
}
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
// Initialize the ready queue
size_t node_count = 0;
bool node_complete = false;
for (const auto &node_item : graph_item_->GetRootNodes(group)) {
RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] Start");
GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str());
RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End");
node_complete = node_item->NodeType() == NETOUTPUT;
node_count++;
}

GELOGD("[%s] Done preparing root nodes.", graph_item_->GetName().c_str());
BlockingQueue<const NodeItem *> &prepare_queue = GetPrepareQueue(group);
while (((group != -1) && (node_count < node_size)) || ((group == -1) && !node_complete)) {
const NodeItem *node_item = nullptr;
if (!prepare_queue.Pop(node_item)) {
if (context_->is_eos_) {
GELOGD("[%s] Got end of sequence.", graph_item_->GetName().c_str());
break;
}
if (context_->GetStatus() != SUCCESS) {
GELOGD("[%s] Graph execution Got failed.", graph_item_->GetName().c_str());
return SUCCESS;
}
GELOGE(INTERNAL_ERROR, "[%s] failed to pop node.", graph_item_->GetName().c_str());
return INTERNAL_ERROR;
}

if (!ready_queue_.Push(p_node_state)) {
if (node_item == nullptr) {
GELOGD("[%s] Got EOF from queue.", graph_item_->GetName().c_str());
break;
}

RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] Start");
GE_CHK_STATUS_RET(PrepareNode(*node_item, group), "[%s] failed to prepare task.", node_item->NodeName().c_str());
RECORD_EXECUTION_EVENT(context_, node_item->NodeName().c_str(), "[PrepareNode] End");
node_complete = node_item->NodeType() == NETOUTPUT;
node_count++;
}

GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str());
return SUCCESS;
}

Status SubgraphExecutor::NodeScheduled(NodeState *node_state) {
GELOGD("Graph[%s] After [%s] scheduled, data size: %zu, ctrl size: %zu, switch index: %d, merge index: %d",
graph_item_->GetName().c_str(), node_state->GetName().c_str(),
node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(),
node_state->GetSwitchIndex(), node_state->GetMergeIndex());
auto future = pre_run_pool_.commit([this, node_state]() -> Status {
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start");
std::function<void(const NodeItem *)> callback = [&](const NodeItem *node_item) {
const auto &node_name = node_item->node_name;
int group = (node_state->GetGroup() != -1) ? node_item->group : -1;
GELOGI("After [%s] scheduled, [%s] is ready for prepare.", node_state->GetName().c_str(), node_name.c_str());
BlockingQueue<const NodeItem *> &prepare_queue = GetPrepareQueue(group);
if (!prepare_queue.Push(node_item)) {
if (!context_->is_eos_) {
GELOGE(INTERNAL_ERROR, "[Check][State][%s] error occurs when push to queue.", graph_item_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "[%s] error occurs when push to queue.", graph_item_->GetName().c_str());
}
}
};

GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback));
node_state->ResetSchedule();
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End");
return SUCCESS;
});

node_state->SetScheduleFuture(std::move(future));
if (schedule_queue_.Push(node_state)) {
return SUCCESS;
}

if (context_->is_eos_) {
GELOGD("[%s] Got end of sequence", graph_item_->GetName().c_str());
return SUCCESS;
}

GELOGE(INTERNAL_ERROR, "[Check][State][%s] error occurs when push to queue.", graph_item_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "[%s] error occurs when push to queue.", graph_item_->GetName().c_str());
return INTERNAL_ERROR;
}

Status SubgraphExecutor::AfterPrepared(NodeState *node_state) {
if (!graph_item_->HasCtrlFlowOp()) {
return SUCCESS;
}
if (node_state->IsShapeDependence()) {
return SUCCESS;
}
// Not control flow node, propagate state.
return NodeScheduled(node_state);
}

void SubgraphExecutor::AfterExecuted(NodeState *node_state) {
if (!node_state->IsShapeDependence()) {
return;
}
// For control flow node, propagate state.
auto error = NodeScheduled(node_state);
if (error != SUCCESS) {
auto task_context = node_state->GetTaskContext();
task_context->OnError(error);
}
}

void SubgraphExecutor::OnNodeDone(NodeState *node_state) {
auto task_context = node_state->GetTaskContext();
NodeDoneCallback cb(context_, task_context);
auto error = cb.OnNodeDone();
if (error != SUCCESS) {
task_context->OnError(error);
}

if (node_state->IsShapeDependence() && graph_item_->HasCtrlFlowOp()) {
AfterExecuted(node_state);
}
}

Status SubgraphExecutor::InitCallback(NodeState *node_state, std::function<void()> &callback) {
auto task_context = node_state->GetTaskContext();
GE_CHECK_NOTNULL(task_context);
if (task_context->NeedCallback()) {
callback = std::bind(&SubgraphExecutor::OnNodeDone, this, node_state);
} else if (node_state->IsShapeDependence() && graph_item_->HasCtrlFlowOp()) {
callback = std::bind(&SubgraphExecutor::AfterExecuted, this, node_state);
}

return SUCCESS;
}

Status SubgraphExecutor::ScheduleNodes() {
GELOGD("[%s] Start to schedule nodes.", graph_item_->GetName().c_str());
while (true) {
NodeState *node_state = nullptr;
if (!schedule_queue_.Pop(node_state)) {
if (context_->is_eos_) {
GELOGD("Got end of sequence");
GELOGD("[%s] Got end of sequence.", graph_item_->GetName().c_str());
break;
}
if (context_->GetStatus() != SUCCESS) {
GELOGD("[%s] Graph execution Got failed.", graph_item_->GetName().c_str());
return SUCCESS;
}
GELOGE(INTERNAL_ERROR, "[Check][State][%s] Error occurs while launching tasks. quit from preparing nodes.",
graph_item_->GetName().c_str());
REPORT_INNER_ERROR("E19999", "[%s] Error occurs while launching tasks. quit from preparing nodes.",
graph_item_->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[%s] failed to pop node.", graph_item_->GetName().c_str());
return INTERNAL_ERROR;
}

GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str());
if (node_state == nullptr) {
GELOGD("[%s] Got EOF from queue.", graph_item_->GetName().c_str());
break;
}

GE_CHK_STATUS_RET_NOLOG(node_state->WaitForScheduleDone());
}

GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str());
GELOGD("[%s] Done schedule nodes successfully.", graph_item_->GetName().c_str());
return SUCCESS;
}

@@ -341,7 +523,10 @@ Status SubgraphExecutor::LaunchTasks() {
auto shared_task_context = node_state->GetTaskContext();
GE_CHECK_NOTNULL(shared_task_context);
shared_task_context->SetForceInferShape(force_infer_shape_);
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_),

std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state, callback));
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_, callback),
"[Invoke][ExecuteAsync] failed for [%s].", node_state->GetName().c_str());
GELOGD("[%s] Done executing node successfully.", node_state->GetName().c_str());
}
@@ -354,22 +539,38 @@ Status SubgraphExecutor::ScheduleTasks(int group) {
GetContext().SetContextId(context_->context_id);
auto ret = PrepareNodes(group);
ready_queue_.Push(nullptr);
schedule_queue_.Push(nullptr);
for (auto &item : prepare_queues_) {
item.second.Push(nullptr);
}
return ret;
});

auto schedule_future = std::async(std::launch::async, [&]() -> Status {
return ScheduleNodes();
});

GELOGD("[%s] Start to execute subgraph.", graph_item_->GetName().c_str());
auto ret = LaunchTasks();
if (ret != SUCCESS) {
subgraph_context_->OnError(ret);
context_->SetErrorCode(ret);
ready_queue_.Stop();
schedule_queue_.Stop();
for (auto &item : prepare_queues_) {
item.second.Stop();
}
prepare_future.wait();
schedule_future.wait();
return ret;
}

GE_CHK_STATUS_RET(prepare_future.get(), "[Invoke][get] [%s] Error occurred in task preparation.",
graph_item_->GetName().c_str());

GE_CHK_STATUS_RET(schedule_future.get(), "[Invoke][get] [%s] Error occurred in task preparation.",
graph_item_->GetName().c_str());

GELOGD("[%s] Done launching all tasks successfully.", graph_item_->GetName().c_str());
return SUCCESS;
}


+ 16
- 0
ge/hybrid/executor/subgraph_executor.h View File

@@ -105,6 +105,18 @@ class SubgraphExecutor {
Status PrepareNodes(int group = -1);
Status LaunchTasks();
Status SetOutputsToParentNode(TaskContext &task_context);
Status InitCallback(NodeState *node_state, std::function<void()> &callback);

Status NodeEnqueue(NodeState *node_state);
Status PrepareNode(const NodeItem &node_item, int group);

BlockingQueue<const NodeItem *> &GetPrepareQueue(int group);

Status ScheduleNodes();
Status NodeScheduled(NodeState *node_state);
Status AfterPrepared(NodeState *node_state);
void AfterExecuted(NodeState *node_state);
void OnNodeDone(NodeState *node_state);

const GraphItem *graph_item_;
GraphExecutionContext *context_;
@@ -114,6 +126,10 @@ class SubgraphExecutor {
BlockingQueue<NodeState *> ready_queue_;
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_;
std::shared_ptr<TaskContext> known_shape_task_context_;

std::mutex mu_; // Guard for prepare_queues_.
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_;
BlockingQueue<NodeState *> schedule_queue_;
};
} // namespace hybrid
} // namespace ge


+ 4
- 36
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -22,7 +22,6 @@
#include "graph/load/model_manager/model_manager.h"
#include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor//worker//shape_inference_engine.h"
#include "common/dump/dump_op.h"
#include "common/profiling/profiling_manager.h"

namespace ge {
@@ -62,22 +61,6 @@ Status LogOutputs(const NodeItem &node_item, const TaskContext &task_context) {
return SUCCESS;
}
} // namespace
class NodeDoneCallback {
public:
NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr<TaskContext> task_context);
~NodeDoneCallback() = default;
Status OnNodeDone();
private:
Status PrepareConstInputs(const NodeItem &node_item);
Status DumpDynamicNode();
Status ProfilingReport();
Status SaveDumpOpInfo();
Status GetTaskDescInfo(const NodePtr node, const HybridModel *model,
std::vector<TaskDescInfo> &task_desc_info);
GraphExecutionContext *graph_context_;
std::shared_ptr<TaskContext> context_;
DumpOp dump_op_;
};

NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context,
std::shared_ptr<TaskContext> task_context)
@@ -334,6 +317,7 @@ Status NodeDoneCallback::OnNodeDone() {
GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item));
if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) {
// update output tensor sizes
const auto &guard = node_item.MutexGuard("OnNodeDone");
GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item));
GE_CHK_STATUS_RET_NOLOG(context_->GetNodeState()->GetShapeInferenceState().UpdateOutputDesc());
}
@@ -361,31 +345,15 @@ Status NodeDoneCallback::OnNodeDone() {

Status ExecutionEngine::ExecuteAsync(NodeState &node_state,
const std::shared_ptr<TaskContext> &task_context,
GraphExecutionContext &execution_context) {
GraphExecutionContext &execution_context,
const std::function<void()> &callback) {
GELOGI("[%s] Node is ready for execution", task_context->GetNodeName());
RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start");
std::function<void()> callback = nullptr;
GE_CHK_STATUS_RET_NOLOG(InitCallback(task_context, execution_context, callback));
GE_CHK_STATUS_RET_NOLOG(DoExecuteAsync(node_state, *task_context, execution_context, callback));
GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_state.GetNodeItem(), *task_context, execution_context));
return SUCCESS;
}

Status ExecutionEngine::InitCallback(const std::shared_ptr<TaskContext> &task_context,
GraphExecutionContext &execution_context, std::function<void()> &callback) {
if (task_context->NeedCallback()) {
auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context));
GE_CHECK_NOTNULL(cb);
callback = [task_context, cb]() {
auto ret = cb->OnNodeDone();
if (ret != SUCCESS) {
task_context->OnError(ret);
}
};
}
return SUCCESS;
}

Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
TaskContext &task_context,
GraphExecutionContext &context,
@@ -423,7 +391,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
node_state.GetName().c_str());
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ValidateInputTensors] End");

if (context.profiling_level > 0) {
if (GraphExecutionContext::profiling_level > 0) {
auto *ctx = &context;
const string &name = node_state.GetName();
(void)task_context.RegisterCallback([ctx, name]() {


+ 20
- 3
ge/hybrid/executor/worker/execution_engine.h View File

@@ -19,14 +19,33 @@

#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/node_executor/task_context.h"
#include "common/dump/dump_op.h"

namespace ge {
namespace hybrid {
class NodeDoneCallback {
public:
NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr<TaskContext> task_context);
~NodeDoneCallback() = default;
Status OnNodeDone();
private:
Status PrepareConstInputs(const NodeItem &node_item);
Status DumpDynamicNode();
Status ProfilingReport();
Status SaveDumpOpInfo();
Status GetTaskDescInfo(const NodePtr node, const HybridModel *model,
std::vector<TaskDescInfo> &task_desc_info);
GraphExecutionContext *graph_context_;
std::shared_ptr<TaskContext> context_;
DumpOp dump_op_;
};

class ExecutionEngine {
public:
static Status ExecuteAsync(NodeState &node_state,
const std::shared_ptr<TaskContext> &task_context,
GraphExecutionContext &execution_context);
GraphExecutionContext &execution_context,
const std::function<void()> &callback);

private:
static Status ValidateInputTensors(const NodeState &node_state, const TaskContext &task_context);
@@ -35,8 +54,6 @@ class ExecutionEngine {
TaskContext &task_context,
GraphExecutionContext &context,
const std::function<void()> &callback);
static Status InitCallback(const std::shared_ptr<TaskContext> &task_context,
GraphExecutionContext &execution_context, std::function<void()> &callback);
};
} // namespace hybrid
} // namespace ge


+ 3
- 1
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -45,6 +45,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
return SUCCESS;
}

const auto &guard = node_item.MutexGuard("InferShape");
if (node_item.fused_subgraph != nullptr) {
GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph));
GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item));
@@ -123,8 +124,9 @@ Status ShapeInferenceEngine::PropagateOutputShapes(NodeState &node_state) {
node_item.shape_inference_type);
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start");
// propagate each output
const auto &guard = node_item.MutexGuard("PropagateOutputShapes");
for (int i = 0; i < node_item.num_outputs; ++i) {
auto output_desc = node_item.op_desc->MutableOutputDesc(i);
auto output_desc = node_item.MutableOutputDesc(i);
auto &output_nodes = node_item.outputs[i];

// propagate output to all sub-inputs


+ 36
- 5
ge/hybrid/model/graph_item.cc View File

@@ -43,6 +43,27 @@ const vector<NodeItem *> &GraphItem::GetAllNodes(int group) const {
return grouped_node_items_[group];
}

const vector<NodeItem *> &GraphItem::GetRootNodes(int group) const {
if (group == -1) {
return root_items_;
}

if (static_cast<uint32_t>(group) >= grouped_root_items_.size()) {
static vector<NodeItem *> empty_nodes;
return empty_nodes;
}

return grouped_root_items_[group];
}

size_t GraphItem::GetNodeSize(int group) const {
if (group == -1) {
return node_items_.size();
}

return (static_cast<uint32_t>(group) < grouped_node_items_.size()) ? grouped_node_items_[group].size() : 0;
}

const vector<const NodeItem *> &GraphItem::GetInputNodes() const {
return input_nodes_;
}
@@ -88,10 +109,12 @@ const vector<std::pair<const NodeItem *, int>> &GraphItem::GetOutputEdges() cons
return output_edges_;
}

Status GraphItem::GroupNodes() {
Status GraphItem::GroupNodes(const std::vector<NodeItem *> &node_items,
std::vector<std::vector<NodeItem *>> &grouped_node_items) const {
int curr_group = 0;
int last_group = INT32_MIN;
std::set<int> seen_groups;
for (auto node : node_items_) {
for (auto node : node_items) {
int group = node->group;
if (group != last_group) {
if (seen_groups.find(group) != seen_groups.end()) {
@@ -101,15 +124,23 @@ Status GraphItem::GroupNodes() {
} else {
last_group = group;
seen_groups.insert(group);
grouped_node_items_.emplace_back(std::vector<NodeItem *>());
curr_group = static_cast<int>(grouped_node_items.size());
grouped_node_items.emplace_back(std::vector<NodeItem *>());
}
}

GELOGD("Adding node [%s] to group %d", node->NodeName().c_str(), group);
grouped_node_items_.back().emplace_back(node);
node->group = curr_group;
GELOGD("Adding node [%s] to group %d", node->NodeName().c_str(), node->group);
grouped_node_items.back().emplace_back(node);
}

return SUCCESS;
}

Status GraphItem::GroupNodes() {
GE_CHK_STATUS_RET_NOLOG(GroupNodes(node_items_, grouped_node_items_));
GE_CHK_STATUS_RET_NOLOG(GroupNodes(root_items_, grouped_root_items_));
return SUCCESS;
}
} // namespace hybrid
} // namespace ge

+ 13
- 0
ge/hybrid/model/graph_item.h View File

@@ -29,6 +29,7 @@ class GraphItem {
Status GroupNodes();
const vector<NodeItem *> &GetAllNodes() const;
const vector<NodeItem *> &GetAllNodes(int group) const;
const vector<NodeItem *> &GetRootNodes(int group) const;
const vector<const NodeItem *> &GetInputNodes() const;
Status GetOutputDescList(std::vector<ConstGeTensorDescPtr> &output_desc_list) const;
const vector<std::pair<const NodeItem *, int>> &GetOutputEdges() const;
@@ -40,6 +41,12 @@ class GraphItem {
return total_outputs_;
}

size_t GetNodeSize(int group) const;

bool HasCtrlFlowOp() const {
return has_ctrl_flow_op_;
}

const std::string& GetName() const {
return name_;
}
@@ -60,9 +67,14 @@ class GraphItem {

private:
friend class HybridModelBuilder;
Status GroupNodes(const std::vector<NodeItem *> &node_items,
std::vector<std::vector<NodeItem *>> &grouped_node_items) const;

std::string name_;
std::vector<NodeItem *> node_items_;
std::vector<std::vector<NodeItem *>> grouped_node_items_;
std::vector<NodeItem *> root_items_;
std::vector<std::vector<NodeItem *>> grouped_root_items_;
std::vector<const NodeItem *> input_nodes_;
const NodeItem *output_node_ = nullptr;
// <src_node, out_index>
@@ -71,6 +83,7 @@ class GraphItem {
int total_outputs_ = 0;

bool is_dynamic_ = true;
bool has_ctrl_flow_op_ = false;
std::vector<int> input_index_mapping_;
std::vector<int> output_index_mapping_;
};


+ 345
- 23
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -17,6 +17,7 @@
#include "hybrid/model/hybrid_model_builder.h"
#include <algorithm>
#include "common/math/math_util.h"
#include "common/op/ge_op_utils.h"
#include "graph/ge_context.h"
#include "graph/build/memory/var_mem_assign_util.h"
#include "graph/debug/ge_attr_define.h"
@@ -42,6 +43,11 @@ const uint64_t kProfilingFpStartLogid = 1U;
const uint64_t kProfilingBpEndLogid = 2U;
const uint64_t kProfilingIterEndLogid = 65535U;
const int kBytes = 8;
const int kDecimal = 10;
const uint8_t kStreamActiveIdx = 0;
const uint8_t kStreamActiveNum = 1;
const uint8_t kStreamSwitchIdx = 1;
const uint8_t kStreamSwitchNum = 2;
const uint32_t kStringHeadElems = 2;
const char *const kOwnerGraphIsUnknown = "OwnerGraphIsUnknown";
const char *const kProfilingGraph = "ProfilingGraph";
@@ -213,6 +219,7 @@ Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_ite
"[Invoke][GetCanonicalInputIndex] failed, dst_node:[%s].", dst_node->GetName().c_str());

node_item.outputs[i].emplace_back(canonical_index, dst_node_item);
node_item.SetDataSend(dst_node_item, dst_in_anchor->GetIdx());
}
}

@@ -300,8 +307,9 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
}
auto src_node = peer_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);
NodeItem *src_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(src_node, &src_node_item),
"[%s] failed to get or create node item", src_node->GetName().c_str());

if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) {
GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d",
@@ -323,15 +331,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
}
}

for (const auto &src_node : ge_node->GetInControlNodes()) {
auto src_node_item = MutableNodeItem(src_node);
if ((src_node_item != nullptr) && (is_hccl_op || src_node_item->IsHcclOp())) {
GELOGD("[%s](%s) Add input control dependent node [%s](%s)",
ge_node->GetName().c_str(),
ge_node->GetType().c_str(),
src_node->GetName().c_str(),
src_node->GetType().c_str());
dependent_for_execution.emplace(src_node);
if (node_item.node_type == NETOUTPUT) {
for (const auto &src_node : ge_node->GetInControlNodes()) {
auto src_node_item = MutableNodeItem(src_node);
if ((src_node_item != nullptr) && src_node_item->IsHcclOp()) {
GELOGD("[%s](%s) Add input control dependent node [%s](%s)",
ge_node->GetName().c_str(),
ge_node->GetType().c_str(),
src_node->GetName().c_str(),
src_node->GetType().c_str());
dependent_for_execution.emplace(src_node);
}
}
}

@@ -794,6 +804,7 @@ Status HybridModelBuilder::LoadGraph() {
}

hybrid_model_.root_graph_ = root_graph;
GE_CHK_STATUS_RET(RelinkNextIteration(), "[%s] Relink NextIteration failed", GetGraphName());
// Reset node id by topological order across all subgraphs
int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) {
@@ -839,7 +850,7 @@ Status HybridModelBuilder::LoadGraph() {
parent_node_item->NodeName().c_str());

// if parent is function control op. need add a virtual partitioned call
if (parent_node_item->IsControlOp()) {
if (parent_node_item->IsControlFlowV2Op()) {
GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item),
"[Invoke][LoadKnownShapedSubgraph]Failed to load function control op subgraph [%s]",
sub_graph->GetName().c_str());
@@ -1169,7 +1180,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr
auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node);
auto op_type = parent_node->GetType();
if (IsControlOp(op_type)) {
if (IsControlFlowV2Op(op_type)) {
GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d",
sub_graph.GetName().c_str(),
ge_model->GetModelTaskDefPtr()->task_size());
@@ -1325,6 +1336,10 @@ Status HybridModelBuilder::IndexSpecialNodes() {
}
} else if (op_type == CONSTANTOP) {
constant_op_nodes_.emplace(node->GetName(), node);
} else if (op_type == STREAMMERGE) {
stream_merge_op_nodes_.emplace(node->GetName(), node);
} else if (op_type == NEXTITERATION || op_type == REFNEXTITERATION) {
next_iteration_op_nodes_.emplace(node->GetName(), node);
} else if (op_type == DATA && node->GetOwnerComputeGraph() != root_graph) {
NodePtr src_node;
int peer_out_index = -1;
@@ -1825,7 +1840,7 @@ Status HybridModelBuilder::GenerateEndProfilingTask(const OpDescPtr &op_desc, ve
return SUCCESS;
}

Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node) {
Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node, uint32_t &prev_num) {
GE_CHECK_NOTNULL(node);
const OpDescPtr &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -1871,7 +1886,7 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons
if (!node_task_map.empty()) {
for (const auto &node_task : node_task_map) {
NodePtr profiling_node = node_task.first;
vector<domi::TaskDef> task_def_lists = node_task.second;
const vector<domi::TaskDef> &task_def_lists = node_task.second;
for (const auto &task_def : task_def_lists) {
hybrid_model_.task_defs_[profiling_node].emplace_back(task_def);
}
@@ -1886,6 +1901,7 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons
node_item->input_start = 0;
node_item->output_start = 0;
graph_item.node_items_.emplace_back(node_item);
++prev_num;
}
} else {
GELOGD("No need to create profiling node before.");
@@ -1894,7 +1910,7 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons
return SUCCESS;
}

Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node) {
Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node, uint32_t &post_num) {
GE_CHECK_NOTNULL(node);
const OpDescPtr &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -1952,7 +1968,7 @@ Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const
if (!node_task_map.empty()) {
for (const auto &node_task : node_task_map) {
NodePtr profiling_node = node_task.first;
vector<domi::TaskDef> task_def_lists = node_task.second;
const vector<domi::TaskDef> &task_def_lists = node_task.second;
for (const auto &task_def : task_def_lists) {
hybrid_model_.task_defs_[profiling_node].emplace_back(task_def);
}
@@ -1967,6 +1983,7 @@ Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const
node_item->input_start = 0;
node_item->output_start = 0;
graph_item.node_items_.emplace_back(node_item);
++post_num;
}
} else {
GELOGD("No need to create profiling node after.");
@@ -1986,20 +2003,23 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
int input_start = 0;
int output_start = 0;
std::vector<NodeItem *> data_nodes;
std::map<size_t, std::pair<uint32_t, uint32_t>> profiling_nodes;
for (auto &node : graph.GetDirectNode()) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
const auto &op_type = node->GetType();
if (op_type == NOOP) {
GELOGD("[%s] Skip NoOp", node->GetName().c_str());
continue;
}

NodeItem *node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item));
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item));
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task

GE_CHK_STATUS_RET_NOLOG(BuildControlFlowGroup(*graph_item, node, node_item));
if (node->GetInAllNodes().empty()) {
graph_item->root_items_.emplace_back(node_item);
GELOGD("[%s] add to root node list", node->GetName().c_str());
}

node_item->input_start = input_start;
node_item->output_start = output_start;
input_start += node_item->num_inputs;
@@ -2011,9 +2031,16 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
graph_item->output_node_ = node_item;
GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item, is_root_graph));
}
GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeBefore(*graph_item, node));

uint32_t prev_num = 0;
uint32_t post_num = 0;
GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeBefore(*graph_item, node, prev_num));
size_t node_index = graph_item->node_items_.size();
graph_item->node_items_.emplace_back(node_item);
GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeAfter(*graph_item, node));
GE_CHK_STATUS_RET_NOLOG(CreateProfilingNodeAfter(*graph_item, node, post_num));
if (prev_num > 0 || post_num > 0) {
profiling_nodes[node_index] = { prev_num, post_num };
}
// parse var outputs
GE_CHK_STATUS_RET_NOLOG(ParseVarOutputs(*node_item));
GELOGD("NodeItem created: %s", node_item->DebugString().c_str());
@@ -2022,6 +2049,7 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
graph_item->total_inputs_ = input_start;
graph_item->total_outputs_ = output_start;
GE_CHK_STATUS_RET_NOLOG(BuildInputMapping(*graph_item, data_nodes, is_root_graph));
GE_CHK_STATUS_RET_NOLOG(BuildProfilingControl(*graph_item, profiling_nodes));
if (is_root_graph) {
graph_item->SetName("Root-Graph");
GELOGD("Done loading dynamic subgraph: [%s]", graph_item->GetName().c_str());
@@ -2271,5 +2299,299 @@ Status HybridModelBuilder::Convert2HostTensor(const NodePtr &node, int node_id,
hybrid_model_.host_tensors_[node_id].emplace_back(output_idx, std::move(tensor));
return SUCCESS;
}

Status HybridModelBuilder::RelinkNextIteration() {
for (const auto &item : stream_merge_op_nodes_) {
const auto &merge = item.second;
std::string node_name;
if (!AttrUtils::GetStr(merge->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, node_name)) {
GELOGD("[%s] no attribute[%s], not in while loop", merge->GetName().c_str(), ATTR_NAME_NEXT_ITERATION.c_str());
continue;
}

const auto it = next_iteration_op_nodes_.find(node_name);
if (it == next_iteration_op_nodes_.end()) {
GELOGE(INTERNAL_ERROR, "[%s] expect NextIteration[%s] not found", merge->GetName().c_str(), node_name.c_str());
return INTERNAL_ERROR;
}

const auto &iteration = it->second;
if (GraphUtils::AddEdge(iteration->GetOutDataAnchor(0), merge->GetInDataAnchor(1)) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[%s] -> [%s] Add edge failed", node_name.c_str(), merge->GetName().c_str());
return INTERNAL_ERROR;
}
}

stream_merge_op_nodes_.clear();
next_iteration_op_nodes_.clear();
return SUCCESS;
}

Status HybridModelBuilder::BuildProfilingControl(GraphItem &graph_item,
const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes) {
const auto node_size = graph_item.node_items_.size();
for (const auto &item : nodes) {
const auto node_index = item.first;
GE_CHK_BOOL_RET_STATUS(node_index < node_size, FAILED, "node index invalid");
const auto &node_item = graph_item.node_items_[node_index];
if (item.second.first > 0) {
const auto prev_num = item.second.first;
if (node_index == prev_num) {
// Profiling Before root node.
for (uint32_t i = 1; i <= prev_num; ++i) {
GE_CHK_BOOL_RET_STATUS(node_index - i < node_size, FAILED, "prev index invalid");
const auto &curr_item = graph_item.node_items_[node_index - i];
graph_item.root_items_.emplace(graph_item.root_items_.begin(), curr_item);
}
} else {
GE_CHK_BOOL_RET_STATUS((node_index - prev_num) - 1 < node_size, FAILED, "prev index invalid");
const auto &prev_item = graph_item.node_items_[(node_index - prev_num) - 1];
for (uint32_t i = 1; i <= prev_num; ++i) {
GE_CHK_BOOL_RET_STATUS(node_index - i < node_size, FAILED, "prev index invalid");
const auto &curr_item = graph_item.node_items_[node_index - i];
prev_item->SetCtrlSend(curr_item, UINT32_MAX);
curr_item->SetCtrlSend(node_item, UINT32_MAX);
}
}
}

if (item.second.second > 0) {
const auto post_num = item.second.second;
if (node_size == node_index + post_num + 1) {
// Profiling After last node.
for (uint32_t i = 1; i <= post_num; ++i) {
GE_CHK_BOOL_RET_STATUS(node_index + i < node_size, FAILED, "post index invalid");
const auto &curr_item = graph_item.node_items_[node_index + i];
node_item->SetCtrlSend(curr_item, UINT32_MAX);
}
} else {
GE_CHK_BOOL_RET_STATUS((node_index + post_num) + 1 < node_size, FAILED, "post index invalid");
const auto &post_item = graph_item.node_items_[(node_index + post_num) + 1];
for (uint32_t i = 1; i <= post_num; ++i) {
GE_CHK_BOOL_RET_STATUS(node_index + i < node_size, FAILED, "post index invalid");
const auto &curr_item = graph_item.node_items_[node_index + i];
node_item->SetCtrlSend(curr_item, UINT32_MAX);
curr_item->SetCtrlSend(post_item, UINT32_MAX);
}
}
}
}
return SUCCESS;
}

Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item) {
GELOGD("Build control flow for node %s", node->GetName().c_str());
using GroupBuilder = std::function<Status(HybridModelBuilder *, const NodePtr &, NodeItem *)>;
static const std::map<std::string, GroupBuilder> control_flow{
{ STREAMACTIVE, &HybridModelBuilder::CreateStreamActiveGroup },
{ STREAMSWITCH, &HybridModelBuilder::CreateStreamSwitchGroup },
{ STREAMSWITCHN, &HybridModelBuilder::CreateStreamSwitchNGroup },
{ NEXTITERATION, &HybridModelBuilder::CreateNextIterationGroup },
{ REFNEXTITERATION, &HybridModelBuilder::CreateNextIterationGroup },
{ SWITCH, &HybridModelBuilder::CreateSwitchGroup },
{ REFSWITCH, &HybridModelBuilder::CreateSwitchGroup },
{ LABELSET, &HybridModelBuilder::CreateLabelSetGroup },
{ LABELGOTO, &HybridModelBuilder::CreateLabelGotoGroup },
{ LABELGOTOEX, &HybridModelBuilder::CreateLabelGotoGroup },
{ LABELSWITCH, &HybridModelBuilder::CreateLabelSwitchGroup },
{ LABELSWITCHBYINDEX, &HybridModelBuilder::CreateLabelSwitchGroup }
};

Status ret = SUCCESS;
auto it = control_flow.find(node_item->node_type);
if (it == control_flow.end()) {
ret = CreateNormalNodeGroup(node, node_item);
} else {
graph_item.has_ctrl_flow_op_ = true;
ret = it->second(this, node, node_item);
}
GELOGD("Node: %s, control by: %zu, control for: %zu, switch group: %zu", node->GetName().c_str(),
node_item->ctrl_recv_.size(), node_item->ctrl_send_.size(), node_item->switch_groups_.size());
return ret;
}

Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) {
const auto out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
node_item->SetCtrlSend(dst_node_item, UINT32_MAX);
}
return SUCCESS;
}

Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != STREAMACTIVE) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str());
return INTERNAL_ERROR;
}

node_item->switch_groups_.resize(kStreamActiveNum);
const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
if (dst_node->GetType() == STREAMMERGE) {
GELOGI("[%s] skip control node: %s", node->GetName().c_str(), dst_node->GetName().c_str());
continue;
}

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
node_item->SetCtrlSend(dst_node_item, kStreamActiveIdx);
}
return SUCCESS;
}

Status HybridModelBuilder::CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != STREAMSWITCH) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str());
return INTERNAL_ERROR;
}

// Consider as two groups, group[0] set empty for false, group[1] for true.
node_item->switch_groups_.resize(kStreamSwitchNum);
const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
node_item->SetCtrlSend(dst_node_item, kStreamSwitchIdx);
}
return SUCCESS;
}

Status HybridModelBuilder::CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != STREAMSWITCHN) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str());
return INTERNAL_ERROR;
}

uint32_t batch_num = 0;
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_BATCH_NUM, batch_num)) {
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_NUM failed", node->GetName().c_str());
return INTERNAL_ERROR;
}

if (batch_num == 0) {
GELOGW("[%s] Got empty branch for SwitchN, Please check.", node->GetName().c_str());
return SUCCESS;
}

node_item->switch_groups_.resize(batch_num);
const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);

std::string batch_label;
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", node->GetName().c_str());
return INTERNAL_ERROR;
}

std::string::size_type pos = batch_label.rfind("_");
if (pos == std::string::npos) {
GELOGW("[%s] Separator not found in batch label: %s.", node->GetName().c_str(), batch_label.c_str());
continue;
}

++pos; // Skip Separator
uint64_t batch_index = std::strtoul(batch_label.data() + pos, nullptr, kDecimal);
if (batch_index >= batch_num) {
GELOGW("batch label: %s, batch index: %lu great than batch num: %u", batch_label.c_str(), batch_index, batch_num);
continue;
}

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
node_item->SetCtrlSend(dst_node_item, batch_index);
}

return SUCCESS;
}

Status HybridModelBuilder::CreateNextIterationGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != NEXTITERATION && node_item->node_type != REFNEXTITERATION) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str());
return INTERNAL_ERROR;
}

return SUCCESS;
}

Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != SWITCH && node_item->node_type != REFSWITCH) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str());
return INTERNAL_ERROR;
}

const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
node_item->SetCtrlSend(dst_node_item, UINT32_MAX);
}

// Group switch flow by out put data.
node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM);
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
const auto &out_anchor = node->GetOutDataAnchor(i);
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
node_item->SetCtrlSend(dst_node_item, i); // take switch data as ctrl.
}
}

return SUCCESS;
}

Status HybridModelBuilder::CreateLabelSetGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != LABELSET) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str());
return INTERNAL_ERROR;
}

GELOGE(UNSUPPORTED, "[%s] Not implemented.", node->GetName().c_str());
return UNSUPPORTED;
}

Status HybridModelBuilder::CreateLabelGotoGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != LABELGOTO && node_item->node_type != LABELGOTOEX) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str());
return INTERNAL_ERROR;
}

GELOGE(UNSUPPORTED, "[%s] Not implemented.", node->GetName().c_str());
return UNSUPPORTED;
}

Status HybridModelBuilder::CreateLabelSwitchGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != LABELSWITCH && node_item->node_type != LABELSWITCHBYINDEX) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node->GetName().c_str());
return INTERNAL_ERROR;
}

GELOGE(UNSUPPORTED, "[%s] Not implemented.", node->GetName().c_str());
return UNSUPPORTED;
}
} // namespace hybrid
} // namespace ge

+ 18
- 2
ge/hybrid/model/hybrid_model_builder.h View File

@@ -85,8 +85,8 @@ class HybridModelBuilder {
Status LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item);
Status RecoverGraphUnknownFlag();
Status CheckAicpuOpList();
Status CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node);
Status CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node);
Status CreateProfilingNodeBefore(GraphItem &graph_item, const NodePtr &node, uint32_t &prev_num);
Status CreateProfilingNodeAfter(GraphItem &graph_item, const NodePtr &node, uint32_t &post_num);
Status GenerateFpProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list);
Status GenerateBpProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list);
Status GenerateEndProfilingTask(const OpDescPtr &op_desc, vector<domi::TaskDef> &task_def_list);
@@ -94,6 +94,20 @@ class HybridModelBuilder {
Status OptimizeDependenciesForConstantInputs();
Status Convert2HostTensor(const NodePtr &node, int node_id, uint32_t output_idx);

Status RelinkNextIteration();
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes);
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item);
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item);
Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item);
Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item);
Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item);
Status CreateNextIterationGroup(const NodePtr &node, NodeItem *node_item);

Status CreateSwitchGroup(const NodePtr &node, NodeItem *node_item);
Status CreateLabelSetGroup(const NodePtr &node, NodeItem *node_item);
Status CreateLabelGotoGroup(const NodePtr &node, NodeItem *node_item);
Status CreateLabelSwitchGroup(const NodePtr &node, NodeItem *node_item);

const char* GetGraphName() const {
return hybrid_model_.model_name_.c_str();
}
@@ -104,6 +118,8 @@ class HybridModelBuilder {
GeRootModelPtr ge_root_model_;
std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, NodePtr> stream_merge_op_nodes_;
std::map<std::string, NodePtr> next_iteration_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_;



+ 58
- 7
ge/hybrid/model/node_item.cc View File

@@ -29,10 +29,19 @@ namespace hybrid {
namespace {
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char *const kNodeTypeRetVal = "_RetVal";
std::set<std::string> kControlOpTypes{
const std::set<std::string> kControlOpTypes{
IF, STATELESSIF, CASE, WHILE, STATELESSWHILE
};

const std::set<std::string> kControlFlowOpTypes{
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX,
NEXTITERATION, REFNEXTITERATION
};

const std::set<std::string> kMergeOpTypes{
MERGE, REFMERGE, STREAMMERGE
};

Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0;
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
@@ -107,7 +116,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) {
}
} // namespace

bool IsControlOp(const std::string &op_type) {
bool IsControlFlowV2Op(const std::string &op_type) {
return kControlOpTypes.count(op_type) > 0;
}

@@ -226,7 +235,7 @@ Status NodeItem::ResolveStaticInputsAndOutputs() {
}

void NodeItem::ResolveUnknownShapeType() {
if (IsControlOp() || node_type == PARTITIONEDCALL) {
if (IsControlFlowV2Op() || (is_dynamic && node_type == PARTITIONEDCALL)) {
shape_inference_type = DEPEND_COMPUTE;
} else {
int32_t unknown_shape_type_val = 0;
@@ -236,6 +245,10 @@ void NodeItem::ResolveUnknownShapeType() {
}

Status NodeItem::Init() {
is_ctrl_flow_v2_op_ = ge::hybrid::IsControlFlowV2Op(node_type);
is_ctrl_flow_op_ = kControlFlowOpTypes.count(node_type) > 0;
is_merge_op_ = kMergeOpTypes.count(node_type) > 0;
is_root_node_ = node->GetInAllNodes().empty();
GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs());
GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState());
ResolveUnknownShapeType();
@@ -244,14 +257,12 @@ Status NodeItem::Init() {
GE_CHK_STATUS_RET(ParseFusedSubgraph(*this),
"[Invoke][ParseFusedSubgraph][%s] Failed to parse fused subgraph", node_name.c_str());
}
copy_mu_ = MakeShared<std::mutex>();
GE_CHECK_NOTNULL(copy_mu_);

return SUCCESS;
}

bool NodeItem::IsControlOp() const {
return ge::hybrid::IsControlOp(op_desc->GetType());
}

bool NodeItem::IsHcclOp() const {
return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL;
}
@@ -383,5 +394,45 @@ bool NodeItem::IsInputShapeStatic(int index) const {

return is_input_shape_static_[index];
}

void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
data_send_.emplace(node_item);
node_item->data_recv_[this] = anchor_index;
if (is_root_node_) {
node_item->root_data_.emplace(this);
}
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
}

void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
if (switch_index < switch_groups_.size()) {
std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index];
switch_group.emplace_back(node_item);
} else {
ctrl_send_.insert(node_item);
}

node_item->ctrl_recv_.emplace(this);
if (is_root_node_) {
node_item->root_ctrl_.emplace(this);
}

GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
}

OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) {
if (mu_ != nullptr) {
GELOGD("lock for %s", name_.c_str());
mu_->lock();
}
}

OptionalMutexGuard::~OptionalMutexGuard() {
if (mu_ != nullptr) {
GELOGD("unlock for %s", name_.c_str());
mu_->unlock();
mu_ = nullptr;
}
}
} // namespace hybrid
} // namespace ge

+ 43
- 3
ge/hybrid/model/node_item.h View File

@@ -37,7 +37,16 @@ struct FusedSubgraph {
ComputeGraphPtr graph;
};

bool IsControlOp(const std::string &op_type);
bool IsControlFlowV2Op(const std::string &op_type);

class OptionalMutexGuard {
public:
OptionalMutexGuard(std::mutex *mutex, const string &name);
~OptionalMutexGuard();
private:
std::mutex *mu_{nullptr};
std::string name_;
};

// for caching static information across execution
struct NodeItem {
@@ -70,12 +79,29 @@ struct NodeItem {

Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const;

bool IsControlOp() const;
bool IsControlFlowV2Op() const {
return is_ctrl_flow_v2_op_;
}

bool IsControlFlowOp() const {
return is_ctrl_flow_op_;
}

bool IsMergeOp() const {
return is_merge_op_;
}

bool IsHcclOp() const;

void SetToDynamic();

void SetDataSend(NodeItem *node_item, int anchor_index);
void SetCtrlSend(NodeItem *node_item, uint32_t switch_index);

OptionalMutexGuard MutexGuard(const std::string &name) const {
return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name);
}

std::string DebugString() const;

NodePtr node;
@@ -99,7 +125,20 @@ struct NodeItem {
std::set<int> to_const_output_id_list;

// src_output_id, dst_anchor_id, dst_node
vector<vector<pair<int, NodeItem *>>> outputs;
std::vector<std::vector<std::pair<int, NodeItem *>>> outputs;

// for linked drive
bool is_root_node_ = false;
bool is_ctrl_flow_v2_op_ = false;
bool is_ctrl_flow_op_ = false;
bool is_merge_op_ = false;
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node
std::set<const NodeItem *> root_data_; // Recv data from root node
std::set<const NodeItem *> data_send_; // Send data notify to
std::map<const NodeItem *, int> data_recv_; // Recv data notify from
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to

std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph;
@@ -122,6 +161,7 @@ struct NodeItem {

std::vector<bool> is_input_shape_static_;
std::vector<uint32_t> input_desc_indices_;
std::shared_ptr<std::mutex> copy_mu_;
mutable std::mutex mu_;
};
} // namespace hybrid


+ 2
- 2
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc View File

@@ -32,7 +32,7 @@ REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGR

Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] Start");
GELOGD("[%s] KnownNodeTask::ExecuteAsync in.", context.GetNodeName());
GELOGD("[%s] KnownNodeTask::ExecuteAsync in, model id: %u.", context.GetNodeName(), davinci_model_->Id());
if (davinci_model_->GetTaskList().empty()) {
GELOGW("KnownNodeExecutor::ExecuteAsync davinci model has no taskinfo.");

@@ -62,7 +62,7 @@ Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> d
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End");

GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback));
GELOGD("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName());
GELOGD("[%s] KnownNodeTask::ExecuteAsync success, model id: %u.", context.GetNodeName(), davinci_model_->Id());
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End");
return SUCCESS;
}


+ 8
- 20
ge/hybrid/node_executor/controlop/control_op_executor.cc View File

@@ -22,18 +22,6 @@
namespace ge {
namespace hybrid {
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::CONTROL_OP, ControlOpNodeExecutor);
namespace {
template<typename T>
Status CopyScalarValueToHost(const TensorValue &tensor, T &value) {
GE_CHECK_GE(tensor.GetSize(), sizeof(value));
GE_CHK_RT_RET(rtMemcpy(&value,
sizeof(value),
tensor.GetData(),
sizeof(value),
RT_MEMCPY_DEVICE_TO_HOST));
return SUCCESS;
}
}

Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph,
TaskContext &task_context,
@@ -60,12 +48,12 @@ Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph,

Status ControlOpNodeTask::ToBool(const TensorValue &tensor, DataType data_type, bool &value) {
switch (data_type) {
#define CASE(DT, T) \
case (DT): { \
T val{}; \
GE_CHK_STATUS_RET(CopyScalarValueToHost(tensor, val)); \
value = val != 0; \
break; \
#define CASE(DT, T) \
case (DT): { \
T val{}; \
GE_CHK_STATUS_RET(tensor.CopyScalarValueToHost(val)); \
value = val != 0; \
break; \
}
// DT_STRING was handled in CondPass
CASE(DT_FLOAT, float)
@@ -77,7 +65,7 @@ Status ControlOpNodeTask::ToBool(const TensorValue &tensor, DataType data_type,
CASE(DT_INT64, int64_t)
#undef CASE
case DT_BOOL:
GE_CHK_STATUS_RET(CopyScalarValueToHost(tensor, value));
GE_CHK_STATUS_RET(tensor.CopyScalarValueToHost(value));
break;
default:
GELOGE(UNSUPPORTED, "Data type %s is not support by cond.", TypeUtils::DataTypeToSerialString(data_type).c_str());
@@ -182,7 +170,7 @@ Status CaseOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::func
auto branch_tensor = task_context.GetInput(kCaseBranchIndex);
GE_CHECK_NOTNULL(branch_tensor);
int32_t branch_index = 0;
GE_CHK_STATUS_RET(CopyScalarValueToHost(*branch_tensor, branch_index));
GE_CHK_STATUS_RET(branch_tensor->CopyScalarValueToHost(branch_index));
const GraphItem *subgraph = SelectBranch(branch_index);
GELOGI("[%s] Taking subgraph [%s] by branch = [%d]",
task_context.GetNodeName(),


+ 1
- 1
ge/hybrid/node_executor/node_executor.cc View File

@@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node
return ExecutorType::GE_LOCAL;
}

if (IsControlOp(op_type)) {
if (IsControlFlowV2Op(op_type)) {
return ExecutorType::CONTROL_OP;
}



+ 2
- 0
ge/hybrid/node_executor/node_executor.h View File

@@ -27,6 +27,8 @@ const uint32_t MEMORY_ALIGN_RATIO = 2;
const uint32_t MEMORY_ALIGN_SIZE = 32;
namespace hybrid {
class HybridModel;
using NodeTaskPtr = std::shared_ptr<NodeTask>;

// Base class of Node Task
class NodeTask {
public:


+ 24
- 24
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -14,7 +14,9 @@
* limitations under the License.
*/

#include "rts_node_executor.h"
#include "hybrid/node_executor/rts/rts_node_executor.h"
#include "hybrid/node_executor/rts/rts_task_factory.h"

#include "common/debug/log.h"
#include "common/ge/ge_util.h"
#include "common/types.h"
@@ -26,6 +28,11 @@ namespace ge {
namespace hybrid {
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::RTS, RtsNodeExecutor);

REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask);
REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask);
REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask);
REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask);

Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) {
auto input_desc = context.MutableInputDesc(index);
GE_CHECK_NOTNULL(input_desc);
@@ -77,10 +84,6 @@ Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()
return SUCCESS;
}

Status IdentityNodeTask::UpdateArgs(TaskContext &context) {
return SUCCESS;
}

Status IdentityNNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", context.GetNodeName());
for (int i = 0; i < context.NumInputs(); ++i) {
@@ -95,7 +98,15 @@ Status IdentityNNodeTask::ExecuteAsync(TaskContext &context, std::function<void(
return SUCCESS;
}

Status ProfilingTraceNodeTask::UpdateArgs(TaskContext &context) {
Status ProfilingTraceNodeTask::Init(const HybridModel &model, const NodePtr &node) {
auto *task_defs = model.GetTaskDefs(node);
if (task_defs == nullptr || task_defs->empty()) {
GELOGE(INTERNAL_ERROR, "Profiling node has no task to execute.");
return INTERNAL_ERROR;
}

task_defs_ = *task_defs;
GELOGD("[%s] Done initialization successfully.", node->GetName().c_str());
return SUCCESS;
}

@@ -116,32 +127,21 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function<
}

return SUCCESS;
};
}

Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const {
GE_CHECK_NOTNULL(node);
GELOGD("[%s] Load for local task.", node->GetName().c_str());
auto op_type = node->GetType();
if (op_type == IDENTITY) {
task = MakeShared<IdentityNodeTask>();
} else if (op_type == IDENTITYN) {
task = MakeShared<IdentityNNodeTask>();
} else if (op_type == READVARIABLEOP) {
task = MakeShared<ReadVariableOpNodeTask>();
} else if (op_type == PROFILINGTRAININGTRACE) {
auto *task_defs = model.GetTaskDefs(node);
if (task_defs == nullptr || task_defs->empty()) {
GELOGE(INTERNAL_ERROR, "Profiling node has no task to execute.");
return INTERNAL_ERROR;
}
task = MakeShared<ProfilingTraceNodeTask>(*task_defs);
} else {
task = RtsTaskFactory::GetInstance().Create(op_type);
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str());
return INTERNAL_ERROR;
}

GE_CHECK_NOTNULL(task);
return SUCCESS;
RtsNodeTask *rts_task = dynamic_cast<RtsNodeTask *>(task.get());
GE_CHECK_NOTNULL(rts_task);
return rts_task->Init(model, node);
}
} // namespace hybrid
} // namespace ge

+ 4
- 7
ge/hybrid/node_executor/rts/rts_node_executor.h View File

@@ -18,13 +18,12 @@
#define GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_EXECUTOR_H_

#include "hybrid/node_executor/node_executor.h"
#include "proto/task.pb.h"
#include "hybrid/node_executor/rts/rts_node_task.h"

namespace ge {
namespace hybrid {
class IdentityNodeTask : public NodeTask {
class IdentityNodeTask : public RtsNodeTask {
public:
Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;

protected:
@@ -41,12 +40,10 @@ class ReadVariableOpNodeTask : public IdentityNodeTask {
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
};

class ProfilingTraceNodeTask : public NodeTask {
class ProfilingTraceNodeTask : public RtsNodeTask {
public:
explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {}
~ProfilingTraceNodeTask() override = default;
Status Init(const HybridModel &model, const NodePtr &node) override;

Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;

private:


+ 240
- 0
ge/hybrid/node_executor/rts/rts_node_task.cc View File

@@ -0,0 +1,240 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "hybrid/node_executor/rts/rts_node_task.h"
#include "hybrid/node_executor/rts/rts_task_factory.h"

#include "graph/debug/ge_attr_define.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "common/ge/ge_util.h"
#include "common/op/ge_op_utils.h"

namespace {
constexpr uint8_t kSwitchPredIndex = 0;
constexpr uint8_t kSwitchCompIndex = 1;

const static std::map<rtCondition_t, std::function<bool(int64_t, int64_t)>> kCompHandle = {
{RT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value == comp_value; }},
{RT_NOT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value != comp_value; }},
{RT_GREATER, [](int64_t pred_value, int64_t comp_value) { return pred_value > comp_value; }},
{RT_GREATER_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value >= comp_value; }},
{RT_LESS, [](int64_t pred_value, int64_t comp_value) { return pred_value < comp_value; }},
{RT_LESS_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value <= comp_value; }},
};
}

namespace ge {
namespace hybrid {
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask);
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask);
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask);
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask);

REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(LOOPCOND, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(NEXTITERATION, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(REFNEXTITERATION, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(EXIT, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(REFEXIT, PassThroughNodeTask);

REGISTER_RTS_TASK_CREATOR(LABELSET, LabelSetNodeTask);
REGISTER_RTS_TASK_CREATOR(LABELGOTO, LabelGotoNodeTask);
REGISTER_RTS_TASK_CREATOR(LABELGOTOEX, LabelGotoNodeTask);
REGISTER_RTS_TASK_CREATOR(LABELSWITCH, LabelSwitchNodeTask);
REGISTER_RTS_TASK_CREATOR(LABELSWITCHBYINDEX, LabelSwitchNodeTask);

Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value) {
auto tensor_value = task_context.GetInput(index);
GE_CHECK_NOTNULL(tensor_value);
auto tensor_desc = task_context.MutableInputDesc(index);
GE_CHECK_NOTNULL(tensor_desc);

auto data_type = tensor_desc->GetDataType();
switch (data_type) {
#define CASE_TYPE(DT, VT) \
case (DT): { \
VT data_val{}; \
GE_CHK_STATUS_RET(tensor_value->CopyScalarValueToHost(data_val)); \
value = static_cast<int64_t>(data_val); \
break; \
}
// Just accept index data type.
CASE_TYPE(DT_INT32, int32_t)
CASE_TYPE(DT_INT64, int64_t)
#undef CASE_TYPE
default: {
GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str());
return UNSUPPORTED;
}
}

return SUCCESS;
}

Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
const auto &node_state = task_context.GetNodeState();
node_state->SetSwitchIndex(0);
if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGI("[%s] Done executing successfully.", task_context.GetNodeName());
return SUCCESS;
}

Status StreamSwitchNodeTask::Init(const HybridModel &model, const NodePtr &node) {
uint32_t value = 0;
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, value)) {
GELOGE(INTERNAL_ERROR, "[%s] Get %s failed.", node->GetName().c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str());
return INTERNAL_ERROR;
}
rtCondition_t cond = static_cast<rtCondition_t>(value);
const auto it = kCompHandle.find(cond);
if (it == kCompHandle.end()) {
GELOGE(INTERNAL_ERROR, "[%s] Get Condition: %u handle failed.", node->GetName().c_str(), value);
return INTERNAL_ERROR;
}

comp_func_ = it->second;
GELOGD("[%s] Done initialization successfully, condition is %u.", node->GetName().c_str(), value);
return SUCCESS;
}

Status StreamSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
GE_CHECK_NOTNULL(comp_func_);

int64_t pred_value = 0;
GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchPredIndex, pred_value));
int64_t comp_value = 0;
GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchCompIndex, comp_value));

bool switch_idx = comp_func_(pred_value, comp_value);
auto node_state = task_context.GetNodeState();
node_state->SetSwitchIndex(static_cast<int>(switch_idx));

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGI("[%s] Done executing successfully, pred value: %ld, comp value: %ld, switch index: %d.",
task_context.GetNodeName(), pred_value, comp_value, static_cast<int>(switch_idx));
return SUCCESS;
}

Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
int index = task_context.GetNodeState()->GetMergeIndex();
GELOGD("[%s] Start to execute, merge index: %d.", task_context.GetNodeName(), index);
if (index < 0 || index >= task_context.NumInputs()) {
GELOGE(INTERNAL_ERROR, "[%s] Invalid merge param, inputs num: %d, merge index: %d.",
task_context.GetNodeName(), task_context.NumInputs(), index);
return INTERNAL_ERROR;
}

const auto in_x = task_context.MutableInput(index); // x
GE_CHECK_NOTNULL(in_x);
task_context.SetOutput(MERGE_DATA_OUTPUT, *in_x); // y

const auto out_y = task_context.MutableOutput(MERGE_INDEX_OUTPUT); // value_index
GE_CHECK_NOTNULL(out_y);
if (out_y->GetSize() > 0) {
GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), &index, sizeof(index),
RT_MEMCPY_HOST_TO_DEVICE_EX, task_context.GetStream()));
}

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

task_context.GetNodeState()->SetMergeIndex(-1); // Invalidate for loop.
GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return SUCCESS;
}

Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
const auto in_x = task_context.GetInput(0); // x
GE_CHECK_NOTNULL(in_x);
const auto out_y = task_context.MutableOutput(0); // value_index
GE_CHECK_NOTNULL(out_y);

GELOGD("[%s] input size: %zu, output size: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize());
if (in_x->GetSize() > 0 && out_y->GetSize() > 0) {
GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), in_x->GetData(), in_x->GetSize(),
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream()));
} else {
GELOGW("[%s] invalid copy size, src: %zu, dst: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize());
}

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return SUCCESS;
}

Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
const auto in_x = task_context.GetInput(0); // x
GE_CHECK_NOTNULL(in_x);
task_context.SetOutput(0, *in_x); // y

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return SUCCESS;
}

Status LabelSetNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return UNSUPPORTED;
}

Status LabelGotoNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return UNSUPPORTED;
}

Status LabelSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return UNSUPPORTED;
}
} // namespace hybrid
} // namespace ge

+ 89
- 0
ge/hybrid/node_executor/rts/rts_node_task.h View File

@@ -0,0 +1,89 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_TASK_H_
#define GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_TASK_H_

#include "hybrid/node_executor/node_executor.h"
#include "proto/task.pb.h"

namespace ge {
namespace hybrid {
class RtsNodeTask : public NodeTask {
public:
Status Init(TaskContext &task_context) override {
return SUCCESS;
}

virtual Status Init(const HybridModel &model, const NodePtr &node) {
GELOGD("[%s] Done initialization successfully.", node->GetName().c_str());
return SUCCESS;
}

Status UpdateArgs(TaskContext &task_context) override {
GELOGD("[%s] Done update args successfully.", task_context.GetNodeName());
return SUCCESS;
}

static Status GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value);
};

class StreamActiveNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class StreamSwitchNodeTask : public RtsNodeTask {
public:
Status Init(const HybridModel &model, const NodePtr &node) override;
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;

private:
std::function<bool(int64_t, int64_t)> comp_func_{nullptr};
};

class StreamMergeNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class MemcpyAsyncNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class PassThroughNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class LabelSetNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class LabelGotoNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class LabelSwitchNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};
} // namespace hybrid
} // namespace ge
#endif // GE_HYBRID_NODE_EXECUTOR_RTS_RTS_NODE_TASK_H_

+ 46
- 0
ge/hybrid/node_executor/rts/rts_task_factory.cc View File

@@ -0,0 +1,46 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "hybrid/node_executor/rts/rts_task_factory.h"

namespace ge {
namespace hybrid {
NodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const {
auto it = creators_.find(task_type);
if (it == creators_.end()) {
GELOGW("Cannot find task type %s in inner map.", task_type.c_str());
return nullptr;
}

return it->second();
}

void RtsTaskFactory::RegisterCreator(const std::string &task_type, const RtsTaskCreatorFun &creator) {
if (creator == nullptr) {
GELOGW("Register %s creator is null", task_type.c_str());
return;
}

auto it = creators_.find(task_type);
if (it != creators_.end()) {
GELOGW("Task %s creator already exist", task_type.c_str());
return;
}

creators_[task_type] = creator;
}
} // namespace hybrid
} // namespace ge

+ 65
- 0
ge/hybrid/node_executor/rts/rts_task_factory.h View File

@@ -0,0 +1,65 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_
#define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_

#include "hybrid/node_executor/node_executor.h"

namespace ge {
namespace hybrid {
using RtsTaskCreatorFun = std::function<NodeTaskPtr()>;

class RtsTaskFactory {
public:
static RtsTaskFactory &GetInstance() {
static RtsTaskFactory instance;
return instance;
}

NodeTaskPtr Create(const std::string &task_type) const;

class RtsTaskRegistrar {
public:
RtsTaskRegistrar(const std::string &task_type, const RtsTaskCreatorFun &creator) {
RtsTaskFactory::GetInstance().RegisterCreator(task_type, creator);
}
~RtsTaskRegistrar() = default;
};

private:
RtsTaskFactory() = default;
~RtsTaskFactory() = default;

/**
* Register build of executor
* @param executor_type type of executor
* @param builder build function
*/
void RegisterCreator(const std::string &task_type, const RtsTaskCreatorFun &creator);

std::map<std::string, RtsTaskCreatorFun> creators_;
};
} // namespace hybrid
} // namespace ge

#define REGISTER_RTS_TASK_CREATOR(task_type, task_clazz) \
REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz)

#define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> NodeTaskPtr { return MakeShared<clazz>(); })

#endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_

+ 3
- 2
ge/hybrid/node_executor/task_context.cc View File

@@ -418,13 +418,14 @@ Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr
return MEMALLOC_FAILED;
}

GELOGD("Allocating workspace of size = %zu successfully", size);
GELOGD("[%s] Allocating workspace of size = %zu successfully", node_item_->NodeName().c_str(), size);
workspaces_.emplace_back(*buffer);
return SUCCESS;
}

Status TaskContext::PropagateOutputs() {
// propagate outputs
const auto &guard = node_item_->MutexGuard("PropagateOutputs");
for (int i = 0; i < NumOutputs(); ++i) {
auto tensor = MutableOutput(i);
GE_CHECK_NOTNULL(tensor);
@@ -561,7 +562,7 @@ const DumpProperties &TaskContext::GetDumpProperties() const {
}

bool TaskContext::NeedCallback() {
return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0 ||
return node_item_->has_observer || IsDumpEnabled() || GraphExecutionContext::profiling_level > 0 ||
!execution_context_->model->IsSingleOp();
}



+ 5
- 1
inc/framework/common/op/ge_op_utils.h View File

@@ -54,6 +54,10 @@ GE_FUNC_VISIBILITY extern const uint32_t SWITCH_TRUE_OUTPUT;
GE_FUNC_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT;
GE_FUNC_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT;

// Merge
GE_FUNC_VISIBILITY extern const uint32_t MERGE_DATA_OUTPUT;
GE_FUNC_VISIBILITY extern const uint32_t MERGE_INDEX_OUTPUT;

// FunctionOp
GE_FUNC_VISIBILITY extern const uint32_t IF_COND_INPUT;
GE_FUNC_VISIBILITY extern const uint32_t FOR_START_INPUT;
@@ -129,7 +133,7 @@ class GE_FUNC_VISIBILITY OpUtils {
/// @param [out] output Data pointer after conversion. The format is HWCK
///
static void TransDataKCHW2HWCK(const void *input, int64_t K, int64_t C, int64_t H, int64_t W, void *output);
static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node);
static vector<ConstGeTensorPtr> GetWeights(ge::ConstNodePtr node);
static vector<GeTensorPtr> MutableWeights(const ge::Node &node);


+ 9
- 1
tests/depends/error_manager/src/error_manager_stub.cc View File

@@ -48,6 +48,14 @@ int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) {
return 0;
}

std::string ErrorManager::GetErrorMessage() {
return std::string();
}

std::string ErrorManager::GetWarningMessage() {
return std::string();
}

int ErrorManager::ReportInterErrMessage(std::string error_code, const std::string &error_msg) {
return 0;
}
@@ -99,7 +107,7 @@ int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) {
const std::string &ErrorManager::GetLogHeader() { return error_context_.log_header; }

struct error_message::Context &ErrorManager::GetErrorManagerContext() {
struct error_message::Context error_context;
static struct error_message::Context error_context;
return error_context;
}



+ 7
- 1
tests/depends/runtime/CMakeLists.txt View File

@@ -15,7 +15,7 @@

#cmake_minimum_required(VERSION 2.8)

project(STUB_MMPA)
project(runtime_stub)

file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"src/runtime_stub.cc"
@@ -26,7 +26,13 @@ include_directories(${GE_CODE_DIR}/inc/framework)

add_library(runtime_stub SHARED ${SRCS})

target_compile_options(runtime_stub PRIVATE
-g
)

target_link_libraries(runtime_stub PRIVATE
$<BUILD_INTERFACE:intf_pub>
c_sec
)

target_include_directories(runtime_stub INTERFACE ${CMAKE_CURRENT_LIST_DIR}/src)

+ 26
- 15
tests/depends/runtime/src/runtime_stub.cc View File

@@ -17,6 +17,9 @@
#include <cce/dnn.h>
#include <securec.h>

#ifdef __cplusplus
extern "C" {
#endif
#define EVENT_LENTH 10

rtError_t rtCtxSetCurrent(rtContext_t ctx) { return RT_ERROR_NONE; }
@@ -96,15 +99,16 @@ rtError_t rtSetDevice(int32_t device) { return RT_ERROR_NONE; }
rtError_t rtStreamSynchronize(rtStream_t stream) { return RT_ERROR_NONE; }

rtError_t rtMemcpy(void *dst, uint64_t dest_max, const void *src, uint64_t count, rtMemcpyKind_t kind) {
#ifdef OTQT_UT
if (dest_max == 12 && count == 12) { // UTEST_kernelinfo_manager.all_success special treatment
if (dst != nullptr && src != nullptr) {
memcpy_s(dst, dest_max, src, count);
}
#endif
return RT_ERROR_NONE;
}
rtError_t rtMemcpyAsync(void *dst, uint64_t dest_max, const void *src, uint64_t count, rtMemcpyKind_t kind,
rtStream_t stream) {
if (dst != nullptr && src != nullptr) {
memcpy_s(dst, dest_max, src, count);
}
return RT_ERROR_NONE;
}

@@ -125,9 +129,6 @@ rtError_t rtEventElapsedTime(float *time, rtEvent_t start, rtEvent_t end) {
*time = 10.0f;
return RT_ERROR_NONE;
}
rtError_t rtFunctionRegister(void *bin_handle, const void *stub_func, const char *stub_name, const void *dev_func) {
return RT_ERROR_NONE;
}

rtError_t rtFunctionRegister(void *bin_handle, const void *stub_func, const char *stub_name, const void *dev_func,
uint32_t func_mode) {
@@ -156,7 +157,7 @@ rtError_t rtConfigureCall(uint32_t num_blocks, rtSmDesc_t *sm_desc, rtStream_t s

rtError_t rtSetProfDir(char *prof_dir) { return RT_ERROR_NONE; }

rtError_t rtSetProfDirEx(char *prof_dir, char *address, char *job_ctx) { return RT_ERROR_NONE; }
rtError_t rtSetProfDirEx(const char *profDir, const char *address, const char *jobCtx) { return RT_ERROR_NONE; }

rtError_t rtAiCoreMemorySizes(rtAiCoreMemorySize_t *aicore_memory_size) { return RT_ERROR_NONE; }

@@ -218,9 +219,8 @@ rtError_t rtGetFunctionByName(const char *stub_name, void **stub_func) {
*(char **)stub_func = "func";
return RT_ERROR_NONE;
}
rtError_t rtGetAddrByFun(const void *stubFunc, void **addr)
{
*(char**)addr = "dev_func";
rtError_t rtGetAddrByFun(const void *stubFunc, void **addr) {
*(char **)addr = "dev_func";
return RT_ERROR_NONE;
}
rtError_t rtQueryFunctionRegistered(const char *stub_name) { return RT_ERROR_NONE; }
@@ -244,7 +244,9 @@ rtError_t rtEndGraphEx(rtModel_t model, rtStream_t stream, uint32_t flags)
{
return RT_ERROR_NONE;
}
rtError_t rtProfilerStop(void) { return RT_ERROR_NONE; }
rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) {
return RT_ERROR_NONE;
}

rtError_t rtSetDvfsProfile(DvfsProfileMode mode) { return RT_ERROR_NONE; }

@@ -256,7 +258,9 @@ rtError_t rtCtxDestroy(rtContext_t ctx) { return RT_ERROR_NONE; }

rtError_t rtProfilerInit(const char *prof_dir, const char *address, const char *job_ctx) { return RT_ERROR_NONE; }

rtError_t rtProfilerStart(void) { return RT_ERROR_NONE; }
rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t *deviceList) {
return RT_ERROR_NONE;
}

rtError_t rtLabelCreate(rtLabel_t *label) {
*label = new uint64_t;
@@ -305,7 +309,9 @@ rtError_t rtLabelGotoEx(rtLabel_t label, rtStream_t stream) {
}


rtError_t rtInvalidCache(uint64_t base, uint32_t len) { return RT_ERROR_NONE; }
rtError_t rtInvalidCache(void *base, size_t len) {
return RT_ERROR_NONE;
}

rtError_t rtModelLoadComplete(rtModel_t model) { return RT_ERROR_NONE; }

@@ -314,7 +320,9 @@ rtError_t rtStreamCreateWithFlags(rtStream_t *stream, int32_t priority, uint32_t
return RT_ERROR_NONE;
}

rtError_t rtFlushCache(uint64_t base, uint32_t len) { return RT_ERROR_NONE; }
rtError_t rtFlushCache(void *base, size_t len) {
return RT_ERROR_NONE;
}

rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream_) { return RT_ERROR_NONE; }

@@ -445,4 +453,7 @@ rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void

rtError_t rtDebugUnRegisterForStream(rtStream_t stream) {
return RT_ERROR_NONE;
}
}
#ifdef __cplusplus
}
#endif

+ 20
- 0
tests/depends/slog/src/slog_stub.cc View File

@@ -15,6 +15,7 @@
*/

#include "toolchain/slog.h"
#include "toolchain/plog.h"

#include <stdarg.h>
#include <stdio.h>
@@ -46,3 +47,22 @@ int CheckLogLevel(int moduleId, int logLevel)
{
return 1;
}

/**
* @ingroup plog
* @brief DlogReportInitialize: init log in service process before all device setting.
* @return: 0: SUCCEED, others: FAILED
*/
int DlogReportInitialize() {
return 0;
}

/**
* @ingroup plog
* @brief DlogReportFinalize: release log resource in service process after all device reset.
* @return: 0: SUCCEED, others: FAILED
*/
int DlogReportFinalize() {
return 0;
}


+ 14
- 10
tests/ut/ge/CMakeLists.txt View File

@@ -166,7 +166,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/common/dump/dump_properties.cc"
"${GE_CODE_DIR}/ge/common/helper/model_helper.cc"
"${GE_CODE_DIR}/ge/common/dump/dump_manager.cc"
"${GE_CODE_DIR}/ge/common/dump/exception_dumper.cc"
"${GE_CODE_DIR}/ge/common/dump/exception_dumper.cc"
"${GE_CODE_DIR}/ge/common/dump/opdebug_register.cc"
"${GE_CODE_DIR}/ge/common/dump/dump_op.cc"
"${GE_CODE_DIR}/ge/common/helper/om_file_helper.cc"
@@ -512,8 +512,8 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc"
@@ -621,6 +621,8 @@ set(SINGLE_OP_SRC_FILES
"${GE_CODE_DIR}/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/node_executor/hccl/hccl_node_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/node_executor/rts/rts_node_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/node_executor/rts/rts_node_task.cc"
"${GE_CODE_DIR}/ge/hybrid/node_executor/rts/rts_task_factory.cc"
"${GE_CODE_DIR}/ge/hybrid/node_executor/node_executor.cc"
"${GE_CODE_DIR}/ge/hybrid/node_executor/task_context.cc"
"${GE_CODE_DIR}/ge/hybrid/hybrid_davinci_model.cc"
@@ -707,8 +709,8 @@ set(PASS_TEST_FILES
"graph/passes/transpose_transdata_pass_unittest.cc"
"graph/passes/parallel_group_pass_unittest.cc"
"graph/passes/buffer_pool_memory_pass_unittest.cc"
"graph/passes/mark_node_unknown_shape_pass_unittest.cc"
"graph/passes/reshape_recovery_pass_unittest.cc"
"graph/passes/mark_node_unknown_shape_pass_unittest.cc"
"graph/passes/reshape_recovery_pass_unittest.cc"
"graph/passes/cast_remove_pass_unittest.cc"
)

@@ -751,12 +753,12 @@ set(KERNEL_TEST_FILES

set(MULTI_PARTS_TEST_FILES
"graph_ir/ge_operator_factory_unittest.cc"
"graph_ir/ge_ir_build_unittest.cc"
"graph_ir/ge_ir_build_unittest.cc"
"graph/transop_util_unittest.cc"
"common/datatype_transfer_unittest.cc"
"common/dump_manager_unittest.cc"
"common/dump_op_unittest.cc"
"common/dump_exception_unittest.cc"
"common/dump_exception_unittest.cc"
"common/opdebug_register_unittest.cc"
"common/format_transfer_unittest.cc"
"common/format_transfer_transpose_unittest.cc"
@@ -775,7 +777,7 @@ set(MULTI_PARTS_TEST_FILES
"common/format_transfer_fracz_nhwc_unittest.cc"
"common/format_transfer_fracz_hwcn_unittest.cc"
"common/ge_format_util_unittest.cc"
"common/ge_auth_file_saver_unittest.cc"
"common/ge_auth_file_saver_unittest.cc"
"graph/variable_accelerate_ctrl_unittest.cc"
"graph/build/logical_stream_allocator_unittest.cc"
"graph/build/model_builder_unittest.cc"
@@ -804,7 +806,7 @@ set(SINGLE_OP_TEST_FILES
"single_op/single_op_manager_unittest.cc"
"single_op/stream_resource_unittest.cc"
"single_op/single_op_task_unittest.cc"
"single_op/single_op_unittest.cc"
"single_op/single_op_unittest.cc"
)

set(PROFILING_MNG_TEST_FILES
@@ -814,7 +816,9 @@ set(PROFILING_MNG_TEST_FILES
set(HYBRID_TEST_FILES
"hybrid/ge_hybrid_unittest.cc"
"hybrid/known_node_executor_unittest.cc"
"hybrid/executor/worker/execution_engine_unittest.cc"
"hybrid/executor/worker/execution_engine_unittest.cc"
"hybrid/model/hybrid_model_builder_unittest.cc"
"hybrid/node_executor/rts/rts_node_task_unittest.cc"
)

set(OTHERS_TEST_FILES


+ 4
- 2
tests/ut/ge/graph/load/davinci_model_unittest.cc View File

@@ -333,8 +333,8 @@ TEST_F(UtestDavinciModel, init_unknown) {
TEST_F(UtestDavinciModel, Init_variable_op) {
DavinciModel model(0, g_local_call_back);
model.ge_model_ = make_shared<GeModel>();
model.runtime_param_.mem_base = (uint8_t *)0x08000000;
model.runtime_param_.mem_size = 5120000;
model.runtime_param_.mem_size = 51200;
model.runtime_param_.mem_base = (uint8_t *)malloc(model.runtime_param_.mem_size);
ComputeGraphPtr graph = make_shared<ComputeGraph>("default");

GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
@@ -365,6 +365,8 @@ TEST_F(UtestDavinciModel, Init_variable_op) {
EXPECT_EQ(model.CopyOutputData(1, output_data, RT_MEMCPY_DEVICE_TO_HOST), SUCCESS);

EXPECT_EQ(model.ReturnResult(1, false, true, &output_data), INTERNAL_ERROR);
free(model.runtime_param_.mem_base);
model.runtime_param_.mem_base = nullptr;
}

TEST_F(UtestDavinciModel, InitRealSizeAndShapeInfo_succ1) {


+ 99
- 3
tests/ut/ge/graph/passes/infershape_pass_unittest.cc View File

@@ -20,9 +20,8 @@
#define private public
#include "graph/passes/infershape_pass.h"

#include "graph/compute_graph.h"
#include "graph/node.h"
#include "graph/operator.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/operator_factory.h"
#include "graph/operator_reg.h"
#include "graph_builder_utils.h"
@@ -36,6 +35,40 @@ class UtestGraphInfershapePass : public testing::Test {
void TearDown() {}
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
TensorUtils::SetSize(tensor, 512);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(1024);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(1024);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
op_desc->AddInferFunc(stub_func);
op_desc->AddInferFormatFunc(stub_func);
op_desc->AddVerifierFunc(stub_func);

return graph.AddNode(op_desc);
}

TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
string type = "AddN";
@@ -62,4 +95,67 @@ TEST_F(UtestGraphInfershapePass, delete_need_infer_again) {
EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS);
}

TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) {
/*******************************************************************************
* Exit Identify
* \ / \.
* \ / \.
* Switch Add
* / | |
* / | |
* / | |
* LoopCond | |
* \ | |
* \ | |
* \ | |
* Less | |
* \ | NextIteration
* \ | |
* \ | |
* Merge <---------|
* |
* |
* Enter
******************************************************************************/
auto graph = std::make_shared<ComputeGraph>("test_infer_shape");
auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);

GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));

GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));

GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));
GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0));

GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));

GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));

GEPass ge_passes(graph);
NamesToPass names_to_passes;
InferShapePass infer_shape_pass;
names_to_passes.emplace_back("InferShapePass", &infer_shape_pass);

EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS);
}
} // namespace ge

+ 17
- 17
tests/ut/ge/graph/utils/buffer_pool_graph_builder.cc View File

@@ -114,9 +114,9 @@ void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id,
/// Normal graph
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
@@ -188,10 +188,10 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() {
/// Normal graph with multi buffer pool
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// (pool0) (pool1) (pool0) (pool0) (pool1)
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
@@ -265,9 +265,9 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() {
/// SerialGraph: Buffer pool size only can contain one prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
@@ -345,7 +345,7 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() {
/// GraphWithMultiPrefetch: Calc node with more prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1
/// \ / \ / \ /
/// \ / \ / \ /
@@ -426,9 +426,9 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() {
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out
///
///
@@ -540,9 +540,9 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() {
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out
///
///
@@ -651,10 +651,10 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() {
/// batch_label_128
///
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
/// / / / / / / \
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \
/// const1 switch_false / / / / / \
/// \ / / / / / / \
/// / / / / / / \.
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \.
/// const1 switch_false / / / / / \.
/// \ / / / / / / \.
/// switch1 w1 w2 w3 w4 w5 merge1 -- net_output
/// / \ \ \ \ \ \ /
/// const2 switch_true \ \ \ \ \ /
@@ -809,7 +809,7 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() {
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /
@@ -892,7 +892,7 @@ ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() {
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ / \ / \ / \ / \
/// \ / \ / \ / \ / \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /


+ 17
- 17
tests/ut/ge/graph/utils/buffer_pool_graph_builder.h View File

@@ -54,9 +54,9 @@ class BufferPoolGraphBuilder {
/// Normal graph
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
@@ -72,10 +72,10 @@ class BufferPoolGraphBuilder {
/// Normal graph with multi buffer pool
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// (pool0) (pool1) (pool0) (pool0) (pool1)
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
@@ -92,9 +92,9 @@ class BufferPoolGraphBuilder {
/// SerialGraph: Buffer pool size only can contain one prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
///
///
@@ -116,7 +116,7 @@ class BufferPoolGraphBuilder {
/// GraphWithMultiPrefetch: Calc node with more prefetch node
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1
/// \ / \ / \ /
/// \ / \ / \ /
@@ -144,9 +144,9 @@ class BufferPoolGraphBuilder {
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out
///
///
@@ -168,9 +168,9 @@ class BufferPoolGraphBuilder {
/// Subgraph1: Subgraph2:
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// \ \ \ \ \
/// \ \ \ \ \.
/// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out
///
///
@@ -189,10 +189,10 @@ class BufferPoolGraphBuilder {
/// batch_label_128
///
/// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
/// / / / / / / \
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \
/// const1 switch_false / / / / / \
/// \ / / / / / / \
/// / / / / / / \.
/// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \.
/// const1 switch_false / / / / / \.
/// \ / / / / / / \.
/// switch1 w1 w2 w3 w4 w5 merge1 -- net_output
/// / \ \ \ \ \ \ /
/// const2 switch_true \ \ \ \ \ /
@@ -215,7 +215,7 @@ class BufferPoolGraphBuilder {
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ \ \ \ \
/// \ \ \ \ \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /
@@ -238,7 +238,7 @@ class BufferPoolGraphBuilder {
/// GraphWithMultiOutputPrefetch: Prefetch has more than one output
///
/// w1 w2 w3 w4 w5
/// \ / \ / \ / \ / \
/// \ / \ / \ / \ / \.
/// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
/// / \ / \ / \ / \ /
/// / \ / \ / \ / \ /


+ 16
- 5
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -288,7 +288,7 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) {
HybridModel *model_ptr = &model;

uint32_t device_id = 0;
rtStream_t stream;
rtStream_t stream = nullptr;
HybridModelExecutor executor(model_ptr, device_id, stream);
executor.Init();
}
@@ -644,17 +644,28 @@ TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) {
std::unique_ptr<NodeItem> node_item_1;
NodeItem::Create(node_1, node_item_1);
node_item_1->node_id = 1;

node->GetOutControlAnchor()->LinkTo(node_1->GetInControlAnchor());

OpDescPtr op_desc_2 = CreateOpDesc("net_output", NETOUTPUT);
auto node_2 = compute_graph->AddNode(op_desc_2);
std::unique_ptr<NodeItem> node_item_2;
NodeItem::Create(node_2, node_item_2);
node_item_2->node_id = 2;
node_1->GetOutControlAnchor()->LinkTo(node_2->GetInControlAnchor());

GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);
model.root_graph_ = compute_graph;
model.node_items_.emplace(node, std::move(node_item));
model.node_items_.emplace(node_1, std::move(node_item_1));
model.node_items_.emplace(node_2, std::move(node_item_2));

HybridModelBuilder builder(model);
std::vector<std::string> deps;
ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS);
ASSERT_TRUE(model.GetNodeItem(node)->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
ASSERT_EQ(builder.ParseDependentInputNodes(*model.node_items_[node_1], deps), SUCCESS);
ASSERT_EQ(builder.ParseDependentInputNodes(*model.node_items_[node_2], deps), SUCCESS);
ASSERT_FALSE(model.GetNodeItem(node)->has_observer);
ASSERT_TRUE(model.GetNodeItem(node_1)->has_observer);
ASSERT_EQ(model.node_items_[node_1]->dependents_for_execution.size(), 0);
ASSERT_EQ(model.node_items_[node_2]->dependents_for_execution.size(), 1);
}

+ 233
- 0
tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc View File

@@ -0,0 +1,233 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include <vector>

#define private public
#define protected public
#include "hybrid/model/hybrid_model_builder.h"
#include "hybrid/node_executor/node_executor.h"

#include "graph/utils/tensor_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"

using namespace std;
using namespace testing;

namespace ge {
using namespace hybrid;

class UtestHybridModelBuilder : public testing::Test {
protected:
void SetUp() {}

void TearDown() { }
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
TensorUtils::SetSize(tensor, 512);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(1024);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(1024);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

return graph.AddNode(op_desc);
}

TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
/*******************************************************************************
* Exit Identify
* \ / \.
* \ / \.
* Switch Add
* / | |
* / | |
* / | |
* LoopCond | |
* \ | |
* \ | |
* \ | |
* Less | |
* \ | NextIteration
* \ | |
* \ | |
* Merge <---------|
* |
* |
* Enter
******************************************************************************/
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
GeModelPtr ge_sub_model = make_shared<GeModel>();
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);

auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2);
auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
less1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine");
auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
auto switch_t = CreateNode(*graph, "switch_t", STREAMSWITCH, 2, 0);
auto switch_f = CreateNode(*graph, "switch_f", STREAMSWITCH, 2, 0);
auto ident1 = CreateNode(*graph, "identity", IDENTITY, 2, 1);
auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
add1->GetOpDesc()->SetOpKernelLibName("AIcoreEngine");
auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
auto active1 = CreateNode(*graph, "active1", STREAMACTIVE, 0, 0);
auto active2 = CreateNode(*graph, "active2", STREAMACTIVE, 0, 0);
auto active3 = CreateNode(*graph, "active3", STREAMACTIVE, 0, 0);
auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);

GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));

GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), switch_t->GetInDataAnchor(1));
GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch_f->GetInDataAnchor(0));
GraphUtils::AddEdge(value0->GetOutDataAnchor(0), switch_f->GetInDataAnchor(1));

GraphUtils::AddEdge(switch_f->GetOutControlAnchor(), exit1->GetInControlAnchor());
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));

GraphUtils::AddEdge(switch_t->GetOutControlAnchor(), ident1->GetInControlAnchor());
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), ident1->GetInDataAnchor(0));

GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));

GraphUtils::AddEdge(enter1->GetOutControlAnchor(), active1->GetInControlAnchor());
GraphUtils::AddEdge(active1->GetOutControlAnchor(), merge1->GetInControlAnchor());

GraphUtils::AddEdge(loop1->GetOutControlAnchor(), active2->GetInControlAnchor());
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_f->GetInControlAnchor());
GraphUtils::AddEdge(active2->GetOutControlAnchor(), switch_t->GetInControlAnchor());

GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor());

GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));
AttrUtils::SetStr(merge1->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next1->GetName());

AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true);
AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true);
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true);
AttrUtils::SetBool(add1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true);

// Build -> IndexSpecialNodes --> stream_merge_op_nodes_
// Build -> LoadGraph -> RelinkNextIteration
// Build -> LoadGraph -> LoadDynamicSubgraph --> BuildNodeItem --> NodeItem::SetDataSend
// Build -> LoadGraph -> LoadDynamicSubgraph --> BuildControlFlowGroup --> NodeItem::SetCtrlSend
auto &engine_mapping = NodeExecutorManager::GetInstance().engine_mapping_;
engine_mapping.emplace("AIcoreEngine", NodeExecutorManager::ExecutorType::AICORE);
engine_mapping.emplace("DNN_VM_GE_LOCAL_OP_STORE", NodeExecutorManager::ExecutorType::GE_LOCAL);
engine_mapping.emplace("aicpu_tf_kernel", NodeExecutorManager::ExecutorType::AICPU_TF);
engine_mapping.emplace("aicpu_ascend_kernel", NodeExecutorManager::ExecutorType::AICPU_TF);
engine_mapping.emplace("ops_kernel_info_hccl", NodeExecutorManager::ExecutorType::HCCL);
engine_mapping.emplace("DNN_VM_RTS_OP_STORE", NodeExecutorManager::ExecutorType::RTS);
engine_mapping.emplace("DNN_VM_HOST_CPU_OP_STORE", NodeExecutorManager::ExecutorType::HOST_CPU);

auto &task_executor = NodeExecutorManager::GetInstance().executors_;
task_executor.emplace(NodeExecutorManager::ExecutorType::AICORE, std::unique_ptr<NodeExecutor>(new NodeExecutor()));
task_executor.emplace(NodeExecutorManager::ExecutorType::GE_LOCAL, std::unique_ptr<NodeExecutor>(new NodeExecutor()));
task_executor.emplace(NodeExecutorManager::ExecutorType::AICPU_TF, std::unique_ptr<NodeExecutor>(new NodeExecutor()));
task_executor.emplace(NodeExecutorManager::ExecutorType::HCCL, std::unique_ptr<NodeExecutor>(new NodeExecutor()));
task_executor.emplace(NodeExecutorManager::ExecutorType::RTS, std::unique_ptr<NodeExecutor>(new NodeExecutor()));
task_executor.emplace(NodeExecutorManager::ExecutorType::HOST_CPU, std::unique_ptr<NodeExecutor>(new NodeExecutor()));

HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.Build(), SUCCESS);
engine_mapping.clear();
task_executor.clear();
}

TEST_F(UtestHybridModelBuilder, create_called_invalid) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);

auto node = CreateNode(*graph, "node", PARTITIONEDCALL, 1, 1);
NodeItem node_item(node);

ASSERT_EQ(hybrid_model_builder.CreateStreamActiveGroup(node, &node_item), INTERNAL_ERROR);
ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchGroup(node, &node_item), INTERNAL_ERROR);
ASSERT_EQ(hybrid_model_builder.CreateNextIterationGroup(node, &node_item), INTERNAL_ERROR);
ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(node, &node_item), INTERNAL_ERROR);
ASSERT_EQ(hybrid_model_builder.CreateSwitchGroup(node, &node_item), INTERNAL_ERROR);

ASSERT_EQ(hybrid_model_builder.CreateLabelSetGroup(node, &node_item), INTERNAL_ERROR);
node_item.node_type = LABELSET;
ASSERT_EQ(hybrid_model_builder.CreateLabelSetGroup(node, &node_item), UNSUPPORTED);

ASSERT_EQ(hybrid_model_builder.CreateLabelGotoGroup(node, &node_item), INTERNAL_ERROR);
node_item.node_type = LABELGOTO;
ASSERT_EQ(hybrid_model_builder.CreateLabelGotoGroup(node, &node_item), UNSUPPORTED);

ASSERT_EQ(hybrid_model_builder.CreateLabelSwitchGroup(node, &node_item), INTERNAL_ERROR);
node_item.node_type = LABELSWITCH;
ASSERT_EQ(hybrid_model_builder.CreateLabelSwitchGroup(node, &node_item), UNSUPPORTED);
}

TEST_F(UtestHybridModelBuilder, stream_switch_n_group) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);

auto switch_n = CreateNode(*graph, "switch_n", STREAMSWITCHN, 1, 0);
NodeItem node_item(switch_n);

// no batch_num
ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(switch_n, &node_item), INTERNAL_ERROR);

uint32_t batch_num = 0;
AttrUtils::SetInt(switch_n->GetOpDesc(), ATTR_NAME_BATCH_NUM, batch_num);
ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(switch_n, &node_item), SUCCESS);

batch_num = 3;
AttrUtils::SetInt(switch_n->GetOpDesc(), ATTR_NAME_BATCH_NUM, batch_num);
ASSERT_EQ(hybrid_model_builder.CreateStreamSwitchNGroup(switch_n, &node_item), SUCCESS);
}
} // namespace ge

+ 484
- 0
tests/ut/ge/hybrid/node_executor/rts/rts_node_task_unittest.cc View File

@@ -0,0 +1,484 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>

#define private public
#define protected public
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/rts/rts_node_executor.h"
#include "model/ge_root_model.h"

using namespace std;
using namespace testing;

namespace ge {
using namespace hybrid;

class UtestRtsNodeTask : public testing::Test {
protected:
void SetUp() {}
void TearDown() { }
};

static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
op_desc->SetStreamId(0);
static int32_t index = 0;
op_desc->SetId(index++);

GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64);
TensorUtils::SetSize(tensor, 64);
vector<int64_t> input_offset;
for (int i = 0; i < in_num; i++) {
op_desc->AddInputDesc(tensor);
input_offset.emplace_back(i * 64);
}
op_desc->SetInputOffset(input_offset);

vector<int64_t> output_offset;
for (int i = 0; i < out_num; i++) {
op_desc->AddOutputDesc(tensor);
output_offset.emplace_back(in_num * 64 + i * 64);
}
op_desc->SetOutputOffset(output_offset);

op_desc->SetWorkspace({});
op_desc->SetWorkspaceBytes({});
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");

return graph.AddNode(op_desc);
}

TEST_F(UtestRtsNodeTask, test_stream_switch_task) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "switch", STREAMSWITCH, 2, 0);
ASSERT_TRUE(AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, 0));

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 2;
graph_item.total_outputs_ = 2;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
uint64_t value_1 = 120;
TensorValue in_tensor0(&value_0, sizeof(value_0));
TensorValue in_tensor1(&value_1, sizeof(value_1));
subgraph_context.SetInput(*node_item, 0, in_tensor0);
subgraph_context.SetInput(*node_item, 1, in_tensor1);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
ASSERT_EQ(node_state->GetSwitchIndex(), -1);
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
ASSERT_EQ(node_state->GetSwitchIndex(), 0); // not equal, active 0

uint64_t value_2 = 110;
TensorValue in_tensor2(&value_2, sizeof(value_2));
subgraph_context.SetInput(*node_item, 1, in_tensor2);
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
ASSERT_EQ(node_state->GetSwitchIndex(), 1); // equal, active 1
}

TEST_F(UtestRtsNodeTask, test_stream_active_task) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "active", STREAMACTIVE, 0, 0);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
ASSERT_EQ(node_state->GetSwitchIndex(), -1);
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
ASSERT_EQ(node_state->GetSwitchIndex(), 0);
}

TEST_F(UtestRtsNodeTask, test_stream_merge_task) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "merge", STREAMMERGE, 2, 2);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 2;
graph_item.total_outputs_ = 2;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);
uint64_t value_1 = 220;
TensorValue in_tensor1(&value_1, sizeof(value_1));
subgraph_context.SetInput(*node_item, 1, in_tensor1);

uint64_t value_2 = 123;
TensorValue out_tensor0(&value_2, sizeof(value_2));
subgraph_context.SetOutput(*node_item, 0, out_tensor0);
uint64_t value_3 = 223;
TensorValue out_tensor1(&value_3, sizeof(value_3));
subgraph_context.SetOutput(*node_item, 1, out_tensor1);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
node_state->SetMergeIndex(1);
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);
ASSERT_EQ(node_state->GetSwitchIndex(), -1);

uint64_t value_4 = 323;
ASSERT_EQ(node_state->GetTaskContext()->GetOutput(0)->CopyScalarValueToHost(value_4), SUCCESS);
ASSERT_EQ(value_4, value_1);

uint64_t value_5 = 423;
ASSERT_EQ(node_state->GetTaskContext()->GetOutput(1)->CopyScalarValueToHost(value_5), SUCCESS);
ASSERT_EQ(value_5, 1);
}

TEST_F(UtestRtsNodeTask, test_memcpy_async_task) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "memcpy", MEMCPYASYNC, 1, 1);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 1;
graph_item.total_outputs_ = 1;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);

uint64_t value_1 = 123;
TensorValue out_tensor0(&value_1, sizeof(value_1));
subgraph_context.SetOutput(*node_item, 0, out_tensor0);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);

uint64_t value_4 = 323;
ASSERT_EQ(node_state->GetTaskContext()->GetOutput(0)->CopyScalarValueToHost(value_4), SUCCESS);
ASSERT_EQ(value_4, value_0);
ASSERT_EQ(value_1, value_0);
}

TEST_F(UtestRtsNodeTask, test_pass_through_task) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "enter", ENTER, 1, 1);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 0;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 1;
graph_item.total_outputs_ = 1;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

uint64_t value_0 = 110;
TensorValue in_tensor0(&value_0, sizeof(value_0));
subgraph_context.SetInput(*node_item, 0, in_tensor0);

uint64_t value_1 = 123;
TensorValue out_tensor0(&value_1, sizeof(value_1));
subgraph_context.SetOutput(*node_item, 0, out_tensor0);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), SUCCESS);

uint64_t value_4 = 323;
ASSERT_EQ(node_state->GetTaskContext()->GetOutput(0)->CopyScalarValueToHost(value_4), SUCCESS);
ASSERT_EQ(value_4, value_0);
}

TEST_F(UtestRtsNodeTask, test_unsupport_label_set) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "labelset", LABELSET, 0, 0);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 2;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 2;
graph_item.total_outputs_ = 2;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), UNSUPPORTED);
}

TEST_F(UtestRtsNodeTask, test_unsupport_label_goto) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "labelgoto", LABELGOTO, 0, 0);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 2;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 2;
graph_item.total_outputs_ = 2;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), UNSUPPORTED);
}

TEST_F(UtestRtsNodeTask, test_unsupport_label_switch) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeModelPtr ge_sub_model = std::make_shared<GeModel>();
GeRootModelPtr ge_root_model = std::make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
ge_root_model->SetSubgraphInstanceNameToModel("sub", ge_sub_model);
HybridModel hybrid_model(ge_root_model);

NodePtr node = CreateNode(*graph, "labelswitch", LABELSWITCH, 0, 0);

std::unique_ptr<NodeItem> new_node;
ASSERT_EQ(NodeItem::Create(node, new_node), SUCCESS);
NodeItem *node_item = new_node.get();
hybrid_model.node_items_[node] = std::move(new_node);
node_item->input_start = 0;
node_item->output_start = 2;

GraphItem graph_item;
graph_item.node_items_.emplace_back(node_item);
graph_item.total_inputs_ = 2;
graph_item.total_outputs_ = 2;

GraphExecutionContext graph_context;
SubgraphContext subgraph_context(&graph_item, &graph_context);
ASSERT_EQ(subgraph_context.Init(), SUCCESS);
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());

auto node_state = subgraph_context.GetOrCreateNodeState(node_item);
ASSERT_NE(node_state, nullptr);

auto unique_task_context = TaskContext::Create(node_state.get(), &graph_context, &subgraph_context);
ASSERT_NE(unique_task_context, nullptr);
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);

NodeTaskPtr task = nullptr;
RtsNodeExecutor node_executor;
ASSERT_EQ(node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
ASSERT_NE(task, nullptr);

std::function<void()> done = []() {};
ASSERT_EQ(task->ExecuteAsync(*node_state->GetTaskContext(), done), UNSUPPORTED);
}
} // namespace ge

Loading…
Cancel
Save