Browse Source

Code cleanup driven by R# suggestions:

- Made `NativeApi` into a `static class` (it's not intended to be instantiated)
 - Moved `LLamaTokenType` enum out into a separate file
 - Made `LLamaSeqId` and `LLamaPos` into `record struct`, convenient to have equality etc
tags/0.9.1
Martin Evans 1 year ago
parent
commit
f860f88c36
22 changed files with 126 additions and 140 deletions
  1. +1
    -1
      LLama/Abstractions/IModelParams.cs
  2. +2
    -7
      LLama/Common/ChatHistory.cs
  3. +0
    -1
      LLama/Common/FixedSizeQueue.cs
  4. +2
    -0
      LLama/Common/InferenceParams.cs
  5. +1
    -0
      LLama/Extensions/DictionaryExtensions.cs
  6. +7
    -1
      LLama/Grammars/Grammar.cs
  7. +1
    -10
      LLama/LLamaContext.cs
  8. +8
    -10
      LLama/LLamaEmbedder.cs
  9. +1
    -1
      LLama/LLamaWeights.cs
  10. +2
    -2
      LLama/Native/LLamaKvCacheView.cs
  11. +2
    -2
      LLama/Native/LLamaPos.cs
  12. +2
    -2
      LLama/Native/LLamaSeqId.cs
  13. +12
    -0
      LLama/Native/LLamaTokenType.cs
  14. +1
    -1
      LLama/Native/NativeApi.BeamSearch.cs
  15. +2
    -2
      LLama/Native/NativeApi.Grammar.cs
  16. +24
    -40
      LLama/Native/NativeApi.Load.cs
  17. +1
    -1
      LLama/Native/NativeApi.Quantize.cs
  18. +2
    -2
      LLama/Native/NativeApi.Sampling.cs
  19. +48
    -48
      LLama/Native/NativeApi.cs
  20. +7
    -4
      LLama/Native/NativeLibraryConfig.cs
  21. +0
    -3
      LLama/Native/SafeLLamaContextHandle.cs
  22. +0
    -2
      LLama/Native/SafeLlamaModelHandle.cs

+ 1
- 1
LLama/Abstractions/IModelParams.cs View File

@@ -214,7 +214,7 @@ namespace LLama.Abstractions
/// <summary> /// <summary>
/// Get the key being overriden by this override /// Get the key being overriden by this override
/// </summary> /// </summary>
public string Key { get; init; }
public string Key { get; }


internal LLamaModelKvOverrideType Type { get; } internal LLamaModelKvOverrideType Type { get; }




+ 2
- 7
LLama/Common/ChatHistory.cs View File

@@ -1,5 +1,4 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;


@@ -37,6 +36,7 @@ namespace LLama.Common
/// </summary> /// </summary>
public class ChatHistory public class ChatHistory
{ {
private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true };


/// <summary> /// <summary>
/// Chat message representation /// Chat message representation
@@ -96,12 +96,7 @@ namespace LLama.Common
/// <returns></returns> /// <returns></returns>
public string ToJson() public string ToJson()
{ {
return JsonSerializer.Serialize(
this,
new JsonSerializerOptions()
{
WriteIndented = true
});
return JsonSerializer.Serialize(this, _jsonOptions);
} }


/// <summary> /// <summary>


+ 0
- 1
LLama/Common/FixedSizeQueue.cs View File

@@ -2,7 +2,6 @@
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using LLama.Extensions;


namespace LLama.Common namespace LLama.Common
{ {


+ 2
- 0
LLama/Common/InferenceParams.cs View File

@@ -18,11 +18,13 @@ namespace LLama.Common
/// number of tokens to keep from initial prompt /// number of tokens to keep from initial prompt
/// </summary> /// </summary>
public int TokensKeep { get; set; } = 0; public int TokensKeep { get; set; } = 0;

/// <summary> /// <summary>
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete. /// until it complete.
/// </summary> /// </summary>
public int MaxTokens { get; set; } = -1; public int MaxTokens { get; set; } = -1;

/// <summary> /// <summary>
/// logit bias for specific tokens /// logit bias for specific tokens
/// </summary> /// </summary>


+ 1
- 0
LLama/Extensions/DictionaryExtensions.cs View File

@@ -15,6 +15,7 @@ namespace LLama.Extensions


internal static TValue GetValueOrDefaultImpl<TKey, TValue>(IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue) internal static TValue GetValueOrDefaultImpl<TKey, TValue>(IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
{ {
// ReSharper disable once CanSimplifyDictionaryTryGetValueWithGetValueOrDefault (this is a shim for that method!)
return dictionary.TryGetValue(key, out var value) ? value : defaultValue; return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
} }
} }


+ 7
- 1
LLama/Grammars/Grammar.cs View File

@@ -15,7 +15,7 @@ namespace LLama.Grammars
/// <summary> /// <summary>
/// Index of the initial rule to start from /// Index of the initial rule to start from
/// </summary> /// </summary>
public ulong StartRuleIndex { get; set; }
public ulong StartRuleIndex { get; }


/// <summary> /// <summary>
/// The rules which make up this grammar /// The rules which make up this grammar
@@ -121,6 +121,12 @@ namespace LLama.Grammars
case LLamaGrammarElementType.CHAR_ALT: case LLamaGrammarElementType.CHAR_ALT:
case LLamaGrammarElementType.CHAR_RNG_UPPER: case LLamaGrammarElementType.CHAR_RNG_UPPER:
break; break;

case LLamaGrammarElementType.END:
case LLamaGrammarElementType.ALT:
case LLamaGrammarElementType.RULE_REF:
case LLamaGrammarElementType.CHAR:
case LLamaGrammarElementType.CHAR_NOT:
default: default:
output.Append("] "); output.Append("] ");
break; break;


+ 1
- 10
LLama/LLamaContext.cs View File

@@ -43,7 +43,7 @@ namespace LLama
/// <summary> /// <summary>
/// The context params set for this context /// The context params set for this context
/// </summary> /// </summary>
public IContextParams Params { get; set; }
public IContextParams Params { get; }


/// <summary> /// <summary>
/// The native handle, which is used to be passed to the native APIs /// The native handle, which is used to be passed to the native APIs
@@ -56,15 +56,6 @@ namespace LLama
/// </summary> /// </summary>
public Encoding Encoding { get; } public Encoding Encoding { get; }


internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;

_logger = logger;
Encoding = @params.Encoding;
NativeHandle = nativeContext;
}

/// <summary> /// <summary>
/// Create a new LLamaContext for the given LLamaWeights /// Create a new LLamaContext for the given LLamaWeights
/// </summary> /// </summary>


+ 8
- 10
LLama/LLamaEmbedder.cs View File

@@ -12,17 +12,15 @@ namespace LLama
public sealed class LLamaEmbedder public sealed class LLamaEmbedder
: IDisposable : IDisposable
{ {
private readonly LLamaContext _ctx;

/// <summary> /// <summary>
/// Dimension of embedding vectors /// Dimension of embedding vectors
/// </summary> /// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;
public int EmbeddingSize => Context.EmbeddingSize;


/// <summary> /// <summary>
/// LLama Context /// LLama Context
/// </summary> /// </summary>
public LLamaContext Context => this._ctx;
public LLamaContext Context { get; }


/// <summary> /// <summary>
/// Create a new embedder, using the given LLamaWeights /// Create a new embedder, using the given LLamaWeights
@@ -33,7 +31,7 @@ namespace LLama
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{ {
@params.EmbeddingMode = true; @params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params, logger);
Context = weights.CreateContext(@params, logger);
} }


/// <summary> /// <summary>
@@ -72,20 +70,20 @@ namespace LLama
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos) public float[] GetEmbeddings(string text, bool addBos)
{ {
var embed_inp_array = _ctx.Tokenize(text, addBos);
var embed_inp_array = Context.Tokenize(text, addBos);


// TODO(Rinne): deal with log of prompt // TODO(Rinne): deal with log of prompt


if (embed_inp_array.Length > 0) if (embed_inp_array.Length > 0)
_ctx.Eval(embed_inp_array, 0);
Context.Eval(embed_inp_array, 0);


unsafe unsafe
{ {
var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null) if (embeddings == null)
return Array.Empty<float>(); return Array.Empty<float>();


return new Span<float>(embeddings, EmbeddingSize).ToArray();
return embeddings.ToArray();
} }
} }


@@ -94,7 +92,7 @@ namespace LLama
/// </summary> /// </summary>
public void Dispose() public void Dispose()
{ {
_ctx.Dispose();
Context.Dispose();
} }


} }


+ 1
- 1
LLama/LLamaWeights.cs View File

@@ -64,7 +64,7 @@ namespace LLama
/// </summary> /// </summary>
public IReadOnlyDictionary<string, string> Metadata { get; set; } public IReadOnlyDictionary<string, string> Metadata { get; set; }


internal LLamaWeights(SafeLlamaModelHandle weights)
private LLamaWeights(SafeLlamaModelHandle weights)
{ {
NativeHandle = weights; NativeHandle = weights;
Metadata = weights.ReadMetadata(); Metadata = weights.ReadMetadata();


+ 2
- 2
LLama/Native/LLamaKvCacheView.cs View File

@@ -14,7 +14,7 @@ public struct LLamaKvCacheViewCell
/// May be negative if the cell is not populated. /// May be negative if the cell is not populated.
/// </summary> /// </summary>
public LLamaPos pos; public LLamaPos pos;
};
}


/// <summary> /// <summary>
/// An updateable view of the KV cache (llama_kv_cache_view) /// An updateable view of the KV cache (llama_kv_cache_view)
@@ -130,7 +130,7 @@ public class LLamaKvCacheViewSafeHandle
} }
} }


partial class NativeApi
public static partial class NativeApi
{ {
/// <summary> /// <summary>
/// Create an empty KV cache view. (use only for debugging purposes) /// Create an empty KV cache view. (use only for debugging purposes)


+ 2
- 2
LLama/Native/LLamaPos.cs View File

@@ -6,7 +6,7 @@ namespace LLama.Native;
/// Indicates position in a sequence /// Indicates position in a sequence
/// </summary> /// </summary>
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public struct LLamaPos
public record struct LLamaPos
{ {
/// <summary> /// <summary>
/// The raw value /// The raw value
@@ -17,7 +17,7 @@ public struct LLamaPos
/// Create a new LLamaPos /// Create a new LLamaPos
/// </summary> /// </summary>
/// <param name="value"></param> /// <param name="value"></param>
public LLamaPos(int value)
private LLamaPos(int value)
{ {
Value = value; Value = value;
} }


+ 2
- 2
LLama/Native/LLamaSeqId.cs View File

@@ -6,7 +6,7 @@ namespace LLama.Native;
/// ID for a sequence in a batch /// ID for a sequence in a batch
/// </summary> /// </summary>
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public struct LLamaSeqId
public record struct LLamaSeqId
{ {
/// <summary> /// <summary>
/// The raw value /// The raw value
@@ -17,7 +17,7 @@ public struct LLamaSeqId
/// Create a new LLamaSeqId /// Create a new LLamaSeqId
/// </summary> /// </summary>
/// <param name="value"></param> /// <param name="value"></param>
public LLamaSeqId(int value)
private LLamaSeqId(int value)
{ {
Value = value; Value = value;
} }


+ 12
- 0
LLama/Native/LLamaTokenType.cs View File

@@ -0,0 +1,12 @@
namespace LLama.Native;

public enum LLamaTokenType
{
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
LLAMA_TOKEN_TYPE_CONTROL = 3,
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
LLAMA_TOKEN_TYPE_UNUSED = 5,
LLAMA_TOKEN_TYPE_BYTE = 6,
}

+ 1
- 1
LLama/Native/NativeApi.BeamSearch.cs View File

@@ -3,7 +3,7 @@ using System.Runtime.InteropServices;


namespace LLama.Native; namespace LLama.Native;


public partial class NativeApi
public static partial class NativeApi
{ {
/// <summary> /// <summary>
/// Type of pointer to the beam_search_callback function. /// Type of pointer to the beam_search_callback function.


+ 2
- 2
LLama/Native/NativeApi.Grammar.cs View File

@@ -5,7 +5,7 @@ namespace LLama.Native
{ {
using llama_token = Int32; using llama_token = Int32;


public unsafe partial class NativeApi
public static partial class NativeApi
{ {
/// <summary> /// <summary>
/// Create a new grammar from the given set of grammar rules /// Create a new grammar from the given set of grammar rules
@@ -15,7 +15,7 @@ namespace LLama.Native
/// <param name="start_rule_index"></param> /// <param name="start_rule_index"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);
public static extern unsafe IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index);


/// <summary> /// <summary>
/// Free all memory from the given SafeLLamaGrammarHandle /// Free all memory from the given SafeLLamaGrammarHandle


+ 24
- 40
LLama/Native/NativeApi.Load.cs View File

@@ -4,13 +4,12 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text.Json; using System.Text.Json;


namespace LLama.Native namespace LLama.Native
{ {
public partial class NativeApi
public static partial class NativeApi
{ {
static NativeApi() static NativeApi()
{ {
@@ -97,22 +96,13 @@ namespace LLama.Native
} }


if (string.IsNullOrEmpty(version)) if (string.IsNullOrEmpty(version))
{
return -1; return -1;
}
else
{
version = version.Split('.')[0];
bool success = int.TryParse(version, out var majorVersion);
if (success)
{
return majorVersion;
}
else
{
return -1;
}
}

version = version.Split('.')[0];
if (int.TryParse(version, out var majorVersion))
return majorVersion;

return -1;
} }


private static string GetCudaVersionFromPath(string cudaPath) private static string GetCudaVersionFromPath(string cudaPath)
@@ -129,7 +119,7 @@ namespace LLama.Native
{ {
return string.Empty; return string.Empty;
} }
return versionNode.GetString();
return versionNode.GetString() ?? "";
} }
} }
catch (Exception) catch (Exception)
@@ -169,18 +159,14 @@ namespace LLama.Native
{ {
platform = OSPlatform.OSX; platform = OSPlatform.OSX;
suffix = ".dylib"; suffix = ".dylib";
if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
{
prefix = "runtimes/osx-arm64/native/";
}
else
{
prefix = "runtimes/osx-x64/native/";
}

prefix = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported
? "runtimes/osx-arm64/native/"
: "runtimes/osx-x64/native/";
} }
else else
{ {
throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp.");
throw new RuntimeError("Your system plarform is not supported, please open an issue in LLamaSharp.");
} }
Log($"Detected OS Platform: {platform}", LogLevel.Information); Log($"Detected OS Platform: {platform}", LogLevel.Information);


@@ -275,15 +261,15 @@ namespace LLama.Native


var libraryTryLoadOrder = GetLibraryTryOrder(configuration); var libraryTryLoadOrder = GetLibraryTryOrder(configuration);


string[] preferredPaths = configuration.SearchDirectories;
string[] possiblePathPrefix = new string[] {
System.AppDomain.CurrentDomain.BaseDirectory,
var preferredPaths = configuration.SearchDirectories;
var possiblePathPrefix = new[] {
AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
}; };


var tryFindPath = (string filename) =>
string TryFindPath(string filename)
{ {
foreach(var path in preferredPaths)
foreach (var path in preferredPaths)
{ {
if (File.Exists(Path.Combine(path, filename))) if (File.Exists(Path.Combine(path, filename)))
{ {
@@ -291,7 +277,7 @@ namespace LLama.Native
} }
} }


foreach(var path in possiblePathPrefix)
foreach (var path in possiblePathPrefix)
{ {
if (File.Exists(Path.Combine(path, filename))) if (File.Exists(Path.Combine(path, filename)))
{ {
@@ -300,21 +286,19 @@ namespace LLama.Native
} }


return filename; return filename;
};
}


foreach (var libraryPath in libraryTryLoadOrder) foreach (var libraryPath in libraryTryLoadOrder)
{ {
var fullPath = tryFindPath(libraryPath);
var fullPath = TryFindPath(libraryPath);
var result = TryLoad(fullPath, true); var result = TryLoad(fullPath, true);
if (result is not null && result != IntPtr.Zero) if (result is not null && result != IntPtr.Zero)
{ {
Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information); Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information);
return result ?? IntPtr.Zero;
}
else
{
Log($"Tried to load {fullPath} but failed.", LogLevel.Information);
return (IntPtr)result;
} }

Log($"Tried to load {fullPath} but failed.", LogLevel.Information);
} }


if (!configuration.AllowFallback) if (!configuration.AllowFallback)


+ 1
- 1
LLama/Native/NativeApi.Quantize.cs View File

@@ -2,7 +2,7 @@


namespace LLama.Native namespace LLama.Native
{ {
public partial class NativeApi
public static partial class NativeApi
{ {
/// <summary> /// <summary>
/// Returns 0 on success /// Returns 0 on success


+ 2
- 2
LLama/Native/NativeApi.Sampling.cs View File

@@ -5,7 +5,7 @@ namespace LLama.Native
{ {
using llama_token = Int32; using llama_token = Int32;


public unsafe partial class NativeApi
public static partial class NativeApi
{ {
/// <summary> /// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@@ -19,7 +19,7 @@ namespace LLama.Native
/// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> /// <param name="penalty_freq">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
/// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param> /// <param name="penalty_present">Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.</param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
ref LLamaTokenDataArrayNative candidates, ref LLamaTokenDataArrayNative candidates,
llama_token* last_tokens, ulong last_tokens_size, llama_token* last_tokens, ulong last_tokens_size,
float penalty_repeat, float penalty_repeat,


+ 48
- 48
LLama/Native/NativeApi.cs View File

@@ -9,17 +9,6 @@ namespace LLama.Native
{ {
using llama_token = Int32; using llama_token = Int32;


public enum LLamaTokenType
{
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
LLAMA_TOKEN_TYPE_CONTROL = 3,
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
LLAMA_TOKEN_TYPE_UNUSED = 5,
LLAMA_TOKEN_TYPE_BYTE = 6,
}

/// <summary> /// <summary>
/// Callback from llama.cpp with log messages /// Callback from llama.cpp with log messages
/// </summary> /// </summary>
@@ -30,7 +19,7 @@ namespace LLama.Native
/// <summary> /// <summary>
/// Direct translation of the llama.cpp API /// Direct translation of the llama.cpp API
/// </summary> /// </summary>
public unsafe partial class NativeApi
public static partial class NativeApi
{ {
/// <summary> /// <summary>
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
@@ -165,7 +154,7 @@ namespace LLama.Native
/// <param name="dest"></param> /// <param name="dest"></param>
/// <returns>the number of bytes copied</returns> /// <returns>the number of bytes copied</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest);
public static extern unsafe ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest);


/// <summary> /// <summary>
/// Set the state reading from the specified address /// Set the state reading from the specified address
@@ -174,7 +163,7 @@ namespace LLama.Native
/// <param name="src"></param> /// <param name="src"></param>
/// <returns>the number of bytes read</returns> /// <returns>the number of bytes read</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src);
public static extern unsafe ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src);


/// <summary> /// <summary>
/// Load session file /// Load session file
@@ -186,7 +175,7 @@ namespace LLama.Native
/// <param name="n_token_count_out"></param> /// <param name="n_token_count_out"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out);
public static extern unsafe bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out);


/// <summary> /// <summary>
/// Save session file /// Save session file
@@ -211,7 +200,7 @@ namespace LLama.Native
/// <returns>Returns 0 on success</returns> /// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[Obsolete("use llama_decode() instead")] [Obsolete("use llama_decode() instead")]
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past);
public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past);


/// <summary> /// <summary>
/// Convert the provided text into tokens. /// Convert the provided text into tokens.
@@ -228,34 +217,37 @@ namespace LLama.Native
/// </returns> /// </returns>
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special) public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special)
{ {
// Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
var byteCount = encoding.GetByteCount(text);
var array = ArrayPool<byte>.Shared.Rent(byteCount + 1);
try
unsafe
{ {
// Convert to bytes
fixed (char* textPtr = text)
fixed (byte* arrayPtr = array)
// Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
var byteCount = encoding.GetByteCount(text);
var array = ArrayPool<byte>.Shared.Rent(byteCount + 1);
try
{ {
encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length);
// Convert to bytes
fixed (char* textPtr = text)
fixed (byte* arrayPtr = array)
{
encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length);
}

// Add a zero byte to the end to terminate the string
array[byteCount] = 0;

// Do the actual tokenization
fixed (byte* arrayPtr = array)
fixed (llama_token* tokensPtr = tokens)
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
}
finally
{
ArrayPool<byte>.Shared.Return(array);
} }

// Add a zero byte to the end to terminate the string
array[byteCount] = 0;

// Do the actual tokenization
fixed (byte* arrayPtr = array)
fixed (llama_token* tokensPtr = tokens)
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
}
finally
{
ArrayPool<byte>.Shared.Return(array);
} }
} }


[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token);
public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token);


[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token); public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token);
@@ -281,7 +273,7 @@ namespace LLama.Native
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_logits(SafeLLamaContextHandle ctx);
public static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx);


/// <summary> /// <summary>
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
@@ -290,16 +282,24 @@ namespace LLama.Native
/// <param name="i"></param> /// <param name="i"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);


/// <summary> /// <summary>
/// Get the embeddings for the input /// Get the embeddings for the input
/// shape: [n_embd] (1-dimensional)
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx);
public static Span<float> llama_get_embeddings(SafeLLamaContextHandle ctx)
{
unsafe
{
var ptr = llama_get_embeddings_native(ctx);
return new Span<float>(ptr, ctx.EmbeddingSize);
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")]
static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx);
}


/// <summary> /// <summary>
/// Get the "Beginning of sentence" token /// Get the "Beginning of sentence" token
@@ -426,7 +426,7 @@ namespace LLama.Native
/// <param name="buf_size"></param> /// <param name="buf_size"></param>
/// <returns>The length of the string on success, or -1 on failure</returns> /// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);


/// <summary> /// <summary>
/// Get the number of metadata key/value pairs /// Get the number of metadata key/value pairs
@@ -445,7 +445,7 @@ namespace LLama.Native
/// <param name="buf_size"></param> /// <param name="buf_size"></param>
/// <returns>The length of the string on success, or -1 on failure</returns> /// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
public static extern unsafe int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);


/// <summary> /// <summary>
/// Get metadata value as a string by index /// Get metadata value as a string by index
@@ -456,7 +456,7 @@ namespace LLama.Native
/// <param name="buf_size"></param> /// <param name="buf_size"></param>
/// <returns>The length of the string on success, or -1 on failure</returns> /// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
public static extern unsafe int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);


/// <summary> /// <summary>
/// Get a string describing the model type /// Get a string describing the model type
@@ -466,7 +466,7 @@ namespace LLama.Native
/// <param name="buf_size"></param> /// <param name="buf_size"></param>
/// <returns>The length of the string on success, or -1 on failure</returns> /// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);
public static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);


/// <summary> /// <summary>
/// Get the size of the model in bytes /// Get the size of the model in bytes
@@ -493,7 +493,7 @@ namespace LLama.Native
/// <param name="length">size of the buffer</param> /// <param name="length">size of the buffer</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns> /// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);


/// <summary> /// <summary>
/// Convert text into tokens /// Convert text into tokens
@@ -509,7 +509,7 @@ namespace LLama.Native
/// Returns a negative number on failure - the number of tokens that would have been returned /// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns> /// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);
public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);


/// <summary> /// <summary>
/// Register a callback to receive llama log messages /// Register a callback to receive llama log messages


+ 7
- 4
LLama/Native/NativeLibraryConfig.cs View File

@@ -29,10 +29,11 @@ namespace LLama.Native
private bool _allowFallback = true; private bool _allowFallback = true;
private bool _skipCheck = false; private bool _skipCheck = false;
private bool _logging = false; private bool _logging = false;

/// <summary> /// <summary>
/// search directory -> priority level, 0 is the lowest. /// search directory -> priority level, 0 is the lowest.
/// </summary> /// </summary>
private List<string> _searchDirectories = new List<string>();
private readonly List<string> _searchDirectories = new List<string>();


private static void ThrowIfLoaded() private static void ThrowIfLoaded()
{ {
@@ -159,9 +160,8 @@ namespace LLama.Native
internal static Description CheckAndGatherDescription() internal static Description CheckAndGatherDescription()
{ {
if (Instance._allowFallback && Instance._skipCheck) if (Instance._allowFallback && Instance._skipCheck)
{
throw new ArgumentException("Cannot skip the check when fallback is allowed."); throw new ArgumentException("Cannot skip the check when fallback is allowed.");
}
return new Description( return new Description(
Instance._libraryPath, Instance._libraryPath,
Instance._useCuda, Instance._useCuda,
@@ -169,7 +169,8 @@ namespace LLama.Native
Instance._allowFallback, Instance._allowFallback,
Instance._skipCheck, Instance._skipCheck,
Instance._logging, Instance._logging,
Instance._searchDirectories.Concat(new string[] { "./" }).ToArray());
Instance._searchDirectories.Concat(new[] { "./" }).ToArray()
);
} }


internal static string AvxLevelToString(AvxLevel level) internal static string AvxLevelToString(AvxLevel level)
@@ -204,7 +205,9 @@ namespace LLama.Native
if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported) if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported)
return false; return false;


// ReSharper disable UnusedVariable (ebx is used when < NET8)
var (_, ebx, ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0); var (_, ebx, ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0);
// ReSharper restore UnusedVariable


var vnni = (ecx & 0b_1000_0000_0000) != 0; var vnni = (ecx & 0b_1000_0000_0000) != 0;




+ 0
- 3
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,6 +1,5 @@
using System; using System;
using System.Buffers; using System.Buffers;
using System.Collections.Generic;
using System.Text; using System.Text;
using LLama.Exceptions; using LLama.Exceptions;


@@ -51,8 +50,6 @@ namespace LLama.Native
_model.DangerousAddRef(ref success); _model.DangerousAddRef(ref success);
if (!success) if (!success)
throw new RuntimeError("Failed to increment model refcount"); throw new RuntimeError("Failed to increment model refcount");

} }


/// <inheritdoc /> /// <inheritdoc />


+ 0
- 2
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -214,7 +214,6 @@ namespace LLama.Native
/// Get the metadata key for the given index /// Get the metadata key for the given index
/// </summary> /// </summary>
/// <param name="index">The index to get</param> /// <param name="index">The index to get</param>
/// <param name="buffer">A temporary buffer to store key characters in. Must be large enough to contain the key.</param>
/// <returns>The key, null if there is no such key or if the buffer was too small</returns> /// <returns>The key, null if there is no such key or if the buffer was too small</returns>
public Memory<byte>? MetadataKeyByIndex(int index) public Memory<byte>? MetadataKeyByIndex(int index)
{ {
@@ -243,7 +242,6 @@ namespace LLama.Native
/// Get the metadata value for the given index /// Get the metadata value for the given index
/// </summary> /// </summary>
/// <param name="index">The index to get</param> /// <param name="index">The index to get</param>
/// <param name="buffer">A temporary buffer to store value characters in. Must be large enough to contain the value.</param>
/// <returns>The value, null if there is no such value or if the buffer was too small</returns> /// <returns>The value, null if there is no such value or if the buffer was too small</returns>
public Memory<byte>? MetadataValueByIndex(int index) public Memory<byte>? MetadataValueByIndex(int index)
{ {


Loading…
Cancel
Save