Browse Source

Assorted small changes to clean up some code warnings

tags/0.11.0
Martin Evans 1 year ago
parent
commit
c7d0dc915a
10 changed files with 77 additions and 18 deletions
  1. +12
    -2
      LLama.Unittest/BeamTests.cs
  2. +2
    -2
      LLama/Abstractions/IInferenceParams.cs
  3. +3
    -0
      LLama/Batched/BatchedExecutor.cs
  4. +6
    -3
      LLama/Batched/Conversation.cs
  5. +4
    -4
      LLama/Batched/Exceptions.cs
  6. +3
    -1
      LLama/LLamaInteractExecutor.cs
  7. +22
    -0
      LLama/Native/LLamaTokenType.cs
  8. +5
    -0
      LLama/Native/NativeApi.cs
  9. +14
    -1
      LLama/Native/RopeScalingType.cs
  10. +6
    -5
      LLama/StreamingTokenDecoder.cs

+ 12
- 2
LLama.Unittest/BeamTests.cs View File

@@ -47,14 +47,24 @@ public sealed class BeamTests
for (var i = 0; i < state.Beams.Length; i++)
{
ref var view = ref state.Beams[i];
var tokens = context.DeTokenize(view.Tokens.ToArray());

var decoder = new StreamingTokenDecoder(context);
decoder.AddRange(view.Tokens);
var tokens = decoder.Read();

_testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'");
}

if (state.CommonPrefixLength > 0)
{
var view = state.Beams[0];
result.Append(context.DeTokenize(view.Tokens.Slice(0, (int)state.CommonPrefixLength).ToArray()));

var decoder = new StreamingTokenDecoder(context);
decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength));
var tokens = decoder.Read();

result.Append(tokens);
}

}, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2));


+ 2
- 2
LLama/Abstractions/IInferenceParams.cs View File

@@ -36,12 +36,12 @@ namespace LLama.Abstractions
/// </summary>
public int TopK { get; set; }

/// <summary>llama_eval
/// <summary>
/// 1.0 = disabled
/// </summary>
public float TopP { get; set; }

/// <summary>llama_eval
/// <summary>
/// 0.0 = disabled
/// </summary>
public float MinP { get; set; }


+ 3
- 0
LLama/Batched/BatchedExecutor.cs View File

@@ -55,6 +55,9 @@ public sealed class BatchedExecutor
Epoch = 1;
}

/// <summary>
/// Finalizer for BatchedExecutor
/// </summary>
~BatchedExecutor()
{
Dispose();


+ 6
- 3
LLama/Batched/Conversation.cs View File

@@ -54,6 +54,9 @@ public sealed class Conversation
_end = end;
}

/// <summary>
/// Finalizer for Conversation
/// </summary>
~Conversation()
{
Dispose();
@@ -96,7 +99,7 @@ public sealed class Conversation
AssertNotDisposed();

if (RequiresInference)
throw new CannotForkWhileRequiresInference();
throw new CannotForkWhileRequiresInferenceException();

// Create a new conversation which references the current position in this one
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
@@ -195,13 +198,13 @@ public sealed class Conversation
/// Directly modify the KV cache of this conversation
/// </summary>
/// <param name="modifier"></param>
/// <exception cref="CannotModifyWhileRequiresInference">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
/// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
public void Modify(ModifyKvCache modifier)
{
AssertNotDisposed();

if (RequiresInference)
throw new CannotModifyWhileRequiresInference();
throw new CannotModifyWhileRequiresInferenceException();

// do whatever the modification is
_end = modifier.Invoke(_end, new KvAccessor(this));


+ 4
- 4
LLama/Batched/Exceptions.cs View File

@@ -59,10 +59,10 @@ public class CannotSampleRequiresPromptException
/// <summary>
/// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary>
public class CannotForkWhileRequiresInference
public class CannotForkWhileRequiresInferenceException
: ExperimentalBatchedExecutorException
{
internal CannotForkWhileRequiresInference()
internal CannotForkWhileRequiresInferenceException()
: base("Cannot `Fork()` a conversation while RequiresInference is true")
{
}
@@ -71,10 +71,10 @@ public class CannotForkWhileRequiresInference
/// <summary>
/// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true
/// </summary>
public class CannotModifyWhileRequiresInference
public class CannotModifyWhileRequiresInferenceException
: ExperimentalBatchedExecutorException
{
internal CannotModifyWhileRequiresInference()
internal CannotModifyWhileRequiresInferenceException()
: base("Cannot `Modify()` a conversation while RequiresInference is true")
{
}

+ 3
- 1
LLama/LLamaInteractExecutor.cs View File

@@ -155,7 +155,7 @@ namespace LLama
}

/// <inheritdoc />
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{
@@ -238,6 +238,8 @@ namespace LLama
}
}
}

return Task.CompletedTask;
}

/// <summary>


+ 22
- 0
LLama/Native/LLamaTokenType.cs View File

@@ -1,12 +1,34 @@
namespace LLama.Native;

/// <summary>
/// Token Types
/// </summary>
/// <remarks>C# equivalent of llama_token_get_type</remarks>
public enum LLamaTokenType
{
/// <summary>
/// No specific type has been set for this token
/// </summary>
LLAMA_TOKEN_TYPE_UNDEFINED = 0,

/// <summary>
/// This is a "normal" token
/// </summary>
LLAMA_TOKEN_TYPE_NORMAL = 1,

/// <summary>
/// An "unknown" character/text token e.g. &lt;unk&gt;
/// </summary>
LLAMA_TOKEN_TYPE_UNKNOWN = 2,

/// <summary>
/// A special control token e.g. &lt;/s&gt;
/// </summary>
LLAMA_TOKEN_TYPE_CONTROL = 3,

LLAMA_TOKEN_TYPE_USER_DEFINED = 4,

LLAMA_TOKEN_TYPE_UNUSED = 5,

LLAMA_TOKEN_TYPE_BYTE = 6,
}

+ 5
- 0
LLama/Native/NativeApi.cs View File

@@ -172,6 +172,11 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the batch size for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);



+ 14
- 1
LLama/Native/RopeScalingType.cs View File

@@ -1,17 +1,30 @@
namespace LLama.Native
{
/// <summary>
/// RoPE scaling type. C# equivalent of llama_rope_scaling_type
/// RoPE scaling type.
/// </summary>
/// <remarks>C# equivalent of llama_rope_scaling_type</remarks>
public enum RopeScalingType
: sbyte
{
/// <summary>
/// No particular scaling type has been specified
/// </summary>
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,

/// <summary>
/// Do not apply any RoPE scaling
/// </summary>
LLAMA_ROPE_SCALING_NONE = 0,

/// <summary>
/// Positional linear interpolation, as described by kaikendev: https://kaiokendev.github.io/til#extending-context-to-8k
/// </summary>
LLAMA_ROPE_SCALING_LINEAR = 1,

/// <summary>
/// YaRN scaling: https://arxiv.org/pdf/2309.00071.pdf
/// </summary>
LLAMA_ROPE_SCALING_YARN = 2,
}
}

+ 6
- 5
LLama/StreamingTokenDecoder.cs View File

@@ -142,20 +142,21 @@ namespace LLama
/// Add all tokens in the given enumerable
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<int> tokens)
public void AddRange<T>(T tokens)
where T : IEnumerable<LLamaToken>
{
foreach (var item in tokens)
Add(item);
Add((int)item);
}

/// <summary>
/// Add all tokens in the given enumerable
/// Add all tokens in the given span
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<LLamaToken> tokens)
public void AddRange(ReadOnlySpan<LLamaToken> tokens)
{
foreach (var item in tokens)
Add((int)item);
Add(item);
}

/// <summary>


Loading…
Cancel
Save