| @@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||
| node_item.NodeName().c_str(), | |||
| this->num_pending_shapes_); | |||
| for (int i = 0; i < node_item.num_inputs; ++i){ | |||
| input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); | |||
| input_tensor_desc.resize(node_item.num_inputs); | |||
| for (int i = 0; i < node_item.num_inputs; ++i) { | |||
| node_item.GetInputDesc(i, input_tensor_desc[i]); | |||
| } | |||
| for (int i = 0; i < node_item.num_outputs; ++i){ | |||
| output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); | |||
| output_tensor_desc.resize(node_item.num_outputs); | |||
| for (int i = 0; i < node_item.num_outputs; ++i) { | |||
| node_item.GetOutputDesc(i, output_tensor_desc[i]); | |||
| } | |||
| } | |||
| @@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() { | |||
| } | |||
| } | |||
| GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
| GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const { | |||
| if (!has_optional_inputs) { | |||
| return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); | |||
| } | |||
| @@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
| return op_desc->MutableInputDesc(input_desc_indices_[index]); | |||
| } | |||
| GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| return DoGetInputDesc(index); | |||
| } | |||
| Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| auto input_desc = DoGetInputDesc(index); | |||
| GE_CHECK_NOTNULL(input_desc); | |||
| tensor_desc = *input_desc; | |||
| return SUCCESS; | |||
| } | |||
| Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
| GE_CHECK_NOTNULL(output_desc); | |||
| tensor_desc = *output_desc; | |||
| return SUCCESS; | |||
| } | |||
| GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
| } | |||
| Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| auto input_desc = DoGetInputDesc(index); | |||
| GE_CHECK_NOTNULL(input_desc); | |||
| *input_desc = tensor_desc; | |||
| return SUCCESS; | |||
| } | |||
| Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { | |||
| if (!has_optional_inputs) { | |||
| canonical_index = index; | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ | |||
| #define GE_HYBRID_MODEL_NODE_ITEM_H_ | |||
| #include <mutex> | |||
| #include <vector> | |||
| #include "external/ge/ge_api_error_codes.h" | |||
| #include "graph/node.h" | |||
| @@ -57,12 +58,16 @@ struct NodeItem { | |||
| bool IsInputShapeStatic(int index) const; | |||
| GeTensorDescPtr MutableOutputDesc(int index) const { | |||
| return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
| } | |||
| GeTensorDescPtr MutableOutputDesc(int index) const; | |||
| Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||
| GeTensorDescPtr MutableInputDesc(int index) const; | |||
| Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
| Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
| Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | |||
| bool IsControlOp() const; | |||
| @@ -113,9 +118,11 @@ struct NodeItem { | |||
| Status ResolveDynamicState(); | |||
| Status ResolveStaticInputsAndOutputs(); | |||
| void ResolveUnknownShapeType(); | |||
| GeTensorDescPtr DoGetInputDesc(int index) const; | |||
| std::vector<bool> is_input_shape_static_; | |||
| std::vector<uint32_t> input_desc_indices_; | |||
| mutable std::mutex mu_; | |||
| }; | |||
| } // namespace hybrid | |||
| } // namespace ge | |||
| @@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||
| } | |||
| bool is_continue = false; | |||
| GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||
| "[%s] Failed to execute iteration 0.", | |||
| GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
| "[%s] Failed to execute cond-subgraph", | |||
| task_context.GetNodeName()); | |||
| if (!is_continue) { | |||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
| @@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||
| } | |||
| // backup original input tensor desc | |||
| std::vector<GeTensorDesc> ori_input_desc; | |||
| std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs()); | |||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
| auto tensor_desc = task_context.GetInputDesc(i); | |||
| GE_CHECK_NOTNULL(tensor_desc); | |||
| ori_input_desc.emplace_back(*tensor_desc); | |||
| GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i])); | |||
| } | |||
| int iteration = 1; | |||
| while (true) { | |||
| int iteration = 0; | |||
| while (is_continue) { | |||
| ++iteration; | |||
| GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); | |||
| GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||
| "[%s] Failed to execute iteration %d.", | |||
| task_context.GetNodeName(), | |||
| iteration); | |||
| if (!is_continue) { | |||
| GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||
| break; | |||
| } | |||
| ++iteration; | |||
| } | |||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
| auto input_tensor = task_context.GetInput(i); | |||
| auto tensor_desc = task_context.MutableInputDesc(i); | |||
| GE_CHECK_NOTNULL(input_tensor); | |||
| GE_CHECK_NOTNULL(tensor_desc); | |||
| // restore original input tensor desc | |||
| *tensor_desc = std::move(ori_input_desc[i]); | |||
| GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); | |||
| } | |||
| GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||
| if (done_callback) { | |||
| done_callback(); | |||
| } | |||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
| GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i])); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { | |||
| } | |||
| Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { | |||
| GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
| "[%s] Failed to execute cond-subgraph", | |||
| task_context.GetNodeName()); | |||
| if (!is_continue) { | |||
| return SUCCESS; | |||
| } | |||
| GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); | |||
| GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), | |||
| "[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); | |||
| @@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti | |||
| "[%s] Failed to move outputs to inputs", | |||
| task_context.GetNodeName()); | |||
| GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
| "[%s] Failed to execute cond-subgraph", | |||
| task_context.GetNodeName()); | |||
| if (!is_continue) { | |||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
| auto input_desc = task_context.GetInput(i); | |||
| GE_CHECK_NOTNULL(input_desc); | |||
| GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc)); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask { | |||
| Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; | |||
| static Status MoveOutputs2Inputs(TaskContext &task_context); | |||
| Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; | |||
| private: | |||
| @@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const { | |||
| return node_state_; | |||
| } | |||
| Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
| return node_item_->GetInputDesc(index, tensor_desc); | |||
| } | |||
| Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||
| return const_cast<NodeItem *>(node_item_)->UpdateInputDesc(index, tensor_desc); | |||
| } | |||
| Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
| return node_item_->GetOutputDesc(index, tensor_desc); | |||
| } | |||
| } // namespace hybrid | |||
| } // namespace ge | |||
| @@ -50,9 +50,12 @@ class TaskContext { | |||
| const char *GetNodeName() const; | |||
| TensorValue *MutableInput(int index); | |||
| ConstGeTensorDescPtr GetInputDesc(int index) const; | |||
| Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
| ConstGeTensorDescPtr GetOutputDesc(int index) const; | |||
| Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
| GeTensorDescPtr MutableInputDesc(int index) const; | |||
| GeTensorDescPtr MutableOutputDesc(int index) const; | |||
| Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||
| void ReleaseInputsAndOutputs(); | |||
| bool NeedCallback(); | |||
| void ReleaseInput(int index); | |||
| @@ -383,3 +383,45 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { | |||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
| EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); | |||
| } | |||
| TEST_F(UtestGeHybrid, TestTaskContext) { | |||
| auto graph = make_shared<ComputeGraph>("graph"); | |||
| OpDescPtr op_desc = CreateOpDesc("Add", "Add"); | |||
| GeShape shape({2, 16}); | |||
| GeTensorDesc tensor_desc(shape); | |||
| op_desc->AddInputDesc(tensor_desc); | |||
| op_desc->AddInputDesc(tensor_desc); | |||
| op_desc->AddOutputDesc(tensor_desc); | |||
| auto node = graph->AddNode(op_desc); | |||
| std::unique_ptr<NodeItem> node_item; | |||
| NodeItem::Create(node, node_item); | |||
| node_item->input_start = 0; | |||
| node_item->output_start = 0; | |||
| GraphExecutionContext execution_context; | |||
| SubgraphContext subgraph_context(nullptr, &execution_context); | |||
| subgraph_context.all_inputs_.resize(2); | |||
| subgraph_context.all_outputs_.resize(1); | |||
| NodeState node_state(*node_item, &subgraph_context); | |||
| auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
| ASSERT_TRUE(task_context != nullptr); | |||
| auto desc = task_context->MutableInputDesc(2); | |||
| ASSERT_TRUE(desc == nullptr); | |||
| desc = task_context->MutableOutputDesc(0); | |||
| ASSERT_TRUE(desc != nullptr); | |||
| ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims()); | |||
| GeTensorDesc output_desc; | |||
| ASSERT_EQ(task_context->GetOutputDesc(0, output_desc), SUCCESS); | |||
| ASSERT_EQ(output_desc.GetShape().GetDims(), shape.GetDims()); | |||
| desc = task_context->MutableInputDesc(0); | |||
| ASSERT_TRUE(desc != nullptr); | |||
| ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims()); | |||
| GeShape new_shape({8, 2}); | |||
| tensor_desc.SetShape(new_shape); | |||
| task_context->UpdateInputDesc(1, tensor_desc); | |||
| GeTensorDesc new_desc; | |||
| ASSERT_EQ(task_context->GetInputDesc(1, new_desc), SUCCESS); | |||
| ASSERT_EQ(new_desc.GetShape().GetDims(), new_shape.GetDims()); | |||
| } | |||