You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

davinci_model.h 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_
  17. #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_
  18. #include <map>
  19. #include <memory>
  20. #include <set>
  21. #include <string>
  22. #include <thread>
  23. #include <vector>
  24. #include "common/ge_types.h"
  25. #include "common/helper/model_helper.h"
  26. #include "common/helper/om_file_helper.h"
  27. #include "common/opskernel/ge_task_info.h"
  28. #include "common/properties_manager.h"
  29. #include "common/types.h"
  30. #include "framework/common/util.h"
  31. #include "graph/debug/ge_attr_define.h"
  32. #include "graph/load/new_model_manager/aipp_utils.h"
  33. #include "graph/load/new_model_manager/data_dumper.h"
  34. #include "graph/load/new_model_manager/data_inputer.h"
  35. #include "graph/load/new_model_manager/model_utils.h"
  36. #include "graph/load/new_model_manager/zero_copy_offset.h"
  37. #include "graph/load/new_model_manager/zero_copy_task.h"
  38. #include "graph/model.h"
  39. #include "graph/node.h"
  40. #include "graph/op_desc.h"
  41. #include "graph/operator.h"
  42. #include "graph/utils/attr_utils.h"
  43. #include "graph/utils/tensor_utils.h"
  44. #include "mmpa/mmpa_api.h"
  45. #include "proto/task.pb.h"
  46. #include "task_info/task_info.h"
  47. namespace ge {
  48. // op debug need 2048 bits buffer
  49. const size_t kOpDebugMemorySize = 2048UL;
  50. const size_t kDebugP2pSize = 8UL;
  51. typedef enum tagModelProcStage {
  52. MODEL_LOAD_START = 1,
  53. MODEL_LOAD_END,
  54. MODEL_PRE_PROC_START,
  55. MODEL_PRE_PROC_END,
  56. MODEL_INFER_START,
  57. MODEL_INFER_END,
  58. MODEL_AFTER_PROC_START,
  59. MODEL_AFTER_PROC_END,
  60. MODEL_PROC_INVALID,
  61. } ModelProcStage;
  62. struct timeInfo {
  63. uint32_t modelId;
  64. int64_t processBeginTime;
  65. int64_t processEndTime;
  66. int64_t inferenceBeginTime;
  67. int64_t inferenceEndTime;
  68. int64_t dumpBeginTime;
  69. int64_t dumpEndTime;
  70. };
  71. enum ExecuteMode {
  72. INITIALIZATION,
  73. SYNCHRONIZATION,
  74. ASYNCHRONIZATION,
  75. };
  76. // comments
  77. class DavinciModel {
  78. public:
  79. ///
  80. /// @ingroup ge
  81. /// @brief DavinciModel constructor
  82. /// @author
  83. ///
  84. DavinciModel(int32_t priority, const std::shared_ptr<ModelListener> &listener);
  85. ///
  86. /// @ingroup ge
  87. /// @brief DavinciModel desctructor, free Parse and Init resources
  88. /// @author
  89. ///
  90. ~DavinciModel();
  91. ///
  92. /// @ingroup ge
  93. /// @brief apply model to model_def_
  94. ///
  95. Status Assign(const GeModelPtr &ge_model);
  96. ///
  97. /// @ingroup ge
  98. /// @brief DavinciModel initialization, including Stream, ccHandle, Event, DataInputer, etc
  99. /// @return execute result
  100. /// @author
  101. ///
  102. Status Init(void *dev_ptr = nullptr, size_t memsize = 0, void *weight_ptr = nullptr, size_t weightsize = 0);
  103. ///
  104. /// @ingroup ge
  105. /// @brief ACL case, Load task list with queue.
  106. /// @param [in] input_que_ids: input queue ids from user, nums equal Data Op.
  107. /// @param [in] output_que_ids: input queue ids from user, nums equal NetOutput Op.
  108. /// @return: 0 for success / others for fail
  109. ///
  110. Status SetQueIds(const std::vector<uint32_t> &input_queue_ids, const std::vector<uint32_t> &output_queue_ids);
  111. ///
  112. /// @ingroup ge
  113. /// @brief Get DataInputer
  114. /// @return model ID
  115. ///
  116. uint32_t Id() const { return model_id_; }
  117. ///
  118. /// @ingroup ge
  119. /// @brief Get DataInputer
  120. /// @return model ID
  121. ///
  122. void SetId(uint32_t model_id) { model_id_ = model_id; }
  123. static void *Run(DavinciModel *model_pointer);
  124. ///
  125. /// @ingroup ge
  126. /// @brief NnExecute
  127. /// @param [in] stream execute stream
  128. /// @param [in] async_mode is asynchronize mode.
  129. /// @param [in] input_data model input data
  130. /// @param [out] output_data model output data
  131. ///
  132. Status NnExecute(rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data);
  133. ///
  134. /// @ingroup ge
  135. /// @brief lock mutex run flag
  136. /// @author
  137. ///
  138. void LockRunFlg() { mux_run_flg_.lock(); }
  139. ///
  140. /// @ingroup ge
  141. /// @brief unlock mutex run flag
  142. /// @author
  143. ///
  144. void UnlockRunFlg() { mux_run_flg_.unlock(); }
  145. ///
  146. /// @ingroup ge
  147. /// @brief get DataInputer
  148. /// @return DataInputer pointer
  149. ///
  150. DataInputer *const GetDataInputer() const { return data_inputer_; }
  151. // get Stream number
  152. uint32_t StreamNum() const { return runtime_param_.stream_num; }
  153. // get Event number
  154. uint32_t EventNum() const { return runtime_param_.event_num; }
  155. // get Lable number
  156. uint32_t LabelNum() const { return runtime_param_.label_num; }
  157. // get batch number
  158. uint32_t BatchNum() const { return runtime_param_.batch_num; }
  159. // get session id
  160. uint64_t SessionId() const { return runtime_param_.session_id; }
  161. // get model priority
  162. int32_t Priority() const { return priority_; }
  163. // get total mem size
  164. size_t TotalMemSize() const { return runtime_param_.mem_size; }
  165. // model name
  166. string Name() const { return name_; }
  167. // om_name
  168. string OmName() const { return om_name_; }
  169. // version
  170. uint32_t Version() const { return version_; }
  171. // get total weights mem size
  172. size_t TotalWeightsMemSize() const { return runtime_param_.weight_size; }
  173. size_t TotalVarMemSize() const { return runtime_param_.var_size; }
  174. // get base memory address
  175. uint8_t *MemBase() { return mem_base_; }
  176. // get weight base memory address
  177. uint8_t *WeightsMemBase() { return weights_mem_base_; }
  178. uint8_t *VarMemBase() { return var_mem_base_; }
  179. // get Event list
  180. const vector<rtEvent_t> &GetEventList() const { return event_list_; }
  181. const vector<rtStream_t> &GetStreamList() const { return stream_list_; }
  182. const vector<rtLabel_t> &GetLabelList() const { return label_list_; }
  183. Status DestroyThread();
  184. // Get Data Op.
  185. const vector<OpDescPtr> &GetDataList() const { return data_op_list_; }
  186. // get Op
  187. const map<uint32_t, OpDescPtr> &GetOpList() const { return op_list_; }
  188. OpDescPtr GetOpByIndex(uint32_t index) const {
  189. if (op_list_.find(index) == op_list_.end()) {
  190. return nullptr;
  191. }
  192. return op_list_.at(index);
  193. }
  194. OpDescPtr GetVariableOp(const string &name) {
  195. for (auto op_desc : variable_op_list_) {
  196. if (op_desc != nullptr && op_desc->GetName() == name) {
  197. return op_desc;
  198. }
  199. }
  200. return nullptr;
  201. }
  202. // get task info for profiling
  203. const std::vector<TaskDescInfo> &GetTaskDescInfo() const { return task_desc_info_; }
  204. // get updated task info list
  205. std::vector<TaskInfoPtr> GetTaskList() { return task_list_; }
  206. ///
  207. /// @ingroup ge
  208. /// @brief get model input and output format
  209. /// @return ccTensorFormat_t current model input and output format
  210. ///
  211. Format GetFormat();
  212. rtModel_t GetRtModelHandle() const { return rt_model_handle_; }
  213. rtStream_t GetRtModelStream() const { return rt_model_stream_; }
  214. uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; }
  215. uint64_t GetRtWeightAddr() const { return runtime_param_.logic_weight_base; }
  216. uint64_t GetRtVarAddr() const { return runtime_param_.logic_var_base; }
  217. uint32_t GetFlowctrlIndex(uint32_t op_index);
  218. void PushHcclStream(rtStream_t value);
  219. bool IsBroadCastOpData(const NodePtr &var_node);
  220. ///
  221. /// @ingroup ge
  222. /// @brief For TVM Op, avoid Addr Reuse.
  223. /// @return void*
  224. ///
  225. const char *GetRegisterStub(const string &tvm_binfile_key, const string &session_graph_model_id = "");
  226. ///
  227. /// @ingroup ge
  228. /// @brief get model input and output desc info
  229. /// @param [out] input_shape model input size
  230. /// @param [out] output_shape model output size
  231. /// @return execute result
  232. ///
  233. Status GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_desc, vector<InputOutputDescInfo> &output_desc);
  234. Status GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_desc, vector<InputOutputDescInfo> &output_desc,
  235. std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats);
  236. ///
  237. /// @ingroup ge
  238. /// @brief Get dynamic batch_info
  239. /// @param [out] batch_info
  240. /// @param [out] dynamic_type
  241. /// @return execute result
  242. ///
  243. Status GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) const;
  244. ///
  245. /// @ingroup ge
  246. /// @brief Get combined dynamic dims info
  247. /// @param [out] batch_info
  248. /// @return None
  249. ///
  250. void GetCombinedDynamicDims(std::vector<std::vector<int64_t>> &batch_info) const;
  251. void GetUserDesignateShapeOrder(std::vector<std::string> &user_input_shape_order) const;
  252. void GetCurShape(std::vector<int64_t> &batch_info, int32_t &dynamic_type);
  253. void GetModelAttr(std::vector<std::string> &dynamic_output_shape_info);
  254. ///
  255. /// @ingroup ge
  256. /// @brief Get AIPP input info
  257. /// @param [in] index
  258. /// @param [out] aipp_info
  259. /// @return execute result
  260. ///
  261. Status GetAIPPInfo(uint32_t index, AippConfigInfo &aipp_info);
  262. Status GetAippType(uint32_t index, InputAippType &type, size_t &aipp_index);
  263. ///
  264. /// @ingroup ge
  265. /// @brief Get model_id.
  266. /// @return model_id
  267. ///
  268. uint32_t GetModelId() const { return model_id_; }
  269. ///
  270. /// @ingroup ge
  271. /// @brief get unique identification for op when load two or more models
  272. /// @param [in] op_desc : current op.
  273. /// @param [in] string identification: unique identification for current op.
  274. /// @return None
  275. ///
  276. void GetUniqueId(const OpDescPtr &op_desc, std::string &unique_identification);
  277. ///
  278. /// @ingroup ge
  279. /// @brief get model input and output desc for zero copy
  280. /// @param [out] input_shape model input size
  281. /// @param [out] output_shape model output size
  282. /// @return execute result
  283. ///
  284. Status GetInputOutputDescInfoForZeroCopy(vector<InputOutputDescInfo> &input_desc,
  285. vector<InputOutputDescInfo> &output_desc,
  286. std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats);
  287. Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data);
  288. Status ReturnNoOutput(uint32_t data_id);
  289. Status ModelRunStart();
  290. ///
  291. /// @ingroup ge
  292. /// @brief stop run model
  293. /// @return Status
  294. ///
  295. Status ModelRunStop();
  296. ///
  297. /// @ingroup ge
  298. /// @brief model run flag
  299. /// @return Status
  300. ///
  301. bool RunFlag() const { return run_flg_; }
  302. Status GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &formats);
  303. ///
  304. /// @ingroup ge
  305. /// @brief Set Session Id
  306. /// @return void
  307. ///
  308. void SetSessionId(uint64_t session_id) { session_id_ = session_id; }
  309. ///
  310. /// @ingroup ge
  311. /// @brief Get Session Id
  312. /// @return sessionID
  313. ///
  314. uint64_t GetSessionId() const { return session_id_; }
  315. ///
  316. /// @ingroup ge
  317. /// @brief SetDeviceId
  318. /// @return void
  319. ///
  320. void SetDeviceId(uint32_t device_id) { device_id_ = device_id; }
  321. ///
  322. /// @ingroup ge
  323. /// @brief Get device Id
  324. /// @return device id
  325. ///
  326. uint32_t GetDeviceId() const { return device_id_; }
  327. bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; }
  328. Status UpdateSessionId(uint64_t session_id);
  329. const RuntimeParam &GetRuntimeParam() { return runtime_param_; }
  330. int32_t GetDataInputTid() const { return dataInputTid; }
  331. void SetDataInputTid(int32_t data_input_tid) { dataInputTid = data_input_tid; }
  332. void DisableZeroCopy(const void *addr);
  333. ///
  334. /// @ingroup ge
  335. /// @brief Save outside address of Data or NetOutput used info for ZeroCopy.
  336. /// @param [in] const OpDescPtr &op_desc: current op desc
  337. /// @param [in] const std::vector<void *> &outside_addrs: address of task
  338. /// @param [in] const void *args_offset: arguments address save the address.
  339. /// @return None.
  340. ///
  341. void SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector<void *> &outside_addrs, const void *info, void *args,
  342. size_t size, size_t offset);
  343. void SetDynamicSize(const std::vector<uint64_t> &batch_num, int32_t dynamic_type);
  344. bool GetL1FusionEnableOption() { return is_l1_fusion_enable_; }
  345. void SetProfileTime(ModelProcStage stage, int64_t endTime = 0);
  346. int64_t GetLoadBeginTime() { return load_begin_time_; }
  347. int64_t GetLoadEndTime() { return load_end_time_; }
  348. Status SinkModelProfile();
  349. Status SinkTimeProfile(const InputData &current_data);
  350. void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) {
  351. data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id);
  352. }
  353. void SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr<OpDesc> &op_desc, uintptr_t args) {
  354. data_dumper_.SaveDumpTask(task_id, stream_id, op_desc, args);
  355. }
  356. void SetEndGraphId(uint32_t task_id, uint32_t stream_id);
  357. DavinciModel &operator=(const DavinciModel &model) = delete;
  358. DavinciModel(const DavinciModel &model) = delete;
  359. const map<int64_t, std::vector<rtStream_t>> &GetHcclFolowStream() { return main_follow_stream_mapping_; }
  360. void SaveHcclFollowStream(int64_t main_stream_id, rtStream_t stream);
  361. void InitRuntimeParams();
  362. Status InitVariableMem();
  363. void UpdateMemBase(uint8_t *mem_base) {
  364. runtime_param_.mem_base = mem_base;
  365. mem_base_ = mem_base;
  366. }
  367. void SetTotalArgsSize(uint32_t args_size) { total_args_size_ += args_size; }
  368. uint32_t GetTotalArgsSize() { return total_args_size_; }
  369. void *GetCurrentArgsAddr(uint32_t offset) {
  370. void *cur_args = static_cast<char *>(args_) + offset;
  371. return cur_args;
  372. }
  373. void SetTotalIOAddrs(vector<void *> &io_addrs) {
  374. total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end());
  375. }
  376. void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size);
  377. int64_t GetFixedAddrsSize(string tensor_name);
  378. void *GetCurrentFixedAddr(int64_t offset) const {
  379. void *cur_addr = static_cast<char *>(fixed_addrs_) + offset;
  380. return cur_addr;
  381. }
  382. uint32_t GetFixedAddrOutputIndex(string tensor_name) {
  383. if (tensor_name_to_peer_output_index_.find(tensor_name) != tensor_name_to_peer_output_index_.end()) {
  384. return tensor_name_to_peer_output_index_[tensor_name];
  385. }
  386. return UINT32_MAX;
  387. }
  388. void SetKnownNode(bool known_node) { known_node_ = known_node; }
  389. bool IsKnownNode() { return known_node_; }
  390. Status MallocKnownArgs();
  391. Status UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs);
  392. Status CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs);
  393. Status UpdateKnownZeroCopyAddr();
  394. void SetKnownNodeAddrNotChanged(bool base_addr_not_changed) { base_addr_not_changed_ = base_addr_not_changed; }
  395. Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info);
  396. Status GetAllAippInputOutputDims(uint32_t index, std::vector<InputOutputDims> &input_dims,
  397. std::vector<InputOutputDims> &output_dims);
  398. void SetModelDescVersion(bool is_new_model_desc) { is_new_model_desc_ = is_new_model_desc; }
  399. // om file name
  400. void SetOmName(string om_name) { om_name_ = om_name; }
  401. void SetDumpProperties(const DumpProperties &dump_properties) { data_dumper_.SetDumpProperties(dump_properties); }
  402. const DumpProperties &GetDumpProperties() const { return data_dumper_.GetDumpProperties(); }
  403. void SetMemcpyOffsetAndAddr(map<int64_t, void *> &memcpy_4g_offset_addr) {
  404. memcpy_4g_offset_addr_.insert(memcpy_4g_offset_addr.begin(), memcpy_4g_offset_addr.end());
  405. }
  406. const map<int64_t, void *> &GetMemcpyOffsetAndAddr() const { return memcpy_4g_offset_addr_; }
  407. bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const {
  408. return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info);
  409. }
  410. Status InitInputOutputForDynamic(const ComputeGraphPtr &compute_graph);
  411. private:
  412. // memory address of weights
  413. uint8_t *weights_mem_base_;
  414. uint8_t *var_mem_base_;
  415. // memory address of model
  416. uint8_t *mem_base_;
  417. bool is_inner_mem_base_;
  418. bool is_inner_weight_base_;
  419. // input data manager
  420. DataInputer *data_inputer_;
  421. int64_t load_begin_time_;
  422. int64_t load_end_time_;
  423. struct timeInfo time_info_;
  424. int32_t dataInputTid;
  425. ///
  426. /// @ingroup ge
  427. /// @brief Save Batch label Info.
  428. /// @param [in] const OpDescPtr &op_desc
  429. /// @param [in] uintptr_t addr: address value in args block.
  430. /// @return None.
  431. ///
  432. void SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr);
  433. ///
  434. /// @ingroup ge
  435. /// @brief Copy Check input size and model op size.
  436. /// @param [in] const int64_t &input_size: input size.
  437. /// @param [in] const int64_t &op_size: model op size.
  438. /// @param [in] is_dynamic: dynamic batch input flag.
  439. /// @return true if success
  440. ///
  441. bool CheckInputAndModelSize(const int64_t &input_size, const int64_t &op_size, bool is_dynamic);
  442. ///
  443. /// @ingroup ge
  444. /// @brief Set copy only for No task feed NetOutput address.
  445. /// @return None.
  446. ///
  447. void SetCopyOnlyOutput();
  448. ///
  449. /// @ingroup ge
  450. /// @brief Copy Input/Output to model for direct use.
  451. /// @param [in] const InputData &input_data: user input data info.
  452. /// @param [in/out] OutputData &output_data: user output data info.
  453. /// @param [in] bool is_dynamic: whether is dynamic input, true: is dynamic input; false: not is dynamic input
  454. /// @return SUCCESS handle successfully / others handle failed
  455. ///
  456. Status CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic);
  457. ///
  458. /// @ingroup ge
  459. /// @brief Copy Data addr to model for direct use.
  460. /// @param [in] data_info: model memory addr/size map { data_index, { tensor_size, tensor_addr } }.
  461. /// @param [in] is_input: input data or output data
  462. /// @param [in] blobs: user input/output data list.
  463. /// @param [in] is_dynamic: whether is dynamic input, true: is dynamic input; false: not is dynamic input
  464. /// @param [in] batch_label: batch label for multi-batch scenes
  465. /// @return SUCCESS handle successfully / others handle failed
  466. ///
  467. Status UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> &data_info, bool is_input,
  468. const vector<DataBuffer> &blobs, bool is_dynamic, const string &batch_label);
  469. Status CopyInputData(const InputData &input_data, bool device_data = false);
  470. Status CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind);
  471. Status SyncVarData();
  472. Status InitModelMem(void *dev_ptr, size_t memsize, void *weight_ptr, size_t weightsize);
  473. void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input);
  474. void SetInputDimsInfo(const vector<int64_t> &model_input_dims, Format &format, InputOutputDescInfo &input);
  475. Status GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, std::vector<uint32_t> &formats);
  476. Status InitTaskInfo(domi::ModelTaskDef &modelTaskInfo);
  477. void UnbindHcomStream();
  478. Status DistributeTask();
  479. uint8_t *MallocFeatureMapMem(size_t data_size);
  480. uint8_t *MallocWeightsMem(size_t weights_size);
  481. void FreeFeatureMapMem();
  482. void FreeWeightsMem();
  483. void ReleaseTask();
  484. void UnbindTaskSinkStream();
  485. bool IsAicpuKernelConnectSpecifiedLayer();
  486. ///
  487. /// @ingroup ge
  488. /// @brief Reduce memory usage after task sink.
  489. /// @return: void
  490. ///
  491. void Shrink();
  492. ///
  493. /// @ingroup ge
  494. /// @brief Travel all nodes and do some init.
  495. /// @param [in] compute_graph: ComputeGraph to load.
  496. /// @return Status
  497. ///
  498. Status InitNodes(const ComputeGraphPtr &compute_graph);
  499. ///
  500. /// @ingroup ge
  501. /// @brief Data Op Initialize.
  502. /// @param [in] NodePtr: Data Op.
  503. /// @param [in/out] data_op_index: NetOutput addr size info.
  504. /// @return Status
  505. ///
  506. Status InitDataOp(const NodePtr &node, uint32_t &data_op_index, map<uint32_t, OpDescPtr> &data_by_index);
  507. ///
  508. /// @ingroup ge
  509. /// @brief Sort Data op list by index.
  510. /// @param [in] data_by_index: map of Data Op.
  511. /// @return
  512. ///
  513. void AdjustDataOpList(const map<uint32_t, OpDescPtr> &data_by_index);
  514. ///
  515. /// @ingroup ge
  516. /// @brief input zero copy node Initialize.
  517. /// @param [in] NodePtr: Data Op.
  518. /// @return Status
  519. ///
  520. Status InitInputZeroCopyNodes(const NodePtr &node);
  521. ///
  522. /// @ingroup ge
  523. /// @brief NetOutput Op Initialize.
  524. /// @param [in] NodePtr: NetOutput Op.
  525. /// @return Status
  526. ///
  527. Status InitNetOutput(const NodePtr &node);
  528. ///
  529. /// @ingroup ge
  530. /// @brief output zero copy node Initialize.
  531. /// @param [in] NodePtr: Data Op.
  532. /// @return Status
  533. ///
  534. Status InitOutputZeroCopyNodes(const NodePtr &node);
  535. ///
  536. /// @ingroup ge
  537. /// @brief Constant Op Init.
  538. /// @return Status
  539. ///
  540. Status InitConstant(const OpDescPtr &op_desc);
  541. Status InitVariable(const OpDescPtr &op_desc);
  542. /// @ingroup ge
  543. /// @brief LabelSet Op Initialize.
  544. /// @param [in] op_desc: LabelSet Op descriptor.
  545. /// @return Status
  546. Status InitLabelSet(const OpDescPtr &op_desc);
  547. Status InitStreamSwitch(const OpDescPtr &op_desc);
  548. Status InitStreamActive(const OpDescPtr &op_desc);
  549. Status InitStreamSwitchN(const OpDescPtr &op_desc);
  550. ///
  551. /// @ingroup ge
  552. /// @brief Case Op Init.
  553. /// @return Status
  554. ///
  555. Status InitCase(const OpDescPtr &op_desc);
  556. Status SetDynamicBatchInfo(const OpDescPtr &op_desc, uint32_t batch_num);
  557. ///
  558. /// @ingroup ge
  559. /// @brief TVM Op Init.
  560. /// @return Status
  561. ///
  562. Status InitTbeHandle(const OpDescPtr &op_desc);
  563. void StoreTbeHandle(const std::string &handle_key);
  564. void CleanTbeHandle();
  565. ///
  566. /// @ingroup ge
  567. /// @brief Make active stream list and bind to model.
  568. /// @return: 0 for success / others for fail
  569. ///
  570. Status BindModelStream();
  571. ///
  572. /// @ingroup ge
  573. /// @brief Init model stream for NN model.
  574. /// @return Status
  575. ///
  576. Status InitModelStream(rtStream_t stream);
  577. ///
  578. /// @ingroup ge
  579. /// @brief ACL, Load task list with queue entrance.
  580. /// @return: 0 for success / others for fail
  581. ///
  582. Status LoadWithQueue();
  583. ///
  584. /// @ingroup ge
  585. /// @brief ACL, Bind Data Op addr to input queue.
  586. /// @return: 0 for success / others for fail
  587. ///
  588. Status BindInputQueue();
  589. Status CpuTaskModelZeroCopy(std::vector<uintptr_t> &mbuf_list, std::map<const void *, ZeroCopyOffset> &outside_addrs);
  590. ///
  591. /// @ingroup ge
  592. /// @brief ACL, Bind NetOutput Op addr to output queue.
  593. /// @return: 0 for success / others for fail
  594. ///
  595. Status BindOutputQueue();
  596. Status CpuModelPrepareOutput(uintptr_t addr, uint32_t size);
  597. ///
  598. /// @ingroup ge
  599. /// @brief definiteness queue schedule, bind input queue to task.
  600. /// @param [in] queue_id: input queue id from user.
  601. /// @param [in] addr: Data Op output tensor address.
  602. /// @param [in] size: Data Op output tensor size.
  603. /// @return: 0 for success / others for fail
  604. ///
  605. Status CpuModelDequeue(uint32_t queue_id);
  606. ///
  607. /// @ingroup ge
  608. /// @brief definiteness queue schedule, bind output queue to task.
  609. /// @param [in] queue_id: output queue id from user.
  610. /// @param [in] addr: NetOutput Op input tensor address.
  611. /// @param [in] size: NetOutput Op input tensor size.
  612. /// @return: 0 for success / others for fail
  613. ///
  614. Status CpuModelEnqueue(uint32_t queue_id, uintptr_t addr, uint32_t size);
  615. ///
  616. /// @ingroup ge
  617. /// @brief definiteness queue schedule, active original model stream.
  618. /// @return: 0 for success / others for fail
  619. ///
  620. Status CpuActiveStream();
  621. ///
  622. /// @ingroup ge
  623. /// @brief definiteness queue schedule, wait for end graph.
  624. /// @return: 0 for success / others for fail
  625. ///
  626. Status CpuWaitEndGraph();
  627. Status BindEnqueue();
  628. Status CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf);
  629. ///
  630. /// @ingroup ge
  631. /// @brief definiteness queue schedule, repeat run model.
  632. /// @return: 0 for success / others for fail
  633. ///
  634. Status CpuModelRepeat();
  635. Status InitEntryTask();
  636. Status AddHeadStream();
  637. ///
  638. /// @ingroup ge
  639. /// @brief set ts device.
  640. /// @return: 0 for success / others for fail
  641. ///
  642. Status SetTSDevice();
  643. Status OpDebugRegister();
  644. void OpDebugUnRegister();
  645. void CheckHasHcomOp();
  646. Status DoTaskSink();
  647. void CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputDescInfo &output, uint32_t &format_result);
  648. Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id);
  649. // get desc info of graph for profiling
  650. Status GetComputeGraphInfo(const ComputeGraphPtr &graph, vector<ComputeGraphDescInfo> &graph_desc_info);
  651. void SetDataDumperArgs(const ComputeGraphPtr &compute_graph);
  652. Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data,
  653. std::vector<ge::OutputTensorInfo> &outputs);
  654. void ParseAIPPInfo(std::string in_out_info, InputOutputDims &dims_info);
  655. void GetFixedAddrAttr(const OpDescPtr &op_desc);
  656. bool is_model_has_inited_;
  657. uint32_t model_id_;
  658. uint32_t runtime_model_id_;
  659. string name_;
  660. // used for inference data dump
  661. string om_name_;
  662. uint32_t version_;
  663. GeModelPtr ge_model_;
  664. bool need_destroy_aicpu_kernel_{false};
  665. vector<std::string> out_node_name_;
  666. map<uint32_t, OpDescPtr> op_list_;
  667. // data op_desc
  668. vector<OpDescPtr> data_op_list_;
  669. vector<OpDescPtr> output_op_list_;
  670. vector<OpDescPtr> variable_op_list_;
  671. std::map<uint32_t, ZeroCopyOffset> new_input_data_info_;
  672. std::map<uint32_t, ZeroCopyOffset> new_output_data_info_;
  673. std::map<const void *, ZeroCopyOffset> new_input_outside_addrs_;
  674. std::map<const void *, ZeroCopyOffset> new_output_outside_addrs_;
  675. std::vector<void *> real_virtual_addrs_;
  676. // output op: save cce op actual needed memory size
  677. vector<int64_t> output_memory_size_list_;
  678. std::thread thread_id_;
  679. std::shared_ptr<ModelListener> listener_;
  680. bool run_flg_;
  681. std::mutex mux_run_flg_;
  682. int32_t priority_;
  683. vector<rtStream_t> stream_list_;
  684. std::mutex all_hccl_stream_list_mutex_;
  685. vector<rtStream_t> all_hccl_stream_list_;
  686. // for reuse hccl_follow_stream
  687. std::mutex capacity_of_stream_mutex_;
  688. std::map<int64_t, std::vector<rtStream_t>> main_follow_stream_mapping_;
  689. vector<rtEvent_t> event_list_;
  690. vector<rtLabel_t> label_list_;
  691. set<uint32_t> label_id_indication_;
  692. std::mutex outside_addrs_mutex_;
  693. std::vector<ZeroCopyTask> zero_copy_tasks_; // Task used Data or NetOutput addr.
  694. std::set<const void *> copy_only_addrs_; // Address need copy to original place.
  695. // {op_id, batch_label}
  696. std::map<int64_t, std::string> zero_copy_op_id_batch_label_;
  697. // {batch_label, addrs}
  698. std::map<std::string, std::set<uintptr_t>> zero_copy_batch_label_addrs_;
  699. std::vector<TaskInfoPtr> task_list_;
  700. // rt_moodel_handle
  701. rtModel_t rt_model_handle_;
  702. rtStream_t rt_model_stream_;
  703. bool is_inner_model_stream_;
  704. bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_.
  705. ExecuteMode last_execute_mode_;
  706. bool is_stream_list_bind_{false};
  707. bool is_pure_head_stream_{false};
  708. rtStream_t rt_head_stream_{nullptr};
  709. rtStream_t rt_entry_stream_{nullptr};
  710. rtAicpuDeployType_t deploy_type_{AICPU_DEPLOY_RESERVED};
  711. // ACL queue schedule, save queue ids for Init.
  712. std::vector<TaskInfoPtr> cpu_task_list_;
  713. std::vector<uint32_t> input_queue_ids_; // input queue ids created by caller.
  714. std::vector<uint32_t> output_queue_ids_; // output queue ids created by caller.
  715. std::vector<uintptr_t> input_mbuf_list_; // input mbuf created by dequeue task.
  716. std::vector<uintptr_t> output_mbuf_list_; // output mbuf created by dequeue task.
  717. uint64_t session_id_;
  718. uint32_t device_id_;
  719. std::mutex flowctrl_op_index_internal_map_mutex_;
  720. std::map<uint32_t, uint32_t> flowctrl_op_index_internal_map_;
  721. std::vector<rtStream_t> active_stream_list_;
  722. std::set<uint32_t> active_stream_indication_;
  723. std::set<uint32_t> hcom_streams_;
  724. RuntimeParam runtime_param_;
  725. static std::mutex tvm_bin_mutex_;
  726. std::set<std::string> tvm_bin_kernel_;
  727. std::map<std::string, uint32_t> used_tbe_handle_map_;
  728. // for profiling task and graph info
  729. std::map<uint32_t, std::string> op_name_map_;
  730. std::vector<TaskDescInfo> task_desc_info_;
  731. int64_t maxDumpOpNum_;
  732. // for data dump
  733. DataDumper data_dumper_;
  734. uint64_t iterator_count_;
  735. bool is_l1_fusion_enable_;
  736. std::map<OpDescPtr, void *> saved_task_addrs_;
  737. void *l1_fusion_addr_ = nullptr;
  738. bool known_node_ = false;
  739. uint32_t total_args_size_ = 0;
  740. void *args_ = nullptr;
  741. void *args_host_ = nullptr;
  742. void *fixed_addrs_ = nullptr;
  743. int64_t total_fixed_addr_size_ = 0;
  744. std::map<const void *, void *> knonw_input_data_info_;
  745. std::map<const void *, void *> knonw_output_data_info_;
  746. vector<void *> total_io_addrs_;
  747. vector<void *> orig_total_io_addrs_;
  748. bool base_addr_not_changed_ = false;
  749. vector<vector<int64_t>> batch_info_;
  750. std::vector<std::vector<int64_t>> combined_batch_info_;
  751. vector<string> user_designate_shape_order_;
  752. int32_t dynamic_type_ = 0;
  753. bool is_dynamic_ = false;
  754. vector<uint64_t> batch_size_;
  755. // key: input tensor name, generally rts op;
  756. // value: the fixed addr of input anchor, same as the peer output anchor addr of the peer op
  757. std::map<string, int64_t> tensor_name_to_fixed_addr_size_;
  758. // key: input tensor name, generally rts op; value: the peer output anchor of the peer op
  759. std::map<string, int64_t> tensor_name_to_peer_output_index_;
  760. // if model is first execute
  761. bool is_first_execute_;
  762. // for op debug
  763. std::mutex debug_reg_mutex_;
  764. bool is_op_debug_reg_ = false;
  765. void *op_debug_addr_ = nullptr;
  766. void *p2p_debug_addr_ = nullptr;
  767. bool is_new_model_desc_{false};
  768. std::map<int64_t, void *> memcpy_4g_offset_addr_;
  769. };
  770. } // namespace ge
  771. #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示