You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

davinci_model.h 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_
  17. #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_
  18. #include <map>
  19. #include <memory>
  20. #include <set>
  21. #include <string>
  22. #include <thread>
  23. #include <vector>
  24. #include "common/ge_types.h"
  25. #include "common/helper/model_helper.h"
  26. #include "common/helper/om_file_helper.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "common/opskernel/ge_task_info.h"
  29. #include "common/types.h"
  30. #include "framework/common/util.h"
  31. #include "graph/load/new_model_manager/data_dumper.h"
  32. #include "graph/load/new_model_manager/data_inputer.h"
  33. #include "graph/load/new_model_manager/model_utils.h"
  34. #include "graph/model.h"
  35. #include "graph/node.h"
  36. #include "graph/op_desc.h"
  37. #include "graph/operator.h"
  38. #include "graph/utils/attr_utils.h"
  39. #include "graph/utils/tensor_utils.h"
  40. #include "mmpa/mmpa_api.h"
  41. #include "proto/task.pb.h"
  42. #include "task_info/task_info.h"
  43. #define WEIGHTS_ADDR_TO_CCE(var)
  44. namespace ge {
  45. using std::vector;
  46. enum ZeroCopyMode {
  47. kInputZeroCopy,
  48. kOutputZeroCopy,
  49. };
  50. typedef enum tagModelProcStage {
  51. MODEL_LOAD_START = 1,
  52. MODEL_LOAD_END,
  53. MODEL_PRE_PROC_START,
  54. MODEL_PRE_PROC_END,
  55. MODEL_INFER_START,
  56. MODEL_INFER_END,
  57. MODEL_AFTER_PROC_START,
  58. MODEL_AFTER_PROC_END,
  59. MODEL_PROC_INVALID,
  60. } ModelProcStage;
  61. struct timeInfo {
  62. uint32_t modelId;
  63. int64_t processBeginTime;
  64. int64_t processEndTime;
  65. int64_t inferenceBeginTime;
  66. int64_t inferenceEndTime;
  67. int64_t dumpBeginTime;
  68. int64_t dumpEndTime;
  69. };
  70. // comments
  71. class DavinciModel {
  72. public:
  73. ///
  74. /// @ingroup domi_ome
  75. /// @brief DavinciModel constructor
  76. /// @author
  77. ///
  78. DavinciModel(int32_t priority, const std::shared_ptr<ModelListener> &listener);
  79. ///
  80. /// @ingroup domi_ome
  81. /// @brief DavinciModel desctructor, free Parse and Init resources
  82. /// @author
  83. ///
  84. ~DavinciModel();
  85. ///
  86. /// @ingroup domi_ome
  87. /// @brief apply model to model_def_
  88. ///
  89. Status Assign(const GeModelPtr &ge_model);
  90. ///
  91. /// @ingroup domi_ome
  92. /// @brief DavinciModel initialization, including Stream, ccHandle, Event, DataInputer, etc
  93. /// @return execute result
  94. /// @author
  95. ///
  96. Status Init(void *dev_ptr = nullptr, size_t memsize = 0, void *weight_ptr = nullptr, size_t weightsize = 0);
  97. ///
  98. /// @ingroup ge
  99. /// @brief ACL case, Load task list with queue.
  100. /// @param [in] input_que_ids: input queue ids from user, nums equal Data Op.
  101. /// @param [in] output_que_ids: input queue ids from user, nums equal NetOutput Op.
  102. /// @return: 0 for success / others for fail
  103. ///
  104. Status SetQueIds(const std::vector<uint32_t> &input_queue_ids, const std::vector<uint32_t> &output_queue_ids);
  105. ///
  106. /// @ingroup domi_ome
  107. /// @brief Get DataInputer
  108. /// @return model ID
  109. ///
  110. uint32_t Id() const { return model_id_; }
  111. ///
  112. /// @ingroup domi_ome
  113. /// @brief Get DataInputer
  114. /// @return model ID
  115. ///
  116. void SetId(uint32_t model_id) { model_id_ = model_id; }
  117. static void *Run(DavinciModel *model_pointer);
  118. ///
  119. /// @ingroup domi_ome
  120. /// @brief NnExecute
  121. /// @param [in] stream execute stream
  122. /// @param [in] async_mode is asynchronize mode.
  123. /// @param [in] input_data model input data
  124. /// @param [out] output_data model output data
  125. ///
  126. Status NnExecute(rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data);
  127. ///
  128. /// @ingroup domi_ome
  129. /// @brief get sys mode
  130. /// @return SysMode
  131. ///
  132. static SysMode GetSysMode();
  133. ///
  134. /// @ingroup domi_ome
  135. /// @brief set sys mode
  136. /// @return Status
  137. ///
  138. static Status SetSysMode(SysMode mode);
  139. ///
  140. /// @ingroup domi_ome
  141. /// @brief lock mutex run flag
  142. /// @author
  143. ///
  144. void LockRunFlg() { mux_run_flg_.lock(); }
  145. ///
  146. /// @ingroup domi_ome
  147. /// @brief unlock mutex run flag
  148. /// @author
  149. ///
  150. void UnlockRunFlg() { mux_run_flg_.unlock(); }
  151. ///
  152. /// @ingroup domi_ome
  153. /// @brief get DataInputer
  154. /// @return DataInputer pointer
  155. ///
  156. DataInputer *const GetDataInputer() const { return data_inputer_; }
  157. // get Stream number
  158. uint32_t StreamNum() const { return runtime_param_.stream_num; }
  159. // get Event number
  160. uint32_t EventNum() const { return runtime_param_.event_num; }
  161. // get Lable number
  162. uint32_t LabelNum() const { return runtime_param_.label_num; }
  163. // get batch number
  164. uint32_t BatchNum() const { return runtime_param_.batch_num; }
  165. // get session id
  166. uint64_t SessionId() const { return runtime_param_.session_id; }
  167. vector<ge::OpDescPtr> GetOpDesc() {
  168. vector<ge::OpDescPtr> opDescVector;
  169. GE_IF_BOOL_EXEC(ge::AttrUtils::GetListOpDesc(GetGeModel(), MODEL_ATTR_FUSION_MODEL_DEF, opDescVector),
  170. GELOGI("get opDesc of opDescVector"));
  171. return opDescVector;
  172. }
  173. // get model priority
  174. int32_t Priority() const { return priority_; }
  175. // get total mem size
  176. size_t TotalMemSize() const { return runtime_param_.mem_size; }
  177. // model name
  178. string Name() { return name_; }
  179. // version
  180. uint32_t Version() const { return version_; }
  181. // get total weights mem size
  182. size_t TotalWeightsMemSize() const { return runtime_param_.weight_size; }
  183. size_t TotalVarMemSize() const { return runtime_param_.var_size; }
  184. // get base memory address
  185. uint8_t *MemBase() { return mem_base_; }
  186. // get weight base memory address
  187. uint8_t *WeightsMemBase() { return weights_mem_base_; }
  188. uint8_t *VarMemBase() { return var_mem_base_; }
  189. // get Event list
  190. const vector<rtEvent_t> &GetEventList() const { return event_list_; }
  191. const vector<rtStream_t> &GetStreamList() const { return stream_list_; }
  192. const vector<rtLabel_t> &GetLabelList() const { return label_list_; }
  193. Status DestroyThread();
  194. // Get Data Op.
  195. const vector<OpDescPtr> &GetDataList() const { return data_op_list_; }
  196. // get Op
  197. map<uint32_t, OpDescPtr> GetOpList() const { return op_list_; }
  198. OpDescPtr GetOpByIndex(uint32_t index) {
  199. if (op_list_.find(index) == op_list_.end()) {
  200. return nullptr;
  201. }
  202. return op_list_.at(index);
  203. }
  204. OpDescPtr GetVariableOp(const string &name) {
  205. for (auto op_desc : variable_op_list_) {
  206. if (op_desc != nullptr && op_desc->GetName() == name) {
  207. return op_desc;
  208. }
  209. }
  210. return nullptr;
  211. }
  212. // get task info for profiling
  213. const std::vector<TaskDescInfo> &GetTaskDescInfo() const { return task_desc_info_; }
  214. // get updated task info list
  215. std::vector<TaskInfoPtr> GetTaskList() { return task_list_; }
  216. ///
  217. /// @ingroup domi_ome
  218. /// @brief get model input and output format
  219. /// @return ccTensorFormat_t current model input and output format
  220. ///
  221. ge::Format GetFormat();
  222. rtModel_t GetRtModelHandle() {
  223. rtModel_t res = rt_model_handle_;
  224. return res;
  225. }
  226. uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; }
  227. uint64_t GetRtWeightAddr() const { return runtime_param_.logic_weight_base; }
  228. uint64_t GetRtVarAddr() const { return runtime_param_.logic_var_base; }
  229. uint32_t GetFlowctrlIndex(uint32_t op_index);
  230. void PushHcclStream(rtStream_t value);
  231. bool IsBroadCastOpData(const ge::NodePtr &var_node);
  232. ///
  233. /// @ingroup domi_ome
  234. /// @brief For TVM Op, avoid Addr Reuse.
  235. /// @return void*
  236. ///
  237. static const char *GetRegisterStub(const string &tvm_binfile_key, const string &session_graph_model_id = "");
  238. ///
  239. /// @ingroup domi_ome
  240. /// @brief get model input and output desc info
  241. /// @param [out] input_shape model input size
  242. /// @param [out] output_shape model output size
  243. /// @return execute result
  244. ///
  245. Status GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_desc, vector<InputOutputDescInfo> &output_desc);
  246. Status GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_desc, vector<InputOutputDescInfo> &output_desc,
  247. std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats);
  248. ///
  249. /// @ingroup domi_ome
  250. /// @brief Get dynamic batch_info
  251. /// @param [out] batch_info
  252. /// @return execute result
  253. ///
  254. Status GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info);
  255. ///
  256. /// @ingroup domi_ome
  257. /// @brief Get model_id.
  258. /// @return model_id
  259. ///
  260. uint32_t GetModelId() const { return model_id_; }
  261. ///
  262. /// @ingroup domi_ome
  263. /// @brief get unique identification for op when load two or more models
  264. /// @param [in] op_desc : current op.
  265. /// @param [in] string identification: unique identification for current op.
  266. /// @return None
  267. ///
  268. void GetUniqueId(const OpDescPtr &op_desc, std::string &unique_identification);
  269. ///
  270. /// @ingroup domi_ome
  271. /// @brief get model input and output desc for zero copy
  272. /// @param [out] input_shape model input size
  273. /// @param [out] output_shape model output size
  274. /// @return execute result
  275. ///
  276. Status GetInputOutputDescInfoForZeroCopy(vector<InputOutputDescInfo> &input_desc,
  277. vector<InputOutputDescInfo> &output_desc);
  278. Status GetInputOutputDescInfoForZeroCopy(vector<InputOutputDescInfo> &input_desc,
  279. vector<InputOutputDescInfo> &output_desc,
  280. std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats);
  281. Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data);
  282. Status ReturnNoOutput(uint32_t data_id);
  283. ///
  284. /// @ingroup domi_ome
  285. /// @brief dump all op input and output information
  286. /// @param [in] op_list model_id
  287. /// @return Status
  288. ///
  289. Status DumpOpInputOutput();
  290. ///
  291. /// @ingroup domi_ome
  292. /// @brief dump single op input and output information
  293. /// @param [in] dump_op model_id
  294. /// @return Status
  295. ///
  296. Status DumpSingleOpInputOutput(const OpDescPtr &dump_op);
  297. Status ModelRunStart();
  298. ///
  299. /// @ingroup domi_ome
  300. /// @brief stop run model
  301. /// @return Status
  302. ///
  303. Status ModelRunStop();
  304. ///
  305. /// @ingroup domi_ome
  306. /// @brief model run flag
  307. /// @return Status
  308. ///
  309. bool RunFlag() const { return run_flg_; }
  310. Status GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &formats);
  311. ///
  312. /// @ingroup domi_ome
  313. /// @brief Set Session Id
  314. /// @return void
  315. ///
  316. void SetSessionId(uint64_t session_id) { session_id_ = session_id; }
  317. ///
  318. /// @ingroup domi_ome
  319. /// @brief Get Session Id
  320. /// @return sessionID
  321. ///
  322. uint64_t GetSessionId() const { return session_id_; }
  323. ///
  324. /// @ingroup domi_ome
  325. /// @brief SetDeviceId
  326. /// @return void
  327. ///
  328. void SetDeviceId(uint32_t device_id) { device_id_ = device_id; }
  329. ///
  330. /// @ingroup domi_ome
  331. /// @brief Get device Id
  332. /// @return device id
  333. ///
  334. uint32_t GetDeviceId() const { return device_id_; }
  335. GeModelPtr GetGeModel() { return ge_model_; }
  336. const RuntimeParam &GetRuntimeParam() { return runtime_param_; }
  337. int32_t GetDataInputTid() const { return dataInputTid; }
  338. void SetDataInputTid(int32_t data_input_tid) { dataInputTid = data_input_tid; }
  339. ///
  340. /// @ingroup domi_ome
  341. /// @brief Save outside address of Data or NetOutput used info for ZeroCopy.
  342. /// @param [in] const OpDescPtr &op_desc: current op desc
  343. /// @param [in] const std::vector<void *> &outside_addrs: address of task
  344. /// @param [in] const void *args_offset: arguments address save the address.
  345. /// @return None.
  346. ///
  347. void SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector<void *> &outside_addrs_, void *args_offset);
  348. bool GetL1FusionEnableOption() { return is_l1_fusion_enable_; }
  349. void SetProfileTime(ModelProcStage stage, int64_t endTime = 0);
  350. int64_t GetLoadBeginTime() { return load_begin_time_; }
  351. int64_t GetLoadEndTime() { return load_end_time_; }
  352. Status SinkModelProfile(std::shared_ptr<DavinciModel> &model);
  353. Status SinkTimeProfile(const InputData &current_data);
  354. void SaveDumpTask(uint32_t task_id, const std::shared_ptr<OpDesc> &op_desc, uintptr_t args) {
  355. data_dumper_.SaveDumpTask(task_id, op_desc, args);
  356. }
  357. DavinciModel &operator=(const DavinciModel &model) = delete;
  358. DavinciModel(const DavinciModel &model) = delete;
  359. private:
  360. // memory address of weights
  361. uint8_t *weights_mem_base_;
  362. uint8_t *var_mem_base_;
  363. // memory address of model
  364. uint8_t *mem_base_;
  365. bool is_inner_mem_base_;
  366. bool is_inner_weight_base_;
  367. // input data manager
  368. DataInputer *data_inputer_;
  369. int64_t load_begin_time_;
  370. int64_t load_end_time_;
  371. struct timeInfo time_info_;
  372. int32_t dataInputTid;
  373. ///
  374. /// @ingroup domi_ome
  375. /// @brief Save Data address info for ZeroCopy.
  376. /// @param [in] const std::vector<void *> &outside_addrs
  377. /// @return None.
  378. ///
  379. void SetInputOutsideAddr(const std::vector<void *> &outside_addrs);
  380. ///
  381. /// @ingroup domi_ome
  382. /// @brief Save NetOutput address info for ZeroCopy.
  383. /// @param [in] const std::vector<void *> &outside_addrs
  384. /// @return None.
  385. ///
  386. void SetOutputOutsideAddr(const std::vector<void *> &outside_addrs);
  387. ///
  388. /// @ingroup ge
  389. /// @brief Copy Check input size and model op size.
  390. /// @param [in] const int64_t &input_size: input size.
  391. /// @param [in] const int64_t &op_size: model op size.
  392. /// @param [in] is_dynamic_input: dynamic batch input flag.
  393. /// @return true if success
  394. ///
  395. bool CheckInputAndModelSize(const int64_t &input_size, const int64_t &op_size, bool is_dynamic_input);
  396. ///
  397. /// @ingroup ge
  398. /// @brief Copy Input/Output to model for direct use.
  399. /// @param [in] const InputData &input_data: user input data info.
  400. /// @param [in/out] OutputData &output_data: user output data info.
  401. /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input
  402. /// @return SUCCESS handle successfully / others handle failed
  403. ///
  404. Status CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic_input);
  405. ///
  406. /// @ingroup ge
  407. /// @brief Copy Data addr to model for direct use.
  408. /// @param [in] const std::map<uint32_t, std::pair<int64_t, void *>> &data_info: model memory addr/size list.
  409. /// @param [in] const std::vector<DataBuffer> &blobs: user input data list.
  410. /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input
  411. /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy
  412. /// @param [in] string batch_label: batch label for multi-batch scenes
  413. /// @return SUCCESS handle successfully / others handle failed
  414. ///
  415. Status ZeroCopyBlobs(const std::map<uint32_t, std::pair<int64_t, void *>> &data_info,
  416. const std::vector<DataBuffer> &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode,
  417. string batch_label);
  418. ///
  419. /// @ingroup ge
  420. /// @brief Copy input addr to model for direct use.
  421. /// @param [in] void *addr: model input memory addr.
  422. /// @param [in] uint32_t size: model input memory size.
  423. /// @param [in] const DataBuffer &data_buffer: user input data.
  424. /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input
  425. /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy
  426. /// @param [in] string batch_label: batch label for multi-batch scenes
  427. /// @return SUCCESS handle successfully / others handle failed
  428. ///
  429. Status ZeroCopyInputBlobs(void *addr, int64_t size, const DataBuffer &data_buffer, ZeroCopyMode zero_copy_mode,
  430. string batch_label);
  431. ///
  432. /// @ingroup ge
  433. /// @brief Copy address to args_ space for direct use.
  434. /// @param [in] const void *src_addr: source address of the Op.
  435. /// @param [in] const void *dst_addr: destination address of user data.
  436. /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy
  437. /// @param [in] string batch_label: batch label for multi-batch scenes
  438. /// @return SUCCESS handle successfully / others handle failed
  439. ///
  440. Status ZeroCopyImpl(const void *src_addr, const DataBuffer &data_buf, ZeroCopyMode zero_copy_mode,
  441. string batch_label);
  442. Status CopyInputData(const InputData &current_data, bool device_data = false);
  443. Status CopyTransData(const std::vector<DataBuffer> &data, uint32_t data_index, uint32_t data_op_index,
  444. const std::vector<GeAttrValue::INT> &outputs);
  445. Status CopyPlainData(const std::vector<DataBuffer> &data, uint32_t data_index, uint32_t data_op_index,
  446. const std::vector<GeAttrValue::INT> &outputs, rtMemcpyKind_t kind);
  447. Status CopyOutputData(uint32_t data_id, OutputData &output_data);
  448. Status CopyOutputDataToUser(OpDescPtr &op_desc, std::vector<DataBuffer> &blobs, uint32_t &data_index);
  449. Status SyncVarData();
  450. Status SyncDataAndDump();
  451. Status InitModelMem(void *dev_ptr, size_t memsize, void *weight_ptr, size_t weightsize);
  452. Status GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, std::vector<uint32_t> &formats);
  453. Status InitTaskInfo(domi::ModelTaskDef &modelTaskInfo);
  454. void UnbindHcomStream();
  455. Status DistributeTask();
  456. uint8_t *MallocFeatureMapMem(uint64_t data_size);
  457. uint8_t *MallocWeightsMem(uint32_t weights_size);
  458. void FreeFeatureMapMem();
  459. void FreeWeightsMem();
  460. void ReleaseTask();
  461. void UnbindTaskSinkStream();
  462. void AddEndGraphToTaskList();
  463. ///
  464. /// @ingroup ge
  465. /// @brief Travel all nodes and do some init.
  466. /// @param [in] compute_graph: ComputeGraph to load.
  467. /// @return Status
  468. ///
  469. Status InitNodes(const ComputeGraphPtr &compute_graph);
  470. ///
  471. /// @ingroup ge
  472. /// @brief Data Op Initialize.
  473. /// @param [in] NodePtr: Data Op.
  474. /// @param [in/out] data_op_index: NetOutput addr size info.
  475. /// @return Status
  476. ///
  477. Status InitDataOp(const NodePtr &node, uint32_t &data_op_index);
  478. ///
  479. /// @ingroup ge
  480. /// @brief input zero copy node Initialize.
  481. /// @param [in] NodePtr: Data Op.
  482. /// @return Status
  483. ///
  484. Status InitInputZeroCopyNodes(const NodePtr &node);
  485. ///
  486. /// @ingroup ge
  487. /// @brief NetOutput Op Initialize.
  488. /// @param [in] op_desc: NetOutput Op descriptor.
  489. /// @return Status
  490. ///
  491. Status InitNetOutput(const OpDescPtr &op_desc);
  492. ///
  493. /// @ingroup domi_ome
  494. /// @brief Constant Op Init.
  495. /// @return Status
  496. ///
  497. Status InitConstant(const OpDescPtr &op_desc);
  498. Status InitVariable(const OpDescPtr &op_desc);
  499. Status InitEndGraph(const OpDescPtr &op_desc);
  500. /// @ingroup ge
  501. /// @brief LabelSet Op Initialize.
  502. /// @param [in] op_desc: LabelSet Op descriptor.
  503. /// @return Status
  504. Status InitLabelSet(const OpDescPtr &op_desc);
  505. Status InitStreamSwitch(const OpDescPtr &op_desc);
  506. Status InitStreamActive(const OpDescPtr &op_desc);
  507. Status InitStreamSwitchN(const OpDescPtr &op_desc);
  508. ///
  509. /// @ingroup domi_ome
  510. /// @brief TVM Op Init.
  511. /// @return Status
  512. ///
  513. Status InitTbeHandle(const OpDescPtr &op_desc);
  514. void StoreTbeHandle(const std::string &handle_key);
  515. void CleanTbeHandle();
  516. ///
  517. /// @ingroup domi_ome
  518. /// @brief Init model stream for NN model.
  519. /// @return Status
  520. ///
  521. Status InitModelStream(rtStream_t stream);
  522. ///
  523. /// @ingroup ge
  524. /// @brief ACL, Load task list with queue entrance.
  525. /// @return: 0 for success / others for fail
  526. ///
  527. Status LoadWithQueue();
  528. ///
  529. /// @ingroup ge
  530. /// @brief ACL, Bind Data Op addr to input queue.
  531. /// @return: 0 for success / others for fail
  532. ///
  533. Status BindInputQueue();
  534. Status CpuTaskModelZeroCopy(std::vector<uintptr_t> &mbuf_list,
  535. std::map<const void *, std::vector<void *>> &outside_addrs);
  536. ///
  537. /// @ingroup ge
  538. /// @brief ACL, Bind NetOutput Op addr to output queue.
  539. /// @return: 0 for success / others for fail
  540. ///
  541. Status BindOutputQueue();
  542. Status CpuModelPrepareOutput(uintptr_t addr, uint32_t size);
  543. ///
  544. /// @ingroup ge
  545. /// @brief ACL, Make active stream for S0.
  546. /// @return: 0 for success / others for fail
  547. ///
  548. Status BindActiveStream();
  549. ///
  550. /// @ingroup ge
  551. /// @brief definiteness queue schedule, bind input queue to task.
  552. /// @param [in] queue_id: input queue id from user.
  553. /// @param [in] addr: Data Op output tensor address.
  554. /// @param [in] size: Data Op output tensor size.
  555. /// @return: 0 for success / others for fail
  556. ///
  557. Status CpuModelDequeue(uint32_t queue_id);
  558. ///
  559. /// @ingroup ge
  560. /// @brief definiteness queue schedule, bind output queue to task.
  561. /// @param [in] queue_id: output queue id from user.
  562. /// @param [in] addr: NetOutput Op input tensor address.
  563. /// @param [in] size: NetOutput Op input tensor size.
  564. /// @return: 0 for success / others for fail
  565. ///
  566. Status CpuModelEnqueue(uint32_t queue_id, uintptr_t addr, uint32_t size);
  567. ///
  568. /// @ingroup ge
  569. /// @brief definiteness queue schedule, active original model stream.
  570. /// @param [in] streams: streams will active by S0.
  571. /// @return: 0 for success / others for fail
  572. ///
  573. Status CpuActiveStream(const std::vector<rtStream_t> &stream_list);
  574. ///
  575. /// @ingroup ge
  576. /// @brief definiteness queue schedule, wait for end graph.
  577. /// @return: 0 for success / others for fail
  578. ///
  579. Status CpuWaitEndGraph();
  580. Status BindEnqueue();
  581. Status CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf);
  582. ///
  583. /// @ingroup ge
  584. /// @brief definiteness queue schedule, repeat run model.
  585. /// @return: 0 for success / others for fail
  586. ///
  587. Status CpuModelRepeat();
  588. void InitRuntimeParams();
  589. ///
  590. /// @ingroup ge
  591. /// @brief set ts device.
  592. /// @return: 0 for success / others for fail
  593. ///
  594. Status SetTSDevice();
  595. void CheckHasHcomOp();
  596. Status DoTaskSink();
  597. void CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputDescInfo &output, uint32_t &format_result);
  598. uint32_t GetGraphID(const std::string &session_graph_id);
  599. Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id);
  600. Status CopyVarData(ComputeGraphPtr &graph);
  601. Status CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst);
  602. // get desc info of graph for profiling
  603. Status GetComputeGraphInfo(vector<ComputeGraphDescInfo> &compute_graph_desc_info);
  604. void SetDataDumperArgs();
  605. bool is_model_has_inited_;
  606. uint32_t model_id_;
  607. uint32_t runtime_model_id_;
  608. string name_;
  609. uint32_t version_;
  610. GeModelPtr ge_model_;
  611. map<uint32_t, OpDescPtr> op_list_;
  612. // data op_desc
  613. vector<OpDescPtr> data_op_list_;
  614. vector<OpDescPtr> output_op_list_;
  615. vector<OpDescPtr> variable_op_list_;
  616. std::map<uint32_t, std::pair<int64_t, void *>> input_data_info_; // Init by Data Output Tensor
  617. std::map<uint32_t, std::pair<int64_t, void *>> output_data_info_; // Init by NetOutput Input Tensor
  618. // output op: save cce op actual needed memory size
  619. vector<int64_t> output_memory_size_list_;
  620. std::thread thread_id_;
  621. std::shared_ptr<ModelListener> listener_;
  622. bool run_flg_;
  623. std::mutex mux_run_flg_;
  624. static SysMode mode_;
  625. static std::mutex mutex_mode_;
  626. int32_t priority_;
  627. vector<rtStream_t> stream_list_;
  628. std::mutex all_hccl_stream_list_mutex_;
  629. vector<rtStream_t> all_hccl_stream_list_;
  630. vector<rtEvent_t> event_list_;
  631. vector<rtLabel_t> label_list_;
  632. set<uint32_t> label_id_indication_;
  633. std::mutex outside_addrs_mutex_;
  634. std::map<const void *, std::vector<void *>> input_outside_addrs_;
  635. std::map<const void *, std::vector<void *>> output_outside_addrs_;
  636. // {op_id, batch_label}
  637. map<int64_t, std::string> zero_copy_op_id_batch_label_;
  638. // {batch_label, addrs}
  639. map<std::string, std::vector<void *>> zero_copy_batch_label_addrs_;
  640. std::vector<TaskInfoPtr> task_list_;
  641. // rt_moodel_handle
  642. rtModel_t rt_model_handle_;
  643. rtStream_t rt_model_stream_;
  644. bool is_inner_model_stream_;
  645. bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_.
  646. // ACL queue schedule, save queue ids for Init.
  647. std::vector<TaskInfoPtr> cpu_task_list_;
  648. std::vector<uint32_t> input_queue_ids_; // input queue ids created by caller.
  649. std::vector<uint32_t> output_queue_ids_; // output queue ids created by caller.
  650. std::vector<uintptr_t> input_mbuf_list_; // input mbuf created by dequeue task.
  651. std::vector<uintptr_t> output_mbuf_list_; // output mbuf created by dequeue task.
  652. // save input/output tensor descriptor in maps
  653. std::map<std::string, ConstGeTensorDescPtr> data_op_input_tensor_desc_map_;
  654. std::map<std::string, ConstGeTensorDescPtr> data_op_output_tensor_desc_map_;
  655. bool support_mem_shared_flag_;
  656. uint64_t session_id_;
  657. uint32_t device_id_;
  658. std::mutex flowctrl_op_index_internal_map_mutex_;
  659. std::map<uint32_t, uint32_t> flowctrl_op_index_internal_map_;
  660. std::set<uint32_t> active_stream_indication_;
  661. std::shared_ptr<domi::ModelTaskDef> model_task_def_;
  662. std::set<uint32_t> aicpu_streams_;
  663. std::set<uint32_t> hcom_streams_;
  664. RuntimeParam runtime_param_;
  665. static std::mutex tvm_bin_mutex_; // lock for tvm maps.
  666. static std::set<std::string> tvm_bin_kernel_;
  667. std::map<std::string, uint32_t> used_tbe_handle_map_;
  668. // for profiling task and graph info
  669. std::map<uint32_t, std::string> op_name_map_;
  670. std::vector<TaskDescInfo> task_desc_info_;
  671. ComputeGraphPtr compute_graph_;
  672. int64_t maxDumpOpNum_;
  673. // for data dump
  674. DataDumper data_dumper_;
  675. bool input_use_zero_copy_;
  676. bool output_use_zero_copy_;
  677. uint64_t iterator_count_;
  678. bool is_l1_fusion_enable_;
  679. uint32_t end_graph_id_;
  680. OpDescPtr end_graph_op_;
  681. };
  682. #define TIME_LOG_HEAD_FMT " OP_ID OP_NAME OP_TYPE ELAPSED TIME(ms)"
  683. #define OP_TIME_LOG_FMT "%d_%-5d %-5d | %-20s | %-15s | %10f | %10d"
  684. #define MODEL_TIME_LOG_FMT "******** Model %d ends, elapsed time: %f ms ********"
  685. const size_t INPUT_OUTPUT_NAME_MAX_LEN = 256;
  686. } // namespace ge
  687. #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示