You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generate.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. import torch
  3. # 设置矩阵尺寸
  4. M, K, N = 31, 31, 31 # 可修改为更大规模
  5. # 生成随机输入矩阵,类型为float16
  6. A = np.random.randint(0, 11, size=(M, K)).astype(np.float16)
  7. B = np.random.randint(0, 11, size=(K, N)).astype(np.float16)
  8. A_torch = torch.tensor(A, dtype=torch.float16, device='cuda')
  9. B_torch = torch.tensor(B, dtype=torch.float16, device='cuda')
  10. C_torch = torch.matmul(A_torch, B_torch)
  11. C_ref = C_torch.cpu().numpy().astype(np.float32)
  12. def format_array_c(name, array, c_type="hfloat16"):
  13. flat = array.flatten()
  14. elements = ", ".join(f"{x:.5f}" for x in flat)
  15. return f"{c_type} {name}[{len(flat)}] = {{ {elements} }};\n"
  16. def format_array_c_float(name, array):
  17. flat = array.flatten()
  18. elements = ", ".join(f"{x:.5f}" for x in flat)
  19. return f"float {name}[{len(flat)}] = {{ {elements} }};\n"
  20. # 写入C文件
  21. with open("generated_test.c", "w") as f:
  22. f.write('#include <stdio.h>\n')
  23. f.write('#include <stdlib.h>\n')
  24. f.write('#include <string.h>\n')
  25. f.write('#include <cblas.h>\n\n')
  26. f.write(f"const int M = {M}, K = {K}, N = {N};\n")
  27. f.write("const float alpha = 1.0f, beta = 0.0f;\n\n")
  28. f.write(format_array_c("A", A))
  29. f.write(format_array_c("B", B))
  30. f.write(f"float C[{M*N}] = {{ 0 }};\n\n")
  31. f.write("int main() {\n")
  32. f.write(" cblas_shgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,\n")
  33. f.write(" M, N, K,\n")
  34. f.write(" alpha,\n")
  35. f.write(" A, K,\n")
  36. f.write(" B, N,\n")
  37. f.write(" beta,\n")
  38. f.write(" C, N);\n\n")
  39. f.write(' printf("Result C = A * B:\\n");\n')
  40. f.write(" for (int i = 0; i < M * N; i++) {\n")
  41. f.write(" printf(\"%.5f \", C[i]);\n")
  42. f.write(" if ((i + 1) % N == 0) printf(\"\\n\");\n")
  43. f.write(" }\n")
  44. f.write(" return 0;\n")
  45. f.write("}\n\n")
  46. f.write("// Reference result computed in Python:\n")
  47. c_ref_flat = ", ".join(f"{x:.5f}" for x in C_ref.flatten())
  48. f.write(f"// C_ref = {{ {c_ref_flat} }}\n")