You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_handler.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Description: This file is used for testing graph handler.
  17. """
  18. import os
  19. import pytest
  20. from ....utils.tools import compare_result_with_file
  21. from .conftest import init_graph_handler
  22. class TestGraphHandler:
  23. """Test GraphHandler."""
  24. graph_results_dir = os.path.join(os.path.dirname(__file__), 'expect_results')
  25. graph_handler = init_graph_handler()
  26. @pytest.mark.level0
  27. @pytest.mark.env_single
  28. @pytest.mark.platform_x86_cpu
  29. @pytest.mark.platform_arm_ascend_training
  30. @pytest.mark.platform_x86_gpu_training
  31. @pytest.mark.platform_x86_ascend_training
  32. @pytest.mark.parametrize("filter_condition, result_file", [
  33. (None, "graph_handler_get_1_no_filter_condintion.json"),
  34. ({'name': 'Default'}, "graph_handler_get_2_list_nodes.json"),
  35. ({'name': 'Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190', 'single_node': True},
  36. "graph_handler_get_3_single_node.json")
  37. ])
  38. def test_get(self, filter_condition, result_file):
  39. """Test get."""
  40. result = self.graph_handler.get(filter_condition)
  41. file_path = os.path.join(self.graph_results_dir, result_file)
  42. compare_result_with_file(result, file_path)
  43. @pytest.mark.level0
  44. @pytest.mark.env_single
  45. @pytest.mark.platform_x86_cpu
  46. @pytest.mark.platform_arm_ascend_training
  47. @pytest.mark.platform_x86_gpu_training
  48. @pytest.mark.platform_x86_ascend_training
  49. @pytest.mark.parametrize("node_name, result_file", [
  50. ("Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190",
  51. "tenor_hist_0.json"),
  52. ("Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op22",
  53. "tensor_hist_1.json")
  54. ])
  55. def test_get_tensor_history(self, node_name, result_file):
  56. """Test get tensor history."""
  57. result = self.graph_handler.get_tensor_history(node_name)
  58. file_path = os.path.join(self.graph_results_dir, result_file)
  59. compare_result_with_file(result, file_path)
  60. @pytest.mark.level0
  61. @pytest.mark.env_single
  62. @pytest.mark.platform_x86_cpu
  63. @pytest.mark.platform_arm_ascend_training
  64. @pytest.mark.platform_x86_gpu_training
  65. @pytest.mark.platform_x86_ascend_training
  66. @pytest.mark.parametrize("pattern, result_file", [
  67. ("withlogits", "search_nodes_0.json"),
  68. ("cst", "search_node_1.json")
  69. ])
  70. def test_search_nodes(self, pattern, result_file):
  71. """Test search nodes."""
  72. result = self.graph_handler.search_nodes(pattern)
  73. file_path = os.path.join(self.graph_results_dir, result_file)
  74. compare_result_with_file(result, file_path)
  75. @pytest.mark.level0
  76. @pytest.mark.env_single
  77. @pytest.mark.platform_x86_cpu
  78. @pytest.mark.platform_arm_ascend_training
  79. @pytest.mark.platform_x86_gpu_training
  80. @pytest.mark.platform_x86_ascend_training
  81. @pytest.mark.parametrize("node_name, expect_type", [
  82. ("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/cst1", 'Const'),
  83. ("Default/TransData-op99", "TransData")
  84. ])
  85. def test_get_node_type(self, node_name, expect_type):
  86. """Test get node type."""
  87. node_type = self.graph_handler.get_node_type(node_name)
  88. assert node_type == expect_type
  89. @pytest.mark.level0
  90. @pytest.mark.env_single
  91. @pytest.mark.platform_x86_cpu
  92. @pytest.mark.platform_arm_ascend_training
  93. @pytest.mark.platform_x86_gpu_training
  94. @pytest.mark.platform_x86_ascend_training
  95. @pytest.mark.parametrize("node_name, expect_full_name", [
  96. (None, ""),
  97. ("Default/make_tuple[9]_3/make_tuple-op284", "Default/make_tuple-op284"),
  98. ("Default/args0", "Default/args0")
  99. ])
  100. def test_get_full_name(self, node_name, expect_full_name):
  101. """Test get full name."""
  102. full_name = self.graph_handler.get_full_name(node_name)
  103. assert full_name == expect_full_name
  104. @pytest.mark.level0
  105. @pytest.mark.env_single
  106. @pytest.mark.platform_x86_cpu
  107. @pytest.mark.platform_arm_ascend_training
  108. @pytest.mark.platform_x86_gpu_training
  109. @pytest.mark.platform_x86_ascend_training
  110. @pytest.mark.parametrize("full_name, expect_node_name", [
  111. (None, ""),
  112. ("Default/make_tuple-op284", "Default/make_tuple[9]_3/make_tuple-op284"),
  113. ("Default/args0", "Default/args0")
  114. ])
  115. def test_get_node_name_by_full_name(self, full_name, expect_node_name):
  116. """Test get node name by full name."""
  117. node_name = self.graph_handler.get_node_name_by_full_name(full_name)
  118. assert node_name == expect_node_name
  119. @pytest.mark.level0
  120. @pytest.mark.env_single
  121. @pytest.mark.platform_x86_cpu
  122. @pytest.mark.platform_arm_ascend_training
  123. @pytest.mark.platform_x86_gpu_training
  124. @pytest.mark.platform_x86_ascend_training
  125. @pytest.mark.parametrize("node_name, ascend, expect_next", [
  126. (None, True, "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0"),
  127. (None, False, None),
  128. ("Default/tuple_getitem[10]_0/tuple_getitem-op206", True,
  129. "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89"),
  130. ("Default/tuple_getitem[10]_0/tuple_getitem-op206", False,
  131. "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205")
  132. ])
  133. def test_get_node_by_bfs_order(self, node_name, ascend, expect_next):
  134. """Test get node by BFS order."""
  135. next_node = self.graph_handler.get_node_by_bfs_order(node_name, ascend)
  136. assert next_node == expect_next