Browse Source

!189 fix the bug that when the profiler parameter subgraph is Default or Gradients, the profiler analyse will raise an exception

Merge pull request !189 from chenchao99/profiler_analyser
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b2062a9857
2 changed files with 28 additions and 1 deletions
  1. +2
    -0
      mindinsight/profiler/analyser/analyser.py
  2. +26
    -1
      tests/ut/profiler/analyser/test_analyser_aicore_detail.py

+ 2
- 0
mindinsight/profiler/analyser/analyser.py View File

@@ -124,6 +124,8 @@ class AicoreDetailAnalyser(BaseAnalyser):
result = []
for op_type in op_type_order:
detail_infos = type_detail_cache.get(op_type)
if detail_infos is None:
continue
detail_infos.sort(key=lambda item: item[2], reverse=True)
result.extend(detail_infos)



+ 26
- 1
tests/ut/profiler/analyser/test_analyser_aicore_detail.py View File

@@ -267,7 +267,7 @@ class TestAicoreDetailAnalyser(TestCase):
result = self._analyser.query(condition)
self.assertDictEqual(expect_result, result)

def test_query_and_sort_by_op_type(self):
def test_query_and_sort_by_op_type_1(self):
"""Test the success of the querying and sorting function by operator type."""
detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 5, 3, 4])
expect_result = {
@@ -289,6 +289,31 @@ class TestAicoreDetailAnalyser(TestCase):
)
self.assertDictEqual(expect_result, result)

def test_query_and_sort_by_op_type_2(self):
"""Test the success of the querying and sorting function by operator type."""
detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 3, 4, 8, 6])
expect_result = {
'col_name': AicoreDetailAnalyser.__col_names__[0:4],
'object': [item[0:4] for item in detail_infos]
}

filter_condition = {
'op_type': {},
'subgraph': {
'in': ['Default']
},
'is_display_detail': False,
'is_display_full_op_name': False
}
op_type_order = [
'MatMul', 'AtomicAddrClean', 'Cast', 'Conv2D', 'TransData'
]
result = self._analyser.query_and_sort_by_op_type(
filter_condition, op_type_order
)
print(result)
self.assertDictEqual(expect_result, result)

def test_col_names(self):
"""Test the querying column names function."""
self.assertListEqual(


Loading…
Cancel
Save