@@ -35,12 +35,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||
node_item.NodeName().c_str(), | |||
this->num_pending_shapes_); | |||
for (int i = 0; i < node_item.num_inputs; ++i){ | |||
input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); | |||
input_tensor_desc.resize(node_item.num_inputs); | |||
for (int i = 0; i < node_item.num_inputs; ++i) { | |||
node_item.GetInputDesc(i, input_tensor_desc[i]); | |||
} | |||
for (int i = 0; i < node_item.num_outputs; ++i){ | |||
output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); | |||
output_tensor_desc.resize(node_item.num_outputs); | |||
for (int i = 0; i < node_item.num_outputs; ++i) { | |||
node_item.GetOutputDesc(i, output_tensor_desc[i]); | |||
} | |||
} | |||
@@ -297,7 +297,7 @@ void NodeItem::SetToDynamic() { | |||
} | |||
} | |||
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const { | |||
if (!has_optional_inputs) { | |||
return op_desc->MutableInputDesc(static_cast<uint32_t>(index)); | |||
} | |||
@@ -314,6 +314,40 @@ GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
return op_desc->MutableInputDesc(input_desc_indices_[index]); | |||
} | |||
GeTensorDescPtr NodeItem::MutableInputDesc(int index) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
return DoGetInputDesc(index); | |||
} | |||
Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto input_desc = DoGetInputDesc(index); | |||
GE_CHECK_NOTNULL(input_desc); | |||
tensor_desc = *input_desc; | |||
return SUCCESS; | |||
} | |||
Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
GE_CHECK_NOTNULL(output_desc); | |||
tensor_desc = *output_desc; | |||
return SUCCESS; | |||
} | |||
GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
} | |||
Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
auto input_desc = DoGetInputDesc(index); | |||
GE_CHECK_NOTNULL(input_desc); | |||
*input_desc = tensor_desc; | |||
return SUCCESS; | |||
} | |||
Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const { | |||
if (!has_optional_inputs) { | |||
canonical_index = index; | |||
@@ -17,6 +17,7 @@ | |||
#ifndef GE_HYBRID_MODEL_NODE_ITEM_H_ | |||
#define GE_HYBRID_MODEL_NODE_ITEM_H_ | |||
#include <mutex> | |||
#include <vector> | |||
#include "external/ge/ge_api_error_codes.h" | |||
#include "graph/node.h" | |||
@@ -57,12 +58,16 @@ struct NodeItem { | |||
bool IsInputShapeStatic(int index) const; | |||
GeTensorDescPtr MutableOutputDesc(int index) const { | |||
return op_desc->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
} | |||
GeTensorDescPtr MutableOutputDesc(int index) const; | |||
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||
GeTensorDescPtr MutableInputDesc(int index) const; | |||
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
Status GetCanonicalInputIndex(uint32_t index, int &canonical_index) const; | |||
bool IsControlOp() const; | |||
@@ -113,9 +118,11 @@ struct NodeItem { | |||
Status ResolveDynamicState(); | |||
Status ResolveStaticInputsAndOutputs(); | |||
void ResolveUnknownShapeType(); | |||
GeTensorDescPtr DoGetInputDesc(int index) const; | |||
std::vector<bool> is_input_shape_static_; | |||
std::vector<uint32_t> input_desc_indices_; | |||
mutable std::mutex mu_; | |||
}; | |||
} // namespace hybrid | |||
} // namespace ge | |||
@@ -237,8 +237,8 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||
} | |||
bool is_continue = false; | |||
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||
"[%s] Failed to execute iteration 0.", | |||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
"[%s] Failed to execute cond-subgraph", | |||
task_context.GetNodeName()); | |||
if (!is_continue) { | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
@@ -259,42 +259,28 @@ Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::fun | |||
} | |||
// backup original input tensor desc | |||
std::vector<GeTensorDesc> ori_input_desc; | |||
std::vector<GeTensorDesc> ori_input_desc(task_context.NumInputs()); | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
auto tensor_desc = task_context.GetInputDesc(i); | |||
GE_CHECK_NOTNULL(tensor_desc); | |||
ori_input_desc.emplace_back(*tensor_desc); | |||
GE_CHK_STATUS_RET_NOLOG(task_context.GetInputDesc(i, ori_input_desc[i])); | |||
} | |||
int iteration = 1; | |||
while (true) { | |||
int iteration = 0; | |||
while (is_continue) { | |||
++iteration; | |||
GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); | |||
GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), | |||
"[%s] Failed to execute iteration %d.", | |||
task_context.GetNodeName(), | |||
iteration); | |||
if (!is_continue) { | |||
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||
break; | |||
} | |||
++iteration; | |||
} | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
auto input_tensor = task_context.GetInput(i); | |||
auto tensor_desc = task_context.MutableInputDesc(i); | |||
GE_CHECK_NOTNULL(input_tensor); | |||
GE_CHECK_NOTNULL(tensor_desc); | |||
// restore original input tensor desc | |||
*tensor_desc = std::move(ori_input_desc[i]); | |||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); | |||
} | |||
GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); | |||
if (done_callback) { | |||
done_callback(); | |||
} | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
GE_CHK_STATUS_RET_NOLOG(task_context.UpdateInputDesc(i, ori_input_desc[i])); | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -379,13 +365,6 @@ Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { | |||
} | |||
Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { | |||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
"[%s] Failed to execute cond-subgraph", | |||
task_context.GetNodeName()); | |||
if (!is_continue) { | |||
return SUCCESS; | |||
} | |||
GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); | |||
GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), | |||
"[%s] Failed to execute cond-subgraph", task_context.GetNodeName()); | |||
@@ -396,6 +375,17 @@ Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_conti | |||
"[%s] Failed to move outputs to inputs", | |||
task_context.GetNodeName()); | |||
GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), | |||
"[%s] Failed to execute cond-subgraph", | |||
task_context.GetNodeName()); | |||
if (!is_continue) { | |||
for (int i = 0; i < task_context.NumInputs(); ++i) { | |||
auto input_desc = task_context.GetInput(i); | |||
GE_CHECK_NOTNULL(input_desc); | |||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_desc)); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -80,7 +80,6 @@ class WhileOpNodeTask : public ControlOpNodeTask { | |||
Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; | |||
static Status MoveOutputs2Inputs(TaskContext &task_context); | |||
Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; | |||
private: | |||
@@ -554,5 +554,16 @@ NodeState *TaskContext::GetNodeState() const { | |||
return node_state_; | |||
} | |||
Status TaskContext::GetInputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
return node_item_->GetInputDesc(index, tensor_desc); | |||
} | |||
Status TaskContext::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) { | |||
return const_cast<NodeItem *>(node_item_)->UpdateInputDesc(index, tensor_desc); | |||
} | |||
Status TaskContext::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const { | |||
return node_item_->GetOutputDesc(index, tensor_desc); | |||
} | |||
} // namespace hybrid | |||
} // namespace ge |
@@ -50,9 +50,12 @@ class TaskContext { | |||
const char *GetNodeName() const; | |||
TensorValue *MutableInput(int index); | |||
ConstGeTensorDescPtr GetInputDesc(int index) const; | |||
Status GetInputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
ConstGeTensorDescPtr GetOutputDesc(int index) const; | |||
Status GetOutputDesc(int index, GeTensorDesc &tensor_desc) const; | |||
GeTensorDescPtr MutableInputDesc(int index) const; | |||
GeTensorDescPtr MutableOutputDesc(int index) const; | |||
Status UpdateInputDesc(int index, const GeTensorDesc &tensor_desc); | |||
void ReleaseInputsAndOutputs(); | |||
bool NeedCallback(); | |||
void ReleaseInput(int index); | |||
@@ -383,3 +383,45 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { | |||
HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); | |||
} | |||
TEST_F(UtestGeHybrid, TestTaskContext) { | |||
auto graph = make_shared<ComputeGraph>("graph"); | |||
OpDescPtr op_desc = CreateOpDesc("Add", "Add"); | |||
GeShape shape({2, 16}); | |||
GeTensorDesc tensor_desc(shape); | |||
op_desc->AddInputDesc(tensor_desc); | |||
op_desc->AddInputDesc(tensor_desc); | |||
op_desc->AddOutputDesc(tensor_desc); | |||
auto node = graph->AddNode(op_desc); | |||
std::unique_ptr<NodeItem> node_item; | |||
NodeItem::Create(node, node_item); | |||
node_item->input_start = 0; | |||
node_item->output_start = 0; | |||
GraphExecutionContext execution_context; | |||
SubgraphContext subgraph_context(nullptr, &execution_context); | |||
subgraph_context.all_inputs_.resize(2); | |||
subgraph_context.all_outputs_.resize(1); | |||
NodeState node_state(*node_item, &subgraph_context); | |||
auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); | |||
ASSERT_TRUE(task_context != nullptr); | |||
auto desc = task_context->MutableInputDesc(2); | |||
ASSERT_TRUE(desc == nullptr); | |||
desc = task_context->MutableOutputDesc(0); | |||
ASSERT_TRUE(desc != nullptr); | |||
ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims()); | |||
GeTensorDesc output_desc; | |||
ASSERT_EQ(task_context->GetOutputDesc(0, output_desc), SUCCESS); | |||
ASSERT_EQ(output_desc.GetShape().GetDims(), shape.GetDims()); | |||
desc = task_context->MutableInputDesc(0); | |||
ASSERT_TRUE(desc != nullptr); | |||
ASSERT_EQ(desc->GetShape().GetDims(), shape.GetDims()); | |||
GeShape new_shape({8, 2}); | |||
tensor_desc.SetShape(new_shape); | |||
task_context->UpdateInputDesc(1, tensor_desc); | |||
GeTensorDesc new_desc; | |||
ASSERT_EQ(task_context->GetInputDesc(1, new_desc), SUCCESS); | |||
ASSERT_EQ(new_desc.GetShape().GetDims(), new_shape.GetDims()); | |||
} |