You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """The definition of tensor stream."""
  16. from abc import abstractmethod, ABC
  17. import numpy as np
  18. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  19. from mindinsight.debugger.common.log import LOGGER as log
  20. from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP
  21. from mindinsight.debugger.proto.ms_graph_pb2 import DataType
  22. from mindinsight.utils.tensor import TensorUtils
  23. class BaseTensor(ABC):
  24. """Tensor data structure."""
  25. def __init__(self, step=0):
  26. self._step = step
  27. @property
  28. @abstractmethod
  29. def name(self):
  30. """The property of tensor name."""
  31. @property
  32. @abstractmethod
  33. def dtype(self):
  34. """The property of tensor dtype."""
  35. @property
  36. @abstractmethod
  37. def shape(self):
  38. """The property of tensor shape."""
  39. @property
  40. @abstractmethod
  41. def value(self):
  42. """The property of tensor shape."""
  43. @property
  44. def empty(self):
  45. """If the tensor value is valid."""
  46. return self.value is None
  47. def get_tensor_serializable_value_by_shape(self, shape=None):
  48. """
  49. Get tensor value info by shape.
  50. Args:
  51. shape (tuple): The specified range of tensor value.
  52. Returns:
  53. dict, the specified tensor value and value statistics.
  54. """
  55. tensor_value = self.get_tensor_value_by_shape(shape)
  56. res = {}
  57. # the type of tensor_value is one of None, np.ndarray or str
  58. if isinstance(tensor_value, np.ndarray):
  59. res['value'] = tensor_value.tolist()
  60. else:
  61. res['value'] = tensor_value
  62. res['statistics'] = self.get_tensor_statistics()
  63. return res
  64. @abstractmethod
  65. def get_tensor_value_by_shape(self, shape=None):
  66. """Abstract method."""
  67. @abstractmethod
  68. def get_tensor_statistics(self):
  69. """Abstract method."""
  70. def _to_dict(self):
  71. """Get tensor info in dict format."""
  72. res = {
  73. 'full_name': self.name,
  74. 'step': self._step,
  75. 'dtype': self.dtype,
  76. 'shape': self.shape,
  77. 'has_prev_step': False
  78. }
  79. return res
  80. def get_basic_info(self):
  81. """Return basic info about tensor info."""
  82. tensor_value = self.value
  83. if not self.shape:
  84. value = tensor_value.tolist() if isinstance(tensor_value, np.ndarray) else tensor_value
  85. else:
  86. value = 'click to view'
  87. res = self._to_dict()
  88. res['value'] = value
  89. return res
  90. def get_full_info(self, shape=None):
  91. """Get tensor info with value."""
  92. res = self._to_dict()
  93. value_info = self.get_tensor_serializable_value_by_shape(shape)
  94. res.update(value_info)
  95. return res
  96. class OpTensor(BaseTensor):
  97. """
  98. Tensor data structure for operator Node.
  99. Args:
  100. tensor_proto (TensorProto): Tensor proto contains tensor basic info.
  101. tensor_content (byte): Tensor content value in byte format.
  102. step (int): The step of the tensor.
  103. """
  104. max_number_data_show_on_ui = 100000
  105. def __init__(self, tensor_proto, tensor_content, step=0):
  106. # the type of tensor_proto is TensorProto
  107. super(OpTensor, self).__init__(step)
  108. self._tensor_proto = tensor_proto
  109. self._value = self.to_numpy(tensor_content)
  110. self._stats = None
  111. self._tensor_comparison = None
  112. @property
  113. def name(self):
  114. """The property of tensor name."""
  115. node_name = self._tensor_proto.node_name
  116. slot = self._tensor_proto.slot
  117. return ':'.join([node_name, slot])
  118. @property
  119. def dtype(self):
  120. """The property of tensor dtype."""
  121. tensor_type = DataType.Name(self._tensor_proto.data_type)
  122. return tensor_type
  123. @property
  124. def shape(self):
  125. """The property of tensor shape."""
  126. dims = list(self._tensor_proto.dims)
  127. if dims == [0]:
  128. dims = []
  129. return dims
  130. @property
  131. def value(self):
  132. """The property of tensor value."""
  133. return self._value
  134. @property
  135. def stats(self):
  136. """The property of tensor stats."""
  137. return self._stats
  138. @stats.setter
  139. def stats(self, stats):
  140. """
  141. Update tensor stats.
  142. Args:
  143. stats (Statistics): Instance of Statistics.
  144. """
  145. self._stats = stats
  146. @property
  147. def tensor_comparison(self):
  148. """The property of tensor_comparison."""
  149. return self._tensor_comparison
  150. def to_numpy(self, tensor_content):
  151. """
  152. Construct tensor content from byte to numpy.
  153. Args:
  154. tensor_content (byte): The tensor content.
  155. Returns:
  156. Union[None, np.ndarray], the value of the tensor.
  157. """
  158. tensor_value = None
  159. if tensor_content:
  160. np_type = NUMPY_TYPE_MAP.get(self.dtype)
  161. tensor_value = np.frombuffer(tensor_content, dtype=np_type)
  162. tensor_value = tensor_value.reshape(self.shape)
  163. return tensor_value
  164. def get_tensor_statistics(self):
  165. """
  166. Get Tensor statistics.
  167. Returns:
  168. dict, overall statistics.
  169. """
  170. if self.empty:
  171. return {}
  172. if not self.stats:
  173. self.stats = TensorUtils.get_statistics_from_tensor(self.value)
  174. statistics = TensorUtils.get_overall_statistic_dict(self.stats)
  175. return statistics
  176. def update_tensor_comparisons(self, tensor_comparison):
  177. """
  178. Update tensor comparison for tensor.
  179. Args:
  180. tensor_comparison (TensorComparison) instance of TensorComparison.
  181. """
  182. self._tensor_comparison = tensor_comparison
  183. def get_tensor_value_by_shape(self, shape=None):
  184. """
  185. Get tensor value by shape.
  186. Args:
  187. shape (tuple): The specified shape.
  188. Returns:
  189. Union[None, str, numpy.ndarray], the value of parsed tensor.
  190. """
  191. if self._value is None:
  192. log.warning("%s has no value yet.", self.name)
  193. return None
  194. if shape is None or not isinstance(shape, tuple):
  195. log.info("Get the whole tensor value with shape is %s", shape)
  196. return self._value
  197. if len(shape) != len(self.shape):
  198. log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape)
  199. raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
  200. try:
  201. value = self._value[shape]
  202. except IndexError as err:
  203. log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape)
  204. log.exception(err)
  205. raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
  206. if isinstance(value, np.ndarray):
  207. if value.size > self.max_number_data_show_on_ui:
  208. log.info("The tensor size is %d, which is too large to show on UI.", value.size)
  209. value = "Too large to show."
  210. else:
  211. value = np.asarray(value)
  212. return value
  213. class ConstTensor(BaseTensor):
  214. """Tensor data structure for Const Node."""
  215. _STRING_TYPE = 'DT_STRING'
  216. _DT_TYPE = 'DT_TYPE'
  217. def __init__(self, const_proto):
  218. # the type of const_proto is NamedValueProto
  219. super(ConstTensor, self).__init__()
  220. self._const_proto = const_proto
  221. self._value = self.generate_value_from_proto(const_proto.value)
  222. def set_step(self, step):
  223. """Set step value."""
  224. self._step = step
  225. @property
  226. def name(self):
  227. """The property of tensor name."""
  228. return self._const_proto.key + ':0'
  229. @property
  230. def dtype(self):
  231. """The property of tensor dtype."""
  232. return DataType.Name(self._const_proto.value.dtype)
  233. @property
  234. def shape(self):
  235. """The property of tensor shape."""
  236. return []
  237. @property
  238. def value(self):
  239. """The property of tensor shape."""
  240. return self._value
  241. def generate_value_from_proto(self, tensor_proto):
  242. """
  243. Generate tensor value from proto.
  244. Args:
  245. tensor_proto (TensorProto): The tensor proto.
  246. Returns:
  247. Union[None, str, np.ndarray], the value of the tensor.
  248. """
  249. fields = tensor_proto.ListFields()
  250. if len(fields) != 2:
  251. log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto)
  252. tensor_value = None
  253. for field_obj, field_value in fields:
  254. if field_obj.name != 'dtype':
  255. if tensor_proto.dtype == DataType.DT_TUPLE:
  256. tensor_values = []
  257. for field_value_element in field_value:
  258. value_element = self.generate_value_from_proto(field_value_element)
  259. tensor_values.append(value_element)
  260. tensor_value = tensor_values
  261. elif tensor_proto.dtype == DataType.DT_TYPE:
  262. tensor_value = DataType.Name(field_value.data_type)
  263. else:
  264. tensor_value = field_value
  265. break
  266. if tensor_value is not None and tensor_proto.dtype != self._STRING_TYPE:
  267. tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(tensor_proto.dtype))
  268. return tensor_value
  269. def get_tensor_value_by_shape(self, shape=None):
  270. """
  271. Get tensor value by shape.
  272. Args:
  273. shape (tuple): The specified shape.
  274. Returns:
  275. Union[None, str, int, float], the value of parsed tensor.
  276. """
  277. if shape:
  278. log.warning("Invalid shape for const value.")
  279. return self._value
  280. def get_tensor_statistics(self):
  281. """
  282. Get Tensor statistics.
  283. Returns:
  284. dict, overall statistics.
  285. """
  286. if self.empty or self.dtype == self._STRING_TYPE or self.dtype == self._DT_TYPE:
  287. log.debug("The tensor dtype is: %s, skip getting statistics.", self.dtype)
  288. return {}
  289. stats = TensorUtils.get_statistics_from_tensor(self.value)
  290. statistics = TensorUtils.get_overall_statistic_dict(stats)
  291. return statistics