You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.h 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #pragma once
  2. #include <vector>
  3. #include <random>
  4. #include <algorithm>
  5. #include <mc_runtime.h>
  6. #include <maca_fp16.h>
  7. #include <iostream>
  8. #include <chrono>
  9. #include <cmath>
  10. // 引入模块化头文件
  11. #include "yaml_reporter.h"
  12. #include "performance_utils.h"
  13. // ============================================================================
  14. // 测试配置常量
  15. // ============================================================================
  16. #ifndef RUN_FULL_TEST
  17. const int TEST_SIZES[] = {1000000, 134217728}; // 1M, 128M, 512M, 1G
  18. #else
  19. const int TEST_SIZES[] = {1000000, 134217728, 536870912, 1073741824}; // 1M, 128M, 512M, 1G
  20. #endif
  21. const int NUM_TEST_SIZES = sizeof(TEST_SIZES) / sizeof(TEST_SIZES[0]);
  22. // 性能测试重复次数
  23. constexpr int WARMUP_ITERATIONS = 5;
  24. constexpr int BENCHMARK_ITERATIONS = 10;
  25. // ============================================================================
  26. // 错误检查宏
  27. // ============================================================================
  28. #define MACA_CHECK(call) \
  29. do { \
  30. mcError_t error = call; \
  31. if (error != mcSuccess) { \
  32. std::cerr << "MACA error at " << __FILE__ << ":" << __LINE__ \
  33. << " - " << mcGetErrorString(error) << std::endl; \
  34. exit(1); \
  35. } \
  36. } while(0)
  37. // ============================================================================
  38. // 测试数据生成器
  39. // ============================================================================
  40. class TestDataGenerator {
  41. private:
  42. std::mt19937 rng;
  43. public:
  44. TestDataGenerator(uint32_t seed = 42) : rng(seed) {}
  45. // 生成随机float数组
  46. std::vector<float> generateRandomFloats(int size, float min_val = -1000.0f, float max_val = 1000.0f) {
  47. std::vector<float> data(size);
  48. std::uniform_real_distribution<float> dist(min_val, max_val);
  49. for (int i = 0; i < size; i++) {
  50. data[i] = dist(rng);
  51. }
  52. return data;
  53. }
  54. // 生成随机half数组
  55. std::vector<half> generateRandomHalfs(int size, float min_val = -100.0f, float max_val = 100.0f) {
  56. std::vector<half> data(size);
  57. std::uniform_real_distribution<float> dist(min_val, max_val);
  58. for (int i = 0; i < size; i++) {
  59. data[i] = __float2half(dist(rng));
  60. }
  61. return data;
  62. }
  63. // 生成随机uint32_t数组
  64. std::vector<uint32_t> generateRandomUint32(int size) {
  65. std::vector<uint32_t> data(size);
  66. for (int i = 0; i < size; i++) {
  67. data[i] = static_cast<uint32_t>(i); // 使用索引作为值,便于验证稳定排序
  68. }
  69. return data;
  70. }
  71. // 生成随机int64_t数组
  72. std::vector<int64_t> generateRandomInt64(int size) {
  73. std::vector<int64_t> data(size);
  74. for (int i = 0; i < size; i++) {
  75. data[i] = static_cast<int64_t>(i);
  76. }
  77. return data;
  78. }
  79. // 生成包含NaN和Inf的测试数据 (half版本)
  80. std::vector<half> generateSpecialHalfs(int size) {
  81. std::vector<half> data = generateRandomHalfs(size, -10.0f, 10.0f);
  82. if (size > 100) {
  83. data[10] = __float2half(NAN);
  84. data[20] = __float2half(INFINITY);
  85. data[30] = __float2half(-INFINITY);
  86. }
  87. return data;
  88. }
  89. // 生成包含NaN和Inf的测试数据 (float版本)
  90. std::vector<float> generateSpecialFloats(int size) {
  91. std::vector<float> data = generateRandomFloats(size, -10.0f, 10.0f);
  92. if (size > 100) {
  93. data[10] = NAN;
  94. data[20] = INFINITY;
  95. data[30] = -INFINITY;
  96. }
  97. return data;
  98. }
  99. };
  100. // ============================================================================
  101. // 性能测试工具
  102. // ============================================================================
  103. class PerformanceMeter {
  104. private:
  105. mcEvent_t start, stop;
  106. public:
  107. PerformanceMeter() {
  108. MACA_CHECK(mcEventCreate(&start));
  109. MACA_CHECK(mcEventCreate(&stop));
  110. }
  111. ~PerformanceMeter() {
  112. mcEventDestroy(start);
  113. mcEventDestroy(stop);
  114. }
  115. void startTiming() {
  116. MACA_CHECK(mcEventRecord(start));
  117. }
  118. float stopTiming() {
  119. MACA_CHECK(mcEventRecord(stop));
  120. MACA_CHECK(mcEventSynchronize(stop));
  121. float milliseconds = 0;
  122. MACA_CHECK(mcEventElapsedTime(&milliseconds, start, stop));
  123. return milliseconds;
  124. }
  125. };
  126. // ============================================================================
  127. // 正确性验证工具
  128. // ============================================================================
  129. template<typename T>
  130. bool compareArrays(const std::vector<T>& a, const std::vector<T>& b, double tolerance = 1e-6) {
  131. if (a.size() != b.size()) return false;
  132. for (size_t i = 0; i < a.size(); i++) {
  133. if constexpr (std::is_same_v<T, half>) {
  134. float fa = __half2float(a[i]);
  135. float fb = __half2float(b[i]);
  136. if (std::isnan(fa) && std::isnan(fb)) continue;
  137. if (std::isinf(fa) && std::isinf(fb) && (fa > 0) == (fb > 0)) continue;
  138. if (std::abs(fa - fb) > tolerance) return false;
  139. } else if constexpr (std::is_floating_point_v<T>) {
  140. if (std::isnan(a[i]) && std::isnan(b[i])) continue;
  141. if (std::isinf(a[i]) && std::isinf(b[i]) && (a[i] > 0) == (b[i] > 0)) continue;
  142. if (std::abs(a[i] - b[i]) > tolerance) return false;
  143. } else {
  144. if (a[i] != b[i]) return false;
  145. }
  146. }
  147. return true;
  148. }
  149. // CPU参考实现 - 稳定排序
  150. template<typename KeyType, typename ValueType>
  151. void cpuSortPair(std::vector<KeyType>& keys, std::vector<ValueType>& values, bool descending) {
  152. std::vector<std::pair<KeyType, ValueType>> pairs;
  153. for (size_t i = 0; i < keys.size(); i++) {
  154. pairs.emplace_back(keys[i], values[i]);
  155. }
  156. if (descending) {
  157. std::stable_sort(pairs.begin(), pairs.end(),
  158. [](const auto& a, const auto& b) { return a.first > b.first; });
  159. } else {
  160. std::stable_sort(pairs.begin(), pairs.end());
  161. }
  162. for (size_t i = 0; i < pairs.size(); i++) {
  163. keys[i] = pairs[i].first;
  164. values[i] = pairs[i].second;
  165. }
  166. }
  167. // CPU参考实现 - TopK
  168. template<typename KeyType, typename ValueType>
  169. void cpuTopkPair(const std::vector<KeyType>& keys_in, const std::vector<ValueType>& values_in,
  170. std::vector<KeyType>& keys_out, std::vector<ValueType>& values_out,
  171. int k, bool descending) {
  172. std::vector<std::pair<KeyType, ValueType>> pairs;
  173. for (size_t i = 0; i < keys_in.size(); i++) {
  174. pairs.emplace_back(keys_in[i], values_in[i]);
  175. }
  176. if (descending) {
  177. std::stable_sort(pairs.begin(), pairs.end(),
  178. [](const auto& a, const auto& b) { return a.first > b.first; });
  179. } else {
  180. std::stable_sort(pairs.begin(), pairs.end());
  181. }
  182. keys_out.resize(k);
  183. values_out.resize(k);
  184. for (int i = 0; i < k; i++) {
  185. keys_out[i] = pairs[i].first;
  186. values_out[i] = pairs[i].second;
  187. }
  188. }
  189. // CPU参考实现 - ReduceSum (使用double精度)
  190. template<typename InputT>
  191. double cpuReduceSum(const std::vector<InputT>& data, double init_value) {
  192. double sum = init_value;
  193. for (const auto& val : data) {
  194. if constexpr (std::is_same_v<InputT, half>) {
  195. float f_val = __half2float(val);
  196. if (!std::isnan(f_val)) {
  197. sum += static_cast<double>(f_val);
  198. }
  199. } else {
  200. if (!std::isnan(val)) {
  201. sum += static_cast<double>(val);
  202. }
  203. }
  204. }
  205. return sum;
  206. }