From: @lichun30 Reviewed-by: @xchu42,@wqtshg Signed-off-by: @ji_chentags/v1.1.0
| @@ -17,8 +17,6 @@ | |||||
| #include "aicore_node_executor.h" | #include "aicore_node_executor.h" | ||||
| #include "cce/taskdown_common.hpp" | #include "cce/taskdown_common.hpp" | ||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "init/gelib.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -28,19 +26,10 @@ AiCoreNodeTask::AiCoreNodeTask(std::vector<std::unique_ptr<AiCoreOpTask>> &&task | |||||
| } | } | ||||
| Status AiCoreNodeExecutor::Initialize() { | Status AiCoreNodeExecutor::Initialize() { | ||||
| auto ge_lib = GELib::GetInstance(); | |||||
| GE_CHECK_NOTNULL(ge_lib); | |||||
| if (!ge_lib->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); | |||||
| return GE_CLI_GE_NOT_INITIALIZED; | |||||
| compiler_ = TaskCompilerFactory::GetInstance().GetTaskCompiler(); | |||||
| if (compiler_ != nullptr) { | |||||
| GE_CHK_STATUS_RET(compiler_->Initialize(), "Failed to init aicore task compiler."); | |||||
| } | } | ||||
| auto &kernel_manager = ge_lib->OpsKernelManagerObj(); | |||||
| auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); | |||||
| GE_CHECK_NOTNULL(aic_ops_store); | |||||
| compiler_.reset(new(std::nothrow)AiCoreTaskCompiler(aic_ops_store)); | |||||
| GE_CHECK_NOTNULL(compiler_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -120,6 +109,12 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, | |||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); | GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); | ||||
| auto ori_node_name = node->GetName(); | |||||
| if (compiler_ == nullptr) { | |||||
| GELOGE(FAILED, "[%s] Can not find any valid aicore task compiler.", ori_node_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); | AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); | ||||
| std::string shape_key; | std::string shape_key; | ||||
| GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); | GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); | ||||
| @@ -133,7 +128,6 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, | |||||
| } | } | ||||
| std::vector<domi::TaskDef> task_defs; | std::vector<domi::TaskDef> task_defs; | ||||
| auto ori_node_name = node->GetName(); | |||||
| op_desc->SetName(ori_node_name + "_" + shape_key); | op_desc->SetName(ori_node_name + "_" + shape_key); | ||||
| GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str()); | GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str()); | ||||
| op_desc->SetName(ori_node_name); | op_desc->SetName(ori_node_name); | ||||
| @@ -239,5 +233,23 @@ bool AiCoreNodeTask::IsNoOp(TaskContext &task_context) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| TaskCompilerFactory &TaskCompilerFactory::GetInstance() { | |||||
| static TaskCompilerFactory instance; | |||||
| return instance; | |||||
| } | |||||
| void TaskCompilerFactory::Register(CreateFn fn) { | |||||
| compiler_func_ = fn; | |||||
| } | |||||
| std::unique_ptr<TaskCompiler> TaskCompilerFactory::GetTaskCompiler() { | |||||
| auto compiler_instance = std::unique_ptr<TaskCompiler>(compiler_func_()); | |||||
| return compiler_instance; | |||||
| } | |||||
| CompilerFunctionRegistrar::CompilerFunctionRegistrar(CreateFn fn) { | |||||
| TaskCompilerFactory::GetInstance().Register(fn); | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -18,13 +18,21 @@ | |||||
| #define GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | #define GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | ||||
| #include "hybrid/node_executor/aicore/aicore_task_builder.h" | #include "hybrid/node_executor/aicore/aicore_task_builder.h" | ||||
| #include "hybrid/node_executor/aicore/aicore_task_compiler.h" | |||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| #include <map> | #include <map> | ||||
| #include <mutex> | #include <mutex> | ||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| class TaskCompiler { | |||||
| public: | |||||
| TaskCompiler() = default; | |||||
| virtual ~TaskCompiler() = default; | |||||
| virtual Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks) = 0; | |||||
| virtual Status Initialize() = 0; | |||||
| }; | |||||
| class AiCoreNodeTaskRegistry { | class AiCoreNodeTaskRegistry { | ||||
| public: | public: | ||||
| ~AiCoreNodeTaskRegistry() = default; | ~AiCoreNodeTaskRegistry() = default; | ||||
| @@ -65,8 +73,33 @@ class AiCoreNodeExecutor : public NodeExecutor { | |||||
| private: | private: | ||||
| static Status GenNodeKey(const NodePtr &node, std::string &node_key); | static Status GenNodeKey(const NodePtr &node, std::string &node_key); | ||||
| std::unique_ptr<AiCoreTaskCompiler> compiler_; | |||||
| std::unique_ptr<TaskCompiler> compiler_; | |||||
| }; | |||||
| using CreateFn = TaskCompiler *(*)(); | |||||
| class TaskCompilerFactory { | |||||
| public: | |||||
| static TaskCompilerFactory &GetInstance(); | |||||
| void Register(CreateFn fn); | |||||
| std::unique_ptr<TaskCompiler> GetTaskCompiler(); | |||||
| private: | |||||
| CreateFn compiler_func_; | |||||
| }; | |||||
| class CompilerFunctionRegistrar { | |||||
| public: | |||||
| CompilerFunctionRegistrar(CreateFn fn); | |||||
| ~CompilerFunctionRegistrar() = default; | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif //GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | |||||
| #define REGISTER_TASK_COMPILER(compiler) \ | |||||
| static ::ge::hybrid::CompilerFunctionRegistrar register_compiler_function \ | |||||
| __attribute__((unused)) = \ | |||||
| ::ge::hybrid::CompilerFunctionRegistrar([]()->::ge::hybrid::TaskCompiler* { \ | |||||
| return new (std::nothrow) compiler(); \ | |||||
| }) \ | |||||
| #endif //GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | |||||
| @@ -18,6 +18,7 @@ | |||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "opskernel_manager/ops_kernel_builder_manager.h" | #include "opskernel_manager/ops_kernel_builder_manager.h" | ||||
| #include "init/gelib.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -25,11 +26,22 @@ namespace { | |||||
| uintptr_t kWeightBase = 0x10000000; | uintptr_t kWeightBase = 0x10000000; | ||||
| uintptr_t kMemBase = 0x20000000; | uintptr_t kMemBase = 0x20000000; | ||||
| uint64_t kFakeSize = 0x10000000UL; | uint64_t kFakeSize = 0x10000000UL; | ||||
| REGISTER_TASK_COMPILER(AiCoreTaskCompiler); | |||||
| } | } | ||||
| std::mutex AiCoreTaskCompiler::mu_; | std::mutex AiCoreTaskCompiler::mu_; | ||||
| AiCoreTaskCompiler::AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store) | |||||
| : aic_kernel_store_(std::move(aic_kernel_store)) {} | |||||
| Status AiCoreTaskCompiler::Initialize() { | |||||
| auto ge_lib = GELib::GetInstance(); | |||||
| GE_CHECK_NOTNULL(ge_lib); | |||||
| if (!ge_lib->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); | |||||
| return GE_CLI_GE_NOT_INITIALIZED; | |||||
| } | |||||
| auto &kernel_manager = ge_lib->OpsKernelManagerObj(); | |||||
| aic_kernel_store_ = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); | |||||
| GE_CHECK_NOTNULL(aic_kernel_store_); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status AiCoreTaskCompiler::DoCompileOp(const NodePtr &node) const { | Status AiCoreTaskCompiler::DoCompileOp(const NodePtr &node) const { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| @@ -19,15 +19,17 @@ | |||||
| #include <mutex> | #include <mutex> | ||||
| #include "opskernel_manager/ops_kernel_manager.h" | #include "opskernel_manager/ops_kernel_manager.h" | ||||
| #include "aicore_node_executor.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| class AiCoreTaskCompiler { | |||||
| class AiCoreTaskCompiler : public TaskCompiler { | |||||
| public: | public: | ||||
| explicit AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store); | |||||
| AiCoreTaskCompiler() = default; | |||||
| ~AiCoreTaskCompiler() = default; | ~AiCoreTaskCompiler() = default; | ||||
| Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks); | |||||
| Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks) override; | |||||
| Status Initialize() override; | |||||
| private: | private: | ||||
| Status DoCompileOp(const NodePtr &node) const; | Status DoCompileOp(const NodePtr &node) const; | ||||
| Status DoGenerateTask(const Node &node, std::vector<domi::TaskDef> &tasks); | Status DoGenerateTask(const Node &node, std::vector<domi::TaskDef> &tasks); | ||||