diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 3174df80..93458cfe 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -66,7 +66,7 @@ Status AiCoreNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &nod } AiCoreTaskBuilder builder(node->GetOpDesc(), *task_defs); - std::unique_ptr node_task; + std::unique_ptr node_task; GE_CHK_STATUS_RET(builder.BuildTask(node_task, true, is_single_op), "[%s] Failed to build op tasks.", node->GetName().c_str()); task = std::move(node_task); @@ -99,7 +99,7 @@ Status AiCoreNodeExecutor::GenNodeKey(const NodePtr &node, std::string &node_key return SUCCESS; } -bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::shared_ptr task) { +bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::shared_ptr &task) { GE_CHECK_NOTNULL(task); std::lock_guard lock(mutex_); auto iter = reg_node_tasks_.find(node_key); @@ -111,7 +111,7 @@ bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::sha return ret.second; } -std::shared_ptr AiCoreNodeTaskRegistry::GetTask(const std::string &node_key) { +std::shared_ptr AiCoreNodeTaskRegistry::GetTask(const std::string &node_key) { std::lock_guard lock(mutex_); auto iter = reg_node_tasks_.find(node_key); return (iter != reg_node_tasks_.end()) ? iter->second : nullptr; @@ -140,9 +140,12 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, auto node_key = std::to_string(model.GetModelId()) + "/" + shape_key; GELOGD("NodeKey for %s = %s", node->GetName().c_str(), node_key.c_str()); - task = registry.GetTask(node_key); + auto aicore_task = registry.GetTask(node_key); if (task != nullptr) { + // The workspaces needed by a operator may differ with different shapes + op_desc->SetWorkspaceBytes(aicore_task->GetWorkspaceSizes()); GELOGI("AiCoreNodeExecutor(%s) CompileTask Skip.", node->GetName().c_str()); + task = std::move(aicore_task); return SUCCESS; } @@ -153,16 +156,18 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, GELOGD("successfully generated task_defs: %s", node->GetName().c_str()); AiCoreTaskBuilder builder(node->GetOpDesc(), task_defs); - std::unique_ptr node_task; + std::unique_ptr node_task; GE_CHK_STATUS_RET(builder.BuildTask(node_task, false), "[%s] Failed to build op tasks.", node->GetName().c_str()); - task = std::move(node_task); + node_task->SetWorkspaceSizes(op_desc->GetWorkspaceBytes()); + aicore_task = std::move(node_task); GELOGD("successfully created node task: %s", node->GetName().c_str()); - if (!registry.AddTask(node_key, task)) { + if (!registry.AddTask(node_key, aicore_task)) { GELOGE(INTERNAL_ERROR, "Add NodeTask failed, op name = %s.", node->GetName().c_str()); return INTERNAL_ERROR; } + task = std::move(aicore_task); GELOGI("AiCoreNodeExecutor(%s) CompileTask End.", node->GetName().c_str()); return SUCCESS; } @@ -247,6 +252,14 @@ bool AiCoreNodeTask::IsSupportDynamicShape() { return true; } +const vector &AiCoreNodeTask::GetWorkspaceSizes() const { + return workspace_sizes_; +} + +void AiCoreNodeTask::SetWorkspaceSizes(const vector &workspace_sizes) { + workspace_sizes_ = workspace_sizes; +} + TaskCompilerFactory &TaskCompilerFactory::GetInstance() { static TaskCompilerFactory instance; return instance; diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/ge/hybrid/node_executor/aicore/aicore_node_executor.h index f036ce85..2095b41d 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.h +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.h @@ -24,7 +24,6 @@ namespace ge { namespace hybrid { - class TaskCompiler { public: TaskCompiler() = default; @@ -42,11 +41,11 @@ class AiCoreNodeTaskRegistry { return instance; } - std::shared_ptr GetTask(const std::string &node_key); - bool AddTask(const std::string &node_key, const std::shared_ptr task); + std::shared_ptr GetTask(const std::string &node_key); + bool AddTask(const std::string &node_key, const std::shared_ptr &task); private: AiCoreNodeTaskRegistry() = default; - std::map> reg_node_tasks_; + std::map> reg_node_tasks_; std::mutex mutex_; }; @@ -59,8 +58,12 @@ class AiCoreNodeTask : public NodeTask { Status UpdateArgs(TaskContext &context) override; Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + const vector &GetWorkspaceSizes() const; + void SetWorkspaceSizes(const vector &workspace_sizes); private: std::vector> tasks_; + std::vector workspace_sizes_; }; class AiCoreNodeExecutor : public NodeExecutor { diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/ge/hybrid/node_executor/aicore/aicore_task_builder.cc index c3db378b..966e0910 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_builder.cc +++ b/ge/hybrid/node_executor/aicore/aicore_task_builder.cc @@ -37,7 +37,7 @@ AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector : op_desc_(op_desc), task_defs_(task_defs) { } -Status AiCoreTaskBuilder::BuildTask(std::unique_ptr &node_task, +Status AiCoreTaskBuilder::BuildTask(std::unique_ptr &node_task, bool ignore_failure_on_atomic, bool is_single_op) { GE_CHECK_NOTNULL(op_desc_); diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/ge/hybrid/node_executor/aicore/aicore_task_builder.h index 8f95df15..6a472a21 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_builder.h +++ b/ge/hybrid/node_executor/aicore/aicore_task_builder.h @@ -27,6 +27,7 @@ namespace ge { namespace hybrid { +class AiCoreNodeTask; class AiCoreKernelRegistry { public: ~AiCoreKernelRegistry() = default; @@ -47,7 +48,9 @@ class AiCoreTaskBuilder { AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector &task_defs); ~AiCoreTaskBuilder() = default; - Status BuildTask(std::unique_ptr &node_task, bool ignore_failure_on_atomic, bool is_single_op = false); + Status BuildTask(std::unique_ptr &node_task, + bool ignore_failure_on_atomic, + bool is_single_op = false); private: bool ExpectAtomicAddrCleanTask(); diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc index 26a41737..069c8699 100755 --- a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc +++ b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc @@ -61,11 +61,11 @@ Status AiCoreTaskCompiler::CompileOp(const NodePtr &node, std::vectorGetName().c_str()); - + auto op_desc = node->GetOpDesc(); + op_desc->SetWorkspaceBytes({}); GE_CHK_STATUS_RET_NOLOG(DoCompileOp(node)); GELOGD("successfully compiled op: %s", node->GetName().c_str()); - auto op_desc = node->GetOpDesc(); std::vector input_offsets(op_desc->GetInputsSize(), kMemBase); std::vector output_offsets(op_desc->GetOutputsSize(), kMemBase); op_desc->SetInputOffset(input_offsets);