Browse Source

Align with llama.cpp b1488

tags/v0.8.0
SignalRT 2 years ago
parent
commit
46fb472d42
4 changed files with 204 additions and 41 deletions
  1. +31
    -7
      LLama/Native/LLamaContextParams.cs
  2. +11
    -0
      LLama/Native/NativeApi.Sampling.cs
  3. +162
    -34
      LLama/runtimes/ggml-metal.metal
  4. BIN
      LLama/runtimes/libllama.dylib

+ 31
- 7
LLama/Native/LLamaContextParams.cs View File

@@ -42,17 +42,41 @@ namespace LLama.Native
public uint n_threads_batch;

/// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
/// RoPE base frequency
/// RoPE scaling type, from `enum llama_rope_scaling_type`
/// </summary>
public float rope_freq_base;
public sbyte rope_scaling_type;

/// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
/// RoPE frequency scaling factor
/// RoPE base frequency, 0 = from model
/// </summary>
public float rope_freq_scale;

public float rope_freq_base;
/// <summary>
/// RoPE frequency scaling factor, 0 = from model
/// </summary>
public float rope_freq_scale;
/// <summary>
/// YaRN extrapolation mix factor, NaN = from model
/// </summary>
public float yarn_ext_factor;
/// <summary>
/// YaRN magnitude scaling factor
/// </summary>
public float yarn_attn_factor;
/// <summary>
/// YaRN low correction dim
/// </summary>
public float yarn_beta_fast;
/// <summary>
/// YaRN high correction dim
/// </summary>
public float yarn_beta_slow;
/// <summary>
/// YaRN original context size
/// </summary>
public uint yarn_orig_ctx;
/// <summary>
/// if true, use experimental mul_mat_q kernels
/// </summary>


+ 11
- 0
LLama/Native/NativeApi.Sampling.cs View File

@@ -64,6 +64,17 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);

/// <summary>
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="p"></param>
/// <param name="min_keep"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_min_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>


+ 162
- 34
LLama/runtimes/ggml-metal.metal View File

@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
}
const float max = simd_max(lmax);

float max = simd_max(lmax);
if (tiisg == 0) {
buf[sgitg] = max;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

max = buf[0];

// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
// wish to compute it twice.
pdst[i00] = exp_psrc0;
}

const float sum = simd_sum(lsum);
float sum = simd_sum(lsum);
if (tiisg == 0) {
buf[sgitg] = sum;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

sum = buf[0];

for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] /= sum;
}
}
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
}
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));

const float max = simd_max(lmax);
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max = simd_max(lmax);
if (tiisg == 0) {
buf[sgitg] = max;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

max = buf[0];

// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];

const float sum = simd_sum(lsum);
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
float sum = simd_sum(lsum);
if (tiisg == 0) {
buf[sgitg] = sum;
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

sum = buf[0];

for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] /= sum;
}
}
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
}
}
}

kernel void kernel_diag_mask_inf_8(
@@ -988,6 +1061,45 @@ kernel void kernel_alibi_f32(
}
}

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;

// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
*cos_theta = cos(theta) * mscale;
*sin_theta = sin(theta) * mscale;
}

// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}

static void rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
}

typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
@@ -1011,8 +1123,13 @@ typedef void (rope_t)(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,
constant float & attn_factor,
constant float & beta_fast,
constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1041,8 +1158,13 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,
constant float & attn_factor,
constant float & beta_fast,
constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1052,19 +1174,22 @@ kernel void kernel_rope(

const bool is_neox = mode & 2;

float corr_dims[2];
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);

device const int32_t * pos = src1;

const int64_t p = pos[i2];

const float theta_0 = freq_scale * (float)p;
const float theta_0 = (float)p;
const float inv_ndims = -1.f/n_dims;

if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {

const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -1079,9 +1204,12 @@ kernel void kernel_rope(
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {

const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;

const float theta = theta_0 * pow(freq_base, cur_rot);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);

const int64_t i0 = ib*n_dims + ic/2;



BIN
LLama/runtimes/libllama.dylib View File


Loading…
Cancel
Save