You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Base module."""
  16. import os
  17. import re
  18. import enum
  19. import collections
  20. import numpy as np
  21. from mindinsight.domain.graph.exceptions import UnknownTensorError
  22. class MindSporeType(enum.Enum):
  23. """MindSpore Type."""
  24. INT = 'int32'
  25. UINT = 'uint'
  26. FLOAT = 'float16'
  27. TENSOR = 'tensor'
  28. class DeviceType(enum.Enum):
  29. """Device Type."""
  30. ASCEND = 'ascend'
  31. GPU = 'gpu'
  32. class DumpType(enum.Enum):
  33. """Dump Type."""
  34. E2E = 'e2e'
  35. ASYNC = 'async'
  36. class Tensor:
  37. """
  38. Tensor object of dump file.
  39. Args:
  40. op_id (str): Operator ID.
  41. index (int): Index of operator inputs/outputs.
  42. file_path (str): Absolute file path of tensor file.
  43. """
  44. FEATURE = collections.namedtuple('Feature', ['type', 'id', 'io'])
  45. VALUE = collections.namedtuple('Value', ['index', 'shape', 'dtype', 'path'])
  46. @classmethod
  47. def extract_shape_from_str(cls, shape_str):
  48. """
  49. Extract shape from tensor file.
  50. Args:
  51. shape_str (str): Shape string.
  52. Returns:
  53. tuple, shape of tensor file.
  54. """
  55. shape = tuple([int(dim.strip()) for dim in shape_str.strip('_').split('_')])
  56. # The shape info in dump file name is (0,) which is inconsistent with the actual tensor shape.
  57. # The shape needs to be converted to (1,).
  58. if shape == (0,):
  59. shape = (1,)
  60. return shape
  61. @classmethod
  62. def parse_tensor_file_name(cls, file_name):
  63. """
  64. Parse tensor file name.
  65. Args:
  66. file_name (str): Tensor file name.
  67. Returns:
  68. bool, indicating if node is operator.
  69. dict, tensor file info.
  70. Raises:
  71. UnknownTensorError: If tensor file name can not be recognized.
  72. """
  73. is_op = False
  74. is_npy = file_name.endswith('.npy')
  75. if re.search(r'-op\d+(_|(\.\d+\.\d+\.))(input|output)(_|\.)', file_name):
  76. is_op = True
  77. dump_type = DumpType.E2E
  78. if re.search(r'-op(?P<op_id>\d+)\.(?P<stream_id>\d+)\.(?P<task_id>\d+)', file_name):
  79. dump_type = DumpType.ASYNC
  80. if dump_type == DumpType.ASYNC:
  81. file_name = file_name[file_name.find('.')+1:]
  82. if is_npy:
  83. regex = r'_(?P<op_name>[A-Za-z0-9]+)-op(?P<op_id>\d+)' \
  84. r'\.(?P<stream_id>\d+)\.(?P<task_id>\d+)' \
  85. r'\.(?P<io>input|output)' \
  86. r'\.(?P<index>\d+)' \
  87. r'\.npy$'
  88. else:
  89. regex = r'_(?P<op_name>[A-Za-z0-9]+)-op(?P<op_id>\d+)' \
  90. r'\.(?P<stream_id>\d+)\.(?P<task_id>\d+)' \
  91. r'\.(?P<io>input|output)' \
  92. r'\.(?P<index>\d+)' \
  93. r'\.(?P<shape>[0-9\_]+)' \
  94. r'\.(?P<dtype>bool|((uint|int|float)\d+))' \
  95. r'\.(?P<format>[A-Za-z0-9\_]+)\.bin$'
  96. else:
  97. regex = r'--(?P<op_name>[A-Za-z0-9\_]+)-op(?P<op_id>\d+)' \
  98. r'_(?P<io>input|output)' \
  99. r'_(?P<index>\d+)' \
  100. r'_shape_(?P<shape>[0-9\_]+)' \
  101. r'_.*(?P<dtype>Bool|((UInt|Int|Float)\d+))' \
  102. r'_(?P<format>[A-Za-z0-9\_]+)\.bin$'
  103. else:
  104. regex = r'^(?P<node_name>[A-Za-z0-9\.\_]+)' \
  105. r'_(?P<io>input|output)' \
  106. r'_(?P<index>\d+)' \
  107. r'_shape_(?P<shape>[0-9\_]+)' \
  108. r'_.*(?P<dtype>Bool|((UInt|Int|Float)\d+))' \
  109. r'_(?P<format>[A-Za-z0-9\_]+)\.bin$'
  110. pattern = re.search(regex, file_name)
  111. if pattern is None:
  112. raise UnknownTensorError(is_op, file_name)
  113. info = pattern.groupdict()
  114. info['index'] = int(info['index'])
  115. info['shape'] = None if is_npy else cls.extract_shape_from_str(info['shape'])
  116. info['dtype'] = None if is_npy else info['dtype'].lower()
  117. return is_op, info
  118. @classmethod
  119. def scan_tensors(cls, tensor_dir):
  120. """
  121. Scan tensors.
  122. Args:
  123. tensor_dir (str): Directory path where holds the tensor files.
  124. check (lambda): Function to check tensor values.
  125. Returns:
  126. dict, tensor file mapping.
  127. """
  128. tensor_mapping = {}
  129. if not tensor_dir:
  130. return tensor_mapping
  131. file_names = os.listdir(tensor_dir)
  132. for file_name in file_names:
  133. full_path = os.path.join(tensor_dir, file_name)
  134. if not re.search(r'\.(bin|npy)$', file_name) or os.path.isdir(full_path):
  135. continue
  136. try:
  137. is_op, info = cls.parse_tensor_file_name(file_name)
  138. except UnknownTensorError:
  139. continue
  140. if is_op:
  141. feature = cls.FEATURE(type=info['op_name'], id=info['op_id'], io=info['io'])
  142. else:
  143. feature = cls.FEATURE(type='', id=info['node_name'], io=info['io'])
  144. value = cls.VALUE(index=info['index'], shape=info['shape'], dtype=info['dtype'], path=full_path)
  145. tensors = tensor_mapping.get(feature)
  146. if tensors:
  147. tensor_mapping[feature].append(value)
  148. tensor_mapping[feature].sort(key=lambda x: x[0])
  149. else:
  150. tensor_mapping[feature] = [value]
  151. return tensor_mapping
  152. def __init__(self, op_id, index, file_path):
  153. self.op_id = op_id
  154. self.index = index
  155. self.file_path = file_path
  156. def load(self):
  157. """
  158. Load tensor file.
  159. Returns:
  160. ndarray, tensor data.
  161. """
  162. if self.file_path.endswith('.npy'):
  163. tensor = np.load(self.file_path)
  164. return tensor
  165. metas = self.metas
  166. if metas is None:
  167. return None
  168. dtype = getattr(np, metas['dtype'])
  169. tensor = np.fromfile(self.file_path, dtype=dtype)
  170. try:
  171. tensor = tensor.reshape(metas['shape'])
  172. except ValueError:
  173. pass
  174. return tensor
  175. @property
  176. def metas(self):
  177. """
  178. Metas property.
  179. Returns:
  180. dict, metas extracted from tensor file name.
  181. """
  182. file_name = os.path.basename(self.file_path)
  183. try:
  184. is_op, info = self.parse_tensor_file_name(file_name)
  185. except UnknownTensorError:
  186. return None
  187. if is_op:
  188. info.pop('op_name')
  189. info.pop('op_id')
  190. else:
  191. info.pop('node_name')
  192. if file_name.endswith('.npy'):
  193. info.pop('dtype')
  194. info.pop('shape')
  195. return info
  196. @property
  197. def full_name(self):
  198. """
  199. Full name property.
  200. Returns:
  201. str, full name.
  202. """
  203. full_name_str, _ = os.path.basename(self.file_path).split('_output_')
  204. return full_name_str.replace('--', '/')
  205. @property
  206. def scope(self):
  207. """
  208. Scope property.
  209. Returns:
  210. str, scope.
  211. """
  212. return os.path.dirname(self.full_name)
  213. def __repr__(self):
  214. return str({
  215. 'op_id': self.op_id,
  216. 'index': self.index,
  217. 'file_path': self.file_path,
  218. })
  219. class NodeType(enum.Enum):
  220. """Node Type."""
  221. OPERATOR = 'operator'
  222. PARAMETER = 'parameter'
  223. CONSTANT = 'constant'
  224. class InputType(enum.Enum):
  225. """Input Type."""
  226. OPERATOR = 'operator'
  227. PARAMETER = 'parameter'
  228. CONSTANT = 'constant'
  229. TENSOR = 'tensor'
  230. SCALAR = 'scalar'
  231. REFERENCE = 'reference'
  232. NONE = 'none'
  233. class OutputType(enum.Enum):
  234. """Output Type."""
  235. NONE = 'none'
  236. BOOL = 'bool'
  237. INT8 = 'int8'
  238. INT16 = 'int16'
  239. INT32 = 'int32'
  240. INT64 = 'int64'
  241. UINT8 = 'uint8'
  242. UINT16 = 'uint16'
  243. UINT32 = 'uint32'
  244. UINT64 = 'uint64'
  245. FLOAT16 = 'float16'
  246. FLOAT32 = 'float32'
  247. FLOAT64 = 'float64'
  248. TENSOR = 'tensor'
  249. TUPLE = 'tuple'
  250. class Input:
  251. """
  252. Graph node input.
  253. Args:
  254. input_type (InputType): Input type.
  255. input_name (str): Input name.
  256. """
  257. def __init__(self, input_type, input_name):
  258. self.type = input_type
  259. self.name = input_name
  260. self.op_id = ''
  261. self.info = None
  262. def __repr__(self):
  263. return str({
  264. 'type': self.type,
  265. 'name': self.name,
  266. 'op_id': self.op_id,
  267. 'info': self.info,
  268. })
  269. class Output:
  270. """
  271. Graph node output.
  272. Args:
  273. output_type (OutputType): Output type.
  274. """
  275. SCALAR_TYPES = (
  276. OutputType.INT8,
  277. OutputType.INT16,
  278. OutputType.INT32,
  279. OutputType.INT64,
  280. OutputType.UINT8,
  281. OutputType.UINT16,
  282. OutputType.UINT32,
  283. OutputType.UINT64,
  284. OutputType.FLOAT16,
  285. OutputType.FLOAT32,
  286. )
  287. def __init__(self, output_type):
  288. self.type = output_type
  289. if output_type == OutputType.NONE:
  290. self.info = None
  291. elif output_type == OutputType.BOOL:
  292. self.info = dict(value=None)
  293. elif output_type in self.SCALAR_TYPES:
  294. self.info = dict(value=None)
  295. elif output_type == OutputType.TENSOR:
  296. self.info = dict(dtype='', shape=(), tensor=None)
  297. elif output_type == OutputType.TUPLE:
  298. self.info = dict(dtypes=[], shapes=[], tensors=[])
  299. def __repr__(self):
  300. return str({
  301. 'type': self.type,
  302. 'info': self.info,
  303. })
  304. class Source:
  305. """
  306. Source address info.
  307. Args:
  308. file_path (str): Absolute path of source file.
  309. line_no (int): Line number of code line in source file.
  310. code_line (int): Code line content.
  311. """
  312. def __init__(self, file_path, line_no, code_line):
  313. self.file_path = file_path
  314. self.line_no = line_no
  315. self.code_line = code_line
  316. def to_dict(self):
  317. """Parse to dict."""
  318. return {
  319. 'file_path': self.file_path,
  320. 'line_no': self.line_no,
  321. 'code_line': self.code_line,
  322. }
  323. def __repr__(self):
  324. return str(self.to_dict())
  325. @classmethod
  326. def build_stack_from_source_address(cls, source_address):
  327. """
  328. Build stack from source address.
  329. Args:
  330. source_address (str): Source address content.
  331. Returns:
  332. list, list of Source objects.
  333. """
  334. stack = []
  335. for line in source_address.strip().split('\n'):
  336. regex = r'#\sIn\sfile\s(?P<file_path>.+)\((?P<line_no>\d+)\)/(?P<code_line>.+)/'
  337. pattern = re.search(regex, line.strip())
  338. source = pattern.groupdict()
  339. source['line_no'] = int(source['line_no'])
  340. source['code_line'] = source['code_line'].strip()
  341. stack.append(cls(**source))
  342. return stack
  343. class Node:
  344. """
  345. Graph node.
  346. Args:
  347. name (str): Node name.
  348. """
  349. def __init__(self, name):
  350. self.name = name
  351. self.output = None
  352. self.downstream = []
  353. self.raw = ''
  354. class Constant(Node):
  355. """Constant node within graph."""
  356. def __repr__(self):
  357. return str({
  358. 'name': self.name,
  359. 'output': self.output,
  360. 'downstream': self.downstream,
  361. })
  362. class Parameter(Node):
  363. """Parameter node within graph."""
  364. def __repr__(self):
  365. return str({
  366. 'name': self.name,
  367. 'output': self.output,
  368. 'downstream': self.downstream,
  369. })
  370. class Operator(Node):
  371. """
  372. Operator node within graph.
  373. Args:
  374. op_name (str): Operator name.
  375. op_type (str): Operator type.
  376. """
  377. def __init__(self, op_name, op_type):
  378. super().__init__(op_name)
  379. self.type = op_type
  380. self.inputs = []
  381. self.attrs = {}
  382. self.full_name = ''
  383. self.stack = []
  384. @property
  385. def scope(self):
  386. """
  387. Scope property.
  388. Returns:
  389. str, scope.
  390. """
  391. return os.path.dirname(self.full_name)
  392. @property
  393. def op_id(self):
  394. """
  395. Op ID property.
  396. Returns:
  397. str, op ID.
  398. """
  399. pattern = re.search(r'-op(?P<op_id>\d+)$', self.full_name)
  400. if not pattern:
  401. return self.full_name
  402. info = pattern.groupdict()
  403. return info['op_id']
  404. def __repr__(self):
  405. return str({
  406. 'name': self.name,
  407. 'type': self.type,
  408. 'inputs': self.inputs,
  409. 'output': self.output,
  410. 'downstream': self.downstream,
  411. 'attrs': self.attrs,
  412. 'full_name': self.full_name,
  413. 'op_id': self.op_id,
  414. })
  415. class Parser:
  416. """Graph file parser."""
  417. def __init__(self, graph_data=None, tensor_dir=''):
  418. self.graph_data = graph_data
  419. self.tensor_dir = os.path.realpath(tensor_dir) if tensor_dir else ''
  420. self.constants = []
  421. self.parameters = []
  422. self.operators = []
  423. self.tensor_mapping = {}
  424. def parse(self):
  425. """Parse."""
  426. raise NotImplementedError