|
- #include "test_utils.h"
- #include "performance_utils.h"
- #include "yaml_reporter.h"
- #include <iostream>
- #include <vector>
- #include <iomanip>
-
- // ============================================================================
- // 实现标记宏 - 参赛者修改实现时请将此宏设为0
- // ============================================================================
- #ifndef USE_DEFAULT_REF_IMPL
- #define USE_DEFAULT_REF_IMPL 1 // 1=默认实现, 0=参赛者自定义实现
- #endif
-
- #if USE_DEFAULT_REF_IMPL
- #include <thrust/sort.h>
- #include <thrust/device_vector.h>
- #include <thrust/execution_policy.h>
- #include <thrust/iterator/zip_iterator.h>
- #include <thrust/tuple.h>
- #endif
-
- // ============================================================================
- // SortPair算法实现接口
- // 参赛者需要替换Thrust实现为自己的高性能kernel
- // ============================================================================
-
- template <typename KeyType, typename ValueType>
- class SortPairAlgorithm {
- public:
- // 主要接口函数 - 参赛者需要实现这个函数
- void sort(const KeyType* d_keys_in, KeyType* d_keys_out,
- const ValueType* d_values_in, ValueType* d_values_out,
- int num_items, bool descending) {
-
- #if !USE_DEFAULT_REF_IMPL
- // ========================================
- // 参赛者自定义实现区域
- // ========================================
-
- // TODO: 参赛者在此实现自己的高性能排序算法
-
- // 示例:参赛者可以调用1个或多个自定义kernel
- // preprocessKernel<<<grid, block>>>(d_keys_in, d_values_in, num_items);
- // mainSortKernel<<<grid, block>>>(d_keys_out, d_values_out, num_items, descending);
- // postprocessKernel<<<grid, block>>>(d_keys_out, d_values_out, num_items);
- #else
- // ========================================
- // 默认基准实现
- // ========================================
-
- MACA_CHECK(mcMemcpy(d_keys_out, d_keys_in, num_items * sizeof(KeyType), mcMemcpyDeviceToDevice));
- MACA_CHECK(mcMemcpy(d_values_out, d_values_in, num_items * sizeof(ValueType), mcMemcpyDeviceToDevice));
-
- auto key_ptr = thrust::device_pointer_cast(d_keys_out);
- auto value_ptr = thrust::device_pointer_cast(d_values_out);
-
- if (descending) {
- thrust::stable_sort_by_key(thrust::device, key_ptr, key_ptr + num_items, value_ptr, thrust::greater<KeyType>());
- } else {
- thrust::stable_sort_by_key(thrust::device, key_ptr, key_ptr + num_items, value_ptr, thrust::less<KeyType>());
- }
- #endif
- }
-
- // 获取当前实现状态
- static const char* getImplementationStatus() {
- #if USE_DEFAULT_REF_IMPL
- return "DEFAULT_REF_IMPL";
- #else
- return "CUSTOM_IMPL";
- #endif
- }
-
- private:
- // 参赛者可以在这里添加辅助函数和成员变量
- // 例如:临时缓冲区、多个kernel函数、流等
- };
-
- // ============================================================================
- // 测试和性能评估
- // ============================================================================
-
- bool testCorrectness() {
- std::cout << "SortPair 正确性测试..." << std::endl;
- TestDataGenerator generator;
- SortPairAlgorithm<float, uint32_t> algorithm;
-
- // 测试小规模数据
- int size = 10000;
- auto keys = generator.generateRandomFloats(size);
- auto values = generator.generateRandomUint32(size);
-
- // 分配GPU内存
- float *d_keys_in, *d_keys_out;
- uint32_t *d_values_in, *d_values_out;
-
- MACA_CHECK(mcMalloc(&d_keys_in, size * sizeof(float)));
- MACA_CHECK(mcMalloc(&d_keys_out, size * sizeof(float)));
- MACA_CHECK(mcMalloc(&d_values_in, size * sizeof(uint32_t)));
- MACA_CHECK(mcMalloc(&d_values_out, size * sizeof(uint32_t)));
-
- MACA_CHECK(mcMemcpy(d_keys_in, keys.data(), size * sizeof(float), mcMemcpyHostToDevice));
- MACA_CHECK(mcMemcpy(d_values_in, values.data(), size * sizeof(uint32_t), mcMemcpyHostToDevice));
-
- // 测试升序和降序
- bool allPassed = true;
- for (bool descending : {false, true}) {
- std::cout << " " << (descending ? "降序" : "升序") << " 测试..." << std::endl;
-
- // CPU参考结果
- auto cpu_keys = keys;
- auto cpu_values = values;
- cpuSortPair(cpu_keys, cpu_values, descending);
-
- // GPU算法结果
- algorithm.sort(d_keys_in, d_keys_out, d_values_in, d_values_out, size, descending);
-
- // 获取结果
- std::vector<float> gpu_keys(size);
- std::vector<uint32_t> gpu_values(size);
- MACA_CHECK(mcMemcpy(gpu_keys.data(), d_keys_out, size * sizeof(float), mcMemcpyDeviceToHost));
- MACA_CHECK(mcMemcpy(gpu_values.data(), d_values_out, size * sizeof(uint32_t), mcMemcpyDeviceToHost));
-
- // 验证结果
- bool keysMatch = compareArrays(cpu_keys, gpu_keys, 1e-5);
- bool valuesMatch = compareArrays(cpu_values, gpu_values);
-
- if (!keysMatch || !valuesMatch) {
- std::cout << " 失败: 结果不匹配" << std::endl;
- allPassed = false;
- } else {
- std::cout << " 通过" << std::endl;
- }
- }
-
- // 清理内存
- mcFree(d_keys_in);
- mcFree(d_keys_out);
- mcFree(d_values_in);
- mcFree(d_values_out);
-
- return allPassed;
- }
-
- void benchmarkPerformance() {
- PerformanceDisplay::printSortPairHeader();
-
- TestDataGenerator generator;
- PerformanceMeter meter;
- SortPairAlgorithm<float, uint32_t> algorithm;
-
- const int WARMUP_ITERATIONS = 5;
- const int BENCHMARK_ITERATIONS = 10;
-
- // 用于YAML报告的数据收集
- std::vector<std::map<std::string, std::string>> perf_data;
-
- for (int i = 0; i < NUM_TEST_SIZES; i++) {
- int size = TEST_SIZES[i];
-
- // 生成测试数据
- auto keys = generator.generateRandomFloats(size);
- auto values = generator.generateRandomUint32(size);
-
- // 分配GPU内存
- float *d_keys_in, *d_keys_out;
- uint32_t *d_values_in, *d_values_out;
-
- MACA_CHECK(mcMalloc(&d_keys_in, size * sizeof(float)));
- MACA_CHECK(mcMalloc(&d_keys_out, size * sizeof(float)));
- MACA_CHECK(mcMalloc(&d_values_in, size * sizeof(uint32_t)));
- MACA_CHECK(mcMalloc(&d_values_out, size * sizeof(uint32_t)));
-
- MACA_CHECK(mcMemcpy(d_keys_in, keys.data(), size * sizeof(float), mcMemcpyHostToDevice));
- MACA_CHECK(mcMemcpy(d_values_in, values.data(), size * sizeof(uint32_t), mcMemcpyHostToDevice));
-
- float asc_time = 0, desc_time = 0;
-
- // 测试升序和降序
- for (bool descending : {false, true}) {
- // Warmup阶段
- for (int iter = 0; iter < WARMUP_ITERATIONS; iter++) {
- algorithm.sort(d_keys_in, d_keys_out, d_values_in, d_values_out, size, descending);
- }
-
- // 正式测试阶段
- float total_time = 0;
- for (int iter = 0; iter < BENCHMARK_ITERATIONS; iter++) {
- meter.startTiming();
- algorithm.sort(d_keys_in, d_keys_out, d_values_in, d_values_out, size, descending);
- total_time += meter.stopTiming();
- }
-
- float avg_time = total_time / BENCHMARK_ITERATIONS;
- if (descending) {
- desc_time = avg_time;
- } else {
- asc_time = avg_time;
- }
- }
-
- // 计算性能指标
- auto asc_metrics = PerformanceCalculator::calculateSortPair(size, asc_time);
- auto desc_metrics = PerformanceCalculator::calculateSortPair(size, desc_time);
-
- // 显示性能数据
- PerformanceDisplay::printSortPairData(size, asc_time, desc_time, asc_metrics, desc_metrics);
-
- // 收集YAML报告数据
- auto entry = YAMLPerformanceReporter::createEntry();
- entry["data_size"] = std::to_string(size);
- entry["asc_time_ms"] = std::to_string(asc_time);
- entry["desc_time_ms"] = std::to_string(desc_time);
- entry["asc_throughput_gps"] = std::to_string(asc_metrics.throughput_gps);
- entry["desc_throughput_gps"] = std::to_string(desc_metrics.throughput_gps);
- entry["key_type"] = "float";
- entry["value_type"] = "uint32_t";
- perf_data.push_back(entry);
-
- // 清理内存
- mcFree(d_keys_in);
- mcFree(d_keys_out);
- mcFree(d_values_in);
- mcFree(d_values_out);
- }
-
- // 生成YAML性能报告
- YAMLPerformanceReporter::generateSortPairYAML(perf_data, "sort_pair_performance.yaml");
- PerformanceDisplay::printSavedMessage("sort_pair_performance.yaml");
- }
-
- // ============================================================================
- // 主函数
- // ============================================================================
- int main(int argc, char* argv[]) {
- std::cout << "=== SortPair 算法测试 ===" << std::endl;
-
- // 检查参数
- std::string mode = "all";
- if (argc > 1) {
- mode = argv[1];
- }
-
- bool correctness_passed = true;
- bool performance_completed = true;
-
- try {
- if (mode == "correctness" || mode == "all") {
- correctness_passed = testCorrectness();
- }
-
- if (mode == "performance" || mode == "all") {
- if (correctness_passed || mode == "performance") {
- benchmarkPerformance();
- } else {
- std::cout << "跳过性能测试,因为正确性测试未通过" << std::endl;
- performance_completed = false;
- }
- }
-
- std::cout << "\n=== 测试完成 ===" << std::endl;
- std::cout << "实现状态: " << SortPairAlgorithm<float, uint32_t>::getImplementationStatus() << std::endl;
- if (mode == "all") {
- std::cout << "正确性: " << (correctness_passed ? "通过" : "失败") << std::endl;
- std::cout << "性能测试: " << (performance_completed ? "完成" : "跳过") << std::endl;
- }
-
- return correctness_passed ? 0 : 1;
-
- } catch (const std::exception& e) {
- std::cerr << "测试出错: " << e.what() << std::endl;
- return 1;
- }
- }
|