#include "common.h" /* for debugging/unit tests * this is a drop-in replacement for zgemm/cgemm/ztrmm/ctrmm kernels that supports arbitrary combinations of unroll values */ #ifdef TRMMKERNEL #if defined(LEFT) != defined(TRANSA) #define BACKWARDS #endif #endif #ifdef DOUBLE #define UNROLL_M ZGEMM_DEFAULT_UNROLL_M #define UNROLL_N ZGEMM_DEFAULT_UNROLL_N #else #define UNROLL_M CGEMM_DEFAULT_UNROLL_M #define UNROLL_N CGEMM_DEFAULT_UNROLL_N #endif int CNAME(BLASLONG M,BLASLONG N,BLASLONG K,FLOAT alphar,FLOAT alphai,FLOAT* A,FLOAT* B,FLOAT* C,BLASLONG ldc #ifdef TRMMKERNEL ,BLASLONG offset #endif ) { FLOAT res[UNROLL_M*UNROLL_N*2]; #if defined(NN) || defined(NT) || defined(TN) || defined(TT) FLOAT sign[4] = { 1, -1, 1, 1}; #endif #if defined(NR) || defined(NC) || defined(TR) || defined(TC) FLOAT sign[4] = { 1, 1, 1, -1}; #endif #if defined(RN) || defined(RT) || defined(CN) || defined(CT) FLOAT sign[4] = { 1, 1, -1, 1}; #endif #if defined(RR) || defined(RC) || defined(CR) || defined(CC) FLOAT sign[4] = { 1, -1, -1, -1}; #endif BLASLONG n_packing = UNROLL_N; BLASLONG n_top = 0; while(n_top < N) { while( n_top+n_packing > N ) n_packing >>= 1; BLASLONG m_packing = UNROLL_M; BLASLONG m_top = 0; while (m_top < M) { while( m_top+m_packing > M ) m_packing >>= 1; BLASLONG ai = K*m_top*2; BLASLONG bi = K*n_top*2; BLASLONG pass_K = K; #ifdef TRMMKERNEL #ifdef LEFT BLASLONG off = offset + m_top; #else BLASLONG off = -offset + n_top; #endif #ifdef BACKWARDS ai += off * m_packing*2; bi += off * n_packing*2; pass_K -= off; #else #ifdef LEFT pass_K = off + m_packing; #else pass_K = off + n_packing; #endif #endif #endif memset( res, 0, UNROLL_M*UNROLL_N*2*sizeof(FLOAT) ); for (BLASLONG k=0; k