| @@ -47,14 +47,24 @@ public sealed class BeamTests | |||||
| for (var i = 0; i < state.Beams.Length; i++) | for (var i = 0; i < state.Beams.Length; i++) | ||||
| { | { | ||||
| ref var view = ref state.Beams[i]; | ref var view = ref state.Beams[i]; | ||||
| var tokens = context.DeTokenize(view.Tokens.ToArray()); | |||||
| var decoder = new StreamingTokenDecoder(context); | |||||
| decoder.AddRange(view.Tokens); | |||||
| var tokens = decoder.Read(); | |||||
| _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); | ||||
| } | } | ||||
| if (state.CommonPrefixLength > 0) | if (state.CommonPrefixLength > 0) | ||||
| { | { | ||||
| var view = state.Beams[0]; | var view = state.Beams[0]; | ||||
| result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray())); | |||||
| var decoder = new StreamingTokenDecoder(context); | |||||
| decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); | |||||
| var tokens = decoder.Read(); | |||||
| result.Append(tokens); | |||||
| } | } | ||||
| }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); | ||||
| @@ -36,12 +36,12 @@ namespace LLama.Abstractions | |||||
| /// </summary> | /// </summary> | ||||
| public int TopK { get; set; } | public int TopK { get; set; } | ||||
| /// <summary>llama_eval | |||||
| /// <summary> | |||||
| /// 1.0 = disabled | /// 1.0 = disabled | ||||
| /// </summary> | /// </summary> | ||||
| public float TopP { get; set; } | public float TopP { get; set; } | ||||
| /// <summary>llama_eval | |||||
| /// <summary> | |||||
| /// 0.0 = disabled | /// 0.0 = disabled | ||||
| /// </summary> | /// </summary> | ||||
| public float MinP { get; set; } | public float MinP { get; set; } | ||||
| @@ -55,6 +55,9 @@ public sealed class BatchedExecutor | |||||
| Epoch = 1; | Epoch = 1; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Finalizer for BatchedExecutor | |||||
| /// </summary> | |||||
| ~BatchedExecutor() | ~BatchedExecutor() | ||||
| { | { | ||||
| Dispose(); | Dispose(); | ||||
| @@ -54,6 +54,9 @@ public sealed class Conversation | |||||
| _end = end; | _end = end; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Finalizer for Conversation | |||||
| /// </summary> | |||||
| ~Conversation() | ~Conversation() | ||||
| { | { | ||||
| Dispose(); | Dispose(); | ||||
| @@ -96,7 +99,7 @@ public sealed class Conversation | |||||
| AssertNotDisposed(); | AssertNotDisposed(); | ||||
| if (RequiresInference) | if (RequiresInference) | ||||
| throw new CannotForkWhileRequiresInference(); | |||||
| throw new CannotForkWhileRequiresInferenceException(); | |||||
| // Create a new conversation which references the current position in this one | // Create a new conversation which references the current position in this one | ||||
| var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) | var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end) | ||||
| @@ -195,13 +198,13 @@ public sealed class Conversation | |||||
| /// Directly modify the KV cache of this conversation | /// Directly modify the KV cache of this conversation | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="modifier"></param> | /// <param name="modifier"></param> | ||||
| /// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> | |||||
| /// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception> | |||||
| public void Modify(ModifyKvCache modifier) | public void Modify(ModifyKvCache modifier) | ||||
| { | { | ||||
| AssertNotDisposed(); | AssertNotDisposed(); | ||||
| if (RequiresInference) | if (RequiresInference) | ||||
| throw new CannotModifyWhileRequiresInference(); | |||||
| throw new CannotModifyWhileRequiresInferenceException(); | |||||
| // do whatever the modification is | // do whatever the modification is | ||||
| _end = modifier.Invoke(_end, new KvAccessor(this)); | _end = modifier.Invoke(_end, new KvAccessor(this)); | ||||
| @@ -59,10 +59,10 @@ public class CannotSampleRequiresPromptException | |||||
| /// <summary> | /// <summary> | ||||
| /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true | /// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true | ||||
| /// </summary> | /// </summary> | ||||
| public class CannotForkWhileRequiresInference | |||||
| public class CannotForkWhileRequiresInferenceException | |||||
| : ExperimentalBatchedExecutorException | : ExperimentalBatchedExecutorException | ||||
| { | { | ||||
| internal CannotForkWhileRequiresInference() | |||||
| internal CannotForkWhileRequiresInferenceException() | |||||
| : base("Cannot `Fork()` a conversation while RequiresInference is true") | : base("Cannot `Fork()` a conversation while RequiresInference is true") | ||||
| { | { | ||||
| } | } | ||||
| @@ -71,10 +71,10 @@ public class CannotForkWhileRequiresInference | |||||
| /// <summary> | /// <summary> | ||||
| /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | /// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true | ||||
| /// </summary> | /// </summary> | ||||
| public class CannotModifyWhileRequiresInference | |||||
| public class CannotModifyWhileRequiresInferenceException | |||||
| : ExperimentalBatchedExecutorException | : ExperimentalBatchedExecutorException | ||||
| { | { | ||||
| internal CannotModifyWhileRequiresInference() | |||||
| internal CannotModifyWhileRequiresInferenceException() | |||||
| : base("Cannot `Modify()` a conversation while RequiresInference is true") | : base("Cannot `Modify()` a conversation while RequiresInference is true") | ||||
| { | { | ||||
| } | } | ||||
| @@ -155,7 +155,7 @@ namespace LLama | |||||
| } | } | ||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) | |||||
| { | { | ||||
| if (_embeds.Count > 0) | if (_embeds.Count > 0) | ||||
| { | { | ||||
| @@ -238,6 +238,8 @@ namespace LLama | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return Task.CompletedTask; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -1,12 +1,34 @@ | |||||
| namespace LLama.Native; | namespace LLama.Native; | ||||
| /// <summary> | |||||
| /// Token Types | |||||
| /// </summary> | |||||
| /// <remarks>C# equivalent of llama_token_get_type</remarks> | |||||
| public enum LLamaTokenType | public enum LLamaTokenType | ||||
| { | { | ||||
| /// <summary> | |||||
| /// No specific type has been set for this token | |||||
| /// </summary> | |||||
| LLAMA_TOKEN_TYPE_UNDEFINED = 0, | LLAMA_TOKEN_TYPE_UNDEFINED = 0, | ||||
| /// <summary> | |||||
| /// This is a "normal" token | |||||
| /// </summary> | |||||
| LLAMA_TOKEN_TYPE_NORMAL = 1, | LLAMA_TOKEN_TYPE_NORMAL = 1, | ||||
| /// <summary> | |||||
| /// An "unknown" character/text token e.g. <unk> | |||||
| /// </summary> | |||||
| LLAMA_TOKEN_TYPE_UNKNOWN = 2, | LLAMA_TOKEN_TYPE_UNKNOWN = 2, | ||||
| /// <summary> | |||||
| /// A special control token e.g. </s> | |||||
| /// </summary> | |||||
| LLAMA_TOKEN_TYPE_CONTROL = 3, | LLAMA_TOKEN_TYPE_CONTROL = 3, | ||||
| LLAMA_TOKEN_TYPE_USER_DEFINED = 4, | LLAMA_TOKEN_TYPE_USER_DEFINED = 4, | ||||
| LLAMA_TOKEN_TYPE_UNUSED = 5, | LLAMA_TOKEN_TYPE_UNUSED = 5, | ||||
| LLAMA_TOKEN_TYPE_BYTE = 6, | LLAMA_TOKEN_TYPE_BYTE = 6, | ||||
| } | } | ||||
| @@ -172,6 +172,11 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx); | public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx); | ||||
| /// <summary> | |||||
| /// Get the batch size for this context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern uint llama_n_batch(SafeLLamaContextHandle ctx); | public static extern uint llama_n_batch(SafeLLamaContextHandle ctx); | ||||
| @@ -1,17 +1,30 @@ | |||||
| namespace LLama.Native | namespace LLama.Native | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// RoPE scaling type. C# equivalent of llama_rope_scaling_type | |||||
| /// RoPE scaling type. | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>C# equivalent of llama_rope_scaling_type</remarks> | |||||
| public enum RopeScalingType | public enum RopeScalingType | ||||
| : sbyte | : sbyte | ||||
| { | { | ||||
| /// <summary> | |||||
| /// No particular scaling type has been specified | |||||
| /// </summary> | |||||
| LLAMA_ROPE_SCALING_UNSPECIFIED = -1, | LLAMA_ROPE_SCALING_UNSPECIFIED = -1, | ||||
| /// <summary> | |||||
| /// Do not apply any RoPE scaling | |||||
| /// </summary> | |||||
| LLAMA_ROPE_SCALING_NONE = 0, | LLAMA_ROPE_SCALING_NONE = 0, | ||||
| /// <summary> | |||||
| /// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k | |||||
| /// </summary> | |||||
| LLAMA_ROPE_SCALING_LINEAR = 1, | LLAMA_ROPE_SCALING_LINEAR = 1, | ||||
| /// <summary> | |||||
| /// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf | |||||
| /// </summary> | |||||
| LLAMA_ROPE_SCALING_YARN = 2, | LLAMA_ROPE_SCALING_YARN = 2, | ||||
| } | } | ||||
| } | } | ||||
| @@ -142,20 +142,21 @@ namespace LLama | |||||
| /// Add all tokens in the given enumerable | /// Add all tokens in the given enumerable | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| public void AddRange(IEnumerable<int> tokens) | |||||
| public void AddRange<T>(T tokens) | |||||
| where T : IEnumerable<LLamaToken> | |||||
| { | { | ||||
| foreach (var item in tokens) | foreach (var item in tokens) | ||||
| Add(item); | |||||
| Add((int)item); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Add all tokens in the given enumerable | |||||
| /// Add all tokens in the given span | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tokens"></param> | /// <param name="tokens"></param> | ||||
| public void AddRange(IEnumerable<LLamaToken> tokens) | |||||
| public void AddRange(ReadOnlySpan<LLamaToken> tokens) | |||||
| { | { | ||||
| foreach (var item in tokens) | foreach (var item in tokens) | ||||
| Add((int)item); | |||||
| Add(item); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||