You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hybrid_model_pipeline_executor.h 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_
  2. #define GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_
  3. #include "common/blocking_queue.h"
  4. #include "common/thread_pool.h"
  5. #include "hybrid/executor/hybrid_execution_context.h"
  6. #include "hybrid/executor/rt_callback_manager.h"
  7. #include "hybrid/executor/subgraph_executor.h"
  8. #include "hybrid_model_executor.h"
  9. namespace ge {
  10. namespace hybrid {
  11. struct PipeExecutionConfig {
  12. uint32_t device_id;
  13. rtContext_t rt_context;
  14. int num_executors;
  15. int num_stages;
  16. long iteration_end;
  17. };
  18. class StageExecutor {
  19. public:
  20. struct StageTask {
  21. rtEvent_t event = nullptr;
  22. int stage = 0;
  23. long iteration = 0;
  24. };
  25. StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config);
  26. ~StageExecutor();
  27. Status Init();
  28. void Reset();
  29. Status Start(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc,
  30. int loop_count);
  31. Status SetInputs(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc);
  32. Status ExecuteAsync(const StageTask &args);
  33. Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc);
  34. Status Synchronize();
  35. void SetNext(StageExecutor *next_executor) { next_executor_ = next_executor; }
  36. private:
  37. friend class HybridModelPipelineExecutor;
  38. static Status ResetExecutionContext(GraphExecutionContext &context);
  39. Status InitExecutionContext();
  40. int id_;
  41. HybridModel *model_;
  42. PipeExecutionConfig *pipe_config_;
  43. BlockingQueue<StageTask> task_queue_;
  44. std::unique_ptr<SubgraphExecutor> root_graph_executor_;
  45. GraphExecutionContext context_;
  46. StageExecutor *next_executor_ = nullptr;
  47. rtStream_t stream_ = nullptr;
  48. rtStream_t hccl_stream_ = nullptr;
  49. };
  50. class HybridModelPipelineExecutor {
  51. public:
  52. HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id);
  53. ~HybridModelPipelineExecutor();
  54. Status Init();
  55. Status InitStageExecutors();
  56. Status Execute(HybridModelExecutor::ExecuteArgs &args);
  57. private:
  58. HybridModel *model_;
  59. uint32_t device_id_;
  60. std::vector<std::unique_ptr<StageExecutor>> stage_executors_;
  61. PipeExecutionConfig config_;
  62. GraphExecutionContext context_;
  63. long iteration_ = 0;
  64. };
  65. } // namespace hybrid
  66. } // namespace ge
  67. #endif // GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示