| @@ -47,14 +47,24 @@ public sealed class BeamTests | |||
| for (var i = 0; i < state.Beams.Length; i++) | |||
| { | |||
| ref var view = ref state.Beams[i]; | |||
| var tokens = context.DeTokenize(view.Tokens.ToArray()); | |||
| var decoder = new StreamingTokenDecoder(context); | |||
| decoder.AddRange(view.Tokens); | |||
| var tokens = decoder.Read(); | |||
| _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | |||
| } | |||
| if (state.CommonPrefixLength > 0) | |||
| { | |||
| var view = state.Beams[0]; | |||
| result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); | |||
| var decoder = new StreamingTokenDecoder(context); | |||
| decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); | |||
| var tokens = decoder.Read(); | |||
| result.Append(tokens); | |||
| } | |||
| }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | |||
| @@ -36,12 +36,12 @@ namespace LLama.Abstractions | |||
| /// </summary> | |||
| public int TopK { get; set; } | |||
| /// <summary>llama_eval | |||
| /// <summary> | |||
| /// 1.0 = disabled | |||
| /// </summary> | |||
| public float TopP { get; set; } | |||
| /// <summary>llama_eval | |||
| /// <summary> | |||
| /// 0.0 = disabled | |||
| /// </summary> | |||
| public float MinP { get; set; } | |||
| @@ -55,6 +55,9 @@ public sealed class BatchedExecutor | |||
| Epoch = 1; | |||
| } | |||
| /// <summary> | |||
| /// Finalizer for BatchedExecutor | |||
| /// </summary> | |||
| ~BatchedExecutor() | |||
| { | |||
| Dispose(); | |||
| @@ -54,6 +54,9 @@ public sealed class Conversation | |||
| _end = end; | |||
| } | |||
| /// <summary> | |||
| /// Finalizer for Conversation | |||
| /// </summary> | |||
| ~Conversation() | |||
| { | |||
| Dispose(); | |||
| @@ -96,7 +99,7 @@ public sealed class Conversation | |||
| AssertNotDisposed(); | |||
| if (RequiresInference) | |||
| throw new CannotForkWhileRequiresInference(); | |||
| throw new CannotForkWhileRequiresInferenceException(); | |||
| // Create a new conversation which references the current position in this one | |||
| var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) | |||
| @@ -195,13 +198,13 @@ public sealed class Conversation | |||
| /// Directly modify the KV cache of this conversation | |||
| /// </summary> | |||
| /// <param name="modifier"></param> | |||
| /// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> | |||
| /// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> | |||
| public void Modify(ModifyKvCache modifier) | |||
| { | |||
| AssertNotDisposed(); | |||
| if (RequiresInference) | |||
| throw new CannotModifyWhileRequiresInference(); | |||
| throw new CannotModifyWhileRequiresInferenceException(); | |||
| // do whatever the modification is | |||
| _end = modifier.Invoke(_end, new KvAccessor(this)); | |||
| @@ -59,10 +59,10 @@ public class CannotSampleRequiresPromptException | |||
| /// <summary> | |||
| /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||
| /// </summary> | |||
| public class CannotForkWhileRequiresInference | |||
| public class CannotForkWhileRequiresInferenceException | |||
| : ExperimentalBatchedExecutorException | |||
| { | |||
| internal CannotForkWhileRequiresInference() | |||
| internal CannotForkWhileRequiresInferenceException() | |||
| : base("Cannot `Fork()` a conversation while RequiresInference is true") | |||
| { | |||
| } | |||
| @@ -71,10 +71,10 @@ public class CannotForkWhileRequiresInference | |||
| /// <summary> | |||
| /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | |||
| /// </summary> | |||
| public class CannotModifyWhileRequiresInference | |||
| public class CannotModifyWhileRequiresInferenceException | |||
| : ExperimentalBatchedExecutorException | |||
| { | |||
| internal CannotModifyWhileRequiresInference() | |||
| internal CannotModifyWhileRequiresInferenceException() | |||
| : base("Cannot `Modify()` a conversation while RequiresInference is true") | |||
| { | |||
| } | |||
| @@ -155,7 +155,7 @@ namespace LLama | |||
| } | |||
| /// <inheritdoc /> | |||
| protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||
| { | |||
| if (_embeds.Count > 0) | |||
| { | |||
| @@ -238,6 +238,8 @@ namespace LLama | |||
| } | |||
| } | |||
| } | |||
| return Task.CompletedTask; | |||
| } | |||
| /// <summary> | |||
| @@ -1,12 +1,34 @@ | |||
| namespace LLama.Native; | |||
| /// <summary> | |||
| /// Token Types | |||
| /// </summary> | |||
| /// <remarks>C# equivalent of llama_token_get_type</remarks> | |||
| public enum LLamaTokenType | |||
| { | |||
| /// <summary> | |||
| /// No specific type has been set for this token | |||
| /// </summary> | |||
| LLAMA_TOKEN_TYPE_UNDEFINED = 0, | |||
| /// <summary> | |||
| /// This is a "normal" token | |||
| /// </summary> | |||
| LLAMA_TOKEN_TYPE_NORMAL = 1, | |||
| /// <summary> | |||
| /// An "unknown" character/text token e.g. <unk> | |||
| /// </summary> | |||
| LLAMA_TOKEN_TYPE_UNKNOWN = 2, | |||
| /// <summary> | |||
| /// A special control token e.g. </s> | |||
| /// </summary> | |||
| LLAMA_TOKEN_TYPE_CONTROL = 3, | |||
| LLAMA_TOKEN_TYPE_USER_DEFINED = 4, | |||
| LLAMA_TOKEN_TYPE_UNUSED = 5, | |||
| LLAMA_TOKEN_TYPE_BYTE = 6, | |||
| } | |||
| @@ -172,6 +172,11 @@ namespace LLama.Native | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx); | |||
| /// <summary> | |||
| /// Get the batch size for this context | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <returns></returns> | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern uint llama_n_batch(SafeLLamaContextHandle ctx); | |||
| @@ -1,17 +1,30 @@ | |||
| namespace LLama.Native | |||
| { | |||
| /// <summary> | |||
| /// RoPE scaling type. C# equivalent of llama_rope_scaling_type | |||
| /// RoPE scaling type. | |||
| /// </summary> | |||
| /// <remarks>C# equivalent of llama_rope_scaling_type</remarks> | |||
| public enum RopeScalingType | |||
| : sbyte | |||
| { | |||
| /// <summary> | |||
| /// No particular scaling type has been specified | |||
| /// </summary> | |||
| LLAMA_ROPE_SCALING_UNSPECIFIED = -1, | |||
| /// <summary> | |||
| /// Do not apply any RoPE scaling | |||
| /// </summary> | |||
| LLAMA_ROPE_SCALING_NONE = 0, | |||
| /// <summary> | |||
| /// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k | |||
| /// </summary> | |||
| LLAMA_ROPE_SCALING_LINEAR = 1, | |||
| /// <summary> | |||
| /// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf | |||
| /// </summary> | |||
| LLAMA_ROPE_SCALING_YARN = 2, | |||
| } | |||
| } | |||
| @@ -142,20 +142,21 @@ namespace LLama | |||
| /// Add all tokens in the given enumerable | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| public void AddRange(IEnumerable<int> tokens) | |||
| public void AddRange<T>(T tokens) | |||
| where T : IEnumerable<LLamaToken> | |||
| { | |||
| foreach (var item in tokens) | |||
| Add(item); | |||
| Add((int)item); | |||
| } | |||
| /// <summary> | |||
| /// Add all tokens in the given enumerable | |||
| /// Add all tokens in the given span | |||
| /// </summary> | |||
| /// <param name="tokens"></param> | |||
| public void AddRange(IEnumerable<LLamaToken> tokens) | |||
| public void AddRange(ReadOnlySpan<LLamaToken> tokens) | |||
| { | |||
| foreach (var item in tokens) | |||
| Add((int)item); | |||
| Add(item); | |||
| } | |||
| /// <summary> | |||