|
|
|
@@ -184,36 +184,73 @@ kernel void kernel_soft_max( |
|
|
|
constant int64_t & ne00, |
|
|
|
constant int64_t & ne01, |
|
|
|
constant int64_t & ne02, |
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]], |
|
|
|
uint3 tpitg[[thread_position_in_threadgroup]], |
|
|
|
uint3 ntg[[threads_per_threadgroup]]) { |
|
|
|
const int64_t i03 = tgpig[2]; |
|
|
|
const int64_t i02 = tgpig[1]; |
|
|
|
const int64_t i01 = tgpig[0]; |
|
|
|
threadgroup float * buf [[threadgroup(0)]], |
|
|
|
uint tgpig[[threadgroup_position_in_grid]], |
|
|
|
uint tpitg[[thread_position_in_threadgroup]], |
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]], |
|
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
|
uint ntg[[threads_per_threadgroup]]) { |
|
|
|
const int64_t i03 = (tgpig) / (ne02*ne01); |
|
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; |
|
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); |
|
|
|
|
|
|
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; |
|
|
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; |
|
|
|
|
|
|
|
// parallel max |
|
|
|
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; |
|
|
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { |
|
|
|
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY; |
|
|
|
|
|
|
|
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) { |
|
|
|
lmax = MAX(lmax, psrc0[i00]); |
|
|
|
} |
|
|
|
const float max = simd_max(lmax); |
|
|
|
|
|
|
|
float max = simd_max(lmax); |
|
|
|
if (tiisg == 0) { |
|
|
|
buf[sgitg] = max; |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
// broadcast, simd group number is ntg / 32 |
|
|
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) { |
|
|
|
if (tpitg < i) { |
|
|
|
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
max = buf[0]; |
|
|
|
|
|
|
|
// parallel sum |
|
|
|
float lsum = 0.0f; |
|
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |
|
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |
|
|
|
const float exp_psrc0 = exp(psrc0[i00] - max); |
|
|
|
lsum += exp_psrc0; |
|
|
|
// Remember the result of exp here. exp is expensive, so we really do not |
|
|
|
// whish to compute it twice. |
|
|
|
// wish to compute it twice. |
|
|
|
pdst[i00] = exp_psrc0; |
|
|
|
} |
|
|
|
|
|
|
|
const float sum = simd_sum(lsum); |
|
|
|
float sum = simd_sum(lsum); |
|
|
|
if (tiisg == 0) { |
|
|
|
buf[sgitg] = sum; |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |
|
|
|
// broadcast, simd group number is ntg / 32 |
|
|
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) { |
|
|
|
if (tpitg < i) { |
|
|
|
buf[tpitg] += buf[tpitg + i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
sum = buf[0]; |
|
|
|
|
|
|
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |
|
|
|
pdst[i00] /= sum; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4( |
|
|
|
constant int64_t & ne00, |
|
|
|
constant int64_t & ne01, |
|
|
|
constant int64_t & ne02, |
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]], |
|
|
|
uint3 tpitg[[thread_position_in_threadgroup]], |
|
|
|
uint3 ntg[[threads_per_threadgroup]]) { |
|
|
|
const int64_t i03 = tgpig[2]; |
|
|
|
const int64_t i02 = tgpig[1]; |
|
|
|
const int64_t i01 = tgpig[0]; |
|
|
|
threadgroup float * buf [[threadgroup(0)]], |
|
|
|
uint tgpig[[threadgroup_position_in_grid]], |
|
|
|
uint tpitg[[thread_position_in_threadgroup]], |
|
|
|
uint sgitg[[simdgroup_index_in_threadgroup]], |
|
|
|
uint tiisg[[thread_index_in_simdgroup]], |
|
|
|
uint ntg[[threads_per_threadgroup]]) { |
|
|
|
const int64_t i03 = (tgpig) / (ne02*ne01); |
|
|
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; |
|
|
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); |
|
|
|
|
|
|
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); |
|
|
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); |
|
|
|
|
|
|
|
// parallel max |
|
|
|
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; |
|
|
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { |
|
|
|
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY; |
|
|
|
|
|
|
|
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) { |
|
|
|
lmax4 = fmax(lmax4, psrc4[i00]); |
|
|
|
} |
|
|
|
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); |
|
|
|
|
|
|
|
const float max = simd_max(lmax); |
|
|
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); |
|
|
|
float max = simd_max(lmax); |
|
|
|
if (tiisg == 0) { |
|
|
|
buf[sgitg] = max; |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
// broadcast, simd group number is ntg / 32 |
|
|
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) { |
|
|
|
if (tpitg < i) { |
|
|
|
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
max = buf[0]; |
|
|
|
|
|
|
|
// parallel sum |
|
|
|
float4 lsum4 = 0.0f; |
|
|
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { |
|
|
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { |
|
|
|
const float4 exp_psrc4 = exp(psrc4[i00] - max); |
|
|
|
lsum4 += exp_psrc4; |
|
|
|
pdst4[i00] = exp_psrc4; |
|
|
|
} |
|
|
|
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; |
|
|
|
|
|
|
|
const float sum = simd_sum(lsum); |
|
|
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; |
|
|
|
float sum = simd_sum(lsum); |
|
|
|
if (tiisg == 0) { |
|
|
|
buf[sgitg] = sum; |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
// broadcast, simd group number is ntg / 32 |
|
|
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) { |
|
|
|
if (tpitg < i) { |
|
|
|
buf[tpitg] += buf[tpitg + i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
|
|
|
|
sum = buf[0]; |
|
|
|
|
|
|
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { |
|
|
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { |
|
|
|
pdst4[i00] /= sum; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf( |
|
|
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; |
|
|
|
} else { |
|
|
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
kernel void kernel_diag_mask_inf_8( |
|
|
|
@@ -988,6 +1061,45 @@ kernel void kernel_alibi_f32( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static float rope_yarn_ramp(const float low, const float high, const int i0) { |
|
|
|
const float y = (i0 / 2 - low) / max(0.001f, high - low); |
|
|
|
return 1.0f - min(1.0f, max(0.0f, y)); |
|
|
|
} |
|
|
|
|
|
|
|
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn |
|
|
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. |
|
|
|
static void rope_yarn( |
|
|
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, |
|
|
|
thread float * cos_theta, thread float * sin_theta |
|
|
|
) { |
|
|
|
// Get n-d rotational scaling corrected for extrapolation |
|
|
|
float theta_interp = freq_scale * theta_extrap; |
|
|
|
float theta = theta_interp; |
|
|
|
if (ext_factor != 0.0f) { |
|
|
|
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; |
|
|
|
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; |
|
|
|
|
|
|
|
// Get n-d magnitude scaling corrected for interpolation |
|
|
|
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); |
|
|
|
} |
|
|
|
*cos_theta = cos(theta) * mscale; |
|
|
|
*sin_theta = sin(theta) * mscale; |
|
|
|
} |
|
|
|
|
|
|
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get |
|
|
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` |
|
|
|
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { |
|
|
|
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); |
|
|
|
} |
|
|
|
|
|
|
|
static void rope_yarn_corr_dims( |
|
|
|
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] |
|
|
|
) { |
|
|
|
// start and end correction dims |
|
|
|
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); |
|
|
|
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); |
|
|
|
} |
|
|
|
|
|
|
|
typedef void (rope_t)( |
|
|
|
device const void * src0, |
|
|
|
device const int32_t * src1, |
|
|
|
@@ -1011,8 +1123,13 @@ typedef void (rope_t)( |
|
|
|
constant int & n_past, |
|
|
|
constant int & n_dims, |
|
|
|
constant int & mode, |
|
|
|
constant int & n_orig_ctx, |
|
|
|
constant float & freq_base, |
|
|
|
constant float & freq_scale, |
|
|
|
constant float & ext_factor, |
|
|
|
constant float & attn_factor, |
|
|
|
constant float & beta_fast, |
|
|
|
constant float & beta_slow, |
|
|
|
uint tiitg[[thread_index_in_threadgroup]], |
|
|
|
uint3 tptg[[threads_per_threadgroup]], |
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]]); |
|
|
|
@@ -1041,8 +1158,13 @@ kernel void kernel_rope( |
|
|
|
constant int & n_past, |
|
|
|
constant int & n_dims, |
|
|
|
constant int & mode, |
|
|
|
constant int & n_orig_ctx, |
|
|
|
constant float & freq_base, |
|
|
|
constant float & freq_scale, |
|
|
|
constant float & ext_factor, |
|
|
|
constant float & attn_factor, |
|
|
|
constant float & beta_fast, |
|
|
|
constant float & beta_slow, |
|
|
|
uint tiitg[[thread_index_in_threadgroup]], |
|
|
|
uint3 tptg[[threads_per_threadgroup]], |
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]]) { |
|
|
|
@@ -1052,19 +1174,22 @@ kernel void kernel_rope( |
|
|
|
|
|
|
|
const bool is_neox = mode & 2; |
|
|
|
|
|
|
|
float corr_dims[2]; |
|
|
|
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); |
|
|
|
|
|
|
|
device const int32_t * pos = src1; |
|
|
|
|
|
|
|
const int64_t p = pos[i2]; |
|
|
|
|
|
|
|
const float theta_0 = freq_scale * (float)p; |
|
|
|
const float theta_0 = (float)p; |
|
|
|
const float inv_ndims = -1.f/n_dims; |
|
|
|
|
|
|
|
if (!is_neox) { |
|
|
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { |
|
|
|
|
|
|
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0); |
|
|
|
const float cos_theta = cos(theta); |
|
|
|
const float sin_theta = sin(theta); |
|
|
|
float cos_theta, sin_theta; |
|
|
|
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); |
|
|
|
|
|
|
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); |
|
|
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |
|
|
|
@@ -1079,9 +1204,12 @@ kernel void kernel_rope( |
|
|
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { |
|
|
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { |
|
|
|
|
|
|
|
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); |
|
|
|
const float cos_theta = cos(theta); |
|
|
|
const float sin_theta = sin(theta); |
|
|
|
// simplified from `(ib * n_dims + ic) * inv_ndims` |
|
|
|
const float cur_rot = inv_ndims*ic - ib; |
|
|
|
|
|
|
|
const float theta = theta_0 * pow(freq_base, cur_rot); |
|
|
|
float cos_theta, sin_theta; |
|
|
|
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); |
|
|
|
|
|
|
|
const int64_t i0 = ib*n_dims + ic/2; |
|
|
|
|
|
|
|
|