From b7909baabd5a5287ffd5bd673a29176433e826cd Mon Sep 17 00:00:00 2001 From: guopeian Date: Fri, 25 Jun 2021 17:13:13 +0800 Subject: [PATCH] fix --- .../node_executor/aicpu/aicpu_node_executor.h | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h index 8980d41b..8f38ffcd 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -51,6 +51,10 @@ class AicpuNodeTaskBase : public NodeTask { virtual Status UpdateOutputShapeFromExtInfo(TaskContext &task_context); + virtual Status UpdateShapeAndDataByResultSummary(TaskContext &context); + + Status InitForDependComputeTask(bool is_tfkernel); + Status UpdateShapeToOutputDesc(TaskContext &task_context, const GeShape &shape_new, int32_t output_index); virtual Status LaunchTask(TaskContext &context) = 0; @@ -58,8 +62,10 @@ class AicpuNodeTaskBase : public NodeTask { virtual Status UpdateIoAddr(TaskContext &context) = 0; static Status AllocTensorBuffer(size_t size, std::unique_ptr &tensor_buffer); + private: Status TaskCallback(TaskContext &context); + protected: const NodeItem *node_item_; // just reference. @@ -77,6 +83,17 @@ class AicpuNodeTaskBase : public NodeTask { // ext info addr, device mem std::unique_ptr ext_info_addr_dev_; + + std::vector> output_summary_; + std::vector output_summary_host_; + + std::unique_ptr copy_ioaddr_dev_; + + std::unique_ptr copy_input_release_flag_dev_; + std::unique_ptr copy_input_data_size_dev_; + std::unique_ptr copy_input_src_dev_; + std::unique_ptr copy_input_dst_dev_; + bool need_sync_ = false; }; class AicpuTfNodeTask : public AicpuNodeTaskBase { @@ -97,10 +114,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase { private: Status SetMemCopyTask(const domi::TaskDef &task_def); - Status InitForDependComputeTask(); - - Status UpdateShapeAndDataByResultSummary(TaskContext &context); - /// /// read result summary and prepare copy task memory. /// @param context task context @@ -132,17 +145,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase { // just used for depend DEPEND_COMPUTE op std::unique_ptr copy_task_args_buf_; - std::vector> output_summary_; - std::vector output_summary_host_; - - std::unique_ptr copy_ioaddr_dev_; - - std::unique_ptr copy_input_release_flag_dev_; - std::unique_ptr copy_input_data_size_dev_; - std::unique_ptr copy_input_src_dev_; - std::unique_ptr copy_input_dst_dev_; - bool need_sync_ = false; - std::unique_ptr copy_workspace_buf_; }; @@ -166,10 +168,6 @@ class AicpuNodeTask : public AicpuNodeTaskBase { private: Status SetMemCopyTask(const domi::TaskDef &task_def); - Status InitForDependComputeTask(); - - Status UpdateShapeAndDataByResultSummary(TaskContext &context); - protected: // host mem std::unique_ptr args_;