You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generate_kernel.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. #!/usr/bin/python3
  2. import sys, os
  3. import contextlib
  4. #-----------------------------------------------------------------------
  5. def ERROR(*args, **kwargs):
  6. print(*args, file=sys.stderr, **kwargs)
  7. sys.exit(-1)
  8. class Target(object):
  9. def __init__( self, out, mappings, initial_level=0, tab_width=4 ):
  10. self._level = initial_level
  11. self._tab_width = tab_width
  12. self._out = out
  13. self._mappings = mappings
  14. @contextlib.contextmanager
  15. def map( self, **items ):
  16. old_mappings = self._mappings
  17. self._mappings = dict(old_mappings, **items)
  18. yield self._mappings
  19. self._mappings = old_mappings
  20. @contextlib.contextmanager
  21. def block( self, start=None, end=None, **args ):
  22. with self.map(**args):
  23. if start is not None:
  24. self.write();
  25. self.write(start)
  26. self._level += 1
  27. yield self._level
  28. self._level -= 1
  29. if end is not None:
  30. self.write(end)
  31. self.write()
  32. def write( self, fmt=None, *args, **kwargs ):
  33. if fmt is not None:
  34. mappings = dict(self._mappings, **kwargs) if kwargs else self._mappings
  35. self._out(self._indent_str() + fmt.format(*args, **mappings))
  36. else:
  37. self._out("")
  38. def _indent_str( self ):
  39. return ' ' * (self._level * self._tab_width)
  40. #-----------------------------------------------------------------------
  41. def generate_trmm_block( dest ):
  42. dest.write("{index_type} pass_K = K;")
  43. dest.write("#ifdef LEFT")
  44. with dest.block():
  45. dest.write("{index_type} off = offset + m_top;")
  46. dest.write("#else")
  47. with dest.block():
  48. dest.write("{index_type} off = -offset + n_top;")
  49. dest.write("#endif")
  50. dest.write("#ifdef BACKWARDS")
  51. with dest.block():
  52. dest.write("ai += off*{M}{elt_size};")
  53. dest.write("bi += off*{N}{elt_size};")
  54. dest.write("pass_K -= off;")
  55. dest.write("#else")
  56. with dest.block():
  57. dest.write("#ifdef LEFT")
  58. with dest.block():
  59. dest.write("pass_K = off + {M};")
  60. dest.write("#else")
  61. with dest.block():
  62. dest.write("pass_K = off + {N};")
  63. dest.write("#endif")
  64. dest.write("#endif")
  65. #-----------------------------------------------------------------------
  66. def generate_gemm_kernel_inner_real( settings, dest, M, N, vlen, a_regs ):
  67. TRMM = (settings['op'].value == 'trmm')
  68. narrow_result = (settings['param_precision'].value != 'double') and settings['force_acc_double'].value
  69. with dest.map(
  70. M=M,
  71. N=N,
  72. ):
  73. dest.write("{index_type} ai=m_top*K{elt_size};")
  74. dest.write("{index_type} bi=n_top*K{elt_size};")
  75. if TRMM:
  76. generate_trmm_block( dest )
  77. for i in range(N):
  78. dest.write("{param_scalar_t} B{i} = B[bi+{i}];", i=i)
  79. dest.write("bi += {N};")
  80. dest.write()
  81. for i in range(a_regs):
  82. dest.write("{param_vector_t} A{i} = {VLEV}( &A[ai+{i}*gvl], gvl );", i=i)
  83. dest.write("ai += {M};")
  84. dest.write()
  85. for j in range(N):
  86. for i in range(a_regs):
  87. dest.write("{acc_vector_t} result{dest} = {VMUL_TO_ACC}( A{i}, B{j}, gvl);", dest=j*a_regs+i, i=i, j=j)
  88. with dest.block("for({index_type} k=1; k<{Kend}; k++) {{", "}}", Kend=('pass_K' if TRMM else 'K')):
  89. for i in range(N):
  90. dest.write("B{i} = B[bi+{i}];", i=i )
  91. dest.write("bi += {N};")
  92. dest.write()
  93. for i in range(a_regs):
  94. dest.write("A{i} = {VLEV}( &A[ai+{i}*gvl], gvl );", i=i)
  95. dest.write("ai += {M};")
  96. dest.write()
  97. for j in range(N):
  98. for i in range(a_regs):
  99. dest.write("result{dest} = {VMACC_TO_ACC}( result{dest}, B{j}, A{i}, gvl);", dest= j*a_regs+i, j=j, i=i )
  100. dest.write()
  101. dest.write("{index_type} ci=n_top*ldc+m_top;")
  102. dest.write()
  103. if narrow_result:
  104. for j in range(N):
  105. for i in range(a_regs):
  106. dest.write("{param_vector_t} narrowed{idx} = {VFNCVT}( result{idx}, gvl );", idx=j*a_regs+i)
  107. if not TRMM:
  108. for j in range(N):
  109. for i in range(a_regs):
  110. idx = j*a_regs+i
  111. increment = ' ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
  112. if idx == N*a_regs-1:
  113. increment = ''
  114. dest.write("{param_vector_t} c{idx} = {VLEV}( &C[ci], gvl);{increment}", idx=idx, increment=increment)
  115. if narrow_result:
  116. for j in range(N):
  117. for i in range(a_regs):
  118. idx = j*a_regs+i
  119. if TRMM:
  120. dest.write("{param_vector_t} c{idx} = {VFMUL}( narrowed{idx}, alpha, gvl );", idx=idx)
  121. else:
  122. dest.write("c{idx} = {VFMACC}( c{idx}, alpha, narrowed{idx}, gvl );", idx=idx)
  123. else:
  124. for j in range(N):
  125. for i in range(a_regs):
  126. idx = j*a_regs+i
  127. if TRMM:
  128. dest.write("{param_vector_t} c{idx} = {VFMUL}( result{idx}, alpha, gvl );", idx=idx)
  129. else:
  130. dest.write("c{idx} = {VFMACC}( c{idx}, alpha, result{idx}, gvl );", idx=idx)
  131. if not TRMM:
  132. dest.write()
  133. dest.write("ci=n_top*ldc+m_top;")
  134. dest.write()
  135. for j in range(N):
  136. for i in range(a_regs):
  137. idx = j*a_regs+i
  138. increment = ' ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
  139. if idx == N*a_regs-1:
  140. increment = ''
  141. dest.write("{VSEV}( &C[ci], c{idx}, gvl);{increment}", idx=idx, increment=increment)
  142. #-----------------------------------------------------------------------
  143. def generate_gemm_kernel_inner_complex( settings, dest, M, N, vlen, a_regs ):
  144. TRMM = (settings['op'].value == 'trmm')
  145. narrow_result = (settings['param_precision'].value != 'double') and settings['force_acc_double'].value
  146. if narrow_result:
  147. raise RuntimeError("wide accumulator not supported for generated complex kernels")
  148. # we could, but we run out of registers really really fast
  149. with dest.map(
  150. M=M,
  151. N=N,
  152. ):
  153. dest.write("{index_type} ai=m_top*K*2;")
  154. dest.write("{index_type} bi=n_top*K*2;")
  155. if TRMM:
  156. generate_trmm_block( dest )
  157. for i in range(N):
  158. dest.write("{param_scalar_t} B{i}r = B[bi+{i}*2+0];", i=i)
  159. dest.write("{param_scalar_t} B{i}i = B[bi+{i}*2+1];", i=i)
  160. dest.write("bi += {N}*2;")
  161. dest.write()
  162. for i in range(a_regs):
  163. dest.write("{param_vector_t} A{i}r = {VLSEV}( &A[ai+{i}*gvl*2], sizeof(FLOAT)*2, gvl );", i=i)
  164. dest.write("{param_vector_t} A{i}i = {VLSEV}( &A[ai+{i}*gvl*2+1], sizeof(FLOAT)*2, gvl );", i=i)
  165. dest.write("ai += {M}*2;")
  166. dest.write()
  167. # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
  168. accumulation_regs = a_regs * N
  169. dest.write("// {a_regs} vector regs to hold A array contents, {accumulation_regs} regs to hold values accumulated over k",
  170. a_regs=a_regs*2, accumulation_regs=accumulation_regs*2
  171. )
  172. pass_regs = (accumulation_regs + a_regs)*2
  173. tmp_regs = (32 // settings['LMUL_ACC'].value) - pass_regs
  174. if tmp_regs < 2:
  175. raise RuntimeError("Complex kernel would use too many registers!")
  176. dest.write("// leaving {tmp_regs} vector registers for temporaries", tmp_regs=tmp_regs)
  177. tmp_unroll_i = min(tmp_regs, a_regs)
  178. tmp_unroll_j = N
  179. while tmp_unroll_j > 1 and (tmp_regs/(tmp_unroll_i*2)) < tmp_unroll_j:
  180. tmp_unroll_j = int(tmp_unroll_j / 2)
  181. if tmp_unroll_i < a_regs or tmp_unroll_j < N:
  182. dest.write("// performing {ops} operations between reuses of temporaries", ops=tmp_unroll_j*tmp_unroll_i)
  183. for tj in range(0, N, tmp_unroll_j):
  184. for ti in range(0, a_regs, tmp_unroll_i):
  185. for j in range(tj, tj+tmp_unroll_j):
  186. for i in range(ti, ti+tmp_unroll_i):
  187. with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
  188. if ti == 0 and tj==0:
  189. dest.write("{acc_vector_t} tmp{tmp}r = {VMUL_TO_ACC}( A{i}i, B{j}i, gvl);")
  190. dest.write("{acc_vector_t} tmp{tmp}i = {VMUL_TO_ACC}( A{i}r, B{j}i, gvl);")
  191. else:
  192. dest.write("tmp{tmp}r = {VMUL_TO_ACC}( A{i}i, B{j}i, gvl);")
  193. dest.write("tmp{tmp}i = {VMUL_TO_ACC}( A{i}r, B{j}i, gvl);")
  194. for j in range(tj, tj+tmp_unroll_j):
  195. for i in range(ti, ti+tmp_unroll_i):
  196. with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
  197. dest.write("tmp{tmp}r = VFMACC_RR( tmp{tmp}r, B{j}r, A{i}r, gvl);")
  198. dest.write("tmp{tmp}i = VFMACC_RI( tmp{tmp}i, B{j}r, A{i}i, gvl);")
  199. for j in range(tj, tj+tmp_unroll_j):
  200. for i in range(ti, ti+tmp_unroll_i):
  201. with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
  202. dest.write("{acc_vector_t} ACC{dest}r = tmp{tmp}r;")
  203. dest.write("{acc_vector_t} ACC{dest}i = tmp{tmp}i;")
  204. with dest.block("for({index_type} k=1; k<{Kend}; k++) {{", "}}", Kend=('pass_K' if TRMM else 'K')):
  205. for i in range(N):
  206. dest.write("B{i}r = B[bi+{i}*2+0];", i=i)
  207. dest.write("B{i}i = B[bi+{i}*2+1];", i=i)
  208. dest.write("bi += {N}*2;")
  209. dest.write()
  210. for i in range(a_regs):
  211. dest.write("A{i}r = {VLSEV}( &A[ai+{i}*gvl*2], sizeof(FLOAT)*2, gvl );", i=i)
  212. dest.write("A{i}i = {VLSEV}( &A[ai+{i}*gvl*2+1], sizeof(FLOAT)*2, gvl );", i=i)
  213. dest.write("ai += {M}*2;")
  214. dest.write()
  215. for tj in range(0, N, tmp_unroll_j):
  216. for ti in range(0, a_regs, tmp_unroll_i):
  217. # note the values in tmp{tmp}* are frequently of similar magnitude and opposite sign
  218. # so accumulating them directly to ACC would lose precision when ACC is larger
  219. for j in range(tj, tj+tmp_unroll_j):
  220. for i in range(ti, ti+tmp_unroll_i):
  221. with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
  222. dest.write("tmp{tmp}r = {VMUL_TO_ACC}( A{i}i, B{j}i, gvl);")
  223. dest.write("tmp{tmp}i = {VMUL_TO_ACC}( A{i}r, B{j}i, gvl);")
  224. for j in range(tj, tj+tmp_unroll_j):
  225. for i in range(ti, ti+tmp_unroll_i):
  226. with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
  227. dest.write("tmp{tmp}r = VFMACC_RR( tmp{tmp}r, B{j}r, A{i}r, gvl);")
  228. dest.write("tmp{tmp}i = VFMACC_RI( tmp{tmp}i, B{j}r, A{i}i, gvl);")
  229. for j in range(tj, tj+tmp_unroll_j):
  230. for i in range(ti, ti+tmp_unroll_i):
  231. with dest.map(dest=j*a_regs+i, tmp=(i-ti)+tmp_unroll_i*(j-tj), i=i, j=j):
  232. dest.write("ACC{dest}r = {__riscv_}vfadd( ACC{dest}r, tmp{tmp}r, gvl);")
  233. dest.write("ACC{dest}i = {__riscv_}vfadd( ACC{dest}i, tmp{tmp}i, gvl);")
  234. dest.write()
  235. dest.write("{index_type} ci=n_top*ldc+m_top;")
  236. dest.write()
  237. for j in range(N):
  238. if TRMM:
  239. for i in range(a_regs):
  240. with dest.map(idx=j*a_regs+i):
  241. dest.write("{param_vector_t} C{idx}r = {__riscv_}vfmul( ACC{idx}r, alphar, gvl );")
  242. dest.write("{param_vector_t} C{idx}i = {__riscv_}vfmul( ACC{idx}i, alphar, gvl );")
  243. else:
  244. for i in range(a_regs):
  245. idx = j*a_regs+i
  246. increment = 'ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
  247. if idx == N*a_regs-1:
  248. increment = ''
  249. with dest.map(idx=j*a_regs+i, increment=increment):
  250. dest.write("{param_vector_t} C{idx}r = {VLSEV}( &C[ci*2+0], sizeof(FLOAT)*2, gvl );")
  251. dest.write("{param_vector_t} C{idx}i = {VLSEV}( &C[ci*2+1], sizeof(FLOAT)*2, gvl );")
  252. dest.write("{increment}")
  253. if not TRMM:
  254. for j in range(N):
  255. for i in range(a_regs):
  256. with dest.map(idx=j*a_regs+i):
  257. dest.write("C{idx}r = {__riscv_}vfmacc( C{idx}r, alphar, ACC{idx}r, gvl );")
  258. dest.write("C{idx}i = {__riscv_}vfmacc( C{idx}i, alphar, ACC{idx}i, gvl );")
  259. for j in range(N):
  260. for i in range(a_regs):
  261. with dest.map(idx=j*a_regs+i):
  262. dest.write("C{idx}r = {__riscv_}vfnmsac( C{idx}r, alphai, ACC{idx}i, gvl );")
  263. dest.write("C{idx}i = {__riscv_}vfmacc ( C{idx}i, alphai, ACC{idx}r, gvl );")
  264. if not TRMM:
  265. dest.write()
  266. dest.write("ci=n_top*ldc+m_top;")
  267. dest.write()
  268. for j in range(N):
  269. for i in range(a_regs):
  270. idx = j*a_regs+i
  271. increment = 'ci += ldc-gvl*{};'.format(a_regs-1) if (i == a_regs-1) else ' ci += gvl;'
  272. if idx == N*a_regs-1:
  273. increment = ''
  274. with dest.map(idx=j*a_regs+i, increment=increment):
  275. dest.write("{VSSEV}( &C[ci*2+0], sizeof(FLOAT)*2, C{idx}r, gvl);")
  276. dest.write("{VSSEV}( &C[ci*2+1], sizeof(FLOAT)*2, C{idx}i, gvl);")
  277. dest.write("{increment}")
  278. #-----------------------------------------------------------------------
  279. def generate_gemm_kernel( settings, OUTPUT ):
  280. if settings['conjugate'].value:
  281. ERROR('conjugate gemm not yet supported')
  282. is_complex = settings['complex'].value
  283. generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
  284. dest = Target(OUTPUT, { k:str(settings[k].value) for k in settings })
  285. M = settings['M'].value
  286. N = settings['N'].value
  287. vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value /
  288. settings['ELEN_PARAM'].value)
  289. a_regs = max(int(M/vlenmax), 1)
  290. # for each vector register loaded from matrix A, we require N registers to hold vector-scalar multiply-accumulate results
  291. accumulation_regs = a_regs * N
  292. required_regs = accumulation_regs + a_regs
  293. if is_complex:
  294. required_regs = required_regs * 2 + 2
  295. dest.write('''
  296. #if defined(NN) || defined(NT) || defined(TN) || defined(TT)
  297. #define S0 1
  298. #define S1 -1
  299. #define S2 1
  300. #define S3 1
  301. #define VFMACC_RR __riscv_vfmsac{tail_policy}
  302. #define VFMACC_RI __riscv_vfmacc{tail_policy}
  303. #endif
  304. #if defined(NR) || defined(NC) || defined(TR) || defined(TC)
  305. #define S0 1
  306. #define S1 1
  307. #define S2 1
  308. #define S3 -1
  309. #define VFMACC_RR __riscv_vfmacc{tail_policy}
  310. #define VFMACC_RI __riscv_vfmsac{tail_policy}
  311. #endif
  312. #if defined(RN) || defined(RT) || defined(CN) || defined(CT)
  313. #define S0 1
  314. #define S1 1
  315. #define S2 -1
  316. #define S3 1
  317. #define VFMACC_RR __riscv_vfmacc{tail_policy}
  318. #define VFMACC_RI __riscv_vfnmsac{tail_policy}
  319. #endif
  320. #if defined(RR) || defined(RC) || defined(CR) || defined(CC)
  321. #define S0 1
  322. #define S1 -1
  323. #define S2 -1
  324. #define S3 -1
  325. #define VFMACC_RR __riscv_vfmsac{tail_policy}
  326. #define VFMACC_RI __riscv_vfnmacc{tail_policy}
  327. #endif
  328. '''.format(tail_policy=settings['tail_policy'].value))
  329. if required_regs > (32 // settings['LMUL_ACC'].value):
  330. raise Exception("{} vector registers needed during accumulation for unrolling {} x {}{} but only {} are available".format(
  331. required_regs, N, M, (" with wide accumulator" if settings['LMUL_ACC'].value > 1 else ''), 32 // settings['LMUL_ACC'].value
  332. ))
  333. TRMM = (settings['op'].value == 'trmm')
  334. if TRMM:
  335. with dest.block("#if defined(LEFT) != defined(TRANSA)", "#endif"):
  336. dest.write("#define BACKWARDS")
  337. dest.write("int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, {alpha}, FLOAT* A, FLOAT* B, FLOAT* C, BLASLONG ldc{trmm})",
  338. alpha = ('FLOAT alphar, FLOAT alphai' if is_complex else 'FLOAT alpha'),
  339. trmm = (', BLASLONG offset' if TRMM else '')
  340. )
  341. with dest.block("{{", "}}", elt_size='*2' if is_complex else ''):
  342. if settings['trace'].value:
  343. dest.write("printf(\"\\n\\nENTRY: %s(%d) M %d N %d K %d ldc %d\\n\", __FILE__, __LINE__, M, N, K, ldc);")
  344. dest.write("{index_type} gvl = 0;")
  345. dest.write("{index_type} m_top = 0;")
  346. dest.write("{index_type} n_top = 0;")
  347. dest.write()
  348. dest.write()
  349. dest.write("// -- MAIN PASS")
  350. with dest.block("for ({index_type} j=0; j<N/{N}; j+=1) {{", "}}"):
  351. dest.write("m_top = 0;")
  352. dest.write("{index_type} gvl = {VSETVL}({vlenmax});", vlenmax=min(vlenmax,max(int(M/a_regs),1)))
  353. dest.write()
  354. with dest.block("for ({index_type} i=0; i<M/{M}; i+=1) {{", "}}"):
  355. generate_gemm_kernel_inner( settings, dest, M, N, vlenmax, a_regs )
  356. dest.write( "m_top += {M};" )
  357. dest.write()
  358. dest.write()
  359. dest.write("// -- tails for main pass")
  360. generate_M_tails( dest, settings, M, N )
  361. dest.write( "n_top += {N};" )
  362. N_tail = int(N/2)
  363. while( N_tail > 0 ):
  364. with dest.map(N=N_tail):
  365. dest.write()
  366. dest.write()
  367. dest.write("// -- tails for N={N}")
  368. with dest.block("if( N & {N} ) {{", "}}" ):
  369. if settings['trace'].value:
  370. dest.write("printf(\"N tail entry: %s(%d) M %d N %d K %d m_top %d n_top %d\\n\", __FILE__, __LINE__, M, N, K, m_top, n_top);")
  371. dest.write("gvl = {VSETVL}({vlenmax});", vlenmax=min(vlenmax,max(int(M/a_regs),1)))
  372. dest.write("m_top = 0;")
  373. with dest.block("for ({index_type} i=0; i<M/{M}; i+=1) {{", "}}"):
  374. generate_gemm_kernel_inner( settings, dest, M, N_tail, vlenmax, a_regs )
  375. dest.write("m_top += {M};")
  376. generate_M_tails( dest, settings, M, N_tail )
  377. dest.write("n_top += {N};")
  378. N_tail = int(N_tail/2)
  379. dest.write("return 0;");
  380. #-----------------------------------------------------------------------
  381. def generate_M_tails( dest, settings, M, N ):
  382. M_tail = int(M/2)
  383. M_tail_min = settings['M_tail_scalar_from'].value
  384. vlenmax = int(settings['reg_width_bits'].value * settings['LMUL_ACC'].value
  385. / settings['ELEN_PARAM'].value )
  386. TRMM = (settings['op'].value == 'trmm')
  387. is_complex = settings['complex'].value
  388. generate_gemm_kernel_inner = generate_gemm_kernel_inner_complex if is_complex else generate_gemm_kernel_inner_real
  389. while( M_tail > M_tail_min ):
  390. with dest.block("if( M & {M_tail} ) {{", "}}", M_tail=M_tail ):
  391. if settings['trace'].value:
  392. dest.write("printf(\"tail: %s(%d) M %d N %d K %d m_top %d n_top %d\\n\", __FILE__, __LINE__, M, N, K, m_top, n_top);")
  393. a_regs = max( 1, int(M_tail/vlenmax) )
  394. vlen = int(M_tail/a_regs)
  395. dest.write("gvl = {VSETVL}({vlen});\n", vlen=vlen)
  396. generate_gemm_kernel_inner( settings, dest, M_tail, N, vlen, a_regs )
  397. dest.write( "m_top += {M_tail};" )
  398. M_tail = int( M_tail / 2 )
  399. while( M_tail > 0 ):
  400. with dest.block("if( M & {M_tail} ) {{", "}}",
  401. M_tail=M_tail,
  402. N=N,
  403. result_t = ('double' if settings['force_acc_double'].value else settings['param_scalar_t'].value)
  404. ):
  405. if settings['trace'].value:
  406. dest.write("printf(\"tail: %s(%d) M %d N %d K %d m_top %d n_top %d\\n\", __FILE__, __LINE__, M, N, K, m_top, n_top);")
  407. for r in range(M_tail * N * (2 if is_complex else 1)):
  408. dest.write("{result_t} result{r} = 0;",
  409. r=r
  410. )
  411. dest.write("{index_type} ai=m_top*K{elt_size};")
  412. dest.write("{index_type} bi=n_top*K{elt_size};")
  413. if TRMM:
  414. with dest.map(M=M_tail, N=N):
  415. generate_trmm_block( dest )
  416. with dest.block("for({index_type} k=0; k<{Kend}; k++) {{", "}}", Kend = ('pass_K' if TRMM else 'K') ):
  417. for ki in range( N ):
  418. for kj in range( M_tail ):
  419. if is_complex:
  420. dest.write("result{dest}+=S0*A[ai+{kj}+0]*B[bi+{ki}+0] + S1*A[ai+{kj}+1]*B[bi+{ki}+1];".format(
  421. dest=(ki*M_tail+kj)*2, kj=kj*2, ki=ki*2
  422. ))
  423. dest.write("result{dest}+=S2*A[ai+{kj}+1]*B[bi+{ki}+0] + S3*A[ai+{kj}+0]*B[bi+{ki}+1];".format(
  424. dest=(ki*M_tail+kj)*2+1, kj=kj*2, ki=ki*2
  425. ))
  426. else:
  427. dest.write("result{dest}+=A[ai+{kj}]*B[bi+{ki}];".format(
  428. dest=ki*M_tail+kj, kj=kj, ki=ki
  429. ))
  430. dest.write("ai+={M_tail}{elt_size};")
  431. dest.write("bi+={N}{elt_size};")
  432. dest.write("{index_type} ci=n_top*ldc+m_top;")
  433. if is_complex:
  434. dest.write("{result_t} Cr, Ci;")
  435. for ki in range( N ):
  436. for kj in range( M_tail ):
  437. if is_complex:
  438. if TRMM:
  439. dest.write('Cr = result{dest}*alphar;', dest=(ki*M_tail+kj)*2+0)
  440. dest.write('Ci = result{dest}*alphar;', dest=(ki*M_tail+kj)*2+1)
  441. else:
  442. dest.write('Cr = C[(ci+{ki}*ldc+{kj})*2+0];', ki=ki, kj=kj)
  443. dest.write('Ci = C[(ci+{ki}*ldc+{kj})*2+1];', ki=ki, kj=kj)
  444. dest.write('Cr += result{dest}*alphar;', dest=(ki*M_tail+kj)*2+0)
  445. dest.write('Ci += result{dest}*alphar;', dest=(ki*M_tail+kj)*2+1)
  446. dest.write('Cr -= result{dest}*alphai;', dest=(ki*M_tail+kj)*2+1)
  447. dest.write('Ci += result{dest}*alphai;', dest=(ki*M_tail+kj)*2+0)
  448. dest.write("C[(ci+{ki}*ldc+{kj})*2+0] = Cr;", ki=ki, kj=kj )
  449. dest.write("C[(ci+{ki}*ldc+{kj})*2+1] = Ci;", ki=ki, kj=kj )
  450. else:
  451. op = '' if TRMM else '+'
  452. dest.write("C[ci+{ki}*ldc+{kj}] {op}= alpha * result{dest};",
  453. ki=ki, kj=kj, op=op, dest=ki*M_tail+kj
  454. )
  455. dest.write("m_top+={M_tail};")
  456. M_tail = int(M_tail/2)
  457. #-----------------------------------------------------------------------
  458. class Setting(object):
  459. def __init__( self, value, convert = None ):
  460. self._value = value
  461. self._convert = convert
  462. @classmethod
  463. def ENUM( cls, *values ):
  464. def closure( values ):
  465. return lambda value: values[value.lower()]
  466. return closure( { v.lower():v for v in values } )
  467. @classmethod
  468. def BOOL( cls, value ):
  469. return value.lower().startswith('t') or value == '1'
  470. @property
  471. def value( self ):
  472. return self._value
  473. @property
  474. def configurable( self ):
  475. return self._convert is not None
  476. @value.setter
  477. def value( self, value ):
  478. self._value = self._convert( value )
  479. def __str__( self ):
  480. return str(self._value)
  481. #-----------------------------------------------------------------------
  482. def main():
  483. settings = {
  484. 'op': Setting( 'gemm', Setting.ENUM( 'gemm', 'trmm' ) ),
  485. 'M': Setting( 16, int ),
  486. 'N': Setting( 4, int ),
  487. 'reg_width_bits': Setting( 256, int ),
  488. 'LMUL': Setting( 1, int ),
  489. 'M_tail_scalar_from':Setting( 2, int ),
  490. 'cpu': Setting( 'zvl256b', str ),
  491. 'param_precision': Setting( 'float', Setting.ENUM( 'float', 'double' ) ),
  492. 'force_acc_double': Setting( False, Setting.BOOL ),
  493. 'complex': Setting( False, Setting.BOOL ),
  494. 'conjugate': Setting( False, Setting.BOOL ),
  495. 'index_type': Setting( 'BLASLONG', str ),
  496. 'trace': Setting( False, Setting.BOOL ),
  497. 'output': Setting( None, str ),
  498. 'tail_policy': Setting( '', str ), # _ta, if toolchain supports it
  499. '__riscv_': Setting( '__riscv_', str),
  500. }
  501. for item in sys.argv[1:]:
  502. try:
  503. name, value = tuple(item.split( '=', 1 ))
  504. except:
  505. ERROR("couldn't parse {}, expected arguments of the form name=value".format(item))
  506. if name not in settings:
  507. ERROR("couldn't parse {}, {} it is not a known option\n".format( item, name )
  508. +"options (and current defaults) are\n{}".format(
  509. " ".join([ '{}={}'.format(k, settings[k].value) for k in settings.keys()]))
  510. )
  511. try:
  512. settings[name].value = value
  513. except:
  514. import traceback
  515. traceback.print_exc()
  516. ERROR("couldn't parse {}".format(item))
  517. if settings['output'].value is None:
  518. if settings['complex'].value:
  519. prefix = 'z' if settings['param_precision'].value == 'double' else 'c'
  520. else:
  521. prefix = 'd' if settings['param_precision'].value == 'double' else 's'
  522. settings['output'] = Setting('{}{}_kernel_{}x{}_{}.c'.format(
  523. prefix,
  524. settings['op'],
  525. settings['M'],
  526. settings['N'],
  527. settings['cpu']
  528. ))
  529. if settings['param_precision'].value == 'double':
  530. settings['param_scalar_t'] = Setting( 'double' )
  531. settings['ELEN_PARAM'] = Setting(64)
  532. else:
  533. settings['param_scalar_t'] = Setting( 'float' )
  534. settings['ELEN_PARAM'] = Setting(32)
  535. settings['VFMUL'] = Setting( '{}vfmul_vf_f{}m{}{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL'], settings['tail_policy']) )
  536. settings['VFMACC'] = Setting( '{}vfmacc_vf_f{}m{}{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL'], settings['tail_policy']) )
  537. settings['ELEN_ACC'] = settings['ELEN_PARAM']
  538. settings['LMUL_ACC'] = Setting(settings['LMUL'].value)
  539. widen = ''
  540. if settings['force_acc_double'].value and (settings['param_precision'].value == 'float'):
  541. settings['ELEN_ACC'] = Setting(64)
  542. settings['LMUL_ACC'] = Setting(settings['LMUL'].value*2)
  543. settings['VFNCVT'] = Setting('{}vfncvt_f_f_w_f{}m{}{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL'], settings['tail_policy']))
  544. widen = 'w'
  545. settings['VMUL_TO_ACC'] = Setting( '{}vf{}mul_vf_f{}m{}{}'.format(settings['__riscv_'], widen, settings['ELEN_ACC'], settings['LMUL_ACC'], settings['tail_policy']) )
  546. settings['VMACC_TO_ACC'] = Setting( '{}vf{}macc_vf_f{}m{}{}'.format(settings['__riscv_'], widen, settings['ELEN_ACC'], settings['LMUL_ACC'], settings['tail_policy']) )
  547. settings['param_vector_t']=Setting('vfloat{}m{}_t'.format(settings['ELEN_PARAM'], settings['LMUL']))
  548. settings['acc_vector_t'] =Setting('vfloat{}m{}_t'.format(settings['ELEN_ACC'], settings['LMUL_ACC']))
  549. settings['VLEV'] =Setting('{}vle{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
  550. settings['VSEV'] =Setting('{}vse{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
  551. settings['VLSEV'] =Setting('{}vlse{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
  552. settings['VSSEV'] =Setting('{}vsse{}_v_f{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['ELEN_PARAM'], settings['LMUL']))
  553. settings['VSETVL'] =Setting('{}vsetvl_e{}m{}'.format(settings['__riscv_'], settings['ELEN_PARAM'], settings['LMUL']))
  554. to_stdout = (settings['output'].value == '-')
  555. if not to_stdout:
  556. print("Writing {}".format(settings['output'].value), file=sys.stderr)
  557. with open(sys.stdout.fileno() if to_stdout else settings['output'].value, 'w') as destination_file:
  558. def OUTPUT(*args, **kwargs):
  559. print(*args, file=destination_file, **kwargs)
  560. OUTPUT("/*\n\nAUTOGENERATED KERNEL\nSettings:\n {}".format(" ".join([ "{}={}\n".format(k, repr(settings[k].value)) for k in sorted(settings.keys()) if settings[k].configurable])))
  561. OUTPUT("Derived:\n {}\n*/\n".format(" ".join([ "{}={}\n".format(k, repr(settings[k].value)) for k in sorted(settings.keys()) if not settings[k].configurable])))
  562. OUTPUT('#include "common.h"')
  563. OUTPUT("\n")
  564. if settings['op'].value in ('gemm', 'trmm'):
  565. generate_gemm_kernel(settings, OUTPUT)
  566. else:
  567. ERROR("unsupported kernel type {}".format(settings['op']))
  568. if __name__ == "__main__":
  569. main()