You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.h 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
  17. #define MINDSPORE_INCLUDE_API_CONTEXT_H
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <map>
  22. #include "include/api/types.h"
  23. #include "include/api/dual_abi_helper.h"
  24. namespace mindspore {
  25. enum DeviceType {
  26. kCPU = 0,
  27. kMaliGPU,
  28. kNvidiaGPU,
  29. kKirinNPU,
  30. kAscend910,
  31. kAscend310,
  32. // add new type here
  33. kInvalidDeviceType = 100,
  34. };
  35. class Allocator;
  36. class DeviceInfoContext;
  37. class MS_API Context {
  38. public:
  39. Context();
  40. ~Context() = default;
  41. void SetThreadNum(int32_t thread_num);
  42. int32_t GetThreadNum() const;
  43. void SetAllocator(const std::shared_ptr<Allocator> &allocator);
  44. std::shared_ptr<Allocator> GetAllocator() const;
  45. std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
  46. private:
  47. struct Data;
  48. std::shared_ptr<Data> data_;
  49. };
  50. class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
  51. public:
  52. struct Data;
  53. DeviceInfoContext();
  54. virtual ~DeviceInfoContext() = default;
  55. virtual enum DeviceType GetDeviceType() const = 0;
  56. template <class T>
  57. std::shared_ptr<T> Cast() {
  58. static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
  59. if (GetDeviceType() != T().GetDeviceType()) {
  60. return nullptr;
  61. }
  62. return std::static_pointer_cast<T>(shared_from_this());
  63. }
  64. protected:
  65. std::shared_ptr<Data> data_;
  66. };
  67. class MS_API CPUDeviceInfo : public DeviceInfoContext {
  68. public:
  69. enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
  70. /// \brief Set the thread affinity to CPU cores.
  71. ///
  72. /// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
  73. void SetThreadAffinity(int mode);
  74. int GetThreadAffinity() const;
  75. void SetEnableFP16(bool is_fp16);
  76. bool GetEnableFP16() const;
  77. };
  78. class MS_API MaliGPUDeviceInfo : public DeviceInfoContext {
  79. public:
  80. enum DeviceType GetDeviceType() const override { return DeviceType::kMaliGPU; };
  81. void SetEnableFP16(bool is_fp16);
  82. bool GetEnableFP16() const;
  83. };
  84. class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
  85. public:
  86. enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
  87. void SetFrequency(int frequency);
  88. int GetFrequency() const;
  89. };
  90. class MS_API NvidiaGPUDeviceInfo : public DeviceInfoContext {
  91. public:
  92. enum DeviceType GetDeviceType() const override { return DeviceType::kNvidiaGPU; };
  93. void SetDeviceID(uint32_t device_id);
  94. uint32_t GetDeviceID() const;
  95. void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
  96. bool GetGpuTrtInferMode() const;
  97. inline void SetPrecisionMode(const std::string &precison_mode);
  98. inline std::string GetPrecisionMode() const;
  99. private:
  100. void SetPrecisionMode(const std::vector<char> &precision_mode);
  101. std::vector<char> GetPrecisionModeChar() const;
  102. };
  103. void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  104. SetPrecisionMode(StringToChar(precision_mode));
  105. }
  106. std::string NvidiaGPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
  107. class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
  108. public:
  109. enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
  110. void SetDeviceID(uint32_t device_id);
  111. uint32_t GetDeviceID() const;
  112. };
  113. class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
  114. public:
  115. enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
  116. void SetDeviceID(uint32_t device_id);
  117. uint32_t GetDeviceID() const;
  118. inline void SetDumpConfigPath(const std::string &cfg_path);
  119. inline std::string GetDumpConfigPath() const;
  120. // aipp config file
  121. inline void SetInsertOpConfigPath(const std::string &cfg_path);
  122. inline std::string GetInsertOpConfigPath() const;
  123. // nchw or nhwc
  124. inline void SetInputFormat(const std::string &format);
  125. inline std::string GetInputFormat() const;
  126. // Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1"
  127. inline void SetInputShape(const std::string &shape);
  128. inline std::string GetInputShape() const;
  129. void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
  130. std::map<int, std::vector<int>> GetInputShapeMap() const;
  131. void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
  132. inline std::string GetDynamicBatchSize() const;
  133. // FP32, UINT8 or FP16, default as FP32
  134. void SetOutputType(enum DataType output_type);
  135. enum DataType GetOutputType() const;
  136. // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
  137. inline void SetPrecisionMode(const std::string &precision_mode);
  138. inline std::string GetPrecisionMode() const;
  139. // Optional "high_performance" and "high_precision", "high_performance" is set as default
  140. inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
  141. inline std::string GetOpSelectImplMode() const;
  142. inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
  143. inline std::string GetFusionSwitchConfigPath() const;
  144. // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
  145. inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
  146. inline std::string GetBufferOptimizeMode() const;
  147. private:
  148. void SetDumpConfigPath(const std::vector<char> &cfg_path);
  149. std::vector<char> GetDumpConfigPathChar() const;
  150. void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
  151. std::vector<char> GetInsertOpConfigPathChar() const;
  152. void SetInputFormat(const std::vector<char> &format);
  153. std::vector<char> GetInputFormatChar() const;
  154. void SetInputShape(const std::vector<char> &shape);
  155. std::vector<char> GetInputShapeChar() const;
  156. std::vector<char> GetDynamicBatchSizeChar() const;
  157. void SetPrecisionMode(const std::vector<char> &precision_mode);
  158. std::vector<char> GetPrecisionModeChar() const;
  159. void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
  160. std::vector<char> GetOpSelectImplModeChar() const;
  161. void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
  162. std::vector<char> GetFusionSwitchConfigPathChar() const;
  163. void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
  164. std::vector<char> GetBufferOptimizeModeChar() const;
  165. };
  166. void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
  167. std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
  168. void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
  169. SetInsertOpConfigPath(StringToChar(cfg_path));
  170. }
  171. std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
  172. void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
  173. std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
  174. void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
  175. std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
  176. std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
  177. void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  178. SetPrecisionMode(StringToChar(precision_mode));
  179. }
  180. std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
  181. void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
  182. SetOpSelectImplMode(StringToChar(op_select_impl_mode));
  183. }
  184. std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
  185. void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
  186. SetFusionSwitchConfigPath(StringToChar(cfg_path));
  187. }
  188. std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
  189. return CharToString(GetFusionSwitchConfigPathChar());
  190. }
  191. void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
  192. SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
  193. }
  194. std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
  195. } // namespace mindspore
  196. #endif // MINDSPORE_INCLUDE_API_CONTEXT_H