Browse Source

!1268 Bugfix: online inference get const need from context first

From: @hugo1
Reviewed-by: @xchu42
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
9edd57eef2
7 changed files with 11 additions and 8 deletions
  1. +3
    -2
      ge/host_kernels/gather_v2_kernel.cc
  2. +2
    -0
      ge/hybrid/executor/hybrid_model_async_executor.cc
  3. +2
    -2
      ge/hybrid/executor/subgraph_executor.cc
  4. +0
    -2
      ge/hybrid/node_executor/aicore/aicore_op_task.cc
  5. +1
    -1
      metadef
  6. +2
    -1
      tests/ut/common/graph/CMakeLists.txt
  7. +1
    -0
      tests/ut/ge/CMakeLists.txt

+ 3
- 2
ge/host_kernels/gather_v2_kernel.cc View File

@@ -373,7 +373,7 @@ void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeSh

Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
vector<GeTensorPtr> &v_output) {
GELOGI("Enter GatherV2Kernel Process");
GELOGI("Enter GatherV2Kernel Process.");
Status ret = Check(op_desc_ptr, input, v_output);
if (ret != SUCCESS) {
GELOGW("param check failed");
@@ -407,7 +407,8 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
// check input data type
auto x_data_type = tensor0->GetTensorDesc().GetDataType();
if (supported_type.find(x_data_type) == supported_type.end()) {
GELOGI("GatherV2Kernel does not support this Data type:%s.", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
GELOGI("GatherV2Kernel does not support this Data type:%s.",
TypeUtils::DataTypeToSerialString(x_data_type).c_str());
return NOT_CHANGED;
}
// calc output shape


+ 2
- 0
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -67,6 +67,7 @@ Status HybridModelAsyncExecutor::Start(const std::shared_ptr<ModelListener> &lis
future_ = std::async(std::launch::async, [&]() -> Status {
GetThreadLocalContext() = *executor_->GetContext()->ge_context;
GetContext().SetSessionId(executor_->GetContext()->session_id);
GetContext().SetContextId(executor_->GetContext()->context_id);
return RunInternal();
});

@@ -166,6 +167,7 @@ Status HybridModelAsyncExecutor::RunInternal() {
} else {
GELOGI("HybridModel will execute in singleline mode");
ge::GetContext().SetSessionId(executor_->GetContext()->session_id);
ge::GetContext().SetContextId(executor_->GetContext()->context_id);
ret = executor_->Execute(args);
}
ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput());


+ 2
- 2
ge/hybrid/executor/subgraph_executor.cc View File

@@ -227,6 +227,7 @@ Status SubgraphExecutor::PrepareNodes(int group) {
if (node_item.is_dynamic) {
auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status {
GetContext().SetSessionId(context_->session_id);
GetContext().SetContextId(context_->context_id);
GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state));
return PrepareForExecution(context_, *p_node_state);
});
@@ -273,10 +274,8 @@ Status SubgraphExecutor::PrepareNodes(int group) {
}

Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) const {
GetContext().SetSessionId(context_->context_id);
HYBRID_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state),
"[%s] Failed to InferShape.", node_state.GetName().c_str());
GetContext().SetSessionId(context_->session_id);
HYBRID_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_state),
"[%s] Failed to PropagateOutputShapes.", node_state.GetName().c_str());
return SUCCESS;
@@ -345,6 +344,7 @@ Status SubgraphExecutor::ScheduleTasks(int group) {
GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str());
auto prepare_future = std::async(std::launch::async, [&]() -> Status {
GetContext().SetSessionId(context_->session_id);
GetContext().SetContextId(context_->context_id);
auto ret = PrepareNodes(group);
ready_queue_.Push(nullptr);
return ret;


+ 0
- 2
ge/hybrid/node_executor/aicore/aicore_op_task.cc View File

@@ -307,11 +307,9 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) {

auto execution_context = context.GetExecutionContext();

GetContext().SetSessionId(execution_context->context_id);
RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start");
GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info));
RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End");
GetContext().SetSessionId(execution_context->session_id);

// update op args by tiling info
block_dim_ = static_cast<uint32_t>(tiling_info.block_dim);


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit eef990b3d8669065a969dfa6b1097eac09d601d4
Subproject commit 3a4c3b746cffcb2e1e5cc1c8a7559a07da3dd84e

+ 2
- 1
tests/ut/common/graph/CMakeLists.txt View File

@@ -96,11 +96,12 @@ set(SRC_FILES
"${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc"
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp"
"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc"
"${GE_CODE_DIR}/metadef/graph/utils/transformer_utils.cc"
"${GE_CODE_DIR}/metadef/graph/runtime_inference_context.cc"
"${GE_CODE_DIR}/metadef/graph/ref_relation.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/axis_util.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/expand_dimension.cc"
"${GE_CODE_DIR}/metadef/graph/utils/transformer_utils.cc"
)

#add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS})


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -90,6 +90,7 @@ set(GRAPH_SRC_FILES
"${GE_CODE_DIR}/metadef/graph/op_desc.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/axis_util.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/expand_dimension.cc"
"${GE_CODE_DIR}/metadef/graph/operator.cc"
"${GE_CODE_DIR}/metadef/graph/operator_factory.cc"
"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc"


Loading…
Cancel
Save