Browse Source

fix

pull/1865/head
guopeian 4 years ago
parent
commit
b7909baabd
1 changed files with 17 additions and 19 deletions
  1. +17
    -19
      ge/hybrid/node_executor/aicpu/aicpu_node_executor.h

+ 17
- 19
ge/hybrid/node_executor/aicpu/aicpu_node_executor.h View File

@@ -51,6 +51,10 @@ class AicpuNodeTaskBase : public NodeTask {


virtual Status UpdateOutputShapeFromExtInfo(TaskContext &task_context); virtual Status UpdateOutputShapeFromExtInfo(TaskContext &task_context);


virtual Status UpdateShapeAndDataByResultSummary(TaskContext &context);

Status InitForDependComputeTask(bool is_tfkernel);

Status UpdateShapeToOutputDesc(TaskContext &task_context, const GeShape &shape_new, int32_t output_index); Status UpdateShapeToOutputDesc(TaskContext &task_context, const GeShape &shape_new, int32_t output_index);


virtual Status LaunchTask(TaskContext &context) = 0; virtual Status LaunchTask(TaskContext &context) = 0;
@@ -58,8 +62,10 @@ class AicpuNodeTaskBase : public NodeTask {
virtual Status UpdateIoAddr(TaskContext &context) = 0; virtual Status UpdateIoAddr(TaskContext &context) = 0;


static Status AllocTensorBuffer(size_t size, std::unique_ptr<TensorBuffer> &tensor_buffer); static Status AllocTensorBuffer(size_t size, std::unique_ptr<TensorBuffer> &tensor_buffer);

private: private:
Status TaskCallback(TaskContext &context); Status TaskCallback(TaskContext &context);

protected: protected:
const NodeItem *node_item_; const NodeItem *node_item_;
// just reference. // just reference.
@@ -77,6 +83,17 @@ class AicpuNodeTaskBase : public NodeTask {


// ext info addr, device mem // ext info addr, device mem
std::unique_ptr<TensorBuffer> ext_info_addr_dev_; std::unique_ptr<TensorBuffer> ext_info_addr_dev_;

std::vector<std::unique_ptr<TensorBuffer>> output_summary_;
std::vector<aicpu::FWKAdapter::ResultSummary> output_summary_host_;

std::unique_ptr<TensorBuffer> copy_ioaddr_dev_;

std::unique_ptr<TensorBuffer> copy_input_release_flag_dev_;
std::unique_ptr<TensorBuffer> copy_input_data_size_dev_;
std::unique_ptr<TensorBuffer> copy_input_src_dev_;
std::unique_ptr<TensorBuffer> copy_input_dst_dev_;
bool need_sync_ = false;
}; };


class AicpuTfNodeTask : public AicpuNodeTaskBase { class AicpuTfNodeTask : public AicpuNodeTaskBase {
@@ -97,10 +114,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase {
private: private:
Status SetMemCopyTask(const domi::TaskDef &task_def); Status SetMemCopyTask(const domi::TaskDef &task_def);


Status InitForDependComputeTask();

Status UpdateShapeAndDataByResultSummary(TaskContext &context);

/// ///
/// read result summary and prepare copy task memory. /// read result summary and prepare copy task memory.
/// @param context task context /// @param context task context
@@ -132,17 +145,6 @@ class AicpuTfNodeTask : public AicpuNodeTaskBase {
// just used for depend DEPEND_COMPUTE op // just used for depend DEPEND_COMPUTE op
std::unique_ptr<TensorBuffer> copy_task_args_buf_; std::unique_ptr<TensorBuffer> copy_task_args_buf_;


std::vector<std::unique_ptr<TensorBuffer>> output_summary_;
std::vector<aicpu::FWKAdapter::ResultSummary> output_summary_host_;

std::unique_ptr<TensorBuffer> copy_ioaddr_dev_;

std::unique_ptr<TensorBuffer> copy_input_release_flag_dev_;
std::unique_ptr<TensorBuffer> copy_input_data_size_dev_;
std::unique_ptr<TensorBuffer> copy_input_src_dev_;
std::unique_ptr<TensorBuffer> copy_input_dst_dev_;
bool need_sync_ = false;

std::unique_ptr<TensorBuffer> copy_workspace_buf_; std::unique_ptr<TensorBuffer> copy_workspace_buf_;
}; };


@@ -166,10 +168,6 @@ class AicpuNodeTask : public AicpuNodeTaskBase {
private: private:
Status SetMemCopyTask(const domi::TaskDef &task_def); Status SetMemCopyTask(const domi::TaskDef &task_def);


Status InitForDependComputeTask();

Status UpdateShapeAndDataByResultSummary(TaskContext &context);

protected: protected:
// host mem // host mem
std::unique_ptr<uint8_t[]> args_; std::unique_ptr<uint8_t[]> args_;


Loading…
Cancel
Save