|
|
@@ -67,6 +67,17 @@ kernel void kernel_add( |
|
|
dst[tpig] = src0[tpig] + src1[tpig]; |
|
|
dst[tpig] = src0[tpig] + src1[tpig]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// assumption: src1 is a row |
|
|
|
|
|
// broadcast src1 into src0 |
|
|
|
|
|
kernel void kernel_add_row( |
|
|
|
|
|
device const float * src0, |
|
|
|
|
|
device const float * src1, |
|
|
|
|
|
device float * dst, |
|
|
|
|
|
constant int64_t & ne00, |
|
|
|
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
|
|
|
dst[tpig] = src0[tpig] + src1[tpig % ne00]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
kernel void kernel_mul( |
|
|
kernel void kernel_mul( |
|
|
device const float * src0, |
|
|
device const float * src0, |
|
|
device const float * src1, |
|
|
device const float * src1, |
|
|
@@ -376,87 +387,90 @@ kernel void kernel_rms_norm( |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) |
|
|
|
|
|
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { |
|
|
|
|
|
|
|
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) |
|
|
|
|
|
// il indicates where the q4 quants begin (0 or QK4_0/4) |
|
|
|
|
|
// we assume that the yl's have been multiplied with the appropriate scale factor |
|
|
|
|
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) |
|
|
|
|
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { |
|
|
float d = qb_curr->d; |
|
|
float d = qb_curr->d; |
|
|
float4 acc = 0.f; |
|
|
|
|
|
device uint16_t * qs = ((device uint16_t *)qb_curr + 1); |
|
|
|
|
|
for (int i = 0; i < 16; i+=2) { |
|
|
|
|
|
acc[0] += yl[i] * (qs[i / 2] & 0x000F); |
|
|
|
|
|
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); |
|
|
|
|
|
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); |
|
|
|
|
|
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); |
|
|
|
|
|
|
|
|
float2 acc = 0.f; |
|
|
|
|
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); |
|
|
|
|
|
for (int i = 0; i < 8; i+=2) { |
|
|
|
|
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) |
|
|
|
|
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00); |
|
|
|
|
|
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) |
|
|
|
|
|
+ yl[i + 9] * (qs[i / 2] & 0xF000); |
|
|
} |
|
|
} |
|
|
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); |
|
|
|
|
|
|
|
|
return d * (sumy * -8.f + acc[0] + acc[1]); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) |
|
|
|
|
|
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { |
|
|
|
|
|
|
|
|
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) |
|
|
|
|
|
// il indicates where the q4 quants begin (0 or QK4_0/4) |
|
|
|
|
|
// we assume that the yl's have been multiplied with the appropriate scale factor |
|
|
|
|
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) |
|
|
|
|
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { |
|
|
float d = qb_curr->d; |
|
|
float d = qb_curr->d; |
|
|
float m = qb_curr->m; |
|
|
float m = qb_curr->m; |
|
|
float4 acc = 0.f; |
|
|
|
|
|
device uint16_t * qs = ((device uint16_t *)qb_curr + 2); |
|
|
|
|
|
for (int i = 0; i < 16; i+=2) { |
|
|
|
|
|
acc[0] += yl[i] * (qs[i / 2] & 0x000F); |
|
|
|
|
|
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); |
|
|
|
|
|
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); |
|
|
|
|
|
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); |
|
|
|
|
|
|
|
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); |
|
|
|
|
|
float2 acc = 0.f; |
|
|
|
|
|
for (int i = 0; i < 8; i+=2) { |
|
|
|
|
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) |
|
|
|
|
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00); |
|
|
|
|
|
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) |
|
|
|
|
|
+ yl[i + 9] * (qs[i / 2] & 0xF000); |
|
|
} |
|
|
} |
|
|
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; |
|
|
|
|
|
|
|
|
return d * (acc[0] + acc[1]) + sumy * m; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// putting them in the kernel cause a significant performance penalty |
|
|
// putting them in the kernel cause a significant performance penalty |
|
|
#define N_DST 4 // each SIMD group works on 4 rows |
|
|
#define N_DST 4 // each SIMD group works on 4 rows |
|
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group |
|
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group |
|
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 |
|
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 |
|
|
template<typename block_q_type> |
|
|
|
|
|
|
|
|
//Note: This is a template, but strictly speaking it only applies to |
|
|
|
|
|
// quantizations where the block size is 32. It also does not |
|
|
|
|
|
// giard against the number of rows not being divisible by |
|
|
|
|
|
// N_DST, so this is another explicit assumption of the implementation. |
|
|
|
|
|
template<typename block_q_type, int nr, int nsg, int nw> |
|
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, |
|
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, |
|
|
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, |
|
|
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, |
|
|
uint2 tgpig, uint tiisg, uint sgitg) { |
|
|
uint2 tgpig, uint tiisg, uint sgitg) { |
|
|
const int nb = ne00/QK4_0; |
|
|
const int nb = ne00/QK4_0; |
|
|
const int r0 = tgpig.x; |
|
|
const int r0 = tgpig.x; |
|
|
const int r1 = tgpig.y; |
|
|
const int r1 = tgpig.y; |
|
|
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; |
|
|
|
|
|
|
|
|
const int first_row = (r0 * nsg + sgitg) * nr; |
|
|
|
|
|
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; |
|
|
device const float * y = (device const float *) src1 + r1*ne10; |
|
|
device const float * y = (device const float *) src1 + r1*ne10; |
|
|
float4 y_curr[8]; // src1 vector cache |
|
|
|
|
|
float sumf[N_DST]={0.f}, all_sum; |
|
|
|
|
|
thread float * yl=(thread float *)y_curr; |
|
|
|
|
|
|
|
|
float yl[16]; // src1 vector cache |
|
|
|
|
|
float sumf[nr]={0.f}; |
|
|
|
|
|
|
|
|
// each thread in a SIMD group deals with 1 block. |
|
|
|
|
|
for (int column = 0; column < nb / N_SIMDWIDTH; column++) { |
|
|
|
|
|
float sumy = 0; |
|
|
|
|
|
for (int i = 0; i < QK4_0 / 4; i++) { |
|
|
|
|
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); |
|
|
|
|
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
const int ix = tiisg/2; |
|
|
|
|
|
const int il = 8*(tiisg%2); |
|
|
|
|
|
|
|
|
for (int row = 0; row < N_DST; row++) { |
|
|
|
|
|
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
device const float * yb = y + ix * QK4_0 + il; |
|
|
|
|
|
|
|
|
// from now loads two rows every time and 16 blocks per row |
|
|
|
|
|
int ir = tiisg / (N_SIMDWIDTH / 2); |
|
|
|
|
|
int ib = tiisg % (N_SIMDWIDTH / 2); |
|
|
|
|
|
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { |
|
|
|
|
|
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start |
|
|
|
|
|
|
|
|
// each thread in a SIMD group deals with half a block. |
|
|
|
|
|
for (int ib = ix; ib < nb; ib += nw/2) { |
|
|
float sumy = 0; |
|
|
float sumy = 0; |
|
|
for (int i = 0; i < QK4_0 / 4; i++) { |
|
|
|
|
|
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); |
|
|
|
|
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; |
|
|
|
|
|
|
|
|
for (int i = 0; i < 8; i += 2) { |
|
|
|
|
|
sumy += yb[i] + yb[i+1]; |
|
|
|
|
|
yl[i+0] = yb[i+ 0]; |
|
|
|
|
|
yl[i+1] = yb[i+ 1]/256.f; |
|
|
|
|
|
sumy += yb[i+16] + yb[i+17]; |
|
|
|
|
|
yl[i+8] = yb[i+16]/16.f; |
|
|
|
|
|
yl[i+9] = yb[i+17]/4096.f; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (int row = 0; row < N_DST; row+=2) { |
|
|
|
|
|
if (nb_start + ib < nb) { |
|
|
|
|
|
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
for (int row = 0; row < nr; row++) { |
|
|
|
|
|
sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
yb += QK4_0 * 16; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (int row = 0; row < N_DST; ++row) { |
|
|
|
|
|
all_sum = simd_sum(sumf[row]); |
|
|
|
|
|
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { |
|
|
|
|
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; |
|
|
|
|
|
|
|
|
for (int row = 0; row < nr; ++row) { |
|
|
|
|
|
const float tot = simd_sum(sumf[row]); |
|
|
|
|
|
if (tiisg == 0 && first_row + row < ne01) { |
|
|
|
|
|
dst[r1*ne0 + first_row + row] = tot; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -472,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32( |
|
|
uint2 tgpig[[threadgroup_position_in_grid]], |
|
|
uint2 tgpig[[threadgroup_position_in_grid]], |
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) { |
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) { |
|
|
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); |
|
|
|
|
|
|
|
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
kernel void kernel_mul_mat_q4_1_f32( |
|
|
kernel void kernel_mul_mat_q4_1_f32( |
|
|
@@ -486,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32( |
|
|
uint2 tgpig[[threadgroup_position_in_grid]], |
|
|
uint2 tgpig[[threadgroup_position_in_grid]], |
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) { |
|
|
uint sgitg[[simdgroup_index_in_threadgroup]]) { |
|
|
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); |
|
|
|
|
|
|
|
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
kernel void kernel_mul_mat_f16_f32( |
|
|
kernel void kernel_mul_mat_f16_f32( |
|
|
@@ -495,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32( |
|
|
device float * dst, |
|
|
device float * dst, |
|
|
constant int64_t & ne00, |
|
|
constant int64_t & ne00, |
|
|
constant int64_t & ne01, |
|
|
constant int64_t & ne01, |
|
|
|
|
|
constant int64_t & ne02, |
|
|
constant uint64_t & nb00, |
|
|
constant uint64_t & nb00, |
|
|
constant uint64_t & nb01, |
|
|
constant uint64_t & nb01, |
|
|
constant uint64_t & nb02, |
|
|
constant uint64_t & nb02, |
|
|
constant int64_t & ne10, |
|
|
constant int64_t & ne10, |
|
|
constant int64_t & ne11, |
|
|
constant int64_t & ne11, |
|
|
|
|
|
constant int64_t & ne12, |
|
|
constant uint64_t & nb10, |
|
|
constant uint64_t & nb10, |
|
|
constant uint64_t & nb11, |
|
|
constant uint64_t & nb11, |
|
|
constant uint64_t & nb12, |
|
|
constant uint64_t & nb12, |
|
|
@@ -515,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32( |
|
|
const int64_t r1 = tgpig.y; |
|
|
const int64_t r1 = tgpig.y; |
|
|
const int64_t im = tgpig.z; |
|
|
const int64_t im = tgpig.z; |
|
|
|
|
|
|
|
|
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); |
|
|
|
|
|
|
|
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); |
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |
|
|
|
|
|
|
|
|
sum[tpitg.x] = 0.0f; |
|
|
sum[tpitg.x] = 0.0f; |
|
|
@@ -538,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32( |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kernel void kernel_alibi_f32( |
|
|
kernel void kernel_alibi_f32( |
|
|
device const float * src0, |
|
|
device const float * src0, |
|
|
device float * dst, |
|
|
device float * dst, |
|
|
|