You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clone_test.cc 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include "common/common_test.h"
  18. #include "common/py_func_graph_fetcher.h"
  19. #include "ir/manager.h"
  20. #include "utils/log_adapter.h"
  21. #include "ir/func_graph_cloner.h"
  22. #include "pipeline/jit/parse/parse.h"
  23. #include "ir/graph_utils.h"
  24. #include "debug/draw.h"
  25. #include "base/core_ops.h"
  26. namespace mindspore {
  27. class TestCloner : public UT::Common {
  28. public:
  29. TestCloner() : getPyFun("gtest_input.ir.clone_test", true) {
  30. one = NewValueNode(static_cast<int64_t>(1));
  31. two = NewValueNode(static_cast<int64_t>(2));
  32. three = NewValueNode(static_cast<int64_t>(3));
  33. }
  34. FuncGraphPtr GraphForInline() { return nullptr; }
  35. void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr>& params,
  36. FuncGraphPtr target);
  37. public:
  38. UT::PyFuncGraphFetcher getPyFun;
  39. ValueNodePtr one;
  40. ValueNodePtr two;
  41. ValueNodePtr three;
  42. };
  43. void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig,
  44. const std::vector<AnfNodePtr>& params, FuncGraphPtr target) {
  45. auto g = (*cl)[orig];
  46. ASSERT_TRUE(g != target);
  47. ASSERT_TRUE(g == orig);
  48. auto new_root = (*cl)[orig->output()];
  49. ASSERT_TRUE(new_root != orig->output());
  50. AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output()));
  51. AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root));
  52. for (auto& p : params) {
  53. ASSERT_TRUE(new_nodes.contains(p));
  54. }
  55. for (auto& node : orig_nodes) {
  56. if (node->func_graph() == orig) {
  57. ASSERT_TRUE((*cl)[node]);
  58. }
  59. }
  60. ASSERT_TRUE(target->output() == three);
  61. }
  62. TEST_F(TestCloner, test_clone_simple) {
  63. std::string py_code = "test_clone_simple";
  64. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  65. ASSERT_TRUE(g != nullptr);
  66. std::vector<FuncGraphPtr> gs = {g};
  67. Cloner cl(gs, true);
  68. auto g2 = cl[g];
  69. AnfNodeSet d1 = AnfNodeSet(DeepScopedGraphSearch(g->get_return()));
  70. AnfNodeSet d2 = AnfNodeSet(DeepScopedGraphSearch(g2->get_return()));
  71. auto common = d1 & d2;
  72. ASSERT_EQ((size_t)0, common.size());
  73. Cloner cl2(gs);
  74. auto g3 = cl2[g];
  75. std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("return")};
  76. AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return()));
  77. common = d1 & d3;
  78. for (auto& x : common) {
  79. ASSERT_TRUE(x->isa<ValueNode>());
  80. ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) !=
  81. results.end());
  82. }
  83. }
  84. TEST_F(TestCloner, test_clone_closure) {
  85. std::string py_code = "test_clone_closure";
  86. // parse ast to graph
  87. FuncGraphPtr parsed_f = getPyFun(py_code);
  88. FuncGraphIndex idx(parsed_f);
  89. auto g = idx.GetFirstFuncGraph("j");
  90. std::vector<FuncGraphPtr> gs = {g};
  91. Cloner cl(gs, true);
  92. auto g_clone = cl[g];
  93. draw::Draw("test_clone_closure_g_clone.dot", g_clone);
  94. FuncGraphIndex idx2(g_clone, DeepLinkedGraphSearch);
  95. std::string name_list = "xy";
  96. for (auto name : name_list) {
  97. ASSERT_EQ(idx.GetFirstNode(std::string(1, name)), idx2.GetFirstNode(std::string(1, name)));
  98. }
  99. ASSERT_FALSE(idx.GetFirstNode("z") == idx2.GetFirstNode("z"));
  100. ASSERT_FALSE(idx.GetFirstFuncGraph("j") == idx2.GetFirstFuncGraph("j"));
  101. }
  102. TEST_F(TestCloner, test_clone_lifting) {
  103. std::string py_code = "test_clone_closure";
  104. // parse ast to graph
  105. FuncGraphPtr parsed_f = getPyFun(py_code);
  106. draw::Draw("test_clone_before_lifting.dot", parsed_f);
  107. auto g_lifting = LiftingClone(parsed_f);
  108. draw::Draw("test_clone_after_lifting.dot", g_lifting);
  109. FuncGraphIndex idx(g_lifting);
  110. auto g = idx.GetFirstFuncGraph("j");
  111. auto params = g_lifting->parameters();
  112. auto child_params = g->parameters();
  113. ASSERT_TRUE(params.size() + 1 == child_params.size());
  114. }
  115. TEST_F(TestCloner, test_clone_scoping) {
  116. std::string py_code = "test_clone_scoping";
  117. // parse ast to graph
  118. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  119. std::vector<FuncGraphPtr> gs = {g};
  120. Cloner cl(gs, true);
  121. auto g2 = cl[g];
  122. FuncGraphIndex idx1(g);
  123. FuncGraphIndex idx2(g2);
  124. std::string name_list = "fgi";
  125. for (auto name : name_list) {
  126. auto result1 = idx1.GetFirstFuncGraph(std::string(1, name));
  127. auto result2 = idx2.GetFirstFuncGraph(std::string(1, name));
  128. ASSERT_FALSE(result1 == result2);
  129. }
  130. name_list = "h";
  131. for (auto name : name_list) {
  132. ASSERT_TRUE(idx1.GetFirstFuncGraph(std::string(1, name)) == idx2.GetFirstFuncGraph(std::string(1, name)));
  133. }
  134. }
  135. TEST_F(TestCloner, test_clone_total) {
  136. std::string py_code = "test_clone_total";
  137. // parse ast to graph
  138. getPyFun.SetDoResolve();
  139. FuncGraphPtr g = getPyFun.CallAndParseRet(py_code);
  140. if (g == nullptr) {
  141. return;
  142. }
  143. FuncGraphIndex idx0(g);
  144. std::vector<FuncGraphPtr> gs = {g};
  145. Cloner cl1(gs, true, true, true);
  146. auto g2 = cl1[g];
  147. FuncGraphIndex idx1(g2);
  148. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total_sub") == idx1.GetFirstFuncGraph("clone_total_sub"));
  149. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx1.GetFirstFuncGraph("clone_total"));
  150. Cloner cl2(gs, true);
  151. FuncGraphIndex idx2(cl2[g]);
  152. ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx2.GetFirstFuncGraph("clone_total"));
  153. ASSERT_TRUE(idx0.GetFirstFuncGraph("clone_total_sub") == idx2.GetFirstFuncGraph("clone_total_sub"));
  154. }
  155. } // namespace mindspore