From: @xchu42 Reviewed-by: @ji_chen,@wqtshg Signed-off-by: @liyihan123,@ji_chentags/v1.2.0
| @@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| this->num_pending_shapes_); | this->num_pending_shapes_); | ||||
| for (int i = 0; i < node_item.num_inputs; ++i){ | |||||
| input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); | |||||
| input_tensor_desc.resize(node_item.num_inputs); | |||||
| for (int i = 0; i < node_item.num_inputs; ++i) { | |||||
| node_item.GetInputDesc(i, input_tensor_desc[i]); | |||||
| } | } | ||||
| for (int i = 0; i < node_item.num_outputs; ++i){ | |||||
| output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); | |||||
| output_tensor_desc.resize(node_item.num_outputs); | |||||
| for (int i = 0; i < node_item.num_outputs; ++i) { | |||||
| node_item.GetOutputDesc(i, output_tensor_desc[i]); | |||||
| } | } | ||||
| } | } | ||||
| @@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() { | |||||
| } | } | ||||
| } | } | ||||
| GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||||
| GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const { | |||||
| if (!has_optional_inputs) { | if (!has_optional_inputs) { | ||||
| return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); | return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); | ||||
| } | } | ||||
| @@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||||
| return op_desc->MutableInputDesc(input_desc_indices_[index]); | return op_desc->MutableInputDesc(input_desc_indices_[index]); | ||||
| } | } | ||||
| GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| return DoGetInputDesc(index); | |||||
| } | |||||
| Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| auto input_desc = DoGetInputDesc(index); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| tensor_desc = *input_desc; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| tensor_desc = *output_desc; | |||||
| return SUCCESS; | |||||
| } | |||||
| GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||||
| } | |||||
| Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| auto input_desc = DoGetInputDesc(index); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| *input_desc = tensor_desc; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { | Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { | ||||
| if (!has_optional_inputs) { | if (!has_optional_inputs) { | ||||
| canonical_index = index; | canonical_index = index; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ | #ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ | ||||
| #define GE_HYBRID_MODEL_NODE_ITEM_H_ | #define GE_HYBRID_MODEL_NODE_ITEM_H_ | ||||
| #include <mutex> | |||||
| #include <vector> | #include <vector> | ||||
| #include "external/ge/ge_api_error_codes.h" | #include "external/ge/ge_api_error_codes.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| @@ -57,12 +58,16 @@ struct NodeItem { | |||||
| bool IsInputShapeStatic(int index) const; | bool IsInputShapeStatic(int index) const; | ||||
| GeTensorDescPtr MutableOutputDesc(int index) const { | |||||
| return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||||
| } | |||||
| GeTensorDescPtr MutableOutputDesc(int index) const; | |||||
| Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||||
| GeTensorDescPtr MutableInputDesc(int index) const; | GeTensorDescPtr MutableInputDesc(int index) const; | ||||
| Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
| Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
| Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | ||||
| bool IsControlOp() const; | bool IsControlOp() const; | ||||
| @@ -113,9 +118,11 @@ struct NodeItem { | |||||
| Status ResolveDynamicState(); | Status ResolveDynamicState(); | ||||
| Status ResolveStaticInputsAndOutputs(); | Status ResolveStaticInputsAndOutputs(); | ||||
| void ResolveUnknownShapeType(); | void ResolveUnknownShapeType(); | ||||
| GeTensorDescPtr DoGetInputDesc(int index) const; | |||||
| std::vector<bool> is_input_shape_static_; | std::vector<bool> is_input_shape_static_; | ||||
| std::vector<uint32_t> input_desc_indices_; | std::vector<uint32_t> input_desc_indices_; | ||||
| mutable std::mutex mu_; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||||
| } | } | ||||
| bool is_continue = false; | bool is_continue = false; | ||||
| GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||||
| "[%s] Failed to execute iteration 0.", | |||||
| GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||||
| "[%s] Failed to execute cond-subgraph", | |||||
| task_context.GetNodeName()); | task_context.GetNodeName()); | ||||
| if (!is_continue) { | if (!is_continue) { | ||||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | for (int i = 0; i < task_context.NumInputs(); ++i) { | ||||
| @@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||||
| } | } | ||||
| // backup original input tensor desc | // backup original input tensor desc | ||||
| std::vector<GeTensorDesc> ori_input_desc; | |||||
| std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs()); | |||||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | for (int i = 0; i < task_context.NumInputs(); ++i) { | ||||
| auto tensor_desc = task_context.GetInputDesc(i); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| ori_input_desc.emplace_back(*tensor_desc); | |||||
| GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i])); | |||||
| } | } | ||||
| int iteration = 1; | |||||
| while (true) { | |||||
| int iteration = 0; | |||||
| while (is_continue) { | |||||
| ++iteration; | |||||
| GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); | GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); | ||||
| GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | ||||
| "[%s] Failed to execute iteration %d.", | "[%s] Failed to execute iteration %d.", | ||||
| task_context.GetNodeName(), | task_context.GetNodeName(), | ||||
| iteration); | iteration); | ||||
| if (!is_continue) { | |||||
| GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||||
| break; | |||||
| } | |||||
| ++iteration; | |||||
| } | } | ||||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||||
| auto input_tensor = task_context.GetInput(i); | |||||
| auto tensor_desc = task_context.MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_tensor); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| // restore original input tensor desc | |||||
| *tensor_desc = std::move(ori_input_desc[i]); | |||||
| GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); | |||||
| } | |||||
| GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||||
| if (done_callback) { | if (done_callback) { | ||||
| done_callback(); | done_callback(); | ||||
| } | } | ||||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||||
| GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i])); | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { | |||||
| } | } | ||||
| Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { | Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { | ||||
| GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||||
| "[%s] Failed to execute cond-subgraph", | |||||
| task_context.GetNodeName()); | |||||
| if (!is_continue) { | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); | GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); | ||||
| GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), | GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), | ||||
| "[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); | "[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); | ||||
| @@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti | |||||
| "[%s] Failed to move outputs to inputs", | "[%s] Failed to move outputs to inputs", | ||||
| task_context.GetNodeName()); | task_context.GetNodeName()); | ||||
| GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||||
| "[%s] Failed to execute cond-subgraph", | |||||
| task_context.GetNodeName()); | |||||
| if (!is_continue) { | |||||
| for (int i = 0; i < task_context.NumInputs(); ++i) { | |||||
| auto input_desc = task_context.GetInput(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc)); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask { | |||||
| Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; | Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; | ||||
| static Status MoveOutputs2Inputs(TaskContext &task_context); | static Status MoveOutputs2Inputs(TaskContext &task_context); | ||||
| Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; | Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; | ||||
| private: | private: | ||||
| @@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const { | |||||
| return node_state_; | return node_state_; | ||||
| } | } | ||||
| Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
| return node_item_->GetInputDesc(index, tensor_desc); | |||||
| } | |||||
| Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||||
| return const_cast<NodeItem *>(node_item_)->UpdateInputDesc(index, tensor_desc); | |||||
| } | |||||
| Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||||
| return node_item_->GetOutputDesc(index, tensor_desc); | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -50,9 +50,12 @@ class TaskContext { | |||||
| const char *GetNodeName() const; | const char *GetNodeName() const; | ||||
| TensorValue *MutableInput(int index); | TensorValue *MutableInput(int index); | ||||
| ConstGeTensorDescPtr GetInputDesc(int index) const; | ConstGeTensorDescPtr GetInputDesc(int index) const; | ||||
| Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
| ConstGeTensorDescPtr GetOutputDesc(int index) const; | ConstGeTensorDescPtr GetOutputDesc(int index) const; | ||||
| Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||||
| GeTensorDescPtr MutableInputDesc(int index) const; | GeTensorDescPtr MutableInputDesc(int index) const; | ||||
| GeTensorDescPtr MutableOutputDesc(int index) const; | GeTensorDescPtr MutableOutputDesc(int index) const; | ||||
| Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||||
| void ReleaseInputsAndOutputs(); | void ReleaseInputsAndOutputs(); | ||||
| bool NeedCallback(); | bool NeedCallback(); | ||||
| void ReleaseInput(int index); | void ReleaseInput(int index); | ||||