You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ptest.h 11 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef ST_RESNET50_PTEST_H_
  17. #define ST_RESNET50_PTEST_H_
  18. #include <stdarg.h>
  19. #include <string.h>
  20. #include <exception>
  21. #include <functional>
  22. #include <iostream>
  23. #include <list>
  24. #include <map>
  25. #include <memory>
  26. #include <string>
  27. namespace ptest {
  28. class assertion_error : public std::exception {
  29. public:
  30. const char *what() const throw() { return "Assertion Exception"; }
  31. };
  32. class TestFixture {
  33. public:
  34. virtual void SetUp() {}
  35. virtual void TearDown() {}
  36. void Run() { _func(); }
  37. void BindFunction(std::function<void(void)> function) { _func = function; }
  38. void SetName(const std::string &name) { _name = name; }
  39. std::string Name() const { return _name; }
  40. virtual ~TestFixture() {}
  41. private:
  42. std::function<void(void)> _func;
  43. std::string _name;
  44. };
  45. enum TestResult { SUCCESS, FAILED, UNAVAILABLE, UNKNOWN, NOCASEFOUND };
  46. class TestManager {
  47. public:
  48. static TestManager &GetSingleton() {
  49. static TestManager instance;
  50. return instance;
  51. }
  52. void RegisterTest(const std::string &name, TestFixture *fixture) { _testfixtures[name] = fixture; }
  53. const std::string GetRunningTestcaseName() const { return _running_testcase_name; }
  54. const std::list<std::string> GetAllTestNames() const {
  55. std::list<std::string> result;
  56. for (auto &t : _testfixtures) {
  57. result.push_back(t.first);
  58. }
  59. return result;
  60. }
  61. TestResult RunTest(const std::string &name) {
  62. if (_testfixtures.find(name) == _testfixtures.end()) {
  63. return NOCASEFOUND;
  64. }
  65. _running_testcase_name = name;
  66. do {
  67. SetTestResult(name, UNKNOWN);
  68. _testfixtures[name]->SetUp();
  69. if (_testresults[name] == FAILED) {
  70. _testresults[name] = UNAVAILABLE;
  71. break;
  72. }
  73. SetTestResult(name, SUCCESS);
  74. try {
  75. _testfixtures[name]->Run();
  76. } catch (assertion_error &e) {
  77. // Do nothing as the error has been handled by the TestManager.
  78. }
  79. _testfixtures[name]->TearDown();
  80. } while (0);
  81. return _testresults[name];
  82. }
  83. void SetTestResult(const std::string &name, TestResult result) { _testresults[name] = result; }
  84. TestResult GetTestResult(const std::string &name) { return _testresults[name]; }
  85. private:
  86. std::map<std::string, TestFixture *> _testfixtures;
  87. std::map<std::string, TestResult> _testresults;
  88. std::string _running_testcase_name;
  89. };
  90. class TestFixtureRegister {
  91. public:
  92. TestFixtureRegister(const std::string &name, TestFixture *fixture, std::function<void(void)> function) {
  93. fixture->BindFunction(function);
  94. fixture->SetName(name);
  95. TestManager::GetSingleton().RegisterTest(name, fixture);
  96. }
  97. };
  98. } // namespace ptest
  99. #define _STR(x) #x
  100. #define _EMPTY_NAMESPACE
  101. #define _TEST(NAMESPACE, FIXTURECLASS, TESTNAME, CASENAME) \
  102. void g_func_##TESTNAME##_##CASENAME(void); \
  103. NAMESPACE::FIXTURECLASS g_fixture_##TESTNAME##_##CASENAME; \
  104. ptest::TestFixtureRegister g_register_##TESTNAME##_##CASENAME( \
  105. _STR(TESTNAME##_##CASENAME), &g_fixture_##TESTNAME##_##CASENAME, g_func_##TESTNAME##_##CASENAME); \
  106. void g_func_##TESTNAME##_##CASENAME(void)
  107. #define TEST(TESTNAME, CASENAME) _TEST(ptest, TestFixture, TESTNAME, CASENAME)
  108. #define TEST_F(TESTFIXTURE, CASENAME) _TEST(_EMPTY_NAMESPACE, TESTFIXTURE, TESTFIXTURE, CASENAME)
  109. #define EXPECT_TRUE(X) \
  110. do { \
  111. if (!(X)) { \
  112. std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \
  113. ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \
  114. std::cerr << #X << "Expectation Failed\n" \
  115. << "Testcase Name: " << test_name << "\n" \
  116. << "File: " __FILE__ << "\tLine:" << __LINE__ << std::endl; \
  117. } \
  118. } while (0);
  119. // With the macro definition ensures that the compiler can detect compiler warning.
  120. #define Max_Log_Len 1024
  121. #define PRINT_ERR(lpszFormat, ...) \
  122. do { \
  123. char szTmpBuf[Max_Log_Len + 1] = {0}; \
  124. snprintf(szTmpBuf, Max_Log_Len, lpszFormat, ##__VA_ARGS__); \
  125. std::cerr << szTmpBuf << std::endl; \
  126. } while (0)
  127. // Increase the content of print error messages and error to facilitate rapid analysis
  128. #define EXPECT_TRUE_C(X, ERR_TYPE, format, ...) \
  129. do { \
  130. if (!(X)) { \
  131. std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \
  132. ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \
  133. std::cerr << #X << " Expectation Failed." \
  134. << "Testcase Name: " << test_name << " File:" __FILE__ << " Line:" << __LINE__ << std::endl; \
  135. PRINT_ERR("[" ERR_TYPE "]" format, ##__VA_ARGS__); \
  136. } \
  137. } while (0)
  138. #define ASSERT_TRUE(X) \
  139. do { \
  140. if (!(X)) { \
  141. std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \
  142. ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \
  143. std::cerr << #X << "Assertion Failed\n" \
  144. << "Testcase Name: " << test_name << "\n" \
  145. << "File: " __FILE__ << "\tLine:" << __LINE__ << std::endl; \
  146. throw ptest::assertion_error(); \
  147. } \
  148. } while (0);
  149. // Add printing error information and error line content for quick analysis
  150. #define ASSERT_TRUE_C(X, ERR_TYPE, format, ...) \
  151. do { \
  152. if (!(X)) { \
  153. std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \
  154. ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \
  155. std::cerr << #X << " Assertion Failed." \
  156. << "Testcase Name: " << test_name << " File:" __FILE__ << " Line:" << __LINE__ << std::endl; \
  157. PRINT_ERR("[" ERR_TYPE "]" format, ##__VA_ARGS__); \
  158. throw ptest::assertion_error(); \
  159. } \
  160. } while (0);
  161. #define CONFIG_ERR "CONFIG_ERR"
  162. #define LOAD_MODEL_ERR "LOAD_MODEL_ERR"
  163. #define FILE_READ_ERR "FILE_READ_ERR"
  164. #define RUN_ERROR "RUN_ERROR"
  165. #define MEM_ERROR "MEM_ERROR"
  166. #define RESULT_ERR "RESULT_ERR"
  167. #define EXPECT_FALSE(X) EXPECT_TRUE(!(X))
  168. #define EXPECT_EQ(X, Y) EXPECT_TRUE(((X) == (Y)))
  169. #define EXPECT_NE(X, Y) EXPECT_TRUE(((X) != (Y)))
  170. #define EXPECT_GT(X, Y) EXPECT_TRUE(((X) > (Y)))
  171. #define EXPECT_GE(X, Y) EXPECT_TRUE(((X) >= (Y)))
  172. #define EXPECT_LT(X, Y) EXPECT_TRUE(((X) < (Y)))
  173. #define EXPECT_LE(X, Y) EXPECT_TRUE(((X) <= (Y)))
  174. #define EXPECT_FALSE_C(X, ERR_TYPE, format, ...) EXPECT_TRUE_C(!(X), ERR_TYPE, format, ##__VA_ARGS__)
  175. #define EXPECT_EQ_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) == (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  176. #define EXPECT_NE_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) != (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  177. #define EXPECT_GT_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) > (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  178. #define EXPECT_GE_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) >= (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  179. #define EXPECT_LT_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) < (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  180. #define EXPECT_LE_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) <= (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  181. #define ASSERT_FALSE(X) ASSERT_TRUE(!(X))
  182. #define ASSERT_EQ(X, Y) ASSERT_TRUE(((X) == (Y)))
  183. #define ASSERT_NE(X, Y) ASSERT_TRUE(((X) != (Y)))
  184. #define ASSERT_GT(X, Y) ASSERT_TRUE(((X) > (Y)))
  185. #define ASSERT_GE(X, Y) ASSERT_TRUE(((X) >= (Y)))
  186. #define ASSERT_LT(X, Y) ASSERT_TRUE(((X) < (Y)))
  187. #define ASSERT_LE(X, Y) ASSERT_TRUE(((X) <= (Y)))
  188. #define ASSERT_FALSE_C(X, ERR_TYPE, format, ...) ASSERT_TRUE_C(!(X), ERR_TYPE, format, ##__VA_ARGS__)
  189. #define ASSERT_EQ_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) == (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  190. #define ASSERT_NE_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) != (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  191. #define ASSERT_GT_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) > (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  192. #define ASSERT_GE_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) >= (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  193. #define ASSERT_LT_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) < (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  194. #define ASSERT_LE_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) <= (Y)), ERR_TYPE, format, ##__VA_ARGS__)
  195. #endif // ST_RESNET50_PTEST_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示