From 4b4ed2e1c5b5701bdef4bdf091c34da305850d95 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 29 Mar 2021 17:11:24 +0800 Subject: [PATCH] while loop failed to restore origin input after execution --- ge/hybrid/executor/node_state.cc | 10 ++-- ge/hybrid/model/node_item.cc | 36 +++++++++++- ge/hybrid/model/node_item.h | 13 ++++- .../controlop/control_op_executor.cc | 56 ++++++++----------- .../controlop/control_op_executor.h | 1 - ge/hybrid/node_executor/task_context.cc | 11 ++++ ge/hybrid/node_executor/task_context.h | 3 + tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 42 ++++++++++++++ 8 files changed, 130 insertions(+), 42 deletions(-) diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 3834478c..99fe8593 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( node_item.NodeName().c_str(), this->num_pending_shapes_); - for (int i = 0; i < node_item.num_inputs; ++i){ - input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); + input_tensor_desc.resize(node_item.num_inputs); + for (int i = 0; i < node_item.num_inputs; ++i) { + node_item.GetInputDesc(i, input_tensor_desc[i]); } - for (int i = 0; i < node_item.num_outputs; ++i){ - output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); + output_tensor_desc.resize(node_item.num_outputs); + for (int i = 0; i < node_item.num_outputs; ++i) { + node_item.GetOutputDesc(i, output_tensor_desc[i]); } } diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 06d654cf..f14e9a21 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() { } } -GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { +GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const { if (!has_optional_inputs) { return op_desc->MutableInputDesc(static_cast(index)); } @@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { return op_desc->MutableInputDesc(input_desc_indices_[index]); } +GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { + std::lock_guard lk(mu_); + return DoGetInputDesc(index); +} + +Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { + std::lock_guard lk(mu_); + auto input_desc = DoGetInputDesc(index); + GE_CHECK_NOTNULL(input_desc); + tensor_desc = *input_desc; + return SUCCESS; +} + +Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { + std::lock_guard lk(mu_); + auto output_desc = op_desc->MutableOutputDesc(static_cast(index)); + GE_CHECK_NOTNULL(output_desc); + tensor_desc = *output_desc; + return SUCCESS; +} + +GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const { + std::lock_guard lk(mu_); + return op_desc->MutableOutputDesc(static_cast(index)); +} + +Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { + std::lock_guard lk(mu_); + auto input_desc = DoGetInputDesc(index); + GE_CHECK_NOTNULL(input_desc); + *input_desc = tensor_desc; + return SUCCESS; +} + Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { if (!has_optional_inputs) { canonical_index = index; diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 474a1da4..54c5e938 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -17,6 +17,7 @@ #ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ #define GE_HYBRID_MODEL_NODE_ITEM_H_ +#include #include #include "external/ge/ge_api_error_codes.h" #include "graph/node.h" @@ -57,12 +58,16 @@ struct NodeItem { bool IsInputShapeStatic(int index) const; - GeTensorDescPtr MutableOutputDesc(int index) const { - return op_desc->MutableOutputDesc(static_cast(index)); - } + GeTensorDescPtr MutableOutputDesc(int index) const; + + Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); GeTensorDescPtr MutableInputDesc(int index) const; + Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; + + Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; + Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; bool IsControlOp() const; @@ -113,9 +118,11 @@ struct NodeItem { Status ResolveDynamicState(); Status ResolveStaticInputsAndOutputs(); void ResolveUnknownShapeType(); + GeTensorDescPtr DoGetInputDesc(int index) const; std::vector is_input_shape_static_; std::vector input_desc_indices_; + mutable std::mutex mu_; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.cc b/ge/hybrid/node_executor/controlop/control_op_executor.cc index 74920b22..4e7e71f1 100644 --- a/ge/hybrid/node_executor/controlop/control_op_executor.cc +++ b/ge/hybrid/node_executor/controlop/control_op_executor.cc @@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun } bool is_continue = false; - GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), - "[%s] Failed to execute iteration 0.", + GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), + "[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); if (!is_continue) { for (int i = 0; i < task_context.NumInputs(); ++i) { @@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun } // backup original input tensor desc - std::vector ori_input_desc; + std::vector ori_input_desc(task_context.NumInputs()); for (int i = 0; i < task_context.NumInputs(); ++i) { - auto tensor_desc = task_context.GetInputDesc(i); - GE_CHECK_NOTNULL(tensor_desc); - ori_input_desc.emplace_back(*tensor_desc); + GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i])); } - int iteration = 1; - while (true) { + int iteration = 0; + while (is_continue) { + ++iteration; GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), "[%s] Failed to execute iteration %d.", task_context.GetNodeName(), iteration); - - if (!is_continue) { - GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); - break; - } - - ++iteration; } - - for (int i = 0; i < task_context.NumInputs(); ++i) { - auto input_tensor = task_context.GetInput(i); - auto tensor_desc = task_context.MutableInputDesc(i); - GE_CHECK_NOTNULL(input_tensor); - GE_CHECK_NOTNULL(tensor_desc); - // restore original input tensor desc - *tensor_desc = std::move(ori_input_desc[i]); - GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); - } - + GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); if (done_callback) { done_callback(); } + + for (int i = 0; i < task_context.NumInputs(); ++i) { + GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i])); + } return SUCCESS; } @@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { } Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { - GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), - "[%s] Failed to execute cond-subgraph", - task_context.GetNodeName()); - if (!is_continue) { - return SUCCESS; - } - GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), "[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); @@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti "[%s] Failed to move outputs to inputs", task_context.GetNodeName()); + GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), + "[%s] Failed to execute cond-subgraph", + task_context.GetNodeName()); + + if (!is_continue) { + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_desc = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_desc); + GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc)); + } + } return SUCCESS; } diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.h b/ge/hybrid/node_executor/controlop/control_op_executor.h index 3becfaaa..fd02bd25 100644 --- a/ge/hybrid/node_executor/controlop/control_op_executor.h +++ b/ge/hybrid/node_executor/controlop/control_op_executor.h @@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask { Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; static Status MoveOutputs2Inputs(TaskContext &task_context); - Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; private: diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index f4271551..4e1b367b 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const { return node_state_; } +Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { + return node_item_->GetInputDesc(index, tensor_desc); +} + +Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { + return const_cast(node_item_)->UpdateInputDesc(index, tensor_desc); +} + +Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { + return node_item_->GetOutputDesc(index, tensor_desc); +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index e00c5048..ba4c62e6 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -50,9 +50,12 @@ class TaskContext { const char *GetNodeName() const; TensorValue *MutableInput(int index); ConstGeTensorDescPtr GetInputDesc(int index) const; + Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; ConstGeTensorDescPtr GetOutputDesc(int index) const; + Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; GeTensorDescPtr MutableInputDesc(int index) const; GeTensorDescPtr MutableOutputDesc(int index) const; + Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); void ReleaseInputsAndOutputs(); bool NeedCallback(); void ReleaseInput(int index); diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 8c4517c7..57230f30 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -383,3 +383,45 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { HybridModelBuilder hybrid_model_builder(hybrid_model); EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); } + +TEST_F(UtestGeHybrid, TestTaskContext) { + auto graph = make_shared("graph"); + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + GeShape shape({2, 16}); + GeTensorDesc tensor_desc(shape); + op_desc->AddInputDesc(tensor_desc); + op_desc->AddInputDesc(tensor_desc); + op_desc->AddOutputDesc(tensor_desc); + auto node = graph->AddNode(op_desc); + std::unique_ptr node_item; + NodeItem::Create(node, node_item); + node_item->input_start = 0; + node_item->output_start = 0; + + GraphExecutionContext execution_context; + SubgraphContext subgraph_context(nullptr, &execution_context); + subgraph_context.all_inputs_.resize(2); + subgraph_context.all_outputs_.resize(1); + + NodeState node_state(*node_item, &subgraph_context); + auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + ASSERT_TRUE(task_context != nullptr); + auto desc = task_context->MutableInputDesc(2); + ASSERT_TRUE(desc == nullptr); + desc = task_context->MutableOutputDesc(0); + ASSERT_TRUE(desc != nullptr); + ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims()); + GeTensorDesc output_desc; + ASSERT_EQ(task_context->GetOutputDesc(0, output_desc), SUCCESS); + ASSERT_EQ(output_desc.GetShape().GetDims(), shape.GetDims()); + + desc = task_context->MutableInputDesc(0); + ASSERT_TRUE(desc != nullptr); + ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims()); + GeShape new_shape({8, 2}); + tensor_desc.SetShape(new_shape); + task_context->UpdateInputDesc(1, tensor_desc); + GeTensorDesc new_desc; + ASSERT_EQ(task_context->GetInputDesc(1, new_desc), SUCCESS); + ASSERT_EQ(new_desc.GetShape().GetDims(), new_shape.GetDims()); +}