You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """The definition of tensor stream."""
  16. from abc import abstractmethod, ABC
  17. import numpy as np
  18. from mindinsight.utils.tensor import TensorUtils
  19. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  20. from mindinsight.debugger.common.log import logger as log
  21. from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP
  22. from mindinsight.debugger.proto.ms_graph_pb2 import DataType
  23. class BaseTensor(ABC):
  24. """Tensor data structure."""
  25. def __init__(self, step=0):
  26. self._step = step
  27. @property
  28. @abstractmethod
  29. def name(self):
  30. """The property of tensor name."""
  31. @property
  32. @abstractmethod
  33. def dtype(self):
  34. """The property of tensor dtype."""
  35. @property
  36. @abstractmethod
  37. def shape(self):
  38. """The property of tensor shape."""
  39. @property
  40. @abstractmethod
  41. def value(self):
  42. """The property of tensor shape."""
  43. @abstractmethod
  44. def get_tensor_value_by_shape(self, shape=None):
  45. """Get tensor value by shape."""
  46. def _to_dict(self):
  47. """Get tensor info in dict format."""
  48. res = {
  49. 'full_name': self.name,
  50. 'step': self._step,
  51. 'dtype': self.dtype,
  52. 'shape': self.shape,
  53. 'has_prev_step': False
  54. }
  55. return res
  56. def get_basic_info(self):
  57. """Return basic info about tensor info."""
  58. if not self.shape:
  59. value = self.value
  60. else:
  61. value = 'click to view'
  62. res = self._to_dict()
  63. res['value'] = value
  64. return res
  65. def get_full_info(self, shape=None):
  66. """Get tensor info with value."""
  67. res = self._to_dict()
  68. value_info = self.get_tensor_serializable_value_by_shape(shape)
  69. res.update(value_info)
  70. return res
  71. class OpTensor(BaseTensor):
  72. """Tensor data structure for operator Node."""
  73. max_number_data_show_on_ui = 100000
  74. def __init__(self, tensor_proto, step=0):
  75. # the type of tensor_proto is TensorProto
  76. super(OpTensor, self).__init__(step)
  77. self._tensor_proto = tensor_proto
  78. self._value = self.generate_value(tensor_proto)
  79. @property
  80. def name(self):
  81. """The property of tensor name."""
  82. node_name = self._tensor_proto.node_name
  83. slot = self._tensor_proto.slot
  84. return ':'.join([node_name, slot])
  85. @property
  86. def dtype(self):
  87. """The property of tensor dtype."""
  88. tensor_type = DataType.Name(self._tensor_proto.data_type)
  89. return tensor_type
  90. @property
  91. def shape(self):
  92. """The property of tensor shape."""
  93. return list(self._tensor_proto.dims)
  94. @property
  95. def value(self):
  96. """The property of tensor value."""
  97. tensor_value = None
  98. if self._value is not None:
  99. tensor_value = self._value.tolist()
  100. return tensor_value
  101. @property
  102. def numpy_value(self):
  103. """The property of tensor value in numpy type."""
  104. return self._value
  105. def generate_value(self, tensor_proto):
  106. """Generate tensor value from proto."""
  107. tensor_value = None
  108. if tensor_proto.tensor_content:
  109. tensor_value = tensor_proto.tensor_content
  110. np_type = NUMPY_TYPE_MAP.get(self.dtype)
  111. tensor_value = np.frombuffer(tensor_value, dtype=np_type)
  112. tensor_value = tensor_value.reshape(self.shape)
  113. return tensor_value
  114. def get_tensor_serializable_value_by_shape(self, shape=None):
  115. """
  116. Get tensor value info by shape.
  117. Args:
  118. shape (tuple): The specified range of tensor value.
  119. Returns:
  120. dict, the specified tensor value and value statistics.
  121. """
  122. tensor_value = self.get_tensor_value_by_shape(shape)
  123. res = {}
  124. # the type of tensor_value is one of None, np.ndarray or str
  125. if isinstance(tensor_value, np.ndarray):
  126. statistics = TensorUtils.get_statistics_from_tensor(tensor_value)
  127. res['statistics'] = TensorUtils.get_statistics_dict(statistics)
  128. res['value'] = tensor_value.tolist()
  129. elif isinstance(tensor_value, str):
  130. res['value'] = tensor_value
  131. return res
  132. def get_tensor_value_by_shape(self, shape=None):
  133. """
  134. Get tensor value by shape.
  135. Args:
  136. shape (tuple): The specified shape.
  137. Returns:
  138. Union[None, str, numpy.ndarray], the sub-tensor.
  139. """
  140. if self._value is None:
  141. log.warning("%s has no value yet.", self.name)
  142. return None
  143. if shape is None or not isinstance(shape, tuple):
  144. log.info("Get the whole tensor value with shape is %s", shape)
  145. return self._value
  146. if len(shape) != len(self.shape):
  147. log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape)
  148. raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
  149. try:
  150. value = self._value[shape]
  151. except IndexError as err:
  152. log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape)
  153. log.exception(err)
  154. raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
  155. if isinstance(value, np.ndarray):
  156. if value.size > self.max_number_data_show_on_ui:
  157. value = "Too large to show."
  158. log.info("The tensor size is %s, which is too large to show on UI.")
  159. else:
  160. value = np.asarray(value)
  161. return value
  162. class ConstTensor(BaseTensor):
  163. """Tensor data structure for Const Node."""
  164. def __init__(self, const_proto):
  165. # the type of const_proto is NamedValueProto
  166. super(ConstTensor, self).__init__()
  167. self._const_proto = const_proto
  168. def set_step(self, step):
  169. """Set step value."""
  170. self._step = step
  171. @property
  172. def name(self):
  173. """The property of tensor name."""
  174. return self._const_proto.key + ':0'
  175. @property
  176. def dtype(self):
  177. """The property of tensor dtype."""
  178. return DataType.Name(self._const_proto.value.dtype)
  179. @property
  180. def shape(self):
  181. """The property of tensor shape."""
  182. return []
  183. @property
  184. def value(self):
  185. """The property of tensor shape."""
  186. fields = self._const_proto.value.ListFields()
  187. if len(fields) != 2:
  188. log.warning("Unexpected const proto <%s>.\n Please check offline.", self._const_proto)
  189. for field_name, field_value in fields:
  190. if field_name != 'dtype':
  191. return field_value
  192. return None
  193. def get_tensor_value_by_shape(self, shape=None):
  194. """Get tensor info with value."""
  195. if shape is not None:
  196. log.warning("Invalid shape for const value.")
  197. return self.value