You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PREDICT_INCLUDE_TENSOR_H_
  17. #define PREDICT_INCLUDE_TENSOR_H_
  18. #include <memory>
  19. #include <vector>
  20. #include "dlpack/dlpack.h"
  21. #include "schema/inner/ms_generated.h"
  22. #define MSPREDICT_API __attribute__((visibility("default")))
  23. namespace mindspore {
  24. namespace predict {
  25. ///\brief Allocator definition of MindSpore predict.
  26. class Allocator;
  27. ///\brief Tensor definition of MindSpore predict.
  28. class MSPREDICT_API Tensor {
  29. public:
  30. ///\brief Constructor of MindSpore predict tensor.
  31. ///
  32. ///\param[in] tensor Define the parameters of the tensor.
  33. ///\param[in] copyData Malloc data for the tensor, and copy origin data from
  34. /// input tensor.
  35. ///
  36. ///\return Instance of MindSpore predict tensor.
  37. Tensor(const Tensor &tensor, bool copyData = false);
  38. ///\brief Constructor of MindSpore predict tensor.
  39. ///
  40. ///\param[in] dt Data Type of the tensor, see introduction to 'enum DataType'
  41. /// for supported type.
  42. ///\param[in] dims Dimension Values such as height and width, which defined
  43. /// the shape of the tensor.
  44. ///\param[in] format Tensor format, see introduction to 'enum Format' for
  45. /// supported format.
  46. ///\param[in] data Data of the tensor.
  47. ///
  48. ///\return Instance of MindSpore predict tensor.
  49. ///
  50. ///\note
  51. /// Length of data should align with dt, format and dims, otherwise the
  52. /// application might run into unexpected error,
  53. /// such as segment fault.
  54. /// For example, dt is DT_FLOAT, format is FORMAT_NCHW, dims is [1,3,300,300],
  55. /// then minimum length of data should
  56. /// be 1 * 3 * 300 * 300 * sizeof(float).
  57. Tensor(DataType dt, const std::vector<int64_t> &dims, Format format, void *data);
  58. ///\brief Destructor of MindSpore predict tensor.
  59. ~Tensor();
  60. ///\brief Get MindSpore predict tensor.
  61. ///
  62. ///\param[in] Definition of the tensor.
  63. ///
  64. ///\return Address of MindSpore predict tensor.
  65. static Tensor *CopyFromTensorDef(const TensorDef &tensordef);
  66. ///\brief Get dtype of MindSpore predict tensor.
  67. ///
  68. ///\return Dtype of MindSpore predict tensor.
  69. DLDataType GetTensorDtype() const;
  70. ///\brief Get data of MindSpore predict tensor.
  71. ///
  72. ///\return Address of MindSpore predict tensor data.
  73. void *GetData() const;
  74. ///\brief Set data of MindSpore predict tensor.
  75. ///
  76. ///\param[in] data Address for data of the MindSpore predict tensor instance.
  77. ///
  78. ///\note
  79. /// Length of data should align with dt, format and dims, otherwise the
  80. /// application might run into unexpected error,
  81. /// such as segment fault.
  82. /// For example, dt is DT_FLOAT, format is FORMAT_NCHW, dims is [1,3,300,300],
  83. /// then minimum length of data should
  84. /// be 1 * 3 * 300 * 300 * sizeof(float).
  85. void SetData(void *data);
  86. ///\brief Get data type of MindSpore predict tensor.
  87. ///
  88. ///\return Data Type of the tensor.
  89. DataType GetDataType() const;
  90. ///\brief Set data type of MindSpore predict tensor.
  91. ///
  92. ///\param[in] dt Data Type of the tensor, see introduction to 'enum DataType'
  93. /// for supported type.
  94. void SetDataType(DataType dt);
  95. ///\brief Get number of dimension of MindSpore predict tensor.
  96. ///
  97. ///\return Number of dimension of the MindSpore predict tensor.
  98. int GetNDim() const;
  99. ///\brief Get dimension of MindSpore predict tensor.
  100. ///
  101. ///\return Dimension of the MindSpore predict tensor.
  102. std::vector<int64_t> GetDims() const;
  103. ///\brief Set dimension of MindSpore predict tensor.
  104. ///
  105. ///\param[in] dims Vector that has values of dimension.
  106. void SetDims(const std::vector<int64_t> &dims);
  107. ///\brief Get format of MindSpore predict tensor.
  108. ///
  109. ///\return Format of the MindSpore predict tensor.
  110. Format GetFormat() const { return format; }
  111. ///\brief Set format of MindSpore predict tensor.
  112. ///
  113. ///\param[in] format Format of the tensor.
  114. void SetFormat(Format format) { this->format = format; }
  115. ///\brief Get reference count of MindSpore predict tensor.
  116. ///
  117. ///\return Reference count of the MindSpore predict tensor.
  118. int RefCount() { return refCount; }
  119. ///\brief Increase reference count of MindSpore predict tensor.
  120. ///
  121. ///\param[in] ref The increase of the reference count.
  122. void AddRef(int ref) { refCount += ref; }
  123. ///\brief Decrease reference count of MindSpore predict tensor.
  124. ///
  125. ///\param[in] ref The decrease of the reference count.
  126. void DefRef(int ref) { refCount -= ref; }
  127. ///\brief Get element size of MindSpore predict tensor.
  128. ///
  129. ///\return Element size of MindSpore predict tensor.
  130. size_t GetElementSize() const;
  131. ///\brief Get data size of MindSpore predict tensor.
  132. ///
  133. ///\return Data size of MindSpore predict tensor.
  134. size_t GetDataSize() const;
  135. ///\brief Get element size of MindSpore predict tensor in NC4HW4 format.
  136. ///
  137. ///\param[in] isNhwc Whether the current format is NHWC.
  138. ///
  139. ///\return Element size of MindSpore predict tensor in NC4HW4 format.
  140. size_t GetNC4HW4ElementSize(bool isNhwc);
  141. ///\brief Get data size of MindSpore predict tensor in NC4HW4 format.
  142. ///
  143. ///\param[in] isNhwc Whether the current format is NHWC.
  144. ///
  145. ///\return Data size of MindSpore predict tensor in NC4HW4 format.
  146. size_t GetNC4HW4DataSize(bool isNhwc);
  147. ///\brief Malloc data for the MindSpore predict tensor.
  148. ///
  149. ///\param[in] allocator The malloc source for data.
  150. ///\param[in] refCount The reference count of the data.
  151. ///
  152. ///\return Return RET_OK if the data is successfully allocated, otherwhise return RET_ERROR.
  153. int MallocData(std::shared_ptr<Allocator> allocator = nullptr, int refCount = 0);
  154. ///\brief Free the MindSpore predict tensor.
  155. void FreeTensor();
  156. ///\brief Free the data of MindSpore predict tensor.
  157. void ForceFreeData();
  158. ///\brief Free the data of MindSpore predict tensor.
  159. void FreeData();
  160. ///\brief Compare data size of MindSpore predict tensor in NC4HW4 format.
  161. ///
  162. ///\param[in] dst The compare tensor.
  163. ///
  164. ///\return The result of fuction.
  165. bool CompareShape(const Tensor &dst);
  166. ///\brief Compare shape of MindSpore predict tensor with another shape.
  167. ///
  168. ///\param[in] other The compare shape information.
  169. ///
  170. ///\return The result of function.
  171. bool CompareShape(const std::vector<int64_t> &other);
  172. ///\brief Get instance of MindSpore predict tensor.
  173. ///
  174. ///\return Instance of MindSpore predict dlTensor.
  175. DLTensor *GetDLTensor() { return &dlTensor; }
  176. ///\brief Get height of MindSpore predict tensor.
  177. ///
  178. ///\return Height of MindSpore predict tensor.
  179. int64_t Height() const;
  180. ///\brief Get width of MindSpore predict tensor.
  181. ///
  182. ///\return Width of MindSpore predict tensor.
  183. int64_t Width() const;
  184. ///\brief Get channel of MindSpore predict tensor.
  185. ///
  186. ///\return Channel of MindSpore predict tensor.
  187. int64_t Channel() const;
  188. ///\brief Get batch of MindSpore predict tensor.
  189. ///
  190. ///\return Batch of MindSpore predict tensor.
  191. int64_t Batch() const;
  192. ///\brief Get stride of MindSpore predict tensor.
  193. ///
  194. ///\param[in] index the index of stride.
  195. ///
  196. ///\return Stride of MindSpore predict tensor.
  197. int64_t Stride(int index) const;
  198. ///\brief Set stride of MindSpore predict tensor by input.
  199. ///
  200. ///\param[in] index Index of stride
  201. ///\param[in] stride The stride to set
  202. void SetStride(int index, int64_t stride);
  203. ///\brief Set stride of MindSpore predict tensor by dims.
  204. void SetStride();
  205. void SetScale(bool isScale = true);
  206. private:
  207. bool isScale = false;
  208. int refCount = 0;
  209. int isConst;
  210. Format format;
  211. DLTensor dlTensor;
  212. std::shared_ptr<Allocator> allocator = nullptr;
  213. std::vector<float> scale;
  214. std::vector<int> zeroPoint;
  215. };
  216. } // namespace predict
  217. } // namespace mindspore
  218. #endif // PREDICT_INCLUDE_TENSOR_H_