Browse Source

!1912 FindLastBpFromBpNode

Merge pull request !1912 from 梁昊/lh2
tags/v1.5.1
i-robot Gitee 3 years ago
parent
commit
e3fb1634d3
3 changed files with 24 additions and 26 deletions
  1. +20
    -24
      ge/graph/build/task_generator.cc
  2. +1
    -1
      ge/graph/build/task_generator.h
  3. +3
    -1
      tests/ut/ge/graph/build/task_generator_unittest.cc

+ 20
- 24
ge/graph/build/task_generator.cc View File

@@ -793,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
GELOGI("Start AutoFindBpOpIndex");
NodePtr bp_node = nullptr;
uint32_t current_idx = 0;
uint32_t netoutput_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@@ -811,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) {
if (bp_node == nullptr) {
bp_node = node;
netoutput_idx = current_idx - 1;
}
}
if (graph->GetNeedIteration()) {
@@ -836,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (bp_node == nullptr) {
GELOGW("not find bp_node.");
return SUCCESS;
} else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) {
profiling_point.bp_index = netoutput_idx;
GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx);
} else {
profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node);
}

return SUCCESS;
return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index);
}

uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const {
uint32_t last_bp = 0;
Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node,
uint32_t &bp_index) const {
bp_index = 0;
auto target_desc = target_node->GetOpDesc();
GE_CHECK_NOTNULL(target_desc);
OpDescPtr bp_op_desc = nullptr;
for (auto &in_anchor : bp_node->GetAllInDataAnchors()) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
continue;
}
auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL(out_node_desc);
if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) {
bp_op_desc = out_node_desc;
for (auto &in_node : target_node->GetInAllNodes()) {
GE_CHECK_NOTNULL(in_node);
auto in_node_desc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(in_node_desc);
if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) &&
(in_node_desc->GetStreamId() == target_desc->GetStreamId())){
bp_op_desc = in_node_desc;
}
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId());
}

if (bp_op_desc == nullptr) {
return last_bp;
GELOGI("Did not find bp node.");
return SUCCESS;
}
uint32_t current_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
@@ -871,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const
GE_CHECK_NOTNULL(op_desc);
current_idx++;
if (op_desc->GetName() == bp_op_desc->GetName()) {
last_bp = current_idx;
GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp);
bp_index = current_idx;
GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index);
break;
}
}
return last_bp;
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(),
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId());
return SUCCESS;
}

Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,


+ 1
- 1
ge/graph/build/task_generator.h View File

@@ -116,7 +116,7 @@ class TaskGenerator {
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
vector<uint32_t> &all_reduce_nodes) const;
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const;
Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const;

Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
ProfilingPoint &profiling_point) const;


+ 3
- 1
tests/ut/ge/graph/build/task_generator_unittest.cc View File

@@ -116,7 +116,9 @@ TEST_F(UtestTaskGeneratorTest, FindLastBpFromBpNode) {
TaskGenerator task_generator(nullptr, 0);
auto net_output = graph->FindNode("Node_Output");
// netoutput has no data input, return default value 0
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output), 0);
uint32_t bp_index = 0;
EXPECT_EQ(task_generator.FindLastBpFromBpNode(graph, net_output, bp_index), 0);
EXPECT_EQ(bp_index, 2);
}

TEST_F(UtestTaskGeneratorTest, UpdateOpIsVarAttr) {


Loading…
Cancel
Save