You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_generic_network.py 2.8 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test GenericNetwork class."""
  16. import os
  17. import pytest
  18. from mindinsight.wizard.network import lenet
  19. class TestGenericNetwork:
  20. """Test SourceFile"""
  21. def test_generate_scripts(self):
  22. """Test network object to generate network scripts"""
  23. network_inst = lenet.Network()
  24. network_inst.configure({
  25. "loss": "SoftmaxCrossEntropyWithLogits",
  26. "optimizer": "Momentum",
  27. "dataset": "mnist"})
  28. sources_files = network_inst.generate()
  29. dataset_source_file = None
  30. config_source_file = None
  31. shell_script_dir_files = []
  32. out_files = []
  33. for sources_file in sources_files:
  34. if sources_file.file_relative_path == 'src/dataset.py':
  35. dataset_source_file = sources_file
  36. elif sources_file.file_relative_path == 'src/config.py':
  37. config_source_file = sources_file
  38. elif sources_file.file_relative_path.startswith('scripts'):
  39. shell_script_dir_files.append(sources_file)
  40. elif not os.path.dirname(sources_file.file_relative_path):
  41. out_files.append(sources_file)
  42. else:
  43. pass
  44. assert sources_files
  45. assert dataset_source_file is not None
  46. assert config_source_file is not None
  47. assert shell_script_dir_files
  48. assert out_files
  49. def test_config(self):
  50. """Test network object to config."""
  51. network_inst = lenet.Network()
  52. settings = {
  53. "loss": "SoftmaxCrossEntropyWithLogits",
  54. "optimizer": "Momentum",
  55. "dataset": "mnist"}
  56. configurations = network_inst.configure(settings)
  57. assert configurations["dataset"] == settings["dataset"]
  58. assert configurations["loss"] == settings["loss"]
  59. assert configurations["optimizer"] == settings["optimizer"]
  60. settings["dataset"] = "mnist_another"
  61. with pytest.raises(ModuleNotFoundError) as exec_info:
  62. network_inst.configure(settings)
  63. assert exec_info.value.name == f'mindinsight.wizard.dataset.{settings["dataset"]}'