You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_templates.py 8.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test TemplateManager class."""
  16. import os
  17. import shutil
  18. import tempfile
  19. import textwrap
  20. from mindinsight.wizard.base.templates import TemplateManager
  21. from tests.ut.wizard.utils import generate_file
  22. def create_template_files(template_dir):
  23. """Create network template files."""
  24. all_template_files = []
  25. train_file = os.path.join(template_dir, 'train.py-tpl')
  26. generate_file(train_file,
  27. textwrap.dedent("""\
  28. {% if loss=='SoftmaxCrossEntropyWithLogits' %}
  29. net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
  30. {% elif loss=='SoftmaxCrossEntropyExpand' %}
  31. net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True)
  32. {% endif %}
  33. """))
  34. all_template_files.append(train_file)
  35. os.mkdir(os.path.join(template_dir, 'src'))
  36. config_file = os.path.join(template_dir, 'src', 'config.py-tpl')
  37. generate_file(config_file,
  38. textwrap.dedent("""\
  39. {
  40. 'num_classes': 10,
  41. {% if optimizer=='Momentum' %}
  42. 'lr': 0.01,
  43. "momentum": 0.9,
  44. {% elif optimizer=='SGD' %}
  45. 'lr': 0.1,
  46. {% else %}
  47. 'lr': 0.001,
  48. {% endif %}
  49. 'epoch_size': 1
  50. }
  51. """))
  52. all_template_files.append(config_file)
  53. os.mkdir(os.path.join(template_dir, 'scripts'))
  54. run_standalone_train_file = os.path.join(template_dir, 'scripts', 'run_standalone_train.sh-tpl')
  55. generate_file(run_standalone_train_file,
  56. textwrap.dedent("""\
  57. python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
  58. """))
  59. all_template_files.append(run_standalone_train_file)
  60. os.mkdir(os.path.join(template_dir, 'dataset'))
  61. os.mkdir(os.path.join(template_dir, 'dataset', 'mnist'))
  62. dataset_file = os.path.join(template_dir, 'dataset', 'mnist', 'dataset.py-tpl')
  63. generate_file(dataset_file,
  64. textwrap.dedent("""\
  65. import mindspore.dataset as ds
  66. import mindspore.dataset.transforms.vision.c_transforms as CV
  67. """))
  68. all_template_files.append(dataset_file)
  69. return all_template_files
  70. class TestTemplateManager:
  71. """Test TemplateManager"""
  72. template_dir = None
  73. all_template_files = []
  74. def setup_method(self):
  75. """Setup before call test method."""
  76. self.template_dir = tempfile.mkdtemp()
  77. self.all_template_files = create_template_files(self.template_dir)
  78. def teardown_method(self):
  79. """Tear down after call test method."""
  80. self._remove_dirs()
  81. self.template_dir = None
  82. def _remove_dirs(self):
  83. """Recursively delete a directory tree."""
  84. if self.template_dir and os.path.exists(self.template_dir):
  85. shutil.rmtree(self.template_dir)
  86. def test_template_files(self):
  87. """Test get_template_files method."""
  88. src_file_num = 1
  89. dataset_file_num = 1
  90. template_mgr = TemplateManager(self.template_dir)
  91. all_files = template_mgr.get_template_files()
  92. assert set(all_files) == set(self.all_template_files)
  93. template_mgr = TemplateManager(os.path.join(self.template_dir, 'src'))
  94. all_files = template_mgr.get_template_files()
  95. assert len(all_files) == src_file_num
  96. template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset'))
  97. all_files = template_mgr.get_template_files()
  98. assert len(all_files) == dataset_file_num
  99. template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src'])
  100. all_files = template_mgr.get_template_files()
  101. assert len(all_files) == len(self.all_template_files) - src_file_num
  102. template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset'])
  103. all_files = template_mgr.get_template_files()
  104. assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num
  105. template_mgr = TemplateManager(self.template_dir,
  106. exclude_dirs=['src', 'dataset'],
  107. exclude_files=['train.py-tpl'])
  108. all_files = template_mgr.get_template_files()
  109. assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num - 1
  110. def test_src_render(self):
  111. """Test render file in src directory."""
  112. template_mgr = TemplateManager(os.path.join(self.template_dir, 'src'))
  113. source_files = template_mgr.render(optimizer='Momentum')
  114. assert source_files[0].content == textwrap.dedent("""\
  115. {
  116. 'num_classes': 10,
  117. 'lr': 0.01,
  118. "momentum": 0.9,
  119. 'epoch_size': 1
  120. }
  121. """)
  122. source_files = template_mgr.render(optimizer='SGD')
  123. assert source_files[0].content == textwrap.dedent("""\
  124. {
  125. 'num_classes': 10,
  126. 'lr': 0.1,
  127. 'epoch_size': 1
  128. }
  129. """)
  130. source_files = template_mgr.render()
  131. assert source_files[0].content == textwrap.dedent("""\
  132. {
  133. 'num_classes': 10,
  134. 'lr': 0.001,
  135. 'epoch_size': 1
  136. }
  137. """)
  138. def test_dataset_render(self):
  139. """Test render file in dataset directory."""
  140. template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset'))
  141. source_files = template_mgr.render()
  142. assert source_files[0].content == textwrap.dedent("""\
  143. import mindspore.dataset as ds
  144. import mindspore.dataset.transforms.vision.c_transforms as CV
  145. """)
  146. assert source_files[0].file_relative_path == 'mnist/dataset.py'
  147. assert source_files[0].template_file_path == os.path.join(self.template_dir, 'dataset', 'mnist/dataset.py-tpl')
  148. def test_assemble_render(self):
  149. """Test render assemble files in template directory."""
  150. template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset'])
  151. source_files = template_mgr.render(loss='SoftmaxCrossEntropyWithLogits')
  152. unmatched_files = []
  153. for source_file in source_files:
  154. if source_file.template_file_path == os.path.join(self.template_dir, 'scripts/run_standalone_train.sh-tpl'):
  155. assert source_file.content == textwrap.dedent("""\
  156. python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
  157. """)
  158. assert source_file.file_relative_path == 'scripts/run_standalone_train.sh'
  159. elif source_file.template_file_path == os.path.join(self.template_dir, 'train.py-tpl'):
  160. assert source_file.content == textwrap.dedent("""\
  161. net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
  162. """)
  163. assert source_file.file_relative_path == 'train.py'
  164. else:
  165. unmatched_files.append(source_file)
  166. assert not unmatched_files