| @@ -51,6 +51,10 @@ class AicpuNodeTaskBase : public NodeTask { | |||||
| virtual Status UpdateOutputShapeFromExtInfo(TaskContext &task_context); | virtual Status UpdateOutputShapeFromExtInfo(TaskContext &task_context); | ||||
| virtual Status UpdateShapeAndDataByResultSummary(TaskContext &context); | |||||
| Status InitForDependComputeTask(bool is_tfkernel); | |||||
| Status UpdateShapeToOutputDesc(TaskContext &task_context, const GeShape &shape_new, int32_t output_index); | Status UpdateShapeToOutputDesc(TaskContext &task_context, const GeShape &shape_new, int32_t output_index); | ||||
| virtual Status LaunchTask(TaskContext &context) = 0; | virtual Status LaunchTask(TaskContext &context) = 0; | ||||
| @@ -58,8 +62,10 @@ class AicpuNodeTaskBase : public NodeTask { | |||||
| virtual Status UpdateIoAddr(TaskContext &context) = 0; | virtual Status UpdateIoAddr(TaskContext &context) = 0; | ||||
| static Status AllocTensorBuffer(size_t size, std::unique_ptr<TensorBuffer> &tensor_buffer); | static Status AllocTensorBuffer(size_t size, std::unique_ptr<TensorBuffer> &tensor_buffer); | ||||
| private: | private: | ||||
| Status TaskCallback(TaskContext &context); | Status TaskCallback(TaskContext &context); | ||||
| protected: | protected: | ||||
| const NodeItem *node_item_; | const NodeItem *node_item_; | ||||
| // just reference. | // just reference. | ||||
| @@ -77,6 +83,17 @@ class AicpuNodeTaskBase : public NodeTask { | |||||
| // ext info addr, device mem | // ext info addr, device mem | ||||
| std::unique_ptr<TensorBuffer> ext_info_addr_dev_; | std::unique_ptr<TensorBuffer> ext_info_addr_dev_; | ||||
| std::vector<std::unique_ptr<TensorBuffer>> output_summary_; | |||||
| std::vector<aicpu::FWKAdapter::ResultSummary> output_summary_host_; | |||||
| std::unique_ptr<TensorBuffer> copy_ioaddr_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_release_flag_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_data_size_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_src_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_dst_dev_; | |||||
| bool need_sync_ = false; | |||||
| }; | }; | ||||
| class AicpuTfNodeTask : public AicpuNodeTaskBase { | class AicpuTfNodeTask : public AicpuNodeTaskBase { | ||||
| @@ -97,10 +114,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase { | |||||
| private: | private: | ||||
| Status SetMemCopyTask(const domi::TaskDef &task_def); | Status SetMemCopyTask(const domi::TaskDef &task_def); | ||||
| Status InitForDependComputeTask(); | |||||
| Status UpdateShapeAndDataByResultSummary(TaskContext &context); | |||||
| /// | /// | ||||
| /// read result summary and prepare copy task memory. | /// read result summary and prepare copy task memory. | ||||
| /// @param context task context | /// @param context task context | ||||
| @@ -132,17 +145,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase { | |||||
| // just used for depend DEPEND_COMPUTE op | // just used for depend DEPEND_COMPUTE op | ||||
| std::unique_ptr<TensorBuffer> copy_task_args_buf_; | std::unique_ptr<TensorBuffer> copy_task_args_buf_; | ||||
| std::vector<std::unique_ptr<TensorBuffer>> output_summary_; | |||||
| std::vector<aicpu::FWKAdapter::ResultSummary> output_summary_host_; | |||||
| std::unique_ptr<TensorBuffer> copy_ioaddr_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_release_flag_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_data_size_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_src_dev_; | |||||
| std::unique_ptr<TensorBuffer> copy_input_dst_dev_; | |||||
| bool need_sync_ = false; | |||||
| std::unique_ptr<TensorBuffer> copy_workspace_buf_; | std::unique_ptr<TensorBuffer> copy_workspace_buf_; | ||||
| }; | }; | ||||
| @@ -166,10 +168,6 @@ class AicpuNodeTask : public AicpuNodeTaskBase { | |||||
| private: | private: | ||||
| Status SetMemCopyTask(const domi::TaskDef &task_def); | Status SetMemCopyTask(const domi::TaskDef &task_def); | ||||
| Status InitForDependComputeTask(); | |||||
| Status UpdateShapeAndDataByResultSummary(TaskContext &context); | |||||
| protected: | protected: | ||||
| // host mem | // host mem | ||||
| std::unique_ptr<uint8_t[]> args_; | std::unique_ptr<uint8_t[]> args_; | ||||