Browse Source

!613 Decouple ops kernel builder, CalcOpRunningParam by exeuctor

From: @xchu42
Reviewed-by: @ji_chen,@wqtshg
Signed-off-by: @ji_chen
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
e827ba733b
10 changed files with 235 additions and 87 deletions
  1. +34
    -22
      ge/hybrid/executor/node_state.cc
  2. +2
    -1
      ge/hybrid/executor/node_state.h
  3. +1
    -8
      ge/hybrid/executor/subgraph_executor.cc
  4. +5
    -4
      ge/hybrid/executor/worker/execution_engine.cc
  5. +103
    -18
      ge/hybrid/executor/worker/shape_inference_engine.cc
  6. +4
    -0
      ge/hybrid/executor/worker/shape_inference_engine.h
  7. +57
    -34
      ge/hybrid/model/node_item.cc
  8. +5
    -0
      ge/hybrid/model/node_item.h
  9. +22
    -0
      ge/hybrid/node_executor/task_context.cc
  10. +2
    -0
      ge/hybrid/node_executor/task_context.h

+ 34
- 22
ge/hybrid/executor/node_state.cc View File

@@ -18,6 +18,7 @@
#include <chrono>
#include "framework/common/debug/log.h"
#include "graph/compute_graph.h"
#include "graph/utils/tensor_utils.h"
#include "hybrid_execution_context.h"
#include "subgraph_context.h"

@@ -35,29 +36,31 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(
this->num_pending_shapes_);
}

Status ShapeInferenceState::UpdateInputShape(int idx,
const GeShape &ori_shape,
const GeShape &shape) {
Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) {
if (node_item.IsInputShapeStatic(idx)) {
GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]",
node_item.NodeName().c_str(),
idx,
node_item.MutableInputDesc(idx)->GetShape().ToString().c_str(),
shape.ToString().c_str());
target.GetShape().ToString().c_str());
return SUCCESS;
}

GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s]",
int64_t tensor_size = -1;
(void) TensorUtils::GetSize(target, tensor_size);
GELOGD("[%s] Update input shape [%d] with Shape: [%s] and OriginalShape: [%s], size = %ld",
node_item.NodeName().c_str(),
idx,
shape.ToString().c_str(),
ori_shape.ToString().c_str());
target.GetShape().ToString().c_str(),
target.GetOriginShape().ToString().c_str(),
tensor_size);

std::lock_guard<std::mutex> lk(mu_);
auto tensor_desc = node_item.MutableInputDesc(idx);
GE_CHECK_NOTNULL(tensor_desc);
tensor_desc->SetShape(shape);
tensor_desc->SetOriginShape(ori_shape);
tensor_desc->SetShape(target.GetShape());
tensor_desc->SetOriginShape(target.GetOriginShape());
(void) TensorUtils::SetSize(*tensor_desc, tensor_size);
if (--num_pending_shapes_ == 0) {
ready_cv_.notify_all();
}
@@ -110,24 +113,24 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex
for (auto &p : shape_futures) {
auto idx = p.first;
auto &future = p.second;
GeShape shape;
GeShape ori_shape;
RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx);
GE_CHK_STATUS_RET(future.Get(ori_shape, shape),
"[%s] Get shape failed. index = %u",
node_item.NodeName().c_str(),
idx);
auto src_tensor_desc = future.GetTensorDesc();
GE_CHECK_NOTNULL(src_tensor_desc);
RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx);

GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]",
node_item.NodeName().c_str(),
idx,
shape.ToString().c_str(),
ori_shape.ToString().c_str());
auto input_desc = node_item.MutableInputDesc(idx);
GE_CHECK_NOTNULL(input_desc);
input_desc->SetShape(std::move(shape));
input_desc->SetOriginShape(ori_shape);
int64_t tensor_size = -1;
(void) TensorUtils::GetSize(*src_tensor_desc, tensor_size);
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu",
node_item.NodeName().c_str(),
idx,
src_tensor_desc->GetShape().ToString().c_str(),
src_tensor_desc->GetOriginShape().ToString().c_str(),
tensor_size);
input_desc->SetShape(src_tensor_desc->GetShape());
input_desc->SetOriginShape(src_tensor_desc->GetOriginShape());
(void) TensorUtils::SetSize(*input_desc, tensor_size);
}

return SUCCESS;
@@ -190,5 +193,14 @@ Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) {
GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str());
return SUCCESS;
}

GeTensorDescPtr ShapeFuture::GetTensorDesc() {
GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str());
if (!subgraph_context_->Await(src_node_)) {
GELOGE(INTERNAL_ERROR, "cancelled");
return nullptr;
}
return src_node_->GetOpDesc()->MutableOutputDesc(src_index_);
}
} // namespace hybrid
} // namespace ge

+ 2
- 1
ge/hybrid/executor/node_state.h View File

@@ -35,6 +35,7 @@ class ShapeFuture {
ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context);
~ShapeFuture() = default;
Status Get(GeShape &ori_shape, GeShape &shape);
GeTensorDescPtr GetTensorDesc();

private:
NodePtr src_node_;
@@ -45,7 +46,7 @@ class ShapeFuture {
struct ShapeInferenceState {
explicit ShapeInferenceState(const NodeItem &node_item);

Status UpdateInputShape(int idx, const GeShape &ori_shape, const GeShape &shape);
Status UpdateInputShape(int idx, const GeTensorDesc &tensor_desc);

void UpdateInputShapeFuture(int idx, ShapeFuture &&future);



+ 1
- 8
ge/hybrid/executor/subgraph_executor.cc View File

@@ -96,7 +96,7 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue
GE_CHECK_NOTNULL(tensor_desc);
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node);
GE_CHECK_NOTNULL(node_state);
node_state->GetShapeInferenceState().UpdateInputShape(0, tensor_desc->GetOriginShape(), tensor_desc->GetShape());
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc);
}
}

@@ -268,13 +268,6 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else {
node_state.SetKernelTask(node_item.kernel_task);
}

GELOGD("[%s] Start to invoke CalcOpRunningParam.", node_item.NodeName().c_str());
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start");
GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node),
"[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str());
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] End");
GELOGD("[%s] Done invoking CalcOpRunningParam successfully.", node_item.NodeName().c_str());
return SUCCESS;
}



+ 5
- 4
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -20,12 +20,9 @@
#include "graph/utils/tensor_adapter.h"
#include "graph/debug/ge_attr_define.h"
#include "hybrid/node_executor/node_executor.h"
#include "common/dump/dump_manager.h"
#include "hybrid/executor//worker//shape_inference_engine.h"
#include "common/dump/dump_op.h"
#include "common/types.h"
#include "common/ge_types.h"
#include "common/profiling/profiling_manager.h"
#include "runtime/base.h"

namespace ge {
namespace hybrid {
@@ -348,6 +345,10 @@ Status NodeDoneCallback::OnNodeDone() {
}

GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item));
if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) {
// update output tensor sizes
GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item));
}
// PropagateOutputs for type == DEPEND_COMPUTE
if (node_item.shape_inference_type == DEPEND_COMPUTE) {
if (graph_context_->trace_enabled) {


+ 103
- 18
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -17,9 +17,15 @@
#include "hybrid/executor/worker/shape_inference_engine.h"
#include "graph/shape_refiner.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "common/math/math_util.h"
#include "hybrid/node_executor/node_executor.h"

namespace ge {
namespace {
const int kAlignment = 32;
}
namespace hybrid {
ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context)
: execution_context_(execution_context),
@@ -40,7 +46,9 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
}

if (node_item.fused_subgraph != nullptr) {
return InferShapeForSubgraph(node_item, *node_item.fused_subgraph);
GE_CHK_STATUS_RET_NOLOG(InferShapeForSubgraph(node_item, *node_item.fused_subgraph));
GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item));
return SUCCESS;
}

// Skip shape inference for node of type DEPEND_COMPUTE
@@ -63,21 +71,15 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
std::lock_guard<std::mutex> lk(mu_);
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start");
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true),
"Invoke InferShapeAndType failed.");
"Invoke InferShapeAndType failed.");
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End");
}
// Check again to make sure shape is valid after shape inference
if (node_item.shape_inference_type != DEPEND_SHAPE_RANGE) {
bool is_unknown_shape = false;
GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node_item.node, is_unknown_shape),
"Failed to get shape status. node = %s",
node_item.NodeName().c_str());

GE_CHK_BOOL_RET_STATUS(!is_unknown_shape,
INTERNAL_ERROR,
"[%s] Shape is still unknown after shape inference.",
node_item.NodeName().c_str());
}
// update output tensor sizes after shape inference
// error if shape is still unknown and not of type DEPEND_SHAPE_RANGE
RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start");
GE_CHK_STATUS_RET_NOLOG(CalcOutputTensorSizes(node_item, node_item.shape_inference_type == DEPEND_SHAPE_RANGE));
RECORD_COMPILE_EVENT(execution_context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End");

GELOGD("[%s] [HybridTrace] After shape inference. Node = %s",
node_item.NodeName().c_str(),
@@ -127,8 +129,6 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) {
// propagate each output
for (int i = 0; i < node_item.num_outputs; ++i) {
auto output_desc = node_item.op_desc->MutableOutputDesc(i);
const auto &shape = output_desc->MutableShape();
const auto &ori_shape = output_desc->GetOriginShape();
auto &output_nodes = node_item.outputs[i];

// propagate output to all sub-inputs
@@ -149,9 +149,7 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) {
infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first,
std::move(future));
} else {
GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first,
ori_shape,
shape));
GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc));
}
}
}
@@ -230,5 +228,92 @@ Status ShapeInferenceEngine::UpdatePeerNodeShape(const Node &node) {
}
return SUCCESS;
}

Status ShapeInferenceEngine::CanonicalizeShape(GeTensorDesc &tensor_desc,
std::vector<int64_t> &shape,
bool fallback_with_range) {
const auto &tensor_shape = tensor_desc.MutableShape();
if (tensor_shape.IsUnknownShape()) {
if (!fallback_with_range) {
GELOGE(INTERNAL_ERROR, "Output shape is still unknown after shape inference. shape = [%s]",
tensor_shape.ToString().c_str());
return INTERNAL_ERROR;
}

GELOGD("Calc output size by range");
std::vector<std::pair<int64_t, int64_t>> shape_range;
GE_CHK_GRAPH_STATUS_RET(tensor_desc.GetShapeRange(shape_range), "Failed to get shape range");
if (shape_range.size() != shape.size()) {
GELOGE(INTERNAL_ERROR, "Number of shape ranges (%zu) mismatches that of dims (%zu)",
shape_range.size(),
shape.size());
return INTERNAL_ERROR;
}

for (size_t dim_index = 0; dim_index < shape.size(); ++dim_index) {
if (shape[dim_index] == ge::UNKNOWN_DIM) {
shape[dim_index] = shape_range[dim_index].second;
}
}

GELOGD("After canonicalization, shape = [%s], before = [%s]",
GeShape(shape).ToString().c_str(),
tensor_shape.ToString().c_str());
}

return SUCCESS;
}

Status ShapeInferenceEngine::CalcTensorSize(DataType data_type,
const std::vector<int64_t> &shape,
int64_t &tensor_size) {
GELOGD("To calc tensor size by shape = [%s]", GeShape(shape).ToString().c_str());
uint32_t type_size;
if (!TypeUtils::GetDataTypeLength(data_type, type_size)) {
GELOGE(INTERNAL_ERROR, "Failed to get data type size");
return INTERNAL_ERROR;
}

tensor_size = type_size;
for (const auto &dim : shape) {
GE_CHECK_GE(dim, 0);
GE_CHK_STATUS_RET(Int64MulCheckOverflow(tensor_size, dim),
"Shape size overflow, shape = [%s]",
GeShape(shape).ToString().c_str());
tensor_size *= dim;
}

GE_CHK_STATUS_RET(CheckInt64AddOverflow(tensor_size, kAlignment - 1),
"Tensor size is too large: %ld, shape = [%s]",
tensor_size,
GeShape(shape).ToString().c_str());
tensor_size = (tensor_size + kAlignment - 1) / kAlignment * kAlignment;
return SUCCESS;
}

Status ShapeInferenceEngine::CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range) {
auto op_desc = node_item.GetOpDesc();
for (size_t output_index = 0; output_index < op_desc->GetOutputsSize(); ++output_index) {
auto tensor_desc = op_desc->MutableOutputDesc(output_index);
GE_CHECK_NOTNULL(tensor_desc);
const auto &shape = tensor_desc->MutableShape();
// modify on copy
auto dims = shape.GetDims();
GE_CHK_STATUS_RET(CanonicalizeShape(*tensor_desc, dims, fallback_with_range),
"[%s] Failed to canonicalize shape for output %zu",
node_item.NodeName().c_str(),
output_index);

int64_t tensor_size;
GE_CHK_STATUS_RET(CalcTensorSize(tensor_desc->GetDataType(), dims, tensor_size),
"[%s] Failed to calc tensor size for output %zu",
node_item.NodeName().c_str(),
output_index);
GELOGD("[%s] Tensor size of output %zu = %ld", node_item.NodeName().c_str(), output_index, tensor_size);
(void) TensorUtils::SetSize(*tensor_desc, tensor_size);
}

return SUCCESS;
}
} // namespace hybrid
} // namespace ge

+ 4
- 0
ge/hybrid/executor/worker/shape_inference_engine.h View File

@@ -34,7 +34,11 @@ class ShapeInferenceEngine {

Status PropagateOutputShapes(const NodeItem &node_item);

static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false);

private:
static Status CanonicalizeShape(GeTensorDesc &tensor_desc, std::vector<int64_t> &shape, bool fallback_with_range);
static Status CalcTensorSize(DataType data_type, const std::vector<int64_t> &shape, int64_t &tensor_size);
static Status UpdatePeerNodeShape(const Node &node);
Status AwaitDependentNodes(NodeState &node_state);



+ 57
- 34
ge/hybrid/model/node_item.cc View File

@@ -22,6 +22,7 @@
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/node_utils.h"
#include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor/worker/shape_inference_engine.h"

namespace ge {
namespace hybrid {
@@ -47,7 +48,7 @@ Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgr
GE_CHECK_NOTNULL(dst_op_desc);
auto in_idx = node_and_anchor.second->GetIdx();
auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx);
fused_subgraph.input_mapping[parent_index].emplace_back(tensor_desc);
fused_subgraph.input_mapping[static_cast<int>(parent_index)].emplace_back(tensor_desc);
GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx);
}

@@ -64,7 +65,7 @@ Status ParseOutputMapping(const OpDescPtr &op_desc, FusedSubgraph &fused_subgrap
return FAILED;
}

fused_subgraph.output_mapping.emplace(parent_index, op_desc);
fused_subgraph.output_mapping.emplace(static_cast<int>(parent_index), op_desc);
return SUCCESS;
}

@@ -126,12 +127,7 @@ Status NodeItem::Create(const NodePtr &node, std::unique_ptr<NodeItem> &node_ite
return SUCCESS;
}

Status NodeItem::Init() {
GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX);
GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX);
num_inputs = static_cast<int>(op_desc->GetInputsSize());
num_outputs = static_cast<int>(op_desc->GetOutputsSize());

void NodeItem::ResolveOptionalInputs() {
if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) {
has_optional_inputs = true;
for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) {
@@ -143,7 +139,18 @@ Status NodeItem::Init() {
}
}
}
}

Status NodeItem::InitInputsAndOutputs() {
GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX);
GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX);
num_inputs = static_cast<int>(op_desc->GetInputsSize());
num_outputs = static_cast<int>(op_desc->GetOutputsSize());
ResolveOptionalInputs();
return SUCCESS;
}

Status NodeItem::ResolveDynamicState() {
(void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic);
GELOGD("node name = %s, is_dynamic = %d.", this->node_name.c_str(), is_dynamic);
if (!is_dynamic) {
@@ -151,38 +158,54 @@ Status NodeItem::Init() {
"[%s] Failed to get shape status.",
node->GetName().c_str());
}
return SUCCESS;
}

if (is_dynamic) {
for (int i = 0; i < num_inputs; ++i) {
const auto &input_desc = MutableInputDesc(i);
GE_CHECK_NOTNULL(input_desc);
if (input_desc->MutableShape().IsUnknownShape()) {
is_input_shape_static_.push_back(false);
} else {
num_static_input_shapes++;
is_input_shape_static_.push_back(true);
GELOGD("[%s] The shape of input[%d] is static. shape = [%s]",
NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str());
}
Status NodeItem::ResolveStaticInputsAndOutputs() {
for (int i = 0; i < num_inputs; ++i) {
const auto &input_desc = MutableInputDesc(i);
GE_CHECK_NOTNULL(input_desc);
if (input_desc->MutableShape().IsUnknownShape()) {
is_input_shape_static_.push_back(false);
} else {
num_static_input_shapes++;
is_input_shape_static_.push_back(true);
GELOGD("[%s] The shape of input[%d] is static. shape = [%s]",
NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str());
}
}

for (int i = 0; i < num_outputs; ++i) {
const auto &output_desc = op_desc->MutableOutputDesc(i);
GE_CHECK_NOTNULL(output_desc);
if (output_desc->MutableShape().IsUnknownShape()) {
is_output_shape_static = false;
break;
}
for (int i = 0; i < num_outputs; ++i) {
const auto &output_desc = op_desc->MutableOutputDesc(i);
GE_CHECK_NOTNULL(output_desc);
if (output_desc->MutableShape().IsUnknownShape()) {
is_output_shape_static = false;
break;
}
}

if (IsControlOp() || node_type == PARTITIONEDCALL) {
shape_inference_type = DEPEND_COMPUTE;
} else {
int32_t unknown_shape_type_val = 0;
(void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val);
shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val);
}
if (is_output_shape_static) {
GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this));
}
return SUCCESS;
}

void NodeItem::ResolveUnknownShapeType() {
if (IsControlOp() || node_type == PARTITIONEDCALL) {
shape_inference_type = DEPEND_COMPUTE;
} else {
int32_t unknown_shape_type_val = 0;
(void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val);
shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val);
}
}

Status NodeItem::Init() {
GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs());
GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState());
if (is_dynamic) {
ResolveUnknownShapeType();
GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs());
GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str());
}



+ 5
- 0
ge/hybrid/model/node_item.h View File

@@ -103,6 +103,11 @@ struct NodeItem {
private:
explicit NodeItem(NodePtr node);
Status Init();
Status InitInputsAndOutputs();
void ResolveOptionalInputs();
Status ResolveDynamicState();
Status ResolveStaticInputsAndOutputs();
void ResolveUnknownShapeType();

std::vector<bool> is_input_shape_static_;
std::vector<uint32_t> input_desc_indices_;


+ 22
- 0
ge/hybrid/node_executor/task_context.cc View File

@@ -148,6 +148,10 @@ Status TaskContext::AllocateWorkspaces() {
}

Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) const {
if (callback_fun == nullptr) {
GELOGW("[%s] Callback is NULL", GetNodeName());
return SUCCESS;
}
auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun);
if (ret != SUCCESS) {
GELOGE(ret, "[%s] Failed to register callback", GetNodeName());
@@ -384,6 +388,20 @@ const char *TaskContext::GetNodeName() const {
return node_item_->NodeName().c_str();
}

void TaskContext::ReleaseInputsAndOutputs() {
for (int i = 0; i < node_item_->num_inputs; ++i) {
auto tensor = inputs_start_ + i;
tensor->Destroy();
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), i);
}

for (int i = 0; i < node_item_->num_outputs; ++i) {
auto tensor = outputs_start_ + i;
tensor->Destroy();
GELOGD("[%s] Tensor of output[%d] released", GetNodeName(), i);
}
}

void TaskContext::ReleaseInput(int index) {
auto input_tensor = MutableInput(index);
if (input_tensor != nullptr) {
@@ -456,5 +474,9 @@ Status TaskContext::TryExecuteCallback(const function<void()> &callback_fun) con
const DumpProperties &TaskContext::GetDumpProperties() const {
return execution_context_->dump_properties;
}

bool TaskContext::NeedCallback() {
return node_item_->has_observer || IsDumpEnabled() || execution_context_->profiling_level > 0;
}
} // namespace hybrid
} // namespace ge

+ 2
- 0
ge/hybrid/node_executor/task_context.h View File

@@ -50,6 +50,8 @@ class TaskContext {
ConstGeTensorDescPtr GetOutputDesc(int index) const;
GeTensorDescPtr MutableInputDesc(int index) const;
GeTensorDescPtr MutableOutputDesc(int index) const;
void ReleaseInputsAndOutputs();
bool NeedCallback();
void ReleaseInput(int index);
const TensorValue *GetInput(int index) const;
const TensorValue *GetOutput(int index) const;


Loading…
Cancel
Save