|
|
@@ -25,9 +25,9 @@ typedef struct { |
|
|
} block_q8_0; |
|
|
} block_q8_0; |
|
|
|
|
|
|
|
|
kernel void kernel_add( |
|
|
kernel void kernel_add( |
|
|
device const float * src0, |
|
|
|
|
|
device const float * src1, |
|
|
|
|
|
device float * dst, |
|
|
|
|
|
|
|
|
device const float4 * src0, |
|
|
|
|
|
device const float4 * src1, |
|
|
|
|
|
device float4 * dst, |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
dst[tpig] = src0[tpig] + src1[tpig]; |
|
|
dst[tpig] = src0[tpig] + src1[tpig]; |
|
|
} |
|
|
} |
|
|
@@ -35,18 +35,18 @@ kernel void kernel_add( |
|
|
// assumption: src1 is a row |
|
|
// assumption: src1 is a row |
|
|
// broadcast src1 into src0 |
|
|
// broadcast src1 into src0 |
|
|
kernel void kernel_add_row( |
|
|
kernel void kernel_add_row( |
|
|
device const float * src0, |
|
|
|
|
|
device const float * src1, |
|
|
|
|
|
device float * dst, |
|
|
|
|
|
constant int64_t & ne00, |
|
|
|
|
|
|
|
|
device const float4 * src0, |
|
|
|
|
|
device const float4 * src1, |
|
|
|
|
|
device float4 * dst, |
|
|
|
|
|
constant int64_t & nb, |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
dst[tpig] = src0[tpig] + src1[tpig % ne00]; |
|
|
|
|
|
|
|
|
dst[tpig] = src0[tpig] + src1[tpig % nb]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
kernel void kernel_mul( |
|
|
kernel void kernel_mul( |
|
|
device const float * src0, |
|
|
|
|
|
device const float * src1, |
|
|
|
|
|
device float * dst, |
|
|
|
|
|
|
|
|
device const float4 * src0, |
|
|
|
|
|
device const float4 * src1, |
|
|
|
|
|
device float4 * dst, |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
dst[tpig] = src0[tpig] * src1[tpig]; |
|
|
dst[tpig] = src0[tpig] * src1[tpig]; |
|
|
} |
|
|
} |
|
|
@@ -54,12 +54,12 @@ kernel void kernel_mul( |
|
|
// assumption: src1 is a row |
|
|
// assumption: src1 is a row |
|
|
// broadcast src1 into src0 |
|
|
// broadcast src1 into src0 |
|
|
kernel void kernel_mul_row( |
|
|
kernel void kernel_mul_row( |
|
|
device const float * src0, |
|
|
|
|
|
device const float * src1, |
|
|
|
|
|
device float * dst, |
|
|
|
|
|
constant int64_t & ne00, |
|
|
|
|
|
|
|
|
device const float4 * src0, |
|
|
|
|
|
device const float4 * src1, |
|
|
|
|
|
device float4 * dst, |
|
|
|
|
|
constant int64_t & nb, |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
uint tpig[[thread_position_in_grid]]) { |
|
|
dst[tpig] = src0[tpig] * src1[tpig % ne00]; |
|
|
|
|
|
|
|
|
dst[tpig] = src0[tpig] * src1[tpig % nb]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
kernel void kernel_scale( |
|
|
kernel void kernel_scale( |
|
|
@@ -133,19 +133,24 @@ kernel void kernel_soft_max( |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// broadcast |
|
|
|
|
|
if (tpitg[0] == 0) { |
|
|
|
|
|
buf[0] = buf[0]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of |
|
|
|
|
|
// the loop, and when that is done, buf[0] has the correct (synchronized) value |
|
|
|
|
|
//if (tpitg[0] == 0) { |
|
|
|
|
|
// buf[0] = buf[0]; |
|
|
|
|
|
//} |
|
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
|
//threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
|
const float max = buf[0]; |
|
|
const float max = buf[0]; |
|
|
|
|
|
|
|
|
// parallel sum |
|
|
// parallel sum |
|
|
buf[tpitg[0]] = 0.0f; |
|
|
buf[tpitg[0]] = 0.0f; |
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |
|
|
buf[tpitg[0]] += exp(psrc0[i00] - max); |
|
|
|
|
|
|
|
|
const float exp_psrc0 = exp(psrc0[i00] - max); |
|
|
|
|
|
buf[tpitg[0]] += exp_psrc0; |
|
|
|
|
|
// Remember the result of exp here. exp is expensive, so we really do not |
|
|
|
|
|
// whish to compute it twice. |
|
|
|
|
|
pdst[i00] = exp_psrc0; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// reduce |
|
|
// reduce |
|
|
@@ -157,17 +162,18 @@ kernel void kernel_soft_max( |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// broadcast |
|
|
|
|
|
if (tpitg[0] == 0) { |
|
|
|
|
|
buf[0] = buf[0]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// broadcast - not needed, see above |
|
|
|
|
|
//// broadcast |
|
|
|
|
|
//if (tpitg[0] == 0) { |
|
|
|
|
|
// buf[0] = buf[0]; |
|
|
|
|
|
//} |
|
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
|
//threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
|
const float sum = buf[0]; |
|
|
const float sum = buf[0]; |
|
|
|
|
|
|
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |
|
|
pdst[i00] = exp(psrc0[i00] - max) / sum; |
|
|
|
|
|
|
|
|
pdst[i00] /= sum; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -214,25 +220,17 @@ kernel void kernel_norm( |
|
|
} |
|
|
} |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
} |
|
|
} |
|
|
// broadcast |
|
|
|
|
|
if (tpitg == 0) { |
|
|
|
|
|
sum[0] /= ne00; |
|
|
|
|
|
} |
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
const float mean = sum[0]; |
|
|
|
|
|
|
|
|
const float mean = sum[0] / ne00; |
|
|
|
|
|
|
|
|
// recenter |
|
|
|
|
|
|
|
|
// recenter and VARIANCE |
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
device float * y = dst + tgpig*ne00; |
|
|
device float * y = dst + tgpig*ne00; |
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |
|
|
|
|
|
y[i00] = x[i00] - mean; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// VARIANCE |
|
|
|
|
|
// parallel sum |
|
|
|
|
|
sum[tpitg] = 0.0f; |
|
|
sum[tpitg] = 0.0f; |
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |
|
|
|
|
|
y[i00] = x[i00] - mean; |
|
|
sum[tpitg] += y[i00] * y[i00]; |
|
|
sum[tpitg] += y[i00] * y[i00]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// reduce |
|
|
// reduce |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
for (uint i = ntg/2; i > 0; i /= 2) { |
|
|
for (uint i = ntg/2; i > 0; i /= 2) { |
|
|
@@ -241,12 +239,7 @@ kernel void kernel_norm( |
|
|
} |
|
|
} |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
} |
|
|
} |
|
|
// broadcast |
|
|
|
|
|
if (tpitg == 0) { |
|
|
|
|
|
sum[0] /= ne00; |
|
|
|
|
|
} |
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
const float variance = sum[0]; |
|
|
|
|
|
|
|
|
const float variance = sum[0] / ne00; |
|
|
|
|
|
|
|
|
const float scale = 1.0f/sqrt(variance + eps); |
|
|
const float scale = 1.0f/sqrt(variance + eps); |
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |
|
|
@@ -254,7 +247,6 @@ kernel void kernel_norm( |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kernel void kernel_rms_norm( |
|
|
kernel void kernel_rms_norm( |
|
|
device const void * src0, |
|
|
device const void * src0, |
|
|
device float * dst, |
|
|
device float * dst, |
|
|
@@ -435,6 +427,8 @@ kernel void kernel_mul_mat_q4_1_f32( |
|
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); |
|
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define NB_Q8_0 8 |
|
|
|
|
|
|
|
|
kernel void kernel_mul_mat_q8_0_f32( |
|
|
kernel void kernel_mul_mat_q8_0_f32( |
|
|
device const void * src0, |
|
|
device const void * src0, |
|
|
device const float * src1, |
|
|
device const float * src1, |
|
|
@@ -463,30 +457,30 @@ kernel void kernel_mul_mat_q8_0_f32( |
|
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; |
|
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; |
|
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; |
|
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; |
|
|
|
|
|
|
|
|
float yl[16]; |
|
|
|
|
|
|
|
|
float yl[NB_Q8_0]; |
|
|
float sumf[nr]={0.f}; |
|
|
float sumf[nr]={0.f}; |
|
|
|
|
|
|
|
|
const int ix = tiisg/2; |
|
|
|
|
|
const int il = tiisg%2; |
|
|
|
|
|
|
|
|
const int ix = tiisg/4; |
|
|
|
|
|
const int il = tiisg%4; |
|
|
|
|
|
|
|
|
device const float * yb = y + ix * QK8_0 + 16*il; |
|
|
|
|
|
|
|
|
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; |
|
|
|
|
|
|
|
|
// each thread in a SIMD group deals with half a block. |
|
|
|
|
|
for (int ib = ix; ib < nb; ib += nw/2) { |
|
|
|
|
|
for (int i = 0; i < 16; ++i) { |
|
|
|
|
|
|
|
|
// each thread in a SIMD group deals with NB_Q8_0 quants at a time |
|
|
|
|
|
for (int ib = ix; ib < nb; ib += nw/4) { |
|
|
|
|
|
for (int i = 0; i < NB_Q8_0; ++i) { |
|
|
yl[i] = yb[i]; |
|
|
yl[i] = yb[i]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (int row = 0; row < nr; row++) { |
|
|
for (int row = 0; row < nr; row++) { |
|
|
device const int8_t * qs = x[ib+row*nb].qs + 16*il; |
|
|
|
|
|
|
|
|
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; |
|
|
float sumq = 0.f; |
|
|
float sumq = 0.f; |
|
|
for (int iq = 0; iq < 16; ++iq) { |
|
|
|
|
|
|
|
|
for (int iq = 0; iq < NB_Q8_0; ++iq) { |
|
|
sumq += qs[iq] * yl[iq]; |
|
|
sumq += qs[iq] * yl[iq]; |
|
|
} |
|
|
} |
|
|
sumf[row] += sumq*x[ib+row*nb].d; |
|
|
sumf[row] += sumq*x[ib+row*nb].d; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
yb += QK8_0 * 16; |
|
|
|
|
|
|
|
|
yb += NB_Q8_0 * nw; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (int row = 0; row < nr; ++row) { |
|
|
for (int row = 0; row < nr; ++row) { |
|
|
@@ -497,7 +491,7 @@ kernel void kernel_mul_mat_q8_0_f32( |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
kernel void kernel_mul_mat_f16_f32( |
|
|
|
|
|
|
|
|
kernel void kernel_mul_mat_f16_f32_1row( |
|
|
device const char * src0, |
|
|
device const char * src0, |
|
|
device const char * src1, |
|
|
device const char * src1, |
|
|
device float * dst, |
|
|
device float * dst, |
|
|
@@ -515,11 +509,8 @@ kernel void kernel_mul_mat_f16_f32( |
|
|
constant uint64_t & nb12, |
|
|
constant uint64_t & nb12, |
|
|
constant int64_t & ne0, |
|
|
constant int64_t & ne0, |
|
|
constant int64_t & ne1, |
|
|
constant int64_t & ne1, |
|
|
threadgroup float * sum [[threadgroup(0)]], |
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]], |
|
|
uint3 tgpig[[threadgroup_position_in_grid]], |
|
|
uint3 tpig[[thread_position_in_grid]], |
|
|
|
|
|
uint3 tpitg[[thread_position_in_threadgroup]], |
|
|
|
|
|
uint3 tptg[[threads_per_threadgroup]]) { |
|
|
|
|
|
|
|
|
uint tiisg[[thread_index_in_simdgroup]]) { |
|
|
|
|
|
|
|
|
const int64_t r0 = tgpig.x; |
|
|
const int64_t r0 = tgpig.x; |
|
|
const int64_t r1 = tgpig.y; |
|
|
const int64_t r1 = tgpig.y; |
|
|
@@ -528,23 +519,100 @@ kernel void kernel_mul_mat_f16_f32( |
|
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); |
|
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); |
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |
|
|
|
|
|
|
|
|
sum[tpitg.x] = 0.0f; |
|
|
|
|
|
|
|
|
|
|
|
for (int i = tpitg.x; i < ne00; i += tptg.x) { |
|
|
|
|
|
sum[tpitg.x] += (float) x[i] * (float) y[i]; |
|
|
|
|
|
|
|
|
float sumf = 0; |
|
|
|
|
|
if (ne00 < 128) { |
|
|
|
|
|
for (int i = tiisg; i < ne00; i += 32) { |
|
|
|
|
|
sumf += (float) x[i] * (float) y[i]; |
|
|
|
|
|
} |
|
|
|
|
|
float all_sum = simd_sum(sumf); |
|
|
|
|
|
if (tiisg == 0) { |
|
|
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
device const half4 * x4 = (device const half4 *) x; |
|
|
|
|
|
device const float4 * y4 = (device const float4 *) y; |
|
|
|
|
|
for (int i = tiisg; i < ne00/4; i += 32) { |
|
|
|
|
|
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; |
|
|
|
|
|
} |
|
|
|
|
|
float all_sum = simd_sum(sumf); |
|
|
|
|
|
if (tiisg == 0) { |
|
|
|
|
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; |
|
|
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// accumulate the sum from all threads in the threadgroup |
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
for (uint i = tptg.x/2; i > 0; i /= 2) { |
|
|
|
|
|
if (tpitg.x < i) { |
|
|
|
|
|
sum[tpitg.x] += sum[tpitg.x + i]; |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#define N_F16_F32 4 |
|
|
|
|
|
|
|
|
|
|
|
kernel void kernel_mul_mat_f16_f32( |
|
|
|
|
|
device const char * src0, |
|
|
|
|
|
device const char * src1, |
|
|
|
|
|
device float * dst, |
|
|
|
|
|
constant int64_t & ne00, |
|
|
|
|
|
constant int64_t & ne01, |
|
|
|
|
|
constant int64_t & ne02, |
|
|
|
|
|
constant uint64_t & nb00, |
|
|
|
|
|
constant uint64_t & nb01, |
|
|
|
|
|
constant uint64_t & nb02, |
|
|
|
|
|
constant int64_t & ne10, |
|
|
|
|
|
constant int64_t & ne11, |
|
|
|
|
|
constant int64_t & ne12, |
|
|
|
|
|
constant uint64_t & nb10, |
|
|
|
|
|
constant uint64_t & nb11, |
|
|
|
|
|
constant uint64_t & nb12, |
|
|
|
|
|
constant int64_t & ne0, |
|
|
|
|
|
constant int64_t & ne1, |
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]], |
|
|
|
|
|
uint tiisg[[thread_index_in_simdgroup]]) { |
|
|
|
|
|
|
|
|
|
|
|
const int64_t r0 = tgpig.x; |
|
|
|
|
|
const int64_t rb = tgpig.y*N_F16_F32; |
|
|
|
|
|
const int64_t im = tgpig.z; |
|
|
|
|
|
|
|
|
|
|
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); |
|
|
|
|
|
|
|
|
|
|
|
if (ne00 < 128) { |
|
|
|
|
|
for (int row = 0; row < N_F16_F32; ++row) { |
|
|
|
|
|
int r1 = rb + row; |
|
|
|
|
|
if (r1 >= ne11) { |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |
|
|
|
|
|
|
|
|
|
|
|
float sumf = 0; |
|
|
|
|
|
for (int i = tiisg; i < ne00; i += 32) { |
|
|
|
|
|
sumf += (float) x[i] * (float) y[i]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
float all_sum = simd_sum(sumf); |
|
|
|
|
|
if (tiisg == 0) { |
|
|
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
device const half4 * x4 = (device const half4 *)x; |
|
|
|
|
|
for (int row = 0; row < N_F16_F32; ++row) { |
|
|
|
|
|
int r1 = rb + row; |
|
|
|
|
|
if (r1 >= ne11) { |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |
|
|
|
|
|
device const float4 * y4 = (device const float4 *) y; |
|
|
|
|
|
|
|
|
|
|
|
float sumf = 0; |
|
|
|
|
|
for (int i = tiisg; i < ne00/4; i += 32) { |
|
|
|
|
|
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
if (tpitg.x == 0) { |
|
|
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; |
|
|
|
|
|
|
|
|
float all_sum = simd_sum(sumf); |
|
|
|
|
|
if (tiisg == 0) { |
|
|
|
|
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; |
|
|
|
|
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -614,25 +682,27 @@ kernel void kernel_rope( |
|
|
constant int & mode, |
|
|
constant int & mode, |
|
|
constant float & freq_base, |
|
|
constant float & freq_base, |
|
|
constant float & freq_scale, |
|
|
constant float & freq_scale, |
|
|
uint3 tpig[[thread_position_in_grid]]) { |
|
|
|
|
|
const int64_t i3 = tpig[2]; |
|
|
|
|
|
const int64_t i2 = tpig[1]; |
|
|
|
|
|
const int64_t i1 = tpig[0]; |
|
|
|
|
|
|
|
|
uint tiitg[[thread_index_in_threadgroup]], |
|
|
|
|
|
uint3 tptg[[threads_per_threadgroup]], |
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]]) { |
|
|
|
|
|
const int64_t i3 = tgpig[2]; |
|
|
|
|
|
const int64_t i2 = tgpig[1]; |
|
|
|
|
|
const int64_t i1 = tgpig[0]; |
|
|
|
|
|
|
|
|
const bool is_neox = mode & 2; |
|
|
const bool is_neox = mode & 2; |
|
|
const float theta_scale = pow(freq_base, -2.0f/n_dims); |
|
|
|
|
|
|
|
|
|
|
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); |
|
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); |
|
|
|
|
|
|
|
|
float theta = freq_scale * (float)p; |
|
|
|
|
|
|
|
|
const float theta_0 = freq_scale * (float)p; |
|
|
|
|
|
const float inv_ndims = -1.f/n_dims; |
|
|
|
|
|
|
|
|
if (!is_neox) { |
|
|
if (!is_neox) { |
|
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { |
|
|
|
|
|
|
|
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { |
|
|
|
|
|
|
|
|
|
|
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0); |
|
|
const float cos_theta = cos(theta); |
|
|
const float cos_theta = cos(theta); |
|
|
const float sin_theta = sin(theta); |
|
|
const float sin_theta = sin(theta); |
|
|
|
|
|
|
|
|
theta *= theta_scale; |
|
|
|
|
|
|
|
|
|
|
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); |
|
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); |
|
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |
|
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |
|
|
|
|
|
|
|
|
@@ -644,12 +714,12 @@ kernel void kernel_rope( |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { |
|
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { |
|
|
for (int64_t ic = 0; ic < n_dims; ic += 2) { |
|
|
|
|
|
|
|
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { |
|
|
|
|
|
|
|
|
|
|
|
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); |
|
|
const float cos_theta = cos(theta); |
|
|
const float cos_theta = cos(theta); |
|
|
const float sin_theta = sin(theta); |
|
|
const float sin_theta = sin(theta); |
|
|
|
|
|
|
|
|
theta *= theta_scale; |
|
|
|
|
|
|
|
|
|
|
|
const int64_t i0 = ib*n_dims + ic/2; |
|
|
const int64_t i0 = ib*n_dims + ic/2; |
|
|
|
|
|
|
|
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); |
|
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); |
|
|
@@ -1053,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32( |
|
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; |
|
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; |
|
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; |
|
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; |
|
|
|
|
|
|
|
|
float yl[16]; |
|
|
|
|
|
|
|
|
float yl[32]; |
|
|
|
|
|
|
|
|
const uint16_t kmask1 = 0x0303; |
|
|
|
|
|
|
|
|
const uint16_t kmask1 = 0x3030; |
|
|
const uint16_t kmask2 = 0x0f0f; |
|
|
const uint16_t kmask2 = 0x0f0f; |
|
|
|
|
|
|
|
|
const int tid = tiisg/2; |
|
|
|
|
|
const int ix = tiisg%2; |
|
|
|
|
|
const int ip = tid/8; // 0 or 1 |
|
|
|
|
|
const int il = tid/2 - 4*ip; // 0...3 |
|
|
|
|
|
|
|
|
const int tid = tiisg/4; |
|
|
|
|
|
const int ix = tiisg%4; |
|
|
|
|
|
const int ip = tid/4; // 0 or 1 |
|
|
|
|
|
const int il = 2*((tid%4)/2); // 0 or 2 |
|
|
const int ir = tid%2; |
|
|
const int ir = tid%2; |
|
|
const int n = 8; |
|
|
const int n = 8; |
|
|
const int l0 = n*ir; |
|
|
const int l0 = n*ir; |
|
|
|
|
|
|
|
|
const uint16_t m1 = 1 << (4*ip + il); |
|
|
|
|
|
const uint16_t m2 = m1 << 8; |
|
|
|
|
|
|
|
|
// One would think that the Metal compiler would figure out that ip and il can only have |
|
|
|
|
|
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it |
|
|
|
|
|
// with these two tales. |
|
|
|
|
|
// |
|
|
|
|
|
// Possible masks for the high bit |
|
|
|
|
|
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 |
|
|
|
|
|
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 |
|
|
|
|
|
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 |
|
|
|
|
|
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 |
|
|
|
|
|
|
|
|
|
|
|
// Possible masks for the low 2 bits |
|
|
|
|
|
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; |
|
|
|
|
|
|
|
|
|
|
|
const ushort4 hm = mm[2*ip + il/2]; |
|
|
|
|
|
|
|
|
const int shift = 2*il; |
|
|
const int shift = 2*il; |
|
|
const uint16_t qm1 = 0x0003 << shift; |
|
|
|
|
|
const uint16_t qm2 = 0x0300 << shift; |
|
|
|
|
|
const int32_t v1 = 4 << shift; |
|
|
|
|
|
const int32_t v2 = 1024 << shift; |
|
|
|
|
|
|
|
|
const float v1 = il == 0 ? 4.f : 64.f; |
|
|
|
|
|
const float v2 = 4.f * v1; |
|
|
|
|
|
|
|
|
const uint16_t s_shift1 = 4*ip; |
|
|
const uint16_t s_shift1 = 4*ip; |
|
|
const uint16_t s_shift2 = s_shift1 + 2*(il/2); |
|
|
|
|
|
const int ik = 4 + (il%2); |
|
|
|
|
|
|
|
|
const uint16_t s_shift2 = s_shift1 + il; |
|
|
|
|
|
|
|
|
const int q_offset = 32*ip + l0; |
|
|
const int q_offset = 32*ip + l0; |
|
|
const int y_offset = 128*ip + 32*il + l0; |
|
|
const int y_offset = 128*ip + 32*il + l0; |
|
|
@@ -1086,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32( |
|
|
|
|
|
|
|
|
device const float * y1 = yy + ix*QK_K + y_offset; |
|
|
device const float * y1 = yy + ix*QK_K + y_offset; |
|
|
|
|
|
|
|
|
float sumf1[2] = {0.f}, sumf2[2] = {0.f}; |
|
|
|
|
|
for (int i = ix; i < nb; i += 2) { |
|
|
|
|
|
|
|
|
uint32_t scales32, aux32; |
|
|
|
|
|
thread uint16_t * scales16 = (thread uint16_t *)&scales32; |
|
|
|
|
|
thread const int8_t * scales = (thread const int8_t *)&scales32; |
|
|
|
|
|
|
|
|
|
|
|
float sumf1[2] = {0.f}; |
|
|
|
|
|
float sumf2[2] = {0.f}; |
|
|
|
|
|
for (int i = ix; i < nb; i += 4) { |
|
|
|
|
|
|
|
|
for (int l = 0; l < 8; ++l) { |
|
|
for (int l = 0; l < 8; ++l) { |
|
|
yl[l+0] = y1[l+ 0]; |
|
|
|
|
|
yl[l+8] = y1[l+16]; |
|
|
|
|
|
|
|
|
yl[l+ 0] = y1[l+ 0]; |
|
|
|
|
|
yl[l+ 8] = y1[l+16]; |
|
|
|
|
|
yl[l+16] = y1[l+32]; |
|
|
|
|
|
yl[l+24] = y1[l+48]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); |
|
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); |
|
|
@@ -1102,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32( |
|
|
for (int row = 0; row < 2; ++row) { |
|
|
for (int row = 0; row < 2; ++row) { |
|
|
|
|
|
|
|
|
const float d_all = (float)dh[0]; |
|
|
const float d_all = (float)dh[0]; |
|
|
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); |
|
|
|
|
|
|
|
|
|
|
|
float s1 = 0, s2 = 0; |
|
|
|
|
|
|
|
|
scales16[0] = a[4]; |
|
|
|
|
|
scales16[1] = a[5]; |
|
|
|
|
|
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; |
|
|
|
|
|
scales16[0] = a[il+0]; |
|
|
|
|
|
scales16[1] = a[il+1]; |
|
|
|
|
|
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; |
|
|
|
|
|
|
|
|
|
|
|
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; |
|
|
for (int l = 0; l < n; l += 2) { |
|
|
for (int l = 0; l < n; l += 2) { |
|
|
const uint16_t qs = q[l/2]; |
|
|
|
|
|
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); |
|
|
|
|
|
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); |
|
|
|
|
|
|
|
|
const int32_t qs = q[l/2]; |
|
|
|
|
|
s1 += yl[l+0] * (qs & qm[il/2][0]); |
|
|
|
|
|
s2 += yl[l+1] * (qs & qm[il/2][1]); |
|
|
|
|
|
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); |
|
|
|
|
|
s4 += yl[l+16] * (qs & qm[il/2][2]); |
|
|
|
|
|
s5 += yl[l+17] * (qs & qm[il/2][3]); |
|
|
|
|
|
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); |
|
|
} |
|
|
} |
|
|
float d = d_all * (s1 + 1.f/256.f * s2); |
|
|
|
|
|
sumf1[row] += d * scales[0]; |
|
|
|
|
|
sumf2[row] += d; |
|
|
|
|
|
|
|
|
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); |
|
|
|
|
|
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); |
|
|
|
|
|
sumf1[row] += d1 * (scales[0] - 32); |
|
|
|
|
|
sumf2[row] += d2 * (scales[2] - 32); |
|
|
|
|
|
|
|
|
s1 = s2 = 0; |
|
|
|
|
|
|
|
|
s1 = s2 = s3 = s4 = s5 = s6 = 0; |
|
|
for (int l = 0; l < n; l += 2) { |
|
|
for (int l = 0; l < n; l += 2) { |
|
|
const uint16_t qs = q[l/2+8]; |
|
|
|
|
|
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); |
|
|
|
|
|
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); |
|
|
|
|
|
|
|
|
const int32_t qs = q[l/2+8]; |
|
|
|
|
|
s1 += yl[l+8] * (qs & qm[il/2][0]); |
|
|
|
|
|
s2 += yl[l+9] * (qs & qm[il/2][1]); |
|
|
|
|
|
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); |
|
|
|
|
|
s4 += yl[l+24] * (qs & qm[il/2][2]); |
|
|
|
|
|
s5 += yl[l+25] * (qs & qm[il/2][3]); |
|
|
|
|
|
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); |
|
|
} |
|
|
} |
|
|
d = d_all * (s1 + 1.f/256.f * s2); |
|
|
|
|
|
sumf1[row] += d * scales[1]; |
|
|
|
|
|
sumf2[row] += d; |
|
|
|
|
|
|
|
|
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); |
|
|
|
|
|
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); |
|
|
|
|
|
sumf1[row] += d1 * (scales[1] - 32); |
|
|
|
|
|
sumf2[row] += d2 * (scales[3] - 32); |
|
|
|
|
|
|
|
|
q += step; |
|
|
q += step; |
|
|
h += step; |
|
|
h += step; |
|
|
@@ -1131,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32( |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
y1 += 2 * QK_K; |
|
|
|
|
|
|
|
|
y1 += 4 * QK_K; |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (int row = 0; row < 2; ++row) { |
|
|
for (int row = 0; row < 2; ++row) { |
|
|
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); |
|
|
|
|
|
const float tot = simd_sum(sumf); |
|
|
|
|
|
if (tiisg == 0) { |
|
|
|
|
|
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; |
|
|
|
|
|
|
|
|
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); |
|
|
|
|
|
sumf1[row] = simd_sum(sumf); |
|
|
|
|
|
} |
|
|
|
|
|
if (tiisg == 0) { |
|
|
|
|
|
for (int row = 0; row < 2; ++row) { |
|
|
|
|
|
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
#else |
|
|
#else |
|
|
kernel void kernel_mul_mat_q3_K_f32( |
|
|
kernel void kernel_mul_mat_q3_K_f32( |
|
|
@@ -1244,7 +1349,8 @@ kernel void kernel_mul_mat_q4_K_f32( |
|
|
const int r0 = tgpig.x; |
|
|
const int r0 = tgpig.x; |
|
|
const int r1 = tgpig.y; |
|
|
const int r1 = tgpig.y; |
|
|
const int r2 = tgpig.z; |
|
|
const int r2 = tgpig.z; |
|
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; |
|
|
|
|
|
|
|
|
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; |
|
|
|
|
|
const int first_row = r0 * N_DST; |
|
|
const int ib_row = first_row * nb; |
|
|
const int ib_row = first_row * nb; |
|
|
const uint offset0 = r2/gqa*(nb*ne0); |
|
|
const uint offset0 = r2/gqa*(nb*ne0); |
|
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; |
|
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; |
|
|
@@ -1493,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32( |
|
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); |
|
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); |
|
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); |
|
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); |
|
|
|
|
|
|
|
|
float4 acc = {0.f, 0.f, 0.f, 0.f}; |
|
|
|
|
|
|
|
|
float4 acc1 = {0.f}; |
|
|
|
|
|
float4 acc2 = {0.f}; |
|
|
for (int l = 0; l < n; ++l) { |
|
|
for (int l = 0; l < n; ++l) { |
|
|
uint8_t h = qh[l]; |
|
|
uint8_t h = qh[l]; |
|
|
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); |
|
|
|
|
|
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); |
|
|
|
|
|
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); |
|
|
|
|
|
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); |
|
|
|
|
|
|
|
|
acc1[0] += yl[l+0] * (q1[l] & 0x0F); |
|
|
|
|
|
acc1[1] += yl[l+8] * (q1[l] & 0xF0); |
|
|
|
|
|
acc1[2] += yh[l+0] * (q2[l] & 0x0F); |
|
|
|
|
|
acc1[3] += yh[l+8] * (q2[l] & 0xF0); |
|
|
|
|
|
acc2[0] += h & hm1 ? yl[l+0] : 0.f; |
|
|
|
|
|
acc2[1] += h & hm2 ? yl[l+8] : 0.f; |
|
|
|
|
|
acc2[2] += h & hm3 ? yh[l+0] : 0.f; |
|
|
|
|
|
acc2[3] += h & hm4 ? yh[l+8] : 0.f; |
|
|
} |
|
|
} |
|
|
const float dall = dh[0]; |
|
|
const float dall = dh[0]; |
|
|
const float dmin = dh[1]; |
|
|
const float dmin = dh[1]; |
|
|
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - |
|
|
|
|
|
|
|
|
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + |
|
|
|
|
|
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + |
|
|
|
|
|
sc8[4] * (acc1[2] + 16.f*acc2[2]) + |
|
|
|
|
|
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - |
|
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); |
|
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); |
|
|
|
|
|
|
|
|
q1 += step; |
|
|
q1 += step; |
|
|
|