Browse Source

fix

pull/2052/head
guopeian 4 years ago
parent
commit
379ab5d11d
1 changed files with 11 additions and 12 deletions
  1. +11
    -12
      ge/single_op/task/op_task.h

+ 11
- 12
ge/single_op/task/op_task.h View File

@@ -77,6 +77,11 @@ class OpTask {
class TbeOpTask : public OpTask { class TbeOpTask : public OpTask {
public: public:
~TbeOpTask() override; ~TbeOpTask() override;
Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) override;
Status LaunchKernel(rtStream_t stream) override; Status LaunchKernel(rtStream_t stream) override;
void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override;
void SetSmDesc(void *sm_desc); void SetSmDesc(void *sm_desc);
@@ -162,7 +167,11 @@ class AiCpuBaseTask : public OpTask {
UnknowShapeOpType GetUnknownType() const { return unknown_type_; } UnknowShapeOpType GetUnknownType() const { return unknown_type_; }
Status UpdateArgTable(const SingleOpModelParam &param) override; Status UpdateArgTable(const SingleOpModelParam &param) override;
const std::string &GetTaskType() const override; const std::string &GetTaskType() const override;

Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) override;
protected: protected:
Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs); Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
Status SetInputConst(); Status SetInputConst();
@@ -198,11 +207,7 @@ class AiCpuTask : public AiCpuBaseTask {
Status LaunchKernel(rtStream_t stream) override; Status LaunchKernel(rtStream_t stream) override;
void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override;


Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) override;
Status SetMemCopyTask(const domi::KernelExDef &kernel_def); Status SetMemCopyTask(const domi::KernelExDef &kernel_def);


private: private:
@@ -266,12 +271,6 @@ class AiCpuCCTask : public AiCpuBaseTask {
void SetIoAddr(uintptr_t *io_addr); void SetIoAddr(uintptr_t *io_addr);
size_t GetArgSize() const; size_t GetArgSize() const;


Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
const std::vector<DataBuffer> &input_buffers,
std::vector<GeTensorDesc> &output_desc,
std::vector<DataBuffer> &output_buffers,
rtStream_t stream) override;

private: private:
friend class AiCpuCCTaskBuilder; friend class AiCpuCCTaskBuilder;
std::string so_name_; std::string so_name_;


Loading…
Cancel
Save