| @@ -77,6 +77,11 @@ class OpTask { | |||||
| class TbeOpTask : public OpTask { | class TbeOpTask : public OpTask { | ||||
| public: | public: | ||||
| ~TbeOpTask() override; | ~TbeOpTask() override; | ||||
| Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | |||||
| const std::vector<DataBuffer> &input_buffers, | |||||
| std::vector<GeTensorDesc> &output_desc, | |||||
| std::vector<DataBuffer> &output_buffers, | |||||
| rtStream_t stream) override; | |||||
| Status LaunchKernel(rtStream_t stream) override; | Status LaunchKernel(rtStream_t stream) override; | ||||
| void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; | void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; | ||||
| void SetSmDesc(void *sm_desc); | void SetSmDesc(void *sm_desc); | ||||
| @@ -162,7 +167,11 @@ class AiCpuBaseTask : public OpTask { | |||||
| UnknowShapeOpType GetUnknownType() const { return unknown_type_; } | UnknowShapeOpType GetUnknownType() const { return unknown_type_; } | ||||
| Status UpdateArgTable(const SingleOpModelParam ¶m) override; | Status UpdateArgTable(const SingleOpModelParam ¶m) override; | ||||
| const std::string &GetTaskType() const override; | const std::string &GetTaskType() const override; | ||||
| Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | |||||
| const std::vector<DataBuffer> &input_buffers, | |||||
| std::vector<GeTensorDesc> &output_desc, | |||||
| std::vector<DataBuffer> &output_buffers, | |||||
| rtStream_t stream) override; | |||||
| protected: | protected: | ||||
| Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs); | Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs); | ||||
| Status SetInputConst(); | Status SetInputConst(); | ||||
| @@ -198,11 +207,7 @@ class AiCpuTask : public AiCpuBaseTask { | |||||
| Status LaunchKernel(rtStream_t stream) override; | Status LaunchKernel(rtStream_t stream) override; | ||||
| void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; | void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; | ||||
| Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | |||||
| const std::vector<DataBuffer> &input_buffers, | |||||
| std::vector<GeTensorDesc> &output_desc, | |||||
| std::vector<DataBuffer> &output_buffers, | |||||
| rtStream_t stream) override; | |||||
| Status SetMemCopyTask(const domi::KernelExDef &kernel_def); | Status SetMemCopyTask(const domi::KernelExDef &kernel_def); | ||||
| private: | private: | ||||
| @@ -266,12 +271,6 @@ class AiCpuCCTask : public AiCpuBaseTask { | |||||
| void SetIoAddr(uintptr_t *io_addr); | void SetIoAddr(uintptr_t *io_addr); | ||||
| size_t GetArgSize() const; | size_t GetArgSize() const; | ||||
| Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | |||||
| const std::vector<DataBuffer> &input_buffers, | |||||
| std::vector<GeTensorDesc> &output_desc, | |||||
| std::vector<DataBuffer> &output_buffers, | |||||
| rtStream_t stream) override; | |||||
| private: | private: | ||||
| friend class AiCpuCCTaskBuilder; | friend class AiCpuCCTaskBuilder; | ||||
| std::string so_name_; | std::string so_name_; | ||||