Browse Source

while loop failed to restore origin input after execution

tags/v1.3.0
chuxing 3 years ago
parent
commit
4b4ed2e1c5
8 changed files with 130 additions and 42 deletions
  1. +6
    -4
      ge/hybrid/executor/node_state.cc
  2. +35
    -1
      ge/hybrid/model/node_item.cc
  3. +10
    -3
      ge/hybrid/model/node_item.h
  4. +23
    -33
      ge/hybrid/node_executor/controlop/control_op_executor.cc
  5. +0
    -1
      ge/hybrid/node_executor/controlop/control_op_executor.h
  6. +11
    -0
      ge/hybrid/node_executor/task_context.cc
  7. +3
    -0
      ge/hybrid/node_executor/task_context.h
  8. +42
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 6
- 4
ge/hybrid/executor/node_state.cc View File

@@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(
node_item.NodeName().c_str(),
this->num_pending_shapes_);

for (int i = 0; i < node_item.num_inputs; ++i){
input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i));
input_tensor_desc.resize(node_item.num_inputs);
for (int i = 0; i < node_item.num_inputs; ++i) {
node_item.GetInputDesc(i, input_tensor_desc[i]);
}

for (int i = 0; i < node_item.num_outputs; ++i){
output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i));
output_tensor_desc.resize(node_item.num_outputs);
for (int i = 0; i < node_item.num_outputs; ++i) {
node_item.GetOutputDesc(i, output_tensor_desc[i]);
}
}



+ 35
- 1
ge/hybrid/model/node_item.cc View File

@@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() {
}
}

GeTensorDescPtr NodeItem::MutableInputDesc(int index) const {
GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const {
if (!has_optional_inputs) {
return op_desc->MutableInputDesc(static_cast<uint32_t>(index));
}
@@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const {
return op_desc->MutableInputDesc(input_desc_indices_[index]);
}

GeTensorDescPtr NodeItem::MutableInputDesc(int index) const {
std::lock_guard<std::mutex> lk(mu_);
return DoGetInputDesc(index);
}

Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const {
std::lock_guard<std::mutex> lk(mu_);
auto input_desc = DoGetInputDesc(index);
GE_CHECK_NOTNULL(input_desc);
tensor_desc = *input_desc;
return SUCCESS;
}

Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const {
std::lock_guard<std::mutex> lk(mu_);
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
GE_CHECK_NOTNULL(output_desc);
tensor_desc = *output_desc;
return SUCCESS;
}

GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const {
std::lock_guard<std::mutex> lk(mu_);
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
}

Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) {
std::lock_guard<std::mutex> lk(mu_);
auto input_desc = DoGetInputDesc(index);
GE_CHECK_NOTNULL(input_desc);
*input_desc = tensor_desc;
return SUCCESS;
}

Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const {
if (!has_optional_inputs) {
canonical_index = index;


+ 10
- 3
ge/hybrid/model/node_item.h View File

@@ -17,6 +17,7 @@
#ifndef GE_HYBRID_MODEL_NODE_ITEM_H_
#define GE_HYBRID_MODEL_NODE_ITEM_H_

#include <mutex>
#include <vector>
#include "external/ge/ge_api_error_codes.h"
#include "graph/node.h"
@@ -57,12 +58,16 @@ struct NodeItem {

bool IsInputShapeStatic(int index) const;

GeTensorDescPtr MutableOutputDesc(int index) const {
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
}
GeTensorDescPtr MutableOutputDesc(int index) const;
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc);

GeTensorDescPtr MutableInputDesc(int index) const;

Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const;

Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const;

Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const;

bool IsControlOp() const;
@@ -113,9 +118,11 @@ struct NodeItem {
Status ResolveDynamicState();
Status ResolveStaticInputsAndOutputs();
void ResolveUnknownShapeType();
GeTensorDescPtr DoGetInputDesc(int index) const;

std::vector<bool> is_input_shape_static_;
std::vector<uint32_t> input_desc_indices_;
mutable std::mutex mu_;
};
} // namespace hybrid
} // namespace ge


+ 23
- 33
ge/hybrid/node_executor/controlop/control_op_executor.cc View File

@@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun
}

bool is_continue = false;
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue),
"[%s] Failed to execute iteration 0.",
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
"[%s] Failed to execute cond-subgraph",
task_context.GetNodeName());
if (!is_continue) {
for (int i = 0; i < task_context.NumInputs(); ++i) {
@@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun
}

// backup original input tensor desc
std::vector<GeTensorDesc> ori_input_desc;
std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs());
for (int i = 0; i < task_context.NumInputs(); ++i) {
auto tensor_desc = task_context.GetInputDesc(i);
GE_CHECK_NOTNULL(tensor_desc);
ori_input_desc.emplace_back(*tensor_desc);
GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i]));
}

int iteration = 1;
while (true) {
int iteration = 0;
while (is_continue) {
++iteration;
GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration);
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue),
"[%s] Failed to execute iteration %d.",
task_context.GetNodeName(),
iteration);

if (!is_continue) {
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration);
break;
}

++iteration;
}

for (int i = 0; i < task_context.NumInputs(); ++i) {
auto input_tensor = task_context.GetInput(i);
auto tensor_desc = task_context.MutableInputDesc(i);
GE_CHECK_NOTNULL(input_tensor);
GE_CHECK_NOTNULL(tensor_desc);
// restore original input tensor desc
*tensor_desc = std::move(ori_input_desc[i]);
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor));
}

GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration);
if (done_callback) {
done_callback();
}

for (int i = 0; i < task_context.NumInputs(); ++i) {
GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i]));
}
return SUCCESS;
}

@@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) {
}

Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const {
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
"[%s] Failed to execute cond-subgraph",
task_context.GetNodeName());
if (!is_continue) {
return SUCCESS;
}

GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName());
GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr),
"[%s] Failed to execute cond-subgraph", task_context.GetNodeName());
@@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti
"[%s] Failed to move outputs to inputs",
task_context.GetNodeName());

GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue),
"[%s] Failed to execute cond-subgraph",
task_context.GetNodeName());

if (!is_continue) {
for (int i = 0; i < task_context.NumInputs(); ++i) {
auto input_desc = task_context.GetInput(i);
GE_CHECK_NOTNULL(input_desc);
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc));
}
}
return SUCCESS;
}



+ 0
- 1
ge/hybrid/node_executor/controlop/control_op_executor.h View File

@@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask {
Status ExecuteCond(TaskContext &task_context, bool &is_continue) const;

static Status MoveOutputs2Inputs(TaskContext &task_context);

Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const;

private:


+ 11
- 0
ge/hybrid/node_executor/task_context.cc View File

@@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const {
return node_state_;
}

Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const {
return node_item_->GetInputDesc(index, tensor_desc);
}

Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) {
return const_cast<NodeItem *>(node_item_)->UpdateInputDesc(index, tensor_desc);
}

Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const {
return node_item_->GetOutputDesc(index, tensor_desc);
}
} // namespace hybrid
} // namespace ge

+ 3
- 0
ge/hybrid/node_executor/task_context.h View File

@@ -50,9 +50,12 @@ class TaskContext {
const char *GetNodeName() const;
TensorValue *MutableInput(int index);
ConstGeTensorDescPtr GetInputDesc(int index) const;
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const;
ConstGeTensorDescPtr GetOutputDesc(int index) const;
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const;
GeTensorDescPtr MutableInputDesc(int index) const;
GeTensorDescPtr MutableOutputDesc(int index) const;
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc);
void ReleaseInputsAndOutputs();
bool NeedCallback();
void ReleaseInput(int index);


+ 42
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -383,3 +383,45 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
HybridModelBuilder hybrid_model_builder(hybrid_model);
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS);
}

TEST_F(UtestGeHybrid, TestTaskContext) {
auto graph = make_shared<ComputeGraph>("graph");
OpDescPtr op_desc = CreateOpDesc("Add", "Add");
GeShape shape({2, 16});
GeTensorDesc tensor_desc(shape);
op_desc->AddInputDesc(tensor_desc);
op_desc->AddInputDesc(tensor_desc);
op_desc->AddOutputDesc(tensor_desc);
auto node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> node_item;
NodeItem::Create(node, node_item);
node_item->input_start = 0;
node_item->output_start = 0;

GraphExecutionContext execution_context;
SubgraphContext subgraph_context(nullptr, &execution_context);
subgraph_context.all_inputs_.resize(2);
subgraph_context.all_outputs_.resize(1);

NodeState node_state(*node_item, &subgraph_context);
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context);
ASSERT_TRUE(task_context != nullptr);
auto desc = task_context->MutableInputDesc(2);
ASSERT_TRUE(desc == nullptr);
desc = task_context->MutableOutputDesc(0);
ASSERT_TRUE(desc != nullptr);
ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims());
GeTensorDesc output_desc;
ASSERT_EQ(task_context->GetOutputDesc(0, output_desc), SUCCESS);
ASSERT_EQ(output_desc.GetShape().GetDims(), shape.GetDims());

desc = task_context->MutableInputDesc(0);
ASSERT_TRUE(desc != nullptr);
ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims());
GeShape new_shape({8, 2});
tensor_desc.SetShape(new_shape);
task_context->UpdateInputDesc(1, tensor_desc);
GeTensorDesc new_desc;
ASSERT_EQ(task_context->GetInputDesc(1, new_desc), SUCCESS);
ASSERT_EQ(new_desc.GetShape().GetDims(), new_shape.GetDims());
}

Loading…
Cancel
Save