You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_converter.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test mindconverter to convert user's PyTorch network script.
  18. Usage:
  19. pytest tests/st/func/mindconverter
  20. """
  21. import difflib
  22. import os
  23. import re
  24. import sys
  25. import pytest
  26. from mindinsight.mindconverter.converter import main
  27. from mindinsight.mindconverter.graph_based_converter.framework import main_graph_base_converter
  28. @pytest.mark.usefixtures('create_output_dir')
  29. class TestConverter:
  30. """Test Converter module."""
  31. @classmethod
  32. def setup_class(cls):
  33. """Setup method."""
  34. cls.script_dir = os.path.join(os.path.dirname(__file__), 'data')
  35. pytorch_base_dir = os.path.dirname(__file__).split('/')[:3]
  36. cls.pytorch_dir = \
  37. '/'.join(pytorch_base_dir + ['share-data', 'dataset', 'mindinsight_dataset', 'resnet50'])
  38. sys.path.insert(0, cls.script_dir)
  39. @classmethod
  40. def teardown_class(cls):
  41. """Teardown method."""
  42. sys.path.remove(cls.script_dir)
  43. @pytest.mark.level0
  44. @pytest.mark.platform_arm_ascend_training
  45. @pytest.mark.platform_x86_gpu_training
  46. @pytest.mark.platform_x86_ascend_training
  47. @pytest.mark.platform_x86_cpu
  48. @pytest.mark.env_single
  49. def test_convert_lenet(self, output):
  50. """Test LeNet script of the PyTorch convert to MindSpore script"""
  51. script_filename = "lenet_script.py"
  52. expect_filename = "lenet_converted.py"
  53. files_config = {
  54. 'root_path': self.script_dir,
  55. 'in_files': [os.path.join(self.script_dir, script_filename)],
  56. 'outfile_dir': output,
  57. 'report_dir': output
  58. }
  59. main(files_config)
  60. assert os.path.isfile(os.path.join(output, script_filename))
  61. with open(os.path.join(output, script_filename)) as converted_f:
  62. converted_source = converted_f.readlines()
  63. with open(os.path.join(self.script_dir, expect_filename)) as expect_f:
  64. expect_source = expect_f.readlines()
  65. diff = difflib.ndiff(converted_source, expect_source)
  66. diff_lines = 0
  67. for line in diff:
  68. if line.startswith('+'):
  69. diff_lines += 1
  70. converted_ratio = 100 - (diff_lines * 100) / (len(expect_source))
  71. assert converted_ratio >= 80
  72. @pytest.mark.level0
  73. @pytest.mark.platform_arm_ascend_training
  74. @pytest.mark.platform_x86_gpu_training
  75. @pytest.mark.platform_x86_ascend_training
  76. @pytest.mark.platform_x86_cpu
  77. @pytest.mark.env_single
  78. def test_main_graph_based_converter(self, output):
  79. """Test main graph based converter."""
  80. pytorch_filename = "resnet50.pth"
  81. expected_model_filename = "resnet50.py"
  82. expected_report_filename = "report_of_resnet50.txt"
  83. file_config = {
  84. 'model_file': os.path.join(self.pytorch_dir, pytorch_filename),
  85. 'shape': (1, 3, 224, 224),
  86. 'outfile_dir': output,
  87. 'report_dir': output
  88. }
  89. with pytest.raises(ValueError) as e:
  90. main_graph_base_converter(file_config=file_config)
  91. assert os.path.isfile(os.path.join(output, expected_model_filename))
  92. assert os.path.isfile(os.path.join(output, expected_report_filename))
  93. with open(os.path.join(output, expected_report_filename)) as converted_r:
  94. converted_report = converted_r.readlines()
  95. converted_rate = re.findall(r".*(?:Converted Rate: )(.*)[.]", converted_report[-1])
  96. assert converted_rate[0] == '100.00%'
  97. exec_msg = e.value.args[0]
  98. assert exec_msg == "torch.__spec__ is None"