You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resolve_test.cc 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <string>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "utils/log_adapter.h"
  21. #include "pipeline/jit/parse/parse.h"
  22. #include "debug/draw.h"
  23. namespace mindspore {
  24. namespace parse {
  25. class TestResolve : public UT::Common {
  26. public:
  27. TestResolve() {}
  28. virtual void SetUp();
  29. virtual void TearDown();
  30. };
  31. void TestResolve::SetUp() { UT::InitPythonPath(); }
  32. void TestResolve::TearDown() {}
  33. TEST_F(TestResolve, TestResolveApi) {
  34. py::function fn_ = python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "get_resolve_fn");
  35. // parse graph
  36. FuncGraphPtr func_graph = ParsePythonCode(fn_);
  37. ASSERT_FALSE(nullptr == func_graph);
  38. // save the func_graph to manager
  39. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  40. // call resolve
  41. bool ret_ = ResolveAll(manager);
  42. ASSERT_TRUE(ret_);
  43. ASSERT_EQ(manager->func_graphs().size(), (size_t)2);
  44. // draw graph
  45. int i = 0;
  46. for (auto func_graph : manager->func_graphs()) {
  47. std::string name = "ut_resolve_graph_" + std::to_string(i) + ".dot";
  48. draw::Draw(name, func_graph);
  49. i++;
  50. }
  51. }
  52. TEST_F(TestResolve, TestParseGraphTestClosureResolve) {
  53. py::function test_fn =
  54. python_adapter::CallPyFn("gtest_input.pipeline.parse.parser_test", "test_reslove_closure", 123);
  55. FuncGraphPtr func_graph = ParsePythonCode(test_fn);
  56. ASSERT_TRUE(func_graph != nullptr);
  57. draw::Draw("test_reslove_closure.dot", func_graph);
  58. // save the func_graph to manager
  59. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  60. // call resolve
  61. bool ret_ = ResolveAll(manager);
  62. ASSERT_TRUE(ret_);
  63. ASSERT_EQ(manager->func_graphs().size(), (size_t)2);
  64. // draw graph
  65. int i = 0;
  66. for (auto func_graph : manager->func_graphs()) {
  67. std::string name = "ut_test_reslove_closure_graph_" + std::to_string(i) + ".dot";
  68. draw::Draw(name, func_graph);
  69. i++;
  70. }
  71. }
  72. TEST_F(TestResolve, TestResolveFail) {
  73. py::function fn_ = python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "test_resolvefail");
  74. // parse graph
  75. FuncGraphPtr func_graph = ParsePythonCode(fn_);
  76. ASSERT_FALSE(nullptr == func_graph);
  77. // save the func_graph to manager
  78. std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
  79. // call resolve
  80. EXPECT_THROW({ ResolveAll(manager); }, std::runtime_error);
  81. }
  82. } // namespace parse
  83. } // namespace mindspore