You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session.h 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PREDICT_INCLUDE_SESSION_H_
  17. #define PREDICT_INCLUDE_SESSION_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <map>
  22. #include <unordered_set>
  23. #include "include/context.h"
  24. #include "include/tensor.h"
  25. #define MSPREDICT_API __attribute__((visibility("default")))
  26. namespace mindspore {
  27. namespace predict {
  28. using NODE_ID = std::string;
  29. ///\brief Graph defined by MindSpore predict.
  30. ///
  31. ///\note
  32. /// The caller does not need to care about detailed implementation of this class, so just list the class name here.
  33. class Graph;
  34. ///\brief GraphExecution defined by MindSpore predict.
  35. ///
  36. ///\note
  37. /// The caller does not need to care about detailed implementation of this class, so just list the class name here.
  38. class GraphExecution;
  39. ///\brief MindSpore predict session.
  40. ///
  41. /// This class represents session of MindSpore predict.
  42. ///
  43. ///\note
  44. /// The caller needs to allocate and free memory of inputs and outputs.
  45. /// New Session is not suggested, please use CreateSession function to create new session class.
  46. class MSPREDICT_API Session {
  47. public:
  48. ///\brief Constructor of MindSpore predict session.
  49. ///
  50. ///\param[in] ctx The context of the session.
  51. ///
  52. ///\return Instance of MindSpore predict session.
  53. explicit Session(const Context &ctx);
  54. ///\brief Destructor of MindSpore predict session.
  55. ~Session();
  56. ///\brief Init the session.
  57. ///
  58. ///\param[in] ctx The context of the session.
  59. ///\param[in] size The size of the session.
  60. ///\param[in] graphBuf The buffer of the graph, used for build session.
  61. ///
  62. ///\return Return RET_OK if the initialization is success, otherwhise return RET_ERROR.
  63. int Init(const char *graphBuf, size_t size);
  64. ///\brief Get the input of session.
  65. ///
  66. ///\return Input node's input tensors if found, empty vector otherwise.
  67. ///
  68. ///\note
  69. /// The caller needs to allocate and free memory of inputs.
  70. std::vector<Tensor *> GetInput();
  71. ///\brief Run the session.
  72. ///
  73. ///\param[in] inputs The input of the session.
  74. ///
  75. ///\return Return RET_OK if run success, otherwhise return RET_ERROR.
  76. ///\note
  77. /// Currently input tensors' data format only support FORMAT_NCHW.
  78. /// Currently input tensors' data type only support FLOAT.
  79. int Run(const std::vector<Tensor *> &inputs);
  80. ///\brief Get the output of session.
  81. ///
  82. ///\param[in] nodeName Given output node name.
  83. ///
  84. ///\return Output node's output tensors if found, empty vector otherwise.
  85. ///
  86. ///\note
  87. /// The caller needs to free memory of outputs.
  88. std::vector<Tensor *> GetOutput(const std::string &nodeName);
  89. ///\brief Get the all output of session.
  90. ///
  91. ///\return Every output node's output tensors.
  92. ///
  93. ///\note
  94. /// The caller needs to free memory of outputs.
  95. std::map<std::string, std::vector<Tensor *>> GetAllOutput();
  96. protected:
  97. ///\brief Init the executor.
  98. ///
  99. ///\return Return RET_OK if the initialization is success, otherwhise return RET_ERROR.
  100. int InitExecutor();
  101. const Context &_ctx;
  102. Graph *_graph = nullptr;
  103. GraphExecution *_executor = nullptr;
  104. bool reinitExecutor = true;
  105. };
  106. ///\brief MindSpore predict neural network session create function
  107. ///
  108. /// This function used to create MindSpore predict neural network session, which will be used to run the neural network.
  109. ///
  110. ///\param[in] sessionName The name of the session.
  111. ///\param[in] graphBuf The buffer of the graph, used for build session.
  112. ///\param[in] size The size of the session.
  113. ///\param[in] ctx The context of the session.
  114. ///
  115. ///\return Instance of MindSpore predict session.
  116. ///
  117. ///\note
  118. /// The caller needs to allocate and free memory of graph buffer.
  119. std::shared_ptr<Session> MSPREDICT_API CreateSession(const char *graphBuf, size_t size, const Context &ctx);
  120. } // namespace predict
  121. } // namespace mindspore
  122. #endif // PREDICT_INCLUDE_SESSION_H_