From: @xchu42 Reviewed-by: @ji_chen,@wqtshg Signed-off-by: @ji_chentags/v1.2.0
| @@ -18,6 +18,7 @@ | |||||
| #include <chrono> | #include <chrono> | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "hybrid_execution_context.h" | #include "hybrid_execution_context.h" | ||||
| #include "subgraph_context.h" | #include "subgraph_context.h" | ||||
| @@ -35,29 +36,31 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
| this->num_pending_shapes_); | this->num_pending_shapes_); | ||||
| } | } | ||||
| Status ShapeInferenceState::UpdateInputShape(int idx, | |||||
| const GeShape &ori_shape, | |||||
| const GeShape &shape) { | |||||
| Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | |||||
| if (node_item.IsInputShapeStatic(idx)) { | if (node_item.IsInputShapeStatic(idx)) { | ||||
| GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| idx, | idx, | ||||
| node_item.MutableInputDesc(idx)->GetShape().ToString().c_str(), | node_item.MutableInputDesc(idx)->GetShape().ToString().c_str(), | ||||
| shape.ToString().c_str()); | |||||
| target.GetShape().ToString().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s]", | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(target, tensor_size); | |||||
| GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| idx, | idx, | ||||
| shape.ToString().c_str(), | |||||
| ori_shape.ToString().c_str()); | |||||
| target.GetShape().ToString().c_str(), | |||||
| target.GetOriginShape().ToString().c_str(), | |||||
| tensor_size); | |||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| auto tensor_desc = node_item.MutableInputDesc(idx); | auto tensor_desc = node_item.MutableInputDesc(idx); | ||||
| GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
| tensor_desc->SetShape(shape); | |||||
| tensor_desc->SetOriginShape(ori_shape); | |||||
| tensor_desc->SetShape(target.GetShape()); | |||||
| tensor_desc->SetOriginShape(target.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
| if (--num_pending_shapes_ == 0) { | if (--num_pending_shapes_ == 0) { | ||||
| ready_cv_.notify_all(); | ready_cv_.notify_all(); | ||||
| } | } | ||||
| @@ -110,24 +113,24 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| for (auto &p : shape_futures) { | for (auto &p : shape_futures) { | ||||
| auto idx = p.first; | auto idx = p.first; | ||||
| auto &future = p.second; | auto &future = p.second; | ||||
| GeShape shape; | |||||
| GeShape ori_shape; | |||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | ||||
| GE_CHK_STATUS_RET(future.Get(ori_shape, shape), | |||||
| "[%s] Get shape failed. index = %u", | |||||
| node_item.NodeName().c_str(), | |||||
| idx); | |||||
| auto src_tensor_desc = future.GetTensorDesc(); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | |||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | ||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", | |||||
| node_item.NodeName().c_str(), | |||||
| idx, | |||||
| shape.ToString().c_str(), | |||||
| ori_shape.ToString().c_str()); | |||||
| auto input_desc = node_item.MutableInputDesc(idx); | auto input_desc = node_item.MutableInputDesc(idx); | ||||
| GE_CHECK_NOTNULL(input_desc); | GE_CHECK_NOTNULL(input_desc); | ||||
| input_desc->SetShape(std::move(shape)); | |||||
| input_desc->SetOriginShape(ori_shape); | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | |||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| idx, | |||||
| src_tensor_desc->GetShape().ToString().c_str(), | |||||
| src_tensor_desc->GetOriginShape().ToString().c_str(), | |||||
| tensor_size); | |||||
| input_desc->SetShape(src_tensor_desc->GetShape()); | |||||
| input_desc->SetOriginShape(src_tensor_desc->GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*input_desc, tensor_size); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -190,5 +193,14 @@ Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | |||||
| GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GeTensorDescPtr ShapeFuture::GetTensorDesc() { | |||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | |||||
| if (!subgraph_context_->Await(src_node_)) { | |||||
| GELOGE(INTERNAL_ERROR, "cancelled"); | |||||
| return nullptr; | |||||
| } | |||||
| return src_node_->GetOpDesc()->MutableOutputDesc(src_index_); | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -35,6 +35,7 @@ class ShapeFuture { | |||||
| ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); | ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); | ||||
| ~ShapeFuture() = default; | ~ShapeFuture() = default; | ||||
| Status Get(GeShape &ori_shape, GeShape &shape); | Status Get(GeShape &ori_shape, GeShape &shape); | ||||
| GeTensorDescPtr GetTensorDesc(); | |||||
| private: | private: | ||||
| NodePtr src_node_; | NodePtr src_node_; | ||||
| @@ -45,7 +46,7 @@ class ShapeFuture { | |||||
| struct ShapeInferenceState { | struct ShapeInferenceState { | ||||
| explicit ShapeInferenceState(const NodeItem &node_item); | explicit ShapeInferenceState(const NodeItem &node_item); | ||||
| Status UpdateInputShape(int idx, const GeShape &ori_shape, const GeShape &shape); | |||||
| Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc); | |||||
| void UpdateInputShapeFuture(int idx, ShapeFuture &&future); | void UpdateInputShapeFuture(int idx, ShapeFuture &&future); | ||||
| @@ -96,7 +96,7 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||||
| GE_CHECK_NOTNULL(tensor_desc); | GE_CHECK_NOTNULL(tensor_desc); | ||||
| auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); | ||||
| GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
| node_state->GetShapeInferenceState().UpdateInputShape(0, tensor_desc->GetOriginShape(), tensor_desc->GetShape()); | |||||
| node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); | |||||
| } | } | ||||
| } | } | ||||
| @@ -268,13 +268,6 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||||
| } else { | } else { | ||||
| node_state.SetKernelTask(node_item.kernel_task); | node_state.SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| GELOGD("[%s] Start to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); | |||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); | |||||
| GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node), | |||||
| "[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); | |||||
| RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); | |||||
| GELOGD("[%s] Done invoking CalcOpRunningParam successfully.", node_item.NodeName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -20,12 +20,9 @@ | |||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include "common/dump/dump_manager.h" | |||||
| #include "hybrid/executor//worker//shape_inference_engine.h" | |||||
| #include "common/dump/dump_op.h" | #include "common/dump/dump_op.h" | ||||
| #include "common/types.h" | |||||
| #include "common/ge_types.h" | |||||
| #include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
| #include "runtime/base.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -348,6 +345,10 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| } | } | ||||
| GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); | GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); | ||||
| if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | |||||
| // update output tensor sizes | |||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | |||||
| } | |||||
| // PropagateOutputs for type == DEPEND_COMPUTE | // PropagateOutputs for type == DEPEND_COMPUTE | ||||
| if (node_item.shape_inference_type == DEPEND_COMPUTE) { | if (node_item.shape_inference_type == DEPEND_COMPUTE) { | ||||
| if (graph_context_->trace_enabled) { | if (graph_context_->trace_enabled) { | ||||
| @@ -17,9 +17,15 @@ | |||||
| #include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "common/math/math_util.h" | |||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const int kAlignment = 32; | |||||
| } | |||||
| namespace hybrid { | namespace hybrid { | ||||
| ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) | ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) | ||||
| : execution_context_(execution_context), | : execution_context_(execution_context), | ||||
| @@ -40,7 +46,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| } | } | ||||
| if (node_item.fused_subgraph != nullptr) { | if (node_item.fused_subgraph != nullptr) { | ||||
| return InferShapeForSubgraph(node_item, *node_item.fused_subgraph); | |||||
| GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph)); | |||||
| GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item)); | |||||
| return SUCCESS; | |||||
| } | } | ||||
| // Skip shape inference for node of type DEPEND_COMPUTE | // Skip shape inference for node of type DEPEND_COMPUTE | ||||
| @@ -63,21 +71,15 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); | ||||
| GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), | ||||
| "Invoke InferShapeAndType failed."); | |||||
| "Invoke InferShapeAndType failed."); | |||||
| RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); | ||||
| } | } | ||||
| // Check again to make sure shape is valid after shape inference | |||||
| if (node_item.shape_inference_type != DEPEND_SHAPE_RANGE) { | |||||
| bool is_unknown_shape = false; | |||||
| GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node_item.node, is_unknown_shape), | |||||
| "Failed to get shape status. node = %s", | |||||
| node_item.NodeName().c_str()); | |||||
| GE_CHK_BOOL_RET_STATUS(!is_unknown_shape, | |||||
| INTERNAL_ERROR, | |||||
| "[%s] Shape is still unknown after shape inference.", | |||||
| node_item.NodeName().c_str()); | |||||
| } | |||||
| // update output tensor sizes after shape inference | |||||
| // error if shape is still unknown and not of type DEPEND_SHAPE_RANGE | |||||
| RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); | |||||
| GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item, node_item.shape_inference_type == DEPEND_SHAPE_RANGE)); | |||||
| RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); | |||||
| GELOGD("[%s] [HybridTrace] After shape inference. Node = %s", | GELOGD("[%s] [HybridTrace] After shape inference. Node = %s", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| @@ -127,8 +129,6 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| // propagate each output | // propagate each output | ||||
| for (int i = 0; i < node_item.num_outputs; ++i) { | for (int i = 0; i < node_item.num_outputs; ++i) { | ||||
| auto output_desc = node_item.op_desc->MutableOutputDesc(i); | auto output_desc = node_item.op_desc->MutableOutputDesc(i); | ||||
| const auto &shape = output_desc->MutableShape(); | |||||
| const auto &ori_shape = output_desc->GetOriginShape(); | |||||
| auto &output_nodes = node_item.outputs[i]; | auto &output_nodes = node_item.outputs[i]; | ||||
| // propagate output to all sub-inputs | // propagate output to all sub-inputs | ||||
| @@ -149,9 +149,7 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, | infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, | ||||
| std::move(future)); | std::move(future)); | ||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, | |||||
| ori_shape, | |||||
| shape)); | |||||
| GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -230,5 +228,92 @@ Status ShapeInferenceEngine::UpdatePeerNodeShape(const Node &node) { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ShapeInferenceEngine::CanonicalizeShape(GeTensorDesc &tensor_desc, | |||||
| std::vector<int64_t> &shape, | |||||
| bool fallback_with_range) { | |||||
| const auto &tensor_shape = tensor_desc.MutableShape(); | |||||
| if (tensor_shape.IsUnknownShape()) { | |||||
| if (!fallback_with_range) { | |||||
| GELOGE(INTERNAL_ERROR, "Output shape is still unknown after shape inference. shape = [%s]", | |||||
| tensor_shape.ToString().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Calc output size by range"); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| GE_CHK_GRAPH_STATUS_RET(tensor_desc.GetShapeRange(shape_range), "Failed to get shape range"); | |||||
| if (shape_range.size() != shape.size()) { | |||||
| GELOGE(INTERNAL_ERROR, "Number of shape ranges (%zu) mismatches that of dims (%zu)", | |||||
| shape_range.size(), | |||||
| shape.size()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| for (size_t dim_index = 0; dim_index < shape.size(); ++dim_index) { | |||||
| if (shape[dim_index] == ge::UNKNOWN_DIM) { | |||||
| shape[dim_index] = shape_range[dim_index].second; | |||||
| } | |||||
| } | |||||
| GELOGD("After canonicalization, shape = [%s], before = [%s]", | |||||
| GeShape(shape).ToString().c_str(), | |||||
| tensor_shape.ToString().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeInferenceEngine::CalcTensorSize(DataType data_type, | |||||
| const std::vector<int64_t> &shape, | |||||
| int64_t &tensor_size) { | |||||
| GELOGD("To calc tensor size by shape = [%s]", GeShape(shape).ToString().c_str()); | |||||
| uint32_t type_size; | |||||
| if (!TypeUtils::GetDataTypeLength(data_type, type_size)) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to get data type size"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| tensor_size = type_size; | |||||
| for (const auto &dim : shape) { | |||||
| GE_CHECK_GE(dim, 0); | |||||
| GE_CHK_STATUS_RET(Int64MulCheckOverflow(tensor_size, dim), | |||||
| "Shape size overflow, shape = [%s]", | |||||
| GeShape(shape).ToString().c_str()); | |||||
| tensor_size *= dim; | |||||
| } | |||||
| GE_CHK_STATUS_RET(CheckInt64AddOverflow(tensor_size, kAlignment - 1), | |||||
| "Tensor size is too large: %ld, shape = [%s]", | |||||
| tensor_size, | |||||
| GeShape(shape).ToString().c_str()); | |||||
| tensor_size = (tensor_size + kAlignment - 1) / kAlignment * kAlignment; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeInferenceEngine::CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range) { | |||||
| auto op_desc = node_item.GetOpDesc(); | |||||
| for (size_t output_index = 0; output_index < op_desc->GetOutputsSize(); ++output_index) { | |||||
| auto tensor_desc = op_desc->MutableOutputDesc(output_index); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| const auto &shape = tensor_desc->MutableShape(); | |||||
| // modify on copy | |||||
| auto dims = shape.GetDims(); | |||||
| GE_CHK_STATUS_RET(CanonicalizeShape(*tensor_desc, dims, fallback_with_range), | |||||
| "[%s] Failed to canonicalize shape for output %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index); | |||||
| int64_t tensor_size; | |||||
| GE_CHK_STATUS_RET(CalcTensorSize(tensor_desc->GetDataType(), dims, tensor_size), | |||||
| "[%s] Failed to calc tensor size for output %zu", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index); | |||||
| GELOGD("[%s] Tensor size of output %zu = %ld", node_item.NodeName().c_str(), output_index, tensor_size); | |||||
| (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -34,7 +34,11 @@ class ShapeInferenceEngine { | |||||
| Status PropagateOutputShapes(const NodeItem &node_item); | Status PropagateOutputShapes(const NodeItem &node_item); | ||||
| static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | |||||
| private: | private: | ||||
| static Status CanonicalizeShape(GeTensorDesc &tensor_desc, std::vector<int64_t> &shape, bool fallback_with_range); | |||||
| static Status CalcTensorSize(DataType data_type, const std::vector<int64_t> &shape, int64_t &tensor_size); | |||||
| static Status UpdatePeerNodeShape(const Node &node); | static Status UpdatePeerNodeShape(const Node &node); | ||||
| Status AwaitDependentNodes(NodeState &node_state); | Status AwaitDependentNodes(NodeState &node_state); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include "hybrid/executor/worker/shape_inference_engine.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -47,7 +48,7 @@ Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgr | |||||
| GE_CHECK_NOTNULL(dst_op_desc); | GE_CHECK_NOTNULL(dst_op_desc); | ||||
| auto in_idx = node_and_anchor.second->GetIdx(); | auto in_idx = node_and_anchor.second->GetIdx(); | ||||
| auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx); | auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx); | ||||
| fused_subgraph.input_mapping[parent_index].emplace_back(tensor_desc); | |||||
| fused_subgraph.input_mapping[static_cast<int>(parent_index)].emplace_back(tensor_desc); | |||||
| GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx); | GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx); | ||||
| } | } | ||||
| @@ -64,7 +65,7 @@ Status ParseOutputMapping(const OpDescPtr &op_desc, FusedSubgraph &fused_subgrap | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| fused_subgraph.output_mapping.emplace(parent_index, op_desc); | |||||
| fused_subgraph.output_mapping.emplace(static_cast<int>(parent_index), op_desc); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -126,12 +127,7 @@ Status NodeItem::Create(const NodePtr &node, std::unique_ptr<NodeItem> &node_ite | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status NodeItem::Init() { | |||||
| GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX); | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | |||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | |||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | |||||
| void NodeItem::ResolveOptionalInputs() { | |||||
| if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { | if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) { | ||||
| has_optional_inputs = true; | has_optional_inputs = true; | ||||
| for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | ||||
| @@ -143,7 +139,18 @@ Status NodeItem::Init() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| Status NodeItem::InitInputsAndOutputs() { | |||||
| GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX); | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | |||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | |||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | |||||
| ResolveOptionalInputs(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status NodeItem::ResolveDynamicState() { | |||||
| (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); | (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic); | ||||
| GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic); | GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic); | ||||
| if (!is_dynamic) { | if (!is_dynamic) { | ||||
| @@ -151,38 +158,54 @@ Status NodeItem::Init() { | |||||
| "[%s] Failed to get shape status.", | "[%s] Failed to get shape status.", | ||||
| node->GetName().c_str()); | node->GetName().c_str()); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| if (is_dynamic) { | |||||
| for (int i = 0; i < num_inputs; ++i) { | |||||
| const auto &input_desc = MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| if (input_desc->MutableShape().IsUnknownShape()) { | |||||
| is_input_shape_static_.push_back(false); | |||||
| } else { | |||||
| num_static_input_shapes++; | |||||
| is_input_shape_static_.push_back(true); | |||||
| GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", | |||||
| NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str()); | |||||
| } | |||||
| Status NodeItem::ResolveStaticInputsAndOutputs() { | |||||
| for (int i = 0; i < num_inputs; ++i) { | |||||
| const auto &input_desc = MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| if (input_desc->MutableShape().IsUnknownShape()) { | |||||
| is_input_shape_static_.push_back(false); | |||||
| } else { | |||||
| num_static_input_shapes++; | |||||
| is_input_shape_static_.push_back(true); | |||||
| GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", | |||||
| NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str()); | |||||
| } | } | ||||
| } | |||||
| for (int i = 0; i < num_outputs; ++i) { | |||||
| const auto &output_desc = op_desc->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| if (output_desc->MutableShape().IsUnknownShape()) { | |||||
| is_output_shape_static = false; | |||||
| break; | |||||
| } | |||||
| for (int i = 0; i < num_outputs; ++i) { | |||||
| const auto &output_desc = op_desc->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| if (output_desc->MutableShape().IsUnknownShape()) { | |||||
| is_output_shape_static = false; | |||||
| break; | |||||
| } | } | ||||
| } | |||||
| if (IsControlOp() || node_type == PARTITIONEDCALL) { | |||||
| shape_inference_type = DEPEND_COMPUTE; | |||||
| } else { | |||||
| int32_t unknown_shape_type_val = 0; | |||||
| (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); | |||||
| shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val); | |||||
| } | |||||
| if (is_output_shape_static) { | |||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void NodeItem::ResolveUnknownShapeType() { | |||||
| if (IsControlOp() || node_type == PARTITIONEDCALL) { | |||||
| shape_inference_type = DEPEND_COMPUTE; | |||||
| } else { | |||||
| int32_t unknown_shape_type_val = 0; | |||||
| (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); | |||||
| shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val); | |||||
| } | |||||
| } | |||||
| Status NodeItem::Init() { | |||||
| GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); | |||||
| GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); | |||||
| if (is_dynamic) { | |||||
| ResolveUnknownShapeType(); | |||||
| GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs()); | |||||
| GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | ||||
| } | } | ||||
| @@ -103,6 +103,11 @@ struct NodeItem { | |||||
| private: | private: | ||||
| explicit NodeItem(NodePtr node); | explicit NodeItem(NodePtr node); | ||||
| Status Init(); | Status Init(); | ||||
| Status InitInputsAndOutputs(); | |||||
| void ResolveOptionalInputs(); | |||||
| Status ResolveDynamicState(); | |||||
| Status ResolveStaticInputsAndOutputs(); | |||||
| void ResolveUnknownShapeType(); | |||||
| std::vector<bool> is_input_shape_static_; | std::vector<bool> is_input_shape_static_; | ||||
| std::vector<uint32_t> input_desc_indices_; | std::vector<uint32_t> input_desc_indices_; | ||||
| @@ -148,6 +148,10 @@ Status TaskContext::AllocateWorkspaces() { | |||||
| } | } | ||||
| Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) const { | Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) const { | ||||
| if (callback_fun == nullptr) { | |||||
| GELOGW("[%s] Callback is NULL", GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); | auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | ||||
| @@ -384,6 +388,20 @@ const char *TaskContext::GetNodeName() const { | |||||
| return node_item_->NodeName().c_str(); | return node_item_->NodeName().c_str(); | ||||
| } | } | ||||
| void TaskContext::ReleaseInputsAndOutputs() { | |||||
| for (int i = 0; i < node_item_->num_inputs; ++i) { | |||||
| auto tensor = inputs_start_ + i; | |||||
| tensor->Destroy(); | |||||
| GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), i); | |||||
| } | |||||
| for (int i = 0; i < node_item_->num_outputs; ++i) { | |||||
| auto tensor = outputs_start_ + i; | |||||
| tensor->Destroy(); | |||||
| GELOGD("[%s] Tensor of output[%d] released", GetNodeName(), i); | |||||
| } | |||||
| } | |||||
| void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
| auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
| if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
| @@ -456,5 +474,9 @@ Status TaskContext::TryExecuteCallback(const function<void()> &callback_fun) con | |||||
| const DumpProperties &TaskContext::GetDumpProperties() const { | const DumpProperties &TaskContext::GetDumpProperties() const { | ||||
| return execution_context_->dump_properties; | return execution_context_->dump_properties; | ||||
| } | } | ||||
| bool TaskContext::NeedCallback() { | |||||
| return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -50,6 +50,8 @@ class TaskContext { | |||||
| ConstGeTensorDescPtr GetOutputDesc(int index) const; | ConstGeTensorDescPtr GetOutputDesc(int index) const; | ||||
| GeTensorDescPtr MutableInputDesc(int index) const; | GeTensorDescPtr MutableInputDesc(int index) const; | ||||
| GeTensorDescPtr MutableOutputDesc(int index) const; | GeTensorDescPtr MutableOutputDesc(int index) const; | ||||
| void ReleaseInputsAndOutputs(); | |||||
| bool NeedCallback(); | |||||
| void ReleaseInput(int index); | void ReleaseInput(int index); | ||||
| const TensorValue *GetInput(int index) const; | const TensorValue *GetInput(int index) const; | ||||
| const TensorValue *GetOutput(int index) const; | const TensorValue *GetOutput(int index) const; | ||||