You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

drot.c 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #include "common.h"
  2. #if defined(SKYLAKEX)
  3. #include "drot_microk_skylakex-2.c"
  4. #elif defined(HASWELL) || defined(ZEN)
  5. #include "drot_microk_haswell-2.c"
  6. #endif
  7. #ifndef HAVE_DROT_KERNEL
  8. #include "../simd/intrin.h"
  9. static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
  10. {
  11. BLASLONG i = 0;
  12. #if V_SIMD_F64 && V_SIMD > 256
  13. const int vstep = v_nlanes_f64;
  14. const int unrollx4 = n & (-vstep * 4);
  15. const int unrollx = n & -vstep;
  16. v_f64 __c = v_setall_f64(c);
  17. v_f64 __s = v_setall_f64(s);
  18. v_f64 vx0, vx1, vx2, vx3;
  19. v_f64 vy0, vy1, vy2, vy3;
  20. v_f64 vt0, vt1, vt2, vt3;
  21. for (; i < unrollx4; i += vstep * 4) {
  22. vx0 = v_loadu_f64(x + i);
  23. vx1 = v_loadu_f64(x + i + vstep);
  24. vx2 = v_loadu_f64(x + i + vstep * 2);
  25. vx3 = v_loadu_f64(x + i + vstep * 3);
  26. vy0 = v_loadu_f64(y + i);
  27. vy1 = v_loadu_f64(y + i + vstep);
  28. vy2 = v_loadu_f64(y + i + vstep * 2);
  29. vy3 = v_loadu_f64(y + i + vstep * 3);
  30. vt0 = v_mul_f64(__s, vy0);
  31. vt1 = v_mul_f64(__s, vy1);
  32. vt2 = v_mul_f64(__s, vy2);
  33. vt3 = v_mul_f64(__s, vy3);
  34. vt0 = v_muladd_f64(__c, vx0, vt0);
  35. vt1 = v_muladd_f64(__c, vx1, vt1);
  36. vt2 = v_muladd_f64(__c, vx2, vt2);
  37. vt3 = v_muladd_f64(__c, vx3, vt3);
  38. v_storeu_f64(x + i, vt0);
  39. v_storeu_f64(x + i + vstep, vt1);
  40. v_storeu_f64(x + i + vstep * 2, vt2);
  41. v_storeu_f64(x + i + vstep * 3, vt3);
  42. vt0 = v_mul_f64(__s, vx0);
  43. vt1 = v_mul_f64(__s, vx1);
  44. vt2 = v_mul_f64(__s, vx2);
  45. vt3 = v_mul_f64(__s, vx3);
  46. vt0 = v_mulsub_f64(__c, vy0, vt0);
  47. vt1 = v_mulsub_f64(__c, vy1, vt1);
  48. vt2 = v_mulsub_f64(__c, vy2, vt2);
  49. vt3 = v_mulsub_f64(__c, vy3, vt3);
  50. v_storeu_f64(y + i, vt0);
  51. v_storeu_f64(y + i + vstep, vt1);
  52. v_storeu_f64(y + i + vstep * 2, vt2);
  53. v_storeu_f64(y + i + vstep * 3, vt3);
  54. }
  55. for (; i < unrollx; i += vstep) {
  56. vx0 = v_loadu_f64(x + i);
  57. vy0 = v_loadu_f64(y + i);
  58. vt0 = v_mul_f64(__s, vy0);
  59. vt0 = v_muladd_f64(__c, vx0, vt0);
  60. v_storeu_f64(x + i, vt0);
  61. vt0 = v_mul_f64(__s, vx0);
  62. vt0 = v_mulsub_f64(__c, vy0, vt0);
  63. v_storeu_f64(y + i, vt0);
  64. }
  65. #else
  66. FLOAT f0, f1, f2, f3;
  67. FLOAT x0, x1, x2, x3;
  68. FLOAT g0, g1, g2, g3;
  69. FLOAT y0, y1, y2, y3;
  70. FLOAT* xp = x;
  71. FLOAT* yp = y;
  72. BLASLONG n1 = n & (~7);
  73. while (i < n1) {
  74. x0 = xp[0];
  75. y0 = yp[0];
  76. x1 = xp[1];
  77. y1 = yp[1];
  78. x2 = xp[2];
  79. y2 = yp[2];
  80. x3 = xp[3];
  81. y3 = yp[3];
  82. f0 = c*x0 + s*y0;
  83. g0 = c*y0 - s*x0;
  84. f1 = c*x1 + s*y1;
  85. g1 = c*y1 - s*x1;
  86. f2 = c*x2 + s*y2;
  87. g2 = c*y2 - s*x2;
  88. f3 = c*x3 + s*y3;
  89. g3 = c*y3 - s*x3;
  90. xp[0] = f0;
  91. yp[0] = g0;
  92. xp[1] = f1;
  93. yp[1] = g1;
  94. xp[2] = f2;
  95. yp[2] = g2;
  96. xp[3] = f3;
  97. yp[3] = g3;
  98. xp += 4;
  99. yp += 4;
  100. i += 4;
  101. }
  102. #endif
  103. while (i < n) {
  104. FLOAT temp = c*x[i] + s*y[i];
  105. y[i] = c*y[i] - s*x[i];
  106. x[i] = temp;
  107. i++;
  108. }
  109. }
  110. #endif
  111. static void rot_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
  112. {
  113. BLASLONG i = 0;
  114. BLASLONG ix = 0, iy = 0;
  115. FLOAT temp;
  116. if (n <= 0)
  117. return;
  118. if ((inc_x == 1) && (inc_y == 1)) {
  119. drot_kernel(n, x, y, c, s);
  120. }
  121. else {
  122. while (i < n) {
  123. temp = c * x[ix] + s * y[iy];
  124. y[iy] = c * y[iy] - s * x[ix];
  125. x[ix] = temp;
  126. ix += inc_x;
  127. iy += inc_y;
  128. i++;
  129. }
  130. }
  131. return;
  132. }
  133. #if defined(SMP)
  134. static int rot_thread_function(blas_arg_t *args)
  135. {
  136. rot_compute(args->m,
  137. args->a, args->lda,
  138. args->b, args->ldb,
  139. ((FLOAT *)args->alpha)[0],
  140. ((FLOAT *)args->alpha)[1]);
  141. return 0;
  142. }
  143. extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, int (*function)(void), int nthreads);
  144. #endif
  145. int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
  146. {
  147. #if defined(SMP)
  148. int nthreads;
  149. FLOAT alpha[2]={c, s};
  150. FLOAT dummy_c;
  151. #endif
  152. #if defined(SMP)
  153. if (inc_x == 0 || inc_y == 0 || n <= 100000) {
  154. nthreads = 1;
  155. }
  156. else {
  157. nthreads = num_cpu_avail(1);
  158. }
  159. if (nthreads == 1) {
  160. rot_compute(n, x, inc_x, y, inc_y, c, s);
  161. }
  162. else {
  163. #if defined(DOUBLE)
  164. int mode = BLAS_DOUBLE | BLAS_REAL | BLAS_PTHREAD;
  165. #else
  166. int mode = BLAS_SINGLE | BLAS_REAL | BLAS_PTHREAD;
  167. #endif
  168. blas_level1_thread(mode, n, 0, 0, alpha, x, inc_x, y, inc_y, &dummy_c, 0, (int (*)(void))rot_thread_function, nthreads);
  169. }
  170. #else
  171. rot_compute(n, x, inc_x, y, inc_y, c, s);
  172. #endif
  173. return 0;
  174. }