|
|
@@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() { |
|
|
|
if (context_.rt_gen_context != nullptr) { |
|
|
|
(void) rtCtxDestroy(context_.rt_gen_context); |
|
|
|
} |
|
|
|
if (context_.global_step != nullptr) { |
|
|
|
(void) rtFree(context_.global_step); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status HybridModelExecutor::Init() { |
|
|
@@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { |
|
|
|
auto root_graph_item = model_->GetRootGraphItem(); |
|
|
|
GE_CHECK_NOTNULL(root_graph_item); |
|
|
|
|
|
|
|
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, |
|
|
|
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); |
|
|
|
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); |
|
|
|
auto ret = ExecuteGraphInternal(executor, args); |
|
|
|
Cleanup(); |
|
|
@@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() { |
|
|
|
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); |
|
|
|
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); |
|
|
|
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); |
|
|
|
GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM)); |
|
|
|
|
|
|
|
context_.stream = stream_; |
|
|
|
context_.model = model_; |
|
|
|