Browse Source

Merge pull request #185 from martindevans/wip_major_api_change

Major llama.cpp API Change
tags/v0.6.0
Martin Evans GitHub 2 years ago
parent
commit
d8434ea9d6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 1405 additions and 887 deletions
  1. +2
    -2
      LLama.Examples/NewVersion/LoadAndSaveSession.cs
  2. +2
    -8
      LLama.Examples/NewVersion/SemanticKernelChat.cs
  3. +1
    -1
      LLama.Examples/NewVersion/SemanticKernelMemory.cs
  4. +1
    -1
      LLama.Examples/NewVersion/SemanticKernelPrompt.cs
  5. +1
    -1
      LLama.Examples/NewVersion/TalkToYourself.cs
  6. +2
    -1
      LLama.Examples/Program.cs
  7. +1
    -29
      LLama.Unittest/BasicTest.cs
  8. +1
    -2
      LLama.Unittest/LLamaContextTests.cs
  9. +6
    -3
      LLama.Unittest/ModelsParamsTests.cs
  10. +3
    -3
      LLama.Unittest/StatelessExecutorTest.cs
  11. +4
    -4
      LLama.Unittest/TokenTests.cs
  12. +29
    -19
      LLama.Web/Common/ModelOptions.cs
  13. +6
    -6
      LLama.Web/Services/ConnectionSessionService.cs
  14. +8
    -3
      LLama.WebAPI/Services/StatefulChatService.cs
  15. +8
    -2
      LLama.WebAPI/Services/StatelessChatService.cs
  16. +70
    -0
      LLama/Abstractions/IContextParams.cs
  17. +1
    -1
      LLama/Abstractions/IInferenceParams.cs
  18. +11
    -0
      LLama/Abstractions/ILLamaParams.cs
  19. +53
    -58
      LLama/Abstractions/IModelParams.cs
  20. +28
    -17
      LLama/Common/ModelParams.cs
  21. +46
    -0
      LLama/Extensions/IContextParamsExtensions.cs
  22. +7
    -18
      LLama/Extensions/IModelParamsExtensions.cs
  23. +28
    -91
      LLama/LLamaContext.cs
  24. +12
    -9
      LLama/LLamaEmbedder.cs
  25. +10
    -22
      LLama/LLamaStatelessExecutor.cs
  26. +23
    -13
      LLama/LLamaWeights.cs
  27. +106
    -0
      LLama/Native/LLamaBatchSafeHandle.cs
  28. +7
    -62
      LLama/Native/LLamaContextParams.cs
  29. +67
    -0
      LLama/Native/LLamaModelParams.cs
  30. +10
    -0
      LLama/Native/LLamaModelQuantizeParams.cs
  31. +45
    -0
      LLama/Native/LLamaNativeBatch.cs
  32. +15
    -0
      LLama/Native/LLamaPos.cs
  33. +15
    -0
      LLama/Native/LLamaSeqId.cs
  34. +139
    -115
      LLama/Native/NativeApi.cs
  35. +31
    -61
      LLama/Native/SafeLLamaContextHandle.cs
  36. +31
    -15
      LLama/Native/SafeLlamaModelHandle.cs
  37. +0
    -108
      LLama/Utils.cs
  38. +575
    -212
      LLama/runtimes/ggml-metal.metal
  39. BIN
      LLama/runtimes/libllama-cuda11.dll
  40. BIN
      LLama/runtimes/libllama-cuda11.so
  41. BIN
      LLama/runtimes/libllama-cuda12.dll
  42. BIN
      LLama/runtimes/libllama-cuda12.so
  43. BIN
      LLama/runtimes/libllama.dll
  44. BIN
      LLama/runtimes/libllama.dylib
  45. BIN
      LLama/runtimes/libllama.so

+ 2
- 2
LLama.Examples/NewVersion/LoadAndSaveSession.cs View File

@@ -8,7 +8,7 @@ namespace LLama.Examples.NewVersion
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim();


var parameters = new ModelParams(modelPath) var parameters = new ModelParams(modelPath)
{ {
@@ -50,7 +50,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White; Console.ForegroundColor = ConsoleColor.White;


ex.Context.Dispose(); ex.Context.Dispose();
ex = new(new LLamaContext(parameters));
ex = new(new LLamaContext(model, parameters));
session = new ChatSession(ex); session = new ChatSession(ex);
session.LoadSession(statePath); session.LoadSession(statePath);




+ 2
- 8
LLama.Examples/NewVersion/SemanticKernelChat.cs View File

@@ -1,13 +1,7 @@
using System.Reflection.Metadata;
using System.Security.Cryptography;
using System.Text;
using LLama.Abstractions;
using System.Security.Cryptography;
using LLama.Common; using LLama.Common;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.ChatCompletion; using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextCompletion;
using LLamaSharp.SemanticKernel.ChatCompletion; using LLamaSharp.SemanticKernel.ChatCompletion;
using LLamaSharp.SemanticKernel.TextCompletion;


namespace LLama.Examples.NewVersion namespace LLama.Examples.NewVersion
{ {
@@ -22,7 +16,7 @@ namespace LLama.Examples.NewVersion
// Load weights into memory // Load weights into memory
var parameters = new ModelParams(modelPath) var parameters = new ModelParams(modelPath)
{ {
Seed = RandomNumberGenerator.GetInt32(int.MaxValue),
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)),
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);


+ 1
- 1
LLama.Examples/NewVersion/SemanticKernelMemory.cs View File

@@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();


var seed = 1337;
var seed = 1337u;
// Load weights into memory // Load weights into memory
var parameters = new ModelParams(modelPath) var parameters = new ModelParams(modelPath)
{ {


+ 1
- 1
LLama.Examples/NewVersion/SemanticKernelPrompt.cs View File

@@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion
// Load weights into memory // Load weights into memory
var parameters = new ModelParams(modelPath) var parameters = new ModelParams(modelPath)
{ {
Seed = RandomNumberGenerator.GetInt32(int.MaxValue),
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters); var ex = new StatelessExecutor(model, parameters);


+ 1
- 1
LLama.Examples/NewVersion/TalkToYourself.cs View File

@@ -15,7 +15,7 @@ namespace LLama.Examples.NewVersion
// Load weights into memory // Load weights into memory
var @params = new ModelParams(modelPath) var @params = new ModelParams(modelPath)
{ {
Seed = RandomNumberGenerator.GetInt32(int.MaxValue)
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
}; };
using var weights = LLamaWeights.LoadFromFile(@params); using var weights = LLamaWeights.LoadFromFile(@params);




+ 2
- 1
LLama.Examples/Program.cs View File

@@ -1,4 +1,5 @@
using LLama.Examples.NewVersion; using LLama.Examples.NewVersion;
using LLama.Native;


Console.WriteLine("======================================================================================================"); Console.WriteLine("======================================================================================================");


@@ -7,7 +8,7 @@ Console.WriteLine(" __ __ ____ _
Console.WriteLine("======================================================================================================"); Console.WriteLine("======================================================================================================");




NativeApi.llama_empty_call();
Console.WriteLine(); Console.WriteLine();


await NewVersionTestRunner.Run(); await NewVersionTestRunner.Run();

+ 1
- 29
LLama.Unittest/BasicTest.cs View File

@@ -27,36 +27,8 @@ namespace LLama.Unittest
public void BasicModelProperties() public void BasicModelProperties()
{ {
Assert.Equal(32000, _model.VocabCount); Assert.Equal(32000, _model.VocabCount);
Assert.Equal(2048, _model.ContextSize);
Assert.Equal(4096, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize); Assert.Equal(4096, _model.EmbeddingSize);
Assert.Equal(Encoding.UTF8, _model.Encoding);
}

[Fact]
public void CloneContext()
{
var original = _model.CreateContext(_params);

// Evaluate something (doesn't matter what, as long as it begins with token 1)
original.Eval(new[] { 1, 42, 321 }, 0);

// Clone current state
var clone = original.Clone();

// Now evaluate something more
var reply1a = original.Eval(new[] { 4, 5, 6 }, 3);
var reply2a = original.Eval(new[] { 7, 8, 9 }, 6);

// Assert that the context replied differently each time
Assert.NotEqual(reply1a, reply2a);

// Give the same prompts to the cloned state
var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3);
var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6);

// Assert that the cloned context replied in the same way as originally
Assert.Equal(reply1a, reply1b);
Assert.Equal(reply2a, reply2b);
} }
} }
} }

+ 1
- 2
LLama.Unittest/LLamaContextTests.cs View File

@@ -2,7 +2,7 @@


namespace LLama.Unittest namespace LLama.Unittest
{ {
public class LLamaContextTests
public sealed class LLamaContextTests
: IDisposable : IDisposable
{ {
private readonly LLamaWeights _weights; private readonly LLamaWeights _weights;
@@ -30,7 +30,6 @@ namespace LLama.Unittest
Assert.Equal(768, _context.ContextSize); Assert.Equal(768, _context.ContextSize);
Assert.Equal(4096, _context.EmbeddingSize); Assert.Equal(4096, _context.EmbeddingSize);
Assert.Equal(32000, _context.VocabCount); Assert.Equal(32000, _context.VocabCount);
Assert.Equal(0, _context.KVCacheTokenCount);
} }


[Fact] [Fact]


+ 6
- 3
LLama.Unittest/ModelsParamsTests.cs View File

@@ -13,7 +13,6 @@ namespace LLama.Unittest
{ {
BatchSize = 17, BatchSize = 17,
ContextSize = 42, ContextSize = 42,
LoraAdapter = "adapter",
Seed = 42, Seed = 42,
GpuLayerCount = 111 GpuLayerCount = 111
}; };
@@ -31,9 +30,13 @@ namespace LLama.Unittest
{ {
BatchSize = 17, BatchSize = 17,
ContextSize = 42, ContextSize = 42,
LoraAdapter = "adapter",
Seed = 42, Seed = 42,
GpuLayerCount = 111
GpuLayerCount = 111,
LoraAdapters =
{
new("abc", 1),
new("def", 0)
}
}; };


var settings = new Newtonsoft.Json.JsonSerializerSettings(); var settings = new Newtonsoft.Json.JsonSerializerSettings();


+ 3
- 3
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -16,7 +16,7 @@ namespace LLama.Unittest
_params = new ModelParams(Constants.ModelPath) _params = new ModelParams(Constants.ModelPath)
{ {
ContextSize = 60, ContextSize = 60,
Seed = 1754
Seed = 1754,
}; };
_weights = LLamaWeights.LoadFromFile(_params); _weights = LLamaWeights.LoadFromFile(_params);
} }
@@ -48,13 +48,13 @@ namespace LLama.Unittest
{ {
var executor = new StatelessExecutor(_weights, _params); var executor = new StatelessExecutor(_weights, _params);


const string question = " Question. why is a cat the best pet?\nAnswer: ";
const string question = " Question. cats or dogs?\nAnswer: ";


// The context size is set to 60. Generate more than that, forcing it to generate a coherent response // The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context // with a modified context
var @params = new InferenceParams() var @params = new InferenceParams()
{ {
MaxTokens = 100,
MaxTokens = 65,
TokensKeep = question.Length, TokensKeep = question.Length,
}; };




+ 4
- 4
LLama.Unittest/TokenTests.cs View File

@@ -27,7 +27,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensEndWith() public void TokensEndWith()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString(new[] var result = tokens.TokensEndsWithAnyString(new[]
{ {
@@ -41,7 +41,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensEndSubstring() public void TokensEndSubstring()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString((IList<string>)new[] var result = tokens.TokensEndsWithAnyString((IList<string>)new[]
{ {
@@ -53,7 +53,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensNotEndWith() public void TokensNotEndWith()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString((IList<string>)new[] var result = tokens.TokensEndsWithAnyString((IList<string>)new[]
{ {
@@ -67,7 +67,7 @@ public sealed class TokenTests
[Fact] [Fact]
public void TokensNotEndWithNothing() public void TokensNotEndWithNothing()
{ {
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);


var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8); var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8);
Assert.False(result); Assert.False(result);


+ 29
- 19
LLama.Web/Common/ModelOptions.cs View File

@@ -4,7 +4,7 @@ using LLama.Abstractions;
namespace LLama.Web.Common namespace LLama.Web.Common
{ {
public class ModelOptions public class ModelOptions
: IModelParams
: ILLamaParams
{ {
public string Name { get; set; } public string Name { get; set; }
@@ -14,7 +14,7 @@ namespace LLama.Web.Common
/// <summary> /// <summary>
/// Model context size (n_ctx) /// Model context size (n_ctx)
/// </summary> /// </summary>
public int ContextSize { get; set; } = 512;
public uint ContextSize { get; set; } = 512;
/// <summary> /// <summary>
/// the GPU that is used for scratch and small tensors /// the GPU that is used for scratch and small tensors
/// </summary> /// </summary>
@@ -30,7 +30,7 @@ namespace LLama.Web.Common
/// <summary> /// <summary>
/// Seed for the random number generator (seed) /// Seed for the random number generator (seed)
/// </summary> /// </summary>
public int Seed { get; set; } = 1686349486;
public uint Seed { get; set; } = 1686349486;
/// <summary> /// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16) /// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary> /// </summary>
@@ -51,26 +51,31 @@ namespace LLama.Web.Common
/// Model path (model) /// Model path (model)
/// </summary> /// </summary>
public string ModelPath { get; set; } public string ModelPath { get; set; }

/// <summary> /// <summary>
/// model alias
/// </summary>
public string ModelAlias { get; set; } = "unknown";
/// <summary>
/// lora adapter path (lora_adapter)
/// </summary>
public string LoraAdapter { get; set; } = string.Empty;
/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
public string LoraBase { get; set; } = string.Empty;
/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// List of LoRAs to apply
/// </summary> /// </summary>
public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1);
public AdapterCollection LoraAdapters { get; set; } = new();

/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
public string LoraBase { get; set; } = string.Empty;

/// <summary> /// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// Number of threads (null = autodetect) (n_threads)
/// </summary> /// </summary>
public int BatchSize { get; set; } = 512;
public uint? Threads { get; set; }

/// <summary>
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
/// </summary>
public uint? BatchThreads { get; set; }

/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
public uint BatchSize { get; set; } = 512;


/// <summary> /// <summary>
/// Whether to convert eos to newline during the inference. /// Whether to convert eos to newline during the inference.
@@ -107,5 +112,10 @@ namespace LLama.Web.Common
/// The encoding to use for models /// The encoding to use for models
/// </summary> /// </summary>
public Encoding Encoding { get; set; } = Encoding.UTF8; public Encoding Encoding { get; set; } = Encoding.UTF8;

/// <summary>
/// Load vocab only (no weights)
/// </summary>
public bool VocabOnly { get; set; }
} }
} }

+ 6
- 6
LLama.Web/Services/ConnectionSessionService.cs View File

@@ -3,7 +3,6 @@ using LLama.Web.Common;
using LLama.Web.Models; using LLama.Web.Models;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Drawing;


namespace LLama.Web.Services namespace LLama.Web.Services
{ {
@@ -50,15 +49,16 @@ namespace LLama.Web.Services
if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances)
return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached")); return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached"));


// Create model
var llamaModel = new LLamaContext(modelOption);
// Load weights
// todo: it would be better to have a central service which loads weights and shares them between all contexts that need them!
using var weights = LLamaWeights.LoadFromFile(modelOption);


// Create executor // Create executor
ILLamaExecutor executor = executorType switch ILLamaExecutor executor = executorType switch
{ {
LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel),
LLamaExecutorType.Instruct => new InstructExecutor(llamaModel),
LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel),
LLamaExecutorType.Interactive => new InteractiveExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
LLamaExecutorType.Instruct => new InstructExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
LLamaExecutorType.Stateless => new StatelessExecutor(weights, modelOption),
_ => default _ => default
}; };




+ 8
- 3
LLama.WebAPI/Services/StatefulChatService.cs View File

@@ -16,10 +16,15 @@ public class StatefulChatService : IDisposable


public StatefulChatService(IConfiguration configuration) public StatefulChatService(IConfiguration configuration)
{ {
_context = new LLamaContext(new Common.ModelParams(configuration["ModelPath"])
var @params = new Common.ModelParams(configuration["ModelPath"])
{ {
ContextSize = 512
});
ContextSize = 512,
};

// todo: share weights from a central service
using var weights = LLamaWeights.LoadFromFile(@params);

_context = new LLamaContext(weights, @params);
_session = new ChatSession(new InteractiveExecutor(_context)); _session = new ChatSession(new InteractiveExecutor(_context));
} }




+ 8
- 2
LLama.WebAPI/Services/StatelessChatService.cs View File

@@ -12,10 +12,16 @@ namespace LLama.WebAPI.Services


public StatelessChatService(IConfiguration configuration) public StatelessChatService(IConfiguration configuration)
{ {
_context = new LLamaContext(new ModelParams(configuration["ModelPath"])
var @params = new Common.ModelParams(configuration["ModelPath"])
{ {
ContextSize = 512, ContextSize = 512,
});
};

// todo: share weights from a central service
using var weights = LLamaWeights.LoadFromFile(@params);

_context = new LLamaContext(weights, @params);

// TODO: replace with a stateless executor // TODO: replace with a stateless executor
_session = new ChatSession(new InteractiveExecutor(_context)) _session = new ChatSession(new InteractiveExecutor(_context))
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8))


+ 70
- 0
LLama/Abstractions/IContextParams.cs View File

@@ -0,0 +1,70 @@
using System.Text;

namespace LLama.Abstractions;

/// <summary>
/// The parameters for initializing a LLama context from a model.
/// </summary>
public interface IContextParams
{
/// <summary>
/// Model context size (n_ctx)
/// </summary>
uint ContextSize { get; set; }

/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
uint BatchSize { get; set; }

/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
uint Seed { get; set; }

/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
bool UseFp16Memory { get; set; }

/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
bool Perplexity { get; set; }

/// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
bool EmbeddingMode { get; set; }

/// <summary>
/// RoPE base frequency
/// </summary>
float RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// </summary>
float RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
/// </summary>
bool MulMatQ { get; set; }

/// <summary>
/// The encoding to use for models
/// </summary>
Encoding Encoding { get; set; }

/// <summary>
/// Number of threads (null = autodetect) (n_threads)
/// </summary>
uint? Threads { get; set; }

/// <summary>
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
/// </summary>
uint? BatchThreads { get; set; }
}

+ 1
- 1
LLama/Abstractions/IInferenceParams.cs View File

@@ -36,7 +36,7 @@ namespace LLama.Abstractions
/// </summary> /// </summary>
public int TopK { get; set; } public int TopK { get; set; }


/// <summary>
/// <summary>llama_eval
/// 1.0 = disabled /// 1.0 = disabled
/// </summary> /// </summary>
public float TopP { get; set; } public float TopP { get; set; }


+ 11
- 0
LLama/Abstractions/ILLamaParams.cs View File

@@ -0,0 +1,11 @@
namespace LLama.Abstractions
{
/// <summary>
/// Convenience interface for implementing both type of parameters.
/// </summary>
/// <remarks>Mostly exists for backwards compatibility reasons, when these two were not split.</remarks>
public interface ILLamaParams
: IModelParams, IContextParams
{
}
}

+ 53
- 58
LLama/Abstractions/IModelParams.cs View File

@@ -1,4 +1,6 @@
using System.Text;
using System;
using System.Collections.Generic;
using System.Linq;


namespace LLama.Abstractions namespace LLama.Abstractions
{ {
@@ -7,36 +9,16 @@ namespace LLama.Abstractions
/// </summary> /// </summary>
public interface IModelParams public interface IModelParams
{ {
/// <summary>
/// Model context size (n_ctx)
/// </summary>
int ContextSize { get; set; }

/// <summary> /// <summary>
/// the GPU that is used for scratch and small tensors /// the GPU that is used for scratch and small tensors
/// </summary> /// </summary>
int MainGpu { get; set; } int MainGpu { get; set; }


/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
bool LowVram { get; set; }

/// <summary> /// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary> /// </summary>
int GpuLayerCount { get; set; } int GpuLayerCount { get; set; }


/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
int Seed { get; set; }

/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
bool UseFp16Memory { get; set; }

/// <summary> /// <summary>
/// Use mmap for faster loads (use_mmap) /// Use mmap for faster loads (use_mmap)
/// </summary> /// </summary>
@@ -47,41 +29,15 @@ namespace LLama.Abstractions
/// </summary> /// </summary>
bool UseMemoryLock { get; set; } bool UseMemoryLock { get; set; }


/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
bool Perplexity { get; set; }

/// <summary> /// <summary>
/// Model path (model) /// Model path (model)
/// </summary> /// </summary>
string ModelPath { get; set; } string ModelPath { get; set; }


/// <summary>
/// lora adapter path (lora_adapter)
/// </summary>
string LoraAdapter { get; set; }

/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
string LoraBase { get; set; }

/// <summary> /// <summary>
/// Number of threads (-1 = autodetect) (n_threads) /// Number of threads (-1 = autodetect) (n_threads)
/// </summary> /// </summary>
int Threads { get; set; }

/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
int BatchSize { get; set; }

/// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
bool EmbeddingMode { get; set; }
uint? Threads { get; set; }


/// <summary> /// <summary>
/// how split tensors should be distributed across GPUs /// how split tensors should be distributed across GPUs
@@ -89,23 +45,62 @@ namespace LLama.Abstractions
float[]? TensorSplits { get; set; } float[]? TensorSplits { get; set; }


/// <summary> /// <summary>
/// RoPE base frequency
/// Load vocab only (no weights)
/// </summary> /// </summary>
float RopeFrequencyBase { get; set; }
bool VocabOnly { get; set; }


/// <summary> /// <summary>
/// RoPE frequency scaling factor
/// List of LoRA adapters to apply
/// </summary> /// </summary>
float RopeFrequencyScale { get; set; }
AdapterCollection LoraAdapters { get; }


/// <summary> /// <summary>
/// Use experimental mul_mat_q kernels
/// base model path for the lora adapter (lora_base)
/// </summary> /// </summary>
bool MulMatQ { get; set; }
string LoraBase { get; set; }
}


/// <summary>
/// The encoding to use for models
/// </summary>
Encoding Encoding { get; set; }
/// <summary>
/// A LoRA adapter to apply to a model
/// </summary>
/// <param name="Path">Path to the LoRA file</param>
/// <param name="Scale">Strength of this LoRA</param>
public readonly record struct LoraAdapter(string Path, float Scale);

/// <summary>
/// A list of LoraAdapter objects
/// </summary>
public sealed class AdapterCollection
: List<LoraAdapter>, IEquatable<AdapterCollection>
{
/// <inheritdoc />
public bool Equals(AdapterCollection? other)
{
if (other == null)
return false;

return this.SequenceEqual(other);
}

/// <inheritdoc/>
public override bool Equals(object? obj)
{
return Equals(obj as AdapterCollection);
}

/// <inheritdoc/>
public override int GetHashCode()
{
unchecked
{
var hash = 17;
for (var i = 0; i < Count; i++)
{
hash += this[i].GetHashCode();
hash *= 7823;
}
return hash;
}
}
} }
} }

+ 28
- 17
LLama/Common/ModelParams.cs View File

@@ -10,20 +10,17 @@ namespace LLama.Common
/// The parameters for initializing a LLama model. /// The parameters for initializing a LLama model.
/// </summary> /// </summary>
public record ModelParams public record ModelParams
: IModelParams
: ILLamaParams
{ {
/// <summary> /// <summary>
/// Model context size (n_ctx) /// Model context size (n_ctx)
/// </summary> /// </summary>
public int ContextSize { get; set; } = 512;
public uint ContextSize { get; set; } = 512;
/// <summary> /// <summary>
/// the GPU that is used for scratch and small tensors /// the GPU that is used for scratch and small tensors
/// </summary> /// </summary>
public int MainGpu { get; set; } = 0; public int MainGpu { get; set; } = 0;
/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
public bool LowVram { get; set; } = false;

/// <summary> /// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary> /// </summary>
@@ -31,7 +28,7 @@ namespace LLama.Common
/// <summary> /// <summary>
/// Seed for the random number generator (seed) /// Seed for the random number generator (seed)
/// </summary> /// </summary>
public int Seed { get; set; } = 1686349486;
public uint Seed { get; set; } = 1686349486;
/// <summary> /// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16) /// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary> /// </summary>
@@ -52,22 +49,31 @@ namespace LLama.Common
/// Model path (model) /// Model path (model)
/// </summary> /// </summary>
public string ModelPath { get; set; } public string ModelPath { get; set; }

/// <summary> /// <summary>
/// lora adapter path (lora_adapter)
/// List of LoRAs to apply
/// </summary> /// </summary>
public string LoraAdapter { get; set; } = string.Empty;
public AdapterCollection LoraAdapters { get; set; } = new();

/// <summary> /// <summary>
/// base model path for the lora adapter (lora_base) /// base model path for the lora adapter (lora_base)
/// </summary> /// </summary>
public string LoraBase { get; set; } = string.Empty; public string LoraBase { get; set; } = string.Empty;

/// <summary> /// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// Number of threads (null = autodetect) (n_threads)
/// </summary> /// </summary>
public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1);
public uint? Threads { get; set; }

/// <summary>
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
/// </summary>
public uint? BatchThreads { get; set; }

/// <summary> /// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary> /// </summary>
public int BatchSize { get; set; } = 512;
public uint BatchSize { get; set; } = 512;


/// <summary> /// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true, /// Whether to use embedding mode. (embedding) Note that if this is set to true,
@@ -95,6 +101,11 @@ namespace LLama.Common
/// </summary> /// </summary>
public bool MulMatQ { get; set; } public bool MulMatQ { get; set; }


/// <summary>
/// Load vocab only (no weights)
/// </summary>
public bool VocabOnly { get; set; }

/// <summary> /// <summary>
/// The encoding to use to convert text for the model /// The encoding to use to convert text for the model
/// </summary> /// </summary>
@@ -138,10 +149,10 @@ namespace LLama.Common
/// <param name="mulMatQ">Use experimental mul_mat_q kernels</param> /// <param name="mulMatQ">Use experimental mul_mat_q kernels</param>
/// <param name="encoding">The encoding to use to convert text for the model</param> /// <param name="encoding">The encoding to use to convert text for the model</param>
[Obsolete("Use object initializer to set all optional parameters")] [Obsolete("Use object initializer to set all optional parameters")]
public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20,
int seed = 1337, bool useFp16Memory = true,
public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount = 20,
uint seed = 1337, bool useFp16Memory = true,
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512,
string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512,
bool embeddingMode = false, bool embeddingMode = false,
float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false,
string encoding = "UTF-8") string encoding = "UTF-8")
@@ -154,15 +165,15 @@ namespace LLama.Common
UseMemoryLock = useMemoryLock; UseMemoryLock = useMemoryLock;
Perplexity = perplexity; Perplexity = perplexity;
ModelPath = modelPath; ModelPath = modelPath;
LoraAdapter = loraAdapter;
LoraBase = loraBase; LoraBase = loraBase;
Threads = threads == -1 ? Math.Max(Environment.ProcessorCount / 2, 1) : threads;
Threads = threads < 1 ? null : (uint)threads;
BatchSize = batchSize; BatchSize = batchSize;
EmbeddingMode = embeddingMode; EmbeddingMode = embeddingMode;
RopeFrequencyBase = ropeFrequencyBase; RopeFrequencyBase = ropeFrequencyBase;
RopeFrequencyScale = ropeFrequencyScale; RopeFrequencyScale = ropeFrequencyScale;
MulMatQ = mulMatQ; MulMatQ = mulMatQ;
Encoding = Encoding.GetEncoding(encoding); Encoding = Encoding.GetEncoding(encoding);
LoraAdapters.Add(new LoraAdapter(loraAdapter, 1));
} }
} }




+ 46
- 0
LLama/Extensions/IContextParamsExtensions.cs View File

@@ -0,0 +1,46 @@
using System;
using System.IO;
using LLama.Abstractions;
using LLama.Native;

namespace LLama.Extensions
{
/// <summary>
/// Extention methods to the IContextParams interface
/// </summary>
public static class IContextParamsExtensions
{
/// <summary>
/// Convert the given `IModelParams` into a `LLamaContextParams`
/// </summary>
/// <param name="params"></param>
/// <param name="result"></param>
/// <returns></returns>
/// <exception cref="FileNotFoundException"></exception>
/// <exception cref="ArgumentException"></exception>
public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result)
{
result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;
result.n_batch = @params.BatchSize;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.mul_mat_q = @params.MulMatQ;

result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
}

private static uint Threads(uint? value)
{
if (value is > 0)
return (uint)value;

return (uint)Math.Max(Environment.ProcessorCount / 2, 1);
}
}
}

+ 7
- 18
LLama/Extensions/IModelParamsExtensions.cs View File

@@ -12,41 +12,30 @@ namespace LLama.Extensions
public static class IModelParamsExtensions public static class IModelParamsExtensions
{ {
/// <summary> /// <summary>
/// Convert the given `IModelParams` into a `LLamaContextParams`
/// Convert the given `IModelParams` into a `LLamaModelParams`
/// </summary> /// </summary>
/// <param name="params"></param> /// <param name="params"></param>
/// <param name="result"></param> /// <param name="result"></param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="FileNotFoundException"></exception> /// <exception cref="FileNotFoundException"></exception>
/// <exception cref="ArgumentException"></exception> /// <exception cref="ArgumentException"></exception>
public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result)
public static MemoryHandle ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result)
{ {
if (!File.Exists(@params.ModelPath))
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");

if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");


result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;
result.n_batch = @params.BatchSize;
result = NativeApi.llama_model_default_params();

result.main_gpu = @params.MainGpu; result.main_gpu = @params.MainGpu;
result.n_gpu_layers = @params.GpuLayerCount; result.n_gpu_layers = @params.GpuLayerCount;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
result.use_mmap = @params.UseMemorymap;
result.use_mlock = @params.UseMemoryLock; result.use_mlock = @params.UseMemoryLock;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.low_vram = @params.LowVram;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.mul_mat_q = @params.MulMatQ;
result.use_mmap = @params.UseMemorymap;
result.vocab_only = @params.VocabOnly;


var pin = @params.TensorSplits.AsMemory().Pin(); var pin = @params.TensorSplits.AsMemory().Pin();
unsafe unsafe
{ {
result.tensor_split = (nint)pin.Pointer;
result.tensor_split = (float*)pin.Pointer;
} }


return pin; return pin;


+ 28
- 91
LLama/LLamaContext.cs View File

@@ -42,14 +42,9 @@ namespace LLama
public int EmbeddingSize => _ctx.EmbeddingSize; public int EmbeddingSize => _ctx.EmbeddingSize;


/// <summary> /// <summary>
/// Get the number of tokens in the KV Cache for this context
/// The context params set for this context
/// </summary> /// </summary>
public int KVCacheTokenCount => _ctx.KVCacheTokenCount;

/// <summary>
/// The model params set for this model.
/// </summary>
public IModelParams Params { get; set; }
public IContextParams Params { get; set; }


/// <summary> /// <summary>
/// The native handle, which is used to be passed to the native APIs /// The native handle, which is used to be passed to the native APIs
@@ -62,24 +57,7 @@ namespace LLama
/// </summary> /// </summary>
public Encoding Encoding => _encoding; public Encoding Encoding => _encoding;


/// <summary>
///
/// </summary>
/// <param name="params">Model params.</param>
/// <param name="logger">The logger.</param>
[Obsolete("Use the LLamaWeights.CreateContext instead")]
public LLamaContext(IModelParams @params, ILogger? logger = null)
{
Params = @params;

_logger = logger;
_encoding = @params.Encoding;

_logger?.LogInformation($"[LLamaContext] Initializing LLama model with params: {this.Params}");
_ctx = Utils.InitLLamaContextFromModelParams(Params);
}

internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILogger? logger = null)
internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{ {
Params = @params; Params = @params;


@@ -95,7 +73,7 @@ namespace LLama
/// <param name="params"></param> /// <param name="params"></param>
/// <param name="logger"></param> /// <param name="logger"></param>
/// <exception cref="ObjectDisposedException"></exception> /// <exception cref="ObjectDisposedException"></exception>
public LLamaContext(LLamaWeights model, IModelParams @params, ILogger? logger = null)
public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null)
{ {
if (model.NativeHandle.IsClosed) if (model.NativeHandle.IsClosed)
throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); throw new ObjectDisposedException("Cannot create context, model weights have been disposed");
@@ -105,30 +83,20 @@ namespace LLama
_logger = logger; _logger = logger;
_encoding = @params.Encoding; _encoding = @params.Encoding;


using var pin = @params.ToLlamaContextParams(out var lparams);
@params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
} }


/// <summary>
/// Create a copy of the current state of this context
/// </summary>
/// <returns></returns>
public LLamaContext Clone()
{
using var pin = Params.ToLlamaContextParams(out var lparams);
var clone = _ctx.Clone(lparams);
return new LLamaContext(clone, Params);
}

/// <summary> /// <summary>
/// Tokenize a string. /// Tokenize a string.
/// </summary> /// </summary>
/// <param name="text"></param> /// <param name="text"></param>
/// <param name="addBos">Whether to add a bos to the text.</param> /// <param name="addBos">Whether to add a bos to the text.</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns> /// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true)
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
{ {
return _ctx.Tokenize(text, addBos, _encoding);
return _ctx.Tokenize(text, addBos, special, _encoding);
} }


/// <summary> /// <summary>
@@ -177,19 +145,6 @@ namespace LLama
fileStream.SetLength(writtenBytes); fileStream.SetLength(writtenBytes);
} }


/// <summary>
/// Get the state data as a byte array.
/// </summary>
/// <returns></returns>
[Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")]
public byte[] GetStateData()
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
byte[] stateMemory = new byte[stateSize];
NativeApi.llama_copy_state_data(_ctx, stateMemory);
return stateMemory;
}

/// <summary> /// <summary>
/// Get the state data as an opaque handle /// Get the state data as an opaque handle
/// </summary> /// </summary>
@@ -198,31 +153,28 @@ namespace LLama
{ {
var stateSize = _ctx.GetStateSize(); var stateSize = _ctx.GetStateSize();


unsafe
// Allocate a chunk of memory large enough to hold the entire state
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{ {
// Allocate a chunk of memory large enough to hold the entire state
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy the state data into memory, discover the actual size required
var actualSize = _ctx.GetState(memory, stateSize);
// Copy the state data into memory, discover the actual size required
var actualSize = _ctx.GetState(memory, stateSize);


// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);


// Wrap memory in a "state"
var state = new State(memory);
// Wrap memory in a "state"
var state = new State(memory);


// Set memory to zero, to prevent it being freed in finally block
memory = IntPtr.Zero;
// Set memory to zero, to prevent it being freed in finally block
memory = IntPtr.Zero;


return state;
}
finally
{
if (memory != IntPtr.Zero)
Marshal.FreeHGlobal(memory);
}
return state;
}
finally
{
if (memory != IntPtr.Zero)
Marshal.FreeHGlobal(memory);
} }
} }


@@ -247,21 +199,6 @@ namespace LLama
} }
} }


/// <summary>
/// Load the state from memory.
/// </summary>
/// <param name="stateData"></param>
/// <exception cref="RuntimeError"></exception>
public void LoadState(byte[] stateData)
{
int stateSize = (int)NativeApi.llama_get_state_size(_ctx);
if (stateData.Length > stateSize)
{
throw new RuntimeError("Failed to validate state size.");
}
NativeApi.llama_set_state_data(_ctx, stateData);
}

/// <summary> /// <summary>
/// Load the state from memory. /// Load the state from memory.
/// </summary> /// </summary>
@@ -463,15 +400,15 @@ namespace LLama
public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount) public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
{ {
var total = tokens.Length; var total = tokens.Length;
for(var i = 0; i < total; i += Params.BatchSize)
for(var i = 0; i < total; i += (int)Params.BatchSize)
{ {
var n_eval = total - i; var n_eval = total - i;
if (n_eval > Params.BatchSize) if (n_eval > Params.BatchSize)
{ {
n_eval = Params.BatchSize;
n_eval = (int)Params.BatchSize;
} }


if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads))
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount))
{ {
_logger?.LogError($"[LLamaContext] Failed to eval."); _logger?.LogError($"[LLamaContext] Failed to eval.");
throw new RuntimeError("Failed to eval."); throw new RuntimeError("Failed to eval.");


+ 12
- 9
LLama/LLamaEmbedder.cs View File

@@ -18,19 +18,22 @@ namespace LLama
/// </summary> /// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize; public int EmbeddingSize => _ctx.EmbeddingSize;


/// <summary>
///
/// </summary>
/// <param name="params"></param>
public LLamaEmbedder(IModelParams @params)
public LLamaEmbedder(ILLamaParams allParams)
: this(allParams, allParams)
{ {
@params.EmbeddingMode = true;
using var weights = LLamaWeights.LoadFromFile(@params);
_ctx = weights.CreateContext(@params);
} }


public LLamaEmbedder(LLamaWeights weights, IModelParams @params)
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
{ {
using var weights = LLamaWeights.LoadFromFile(modelParams);

contextParams.EmbeddingMode = true;
_ctx = weights.CreateContext(contextParams);
}

public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
{
@params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params); _ctx = weights.CreateContext(@params);
} }




+ 10
- 22
LLama/LLamaStatelessExecutor.cs View File

@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using LLama.Extensions; using LLama.Extensions;
using LLama.Native;


namespace LLama namespace LLama
{ {
@@ -20,7 +21,7 @@ namespace LLama
: ILLamaExecutor : ILLamaExecutor
{ {
private readonly LLamaWeights _weights; private readonly LLamaWeights _weights;
private readonly IModelParams _params;
private readonly IContextParams _params;


/// <summary> /// <summary>
/// The context used by the executor when running the inference. /// The context used by the executor when running the inference.
@@ -32,7 +33,7 @@ namespace LLama
/// </summary> /// </summary>
/// <param name="weights"></param> /// <param name="weights"></param>
/// <param name="params"></param> /// <param name="params"></param>
public StatelessExecutor(LLamaWeights weights, IModelParams @params)
public StatelessExecutor(LLamaWeights weights, IContextParams @params)
{ {
_weights = weights; _weights = weights;
_params = @params; _params = @params;
@@ -41,20 +42,6 @@ namespace LLama
Context.Dispose(); Context.Dispose();
} }


/// <summary>
/// Create a new stateless executor which will use the model used to create the given context
/// </summary>
/// <param name="context"></param>
[Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
public StatelessExecutor(LLamaContext context)
{
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding);
_params = context.Params;

Context = _weights.CreateContext(_params);
Context.Dispose();
}

/// <inheritdoc /> /// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
@@ -114,15 +101,16 @@ namespace LLama
break; break;


// when run out of context // when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433
if (n_past + tokens.Count > Context.ContextSize)
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
{ {
var n_left = n_past - inferenceParams.TokensKeep;
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;


n_past = Math.Max(1, inferenceParams.TokensKeep);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);


tokens.Clear();
tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));
n_past -= n_discard;
} }


// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)


+ 23
- 13
LLama/LLamaWeights.cs View File

@@ -1,5 +1,4 @@
using System; using System;
using System.Text;
using LLama.Abstractions; using LLama.Abstractions;
using LLama.Extensions; using LLama.Extensions;
using LLama.Native; using LLama.Native;
@@ -20,11 +19,6 @@ namespace LLama
/// <remarks>Be careful how you use this!</remarks> /// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle => _weights; public SafeLlamaModelHandle NativeHandle => _weights;


/// <summary>
/// Encoding to use to convert text into bytes for the model
/// </summary>
public Encoding Encoding { get; }

/// <summary> /// <summary>
/// Total number of tokens in vocabulary of this model /// Total number of tokens in vocabulary of this model
/// </summary> /// </summary>
@@ -35,15 +29,24 @@ namespace LLama
/// </summary> /// </summary>
public int ContextSize => NativeHandle.ContextSize; public int ContextSize => NativeHandle.ContextSize;


/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes => NativeHandle.SizeInBytes;

/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount => NativeHandle.ParameterCount;

/// <summary> /// <summary>
/// Dimension of embedding vectors /// Dimension of embedding vectors
/// </summary> /// </summary>
public int EmbeddingSize => NativeHandle.EmbeddingSize; public int EmbeddingSize => NativeHandle.EmbeddingSize;


internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
internal LLamaWeights(SafeLlamaModelHandle weights)
{ {
_weights = weights; _weights = weights;
Encoding = encoding;
} }


/// <summary> /// <summary>
@@ -53,13 +56,20 @@ namespace LLama
/// <returns></returns> /// <returns></returns>
public static LLamaWeights LoadFromFile(IModelParams @params) public static LLamaWeights LoadFromFile(IModelParams @params)
{ {
using var pin = @params.ToLlamaContextParams(out var lparams);
using var pin = @params.ToLlamaModelParams(out var lparams);
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);


if (!string.IsNullOrEmpty(@params.LoraAdapter))
weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
foreach (var adapter in @params.LoraAdapters)
{
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;

weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads);
}


return new LLamaWeights(weights, @params.Encoding);
return new LLamaWeights(weights);
} }


/// <inheritdoc /> /// <inheritdoc />
@@ -73,7 +83,7 @@ namespace LLama
/// </summary> /// </summary>
/// <param name="params"></param> /// <param name="params"></param>
/// <returns></returns> /// <returns></returns>
public LLamaContext CreateContext(IModelParams @params)
public LLamaContext CreateContext(IContextParams @params)
{ {
return new LLamaContext(this, @params); return new LLamaContext(this, @params);
} }


+ 106
- 0
LLama/Native/LLamaBatchSafeHandle.cs View File

@@ -0,0 +1,106 @@
using System;

namespace LLama.Native;

using llama_token = Int32;

public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
public LLamaNativeBatch Batch { get; private set; }

/// <summary>
/// the token ids of the input (used when embd is NULL)
/// </summary>
public Span<llama_token> Token
{
get
{
unsafe
{
if (_embd != 0)
return new Span<int>(null, 0);
else
return new Span<int>(Batch.token, Batch.n_tokens);
}
}
}

/// <summary>
/// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
/// </summary>
public Span<llama_token> Embed
{
get
{
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token

if (_embd != 0)
return new Span<llama_token>(Batch.embd, Batch.n_tokens * _embd);
else
return new Span<llama_token>(null, 0);
}
}
}

/// <summary>
/// the positions of the respective token in the sequence
/// </summary>
public Span<LLamaPos> Pos
{
get
{
unsafe
{
return new Span<LLamaPos>(Batch.pos, Batch.n_tokens);
}
}
}

/// <summary>
/// the sequence to which the respective token belongs
/// </summary>
public Span<LLamaSeqId> Sequence_ID
{
get
{
unsafe
{
return new Span<LLamaSeqId>(Batch.seq_id, Batch.n_tokens);
}
}
}

/// <summary>
/// if zero, the logits for the respective token will not be output
/// </summary>
public Span<byte> Logits
{
get
{
unsafe
{
return new Span<byte>(Batch.logits, Batch.n_tokens);
}
}
}

public LLamaBatchSafeHandle(int n_tokens, int embd)
: base((nint)1)
{
_embd = embd;
Batch = NativeApi.llama_batch_init(n_tokens, embd);
}

protected override bool ReleaseHandle()
{
NativeApi.llama_batch_free(Batch);
Batch = default;
SetHandle(IntPtr.Zero);
return true;
}
}

+ 7
- 62
LLama/Native/LLamaContextParams.cs View File

@@ -19,32 +19,27 @@ namespace LLama.Native
/// <summary> /// <summary>
/// RNG seed, -1 for random /// RNG seed, -1 for random
/// </summary> /// </summary>
public int seed;
public uint seed;


/// <summary> /// <summary>
/// text context /// text context
/// </summary> /// </summary>
public int n_ctx;
public uint n_ctx;


/// <summary> /// <summary>
/// prompt processing batch size /// prompt processing batch size
/// </summary> /// </summary>
public int n_batch;
public uint n_batch;


/// <summary> /// <summary>
/// number of layers to store in VRAM
/// number of threads to use for generation
/// </summary> /// </summary>
public int n_gpu_layers;
public uint n_threads;


/// <summary> /// <summary>
/// the GPU that is used for scratch and small tensors
/// number of threads to use for batch processing
/// </summary> /// </summary>
public int main_gpu;

/// <summary>
/// how to split layers across multiple GPUs
/// </summary>
public nint tensor_split;
public uint n_threads_batch;


/// <summary> /// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054 /// ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -58,26 +53,6 @@ namespace LLama.Native
/// </summary> /// </summary>
public float rope_freq_scale; public float rope_freq_scale;


/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
public IntPtr progress_callback;

/// <summary>
/// context pointer passed to the progress callback
/// </summary>
public IntPtr progress_callback_user_data;

/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
public bool low_vram
{
readonly get => Convert.ToBoolean(_low_vram);
set => _low_vram = Convert.ToSByte(value);
}
private sbyte _low_vram;

/// <summary> /// <summary>
/// if true, use experimental mul_mat_q kernels /// if true, use experimental mul_mat_q kernels
/// </summary> /// </summary>
@@ -108,36 +83,6 @@ namespace LLama.Native
} }
private sbyte _logits_all; private sbyte _logits_all;


/// <summary>
/// only load the vocabulary, no weights
/// </summary>
public bool vocab_only
{
readonly get => Convert.ToBoolean(_vocab_only);
set => _vocab_only = Convert.ToSByte(value);
}
private sbyte _vocab_only;

/// <summary>
/// use mmap if possible
/// </summary>
public bool use_mmap
{
readonly get => Convert.ToBoolean(_use_mmap);
set => _use_mmap = Convert.ToSByte(value);
}
private sbyte _use_mmap;

/// <summary>
/// force system to keep model in RAM
/// </summary>
public bool use_mlock
{
readonly get => Convert.ToBoolean(_use_mlock);
set => _use_mlock = Convert.ToSByte(value);
}
private sbyte _use_mlock;

/// <summary> /// <summary>
/// embedding mode only /// embedding mode only
/// </summary> /// </summary>


+ 67
- 0
LLama/Native/LLamaModelParams.cs View File

@@ -0,0 +1,67 @@
using System;
using System.Runtime.InteropServices;

namespace LLama.Native
{
/// <summary>
/// A C# representation of the llama.cpp `llama_model_params` struct
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaModelParams
{
/// <summary>
/// // number of layers to store in VRAM
/// </summary>
public int n_gpu_layers;

/// <summary>
/// // the GPU that is used for scratch and small tensors
/// </summary>
public int main_gpu;

/// <summary>
/// how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
/// </summary>
public float* tensor_split;

/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
LlamaProgressCallback progress_callback;

/// <summary>
/// context pointer passed to the progress callback
/// </summary>
void* progress_callback_user_data;

/// <summary>
/// only load the vocabulary, no weights
/// </summary>
public bool vocab_only
{
readonly get => Convert.ToBoolean(_vocab_only);
set => _vocab_only = Convert.ToSByte(value);
}
private sbyte _vocab_only;

/// <summary>
/// use mmap if possible
/// </summary>
public bool use_mmap
{
readonly get => Convert.ToBoolean(_use_mmap);
set => _use_mmap = Convert.ToSByte(value);
}
private sbyte _use_mmap;

/// <summary>
/// force system to keep model in RAM
/// </summary>
public bool use_mlock
{
readonly get => Convert.ToBoolean(_use_mlock);
set => _use_mlock = Convert.ToSByte(value);
}
private sbyte _use_mlock;
}
}

+ 10
- 0
LLama/Native/LLamaModelQuantizeParams.cs View File

@@ -36,5 +36,15 @@ namespace LLama.Native
set => _quantize_output_tensor = Convert.ToSByte(value); set => _quantize_output_tensor = Convert.ToSByte(value);
} }
private sbyte _quantize_output_tensor; private sbyte _quantize_output_tensor;

/// <summary>
/// only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
/// </summary>
public bool only_copy
{
get => Convert.ToBoolean(_only_copy);
set => _only_copy = Convert.ToSByte(value);
}
private sbyte _only_copy;
} }
} }

+ 45
- 0
LLama/Native/LLamaNativeBatch.cs View File

@@ -0,0 +1,45 @@
using System;
using System.Runtime.InteropServices;

namespace LLama.Native;

using llama_token = Int32;

/// <summary>
/// Input data for llama_decode
/// A llama_batch object can contain input about one or many sequences
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public readonly unsafe struct LLamaNativeBatch
{
/// <summary>
/// The number of items pointed at by pos, seq_id and logits.
/// </summary>
public readonly int n_tokens;

/// <summary>
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
/// </summary>
public readonly llama_token* token;

/// <summary>
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created
/// </summary>
public readonly float* embd;

/// <summary>
/// the positions of the respective token in the sequence
/// </summary>
public readonly LLamaPos* pos;

/// <summary>
/// the sequence to which the respective token belongs
/// </summary>
public readonly LLamaSeqId* seq_id;

/// <summary>
/// if zero, the logits for the respective token will not be output
/// </summary>
public readonly byte* logits;
}

+ 15
- 0
LLama/Native/LLamaPos.cs View File

@@ -0,0 +1,15 @@
namespace LLama.Native;

public record struct LLamaPos
{
public int Value;

public LLamaPos(int value)
{
Value = value;
}

public static explicit operator int(LLamaPos pos) => pos.Value;

public static implicit operator LLamaPos(int value) => new(value);
}

+ 15
- 0
LLama/Native/LLamaSeqId.cs View File

@@ -0,0 +1,15 @@
namespace LLama.Native;

public record struct LLamaSeqId
{
public int Value;

public LLamaSeqId(int value)
{
Value = value;
}

public static explicit operator int(LLamaSeqId pos) => pos.Value;

public static explicit operator LLamaSeqId(int value) => new(value);
}

+ 139
- 115
LLama/Native/NativeApi.cs View File

@@ -2,7 +2,6 @@
using System.Buffers; using System.Buffers;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using LLama.Common;
using LLama.Exceptions; using LLama.Exceptions;


#pragma warning disable IDE1006 // Naming Styles #pragma warning disable IDE1006 // Naming Styles
@@ -110,6 +109,13 @@ namespace LLama.Native
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call(); public static extern bool llama_empty_call();


/// <summary>
/// Create a LLamaModelParams with default values
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaModelParams llama_model_default_params();

/// <summary> /// <summary>
/// Create a LLamaContextParams with default values /// Create a LLamaContextParams with default values
/// </summary> /// </summary>
@@ -138,18 +144,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mlock_supported(); public static extern bool llama_mlock_supported();


/// <summary>
/// Export a static computation graph for context of 511 and batch size of 1
/// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
/// parameters here to keep things simple
/// IMPORTANT: do not use for anything else other than debugging and testing!
/// </summary>
/// <param name="ctx"></param>
/// <param name="fname"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname);

/// <summary> /// <summary>
/// Various functions for loading a ggml llama model. /// Various functions for loading a ggml llama model.
/// Allocate (almost) all memory needed for the model. /// Allocate (almost) all memory needed for the model.
@@ -159,7 +153,7 @@ namespace LLama.Native
/// <param name="params"></param> /// <param name="params"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params);
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params);


/// <summary> /// <summary>
/// Create a new llama_context with the given model. /// Create a new llama_context with the given model.
@@ -192,7 +186,7 @@ namespace LLama.Native
/// <param name="model"></param> /// <param name="model"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free_model(IntPtr model); public static extern void llama_free_model(IntPtr model);
/// <summary> /// <summary>
/// Apply a LoRA adapter to a loaded model /// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for /// path_base_model is the path to a higher quality model to use as a base for
@@ -202,19 +196,12 @@ namespace LLama.Native
/// </summary> /// </summary>
/// <param name="model_ptr"></param> /// <param name="model_ptr"></param>
/// <param name="path_lora"></param> /// <param name="path_lora"></param>
/// <param name="scale"></param>
/// <param name="path_base_model"></param> /// <param name="path_base_model"></param>
/// <param name="n_threads"></param> /// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns> /// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads);

/// <summary>
/// Returns the number of tokens in the KV cache
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);


/// <summary> /// <summary>
/// Sets the current rng seed. /// Sets the current rng seed.
@@ -222,7 +209,7 @@ namespace LLama.Native
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="seed"></param> /// <param name="seed"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, int seed);
public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed);


/// <summary> /// <summary>
/// Returns the maximum size in bytes of the state (rng, logits, embedding /// Returns the maximum size in bytes of the state (rng, logits, embedding
@@ -243,21 +230,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest);


/// <summary>
/// Copies the state to the specified destination address.
/// Destination needs to have allocated enough memory (see llama_get_state_size)
/// </summary>
/// <param name="ctx"></param>
/// <param name="dest"></param>
/// <returns>the number of bytes copied</returns>
public static ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte[] dest)
{
fixed (byte* dstPtr = &dest[0])
{
return llama_copy_state_data(ctx, dstPtr);
}
}

/// <summary> /// <summary>
/// Set the state reading from the specified address /// Set the state reading from the specified address
/// </summary> /// </summary>
@@ -267,20 +239,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src);


/// <summary>
/// Set the state reading from the specified address
/// </summary>
/// <param name="ctx"></param>
/// <param name="src"></param>
/// <returns>the number of bytes read</returns>
public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src)
{
fixed (byte* srcPtr = &src[0])
{
return llama_set_state_data(ctx, srcPtr);
}
}

/// <summary> /// <summary>
/// Load session file /// Load session file
/// </summary> /// </summary>
@@ -313,24 +271,9 @@ namespace LLama.Native
/// <param name="tokens"></param> /// <param name="tokens"></param>
/// <param name="n_tokens"></param> /// <param name="n_tokens"></param>
/// <param name="n_past"></param> /// <param name="n_past"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns> /// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads);

/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// tokens + n_tokens is the provided batch of new tokens to process
/// n_past is the number of tokens to use from previous eval calls
/// </summary>
/// <param name="ctx"></param>
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="n_past"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads);
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past);


/// <summary> /// <summary>
/// Convert the provided text into tokens. /// Convert the provided text into tokens.
@@ -341,10 +284,11 @@ namespace LLama.Native
/// <param name="tokens"></param> /// <param name="tokens"></param>
/// <param name="n_max_tokens"></param> /// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param> /// <param name="add_bos"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens. /// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned /// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns> /// </returns>
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos)
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special)
{ {
// Calculate number of bytes in text and borrow an array that large (+1 for nul byte) // Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
var byteCount = encoding.GetByteCount(text); var byteCount = encoding.GetByteCount(text);
@@ -364,7 +308,7 @@ namespace LLama.Native
// Do the actual tokenization // Do the actual tokenization
fixed (byte* arrayPtr = array) fixed (byte* arrayPtr = array)
fixed (llama_token* tokensPtr = tokens) fixed (llama_token* tokensPtr = tokens)
return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos);
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
} }
finally finally
{ {
@@ -372,28 +316,6 @@ namespace LLama.Native
} }
} }


/// <summary>
/// Convert the provided text into tokens.
/// </summary>
/// <param name="ctx"></param>
/// <param name="text"></param>
/// <param name="tokens"></param>
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns>
[DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos);

/// <summary>
/// Get the number of tokens in the model vocabulary for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_vocab(SafeLLamaContextHandle ctx);

/// <summary> /// <summary>
/// Get the size of the context window for the model for this context /// Get the size of the context window for the model for this context
/// </summary> /// </summary>
@@ -402,14 +324,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); public static extern int llama_n_ctx(SafeLLamaContextHandle ctx);


/// <summary>
/// Get the dimension of embedding vectors from the model for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_embd(SafeLLamaContextHandle ctx);

/// <summary> /// <summary>
/// Token logits obtained from the last call to llama_eval() /// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row /// The logits for the last token are stored in the last row
@@ -423,22 +337,21 @@ namespace LLama.Native
public static extern float* llama_get_logits(SafeLLamaContextHandle ctx); public static extern float* llama_get_logits(SafeLLamaContextHandle ctx);


/// <summary> /// <summary>
/// Get the embeddings for the input
/// shape: [n_embd] (1-dimensional)
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx);
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx);


/// <summary> /// <summary>
/// Token Id -> String. Uses the vocabulary in the provided context
/// Get the embeddings for the input
/// shape: [n_embd] (1-dimensional)
/// </summary> /// </summary>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <param name="token"></param>
/// <returns>Pointer to a string.</returns>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token);
public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx);


/// <summary> /// <summary>
/// Get the "Beginning of sentence" token /// Get the "Beginning of sentence" token
@@ -488,7 +401,7 @@ namespace LLama.Native
/// <param name="model"></param> /// <param name="model"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_n_vocab(SafeLlamaModelHandle model);
public static extern int llama_n_vocab(SafeLlamaModelHandle model);


/// <summary> /// <summary>
/// Get the size of the context window for the model /// Get the size of the context window for the model
@@ -496,7 +409,7 @@ namespace LLama.Native
/// <param name="model"></param> /// <param name="model"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_n_ctx(SafeLlamaModelHandle model);
public static extern int llama_n_ctx_train(SafeLlamaModelHandle model);


/// <summary> /// <summary>
/// Get the dimension of embedding vectors from this model /// Get the dimension of embedding vectors from this model
@@ -504,7 +417,23 @@ namespace LLama.Native
/// <param name="model"></param> /// <param name="model"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_n_embd(SafeLlamaModelHandle model);
public static extern int llama_n_embd(SafeLlamaModelHandle model);

/// <summary>
/// Get the size of the model in bytes
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_model_size(SafeLlamaModelHandle model);

/// <summary>
/// Get the number of parameters in this model
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_model_n_params(SafeLlamaModelHandle model);


/// <summary> /// <summary>
/// Convert a single token into text /// Convert a single token into text
@@ -515,21 +444,23 @@ namespace LLama.Native
/// <param name="length">size of the buffer</param> /// <param name="length">size of the buffer</param>
/// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns> /// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);


/// <summary> /// <summary>
/// Convert text into tokens /// Convert text into tokens
/// </summary> /// </summary>
/// <param name="model"></param> /// <param name="model"></param>
/// <param name="text"></param> /// <param name="text"></param>
/// <param name="text_len"></param>
/// <param name="tokens"></param> /// <param name="tokens"></param>
/// <param name="n_max_tokens"></param> /// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param> /// <param name="add_bos"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens. /// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned /// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns> /// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos);
public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);


/// <summary> /// <summary>
/// Register a callback to receive llama log messages /// Register a callback to receive llama log messages
@@ -537,5 +468,98 @@ namespace LLama.Native
/// <param name="logCallback"></param> /// <param name="logCallback"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_log_set(LLamaLogCallback logCallback); public static extern void llama_log_set(LLamaLogCallback logCallback);
}

/// <summary>
/// Remove all tokens data of cells in [c0, c1)
/// </summary>
/// <param name="ctx"></param>
/// <param name="c0"></param>
/// <param name="c1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_tokens_rm(SafeLLamaContextHandle ctx, int c0, int c1);

/// <summary>
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);

/// <summary>
/// Copy all tokens that belong to the specified sequence to another sequence
/// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
/// </summary>
/// <param name="ctx"></param>
/// <param name="src"></param>
/// <param name="dest"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_cp(SafeLLamaContextHandle ctx, LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1);

/// <summary>
/// Removes all tokens that do not belong to the specified sequence
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_keep(SafeLLamaContextHandle ctx, LLamaSeqId seq);

/// <summary>
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
/// If the KV cache is RoPEd, the KV data is updated accordingly
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="delta"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta);

/// <summary>
/// Allocates a batch of tokens on the heap
/// The batch has to be freed with llama_batch_free()
/// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
/// The rest of the llama_batch members are allocated with size n_tokens
/// All members are left uninitialized
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd);

/// <summary>
/// Frees a batch of tokens allocated with llama_batch_init()
/// </summary>
/// <param name="batch"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_batch_free(LLamaNativeBatch batch);

/// <summary>
/// </summary>
/// <param name="ctx"></param>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch);

/// <summary>
/// Set the number of threads used for decoding
/// </summary>
/// <param name="ctx"></param>
/// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
}
} }

+ 31
- 61
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Buffers; using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using LLama.Exceptions; using LLama.Exceptions;
@@ -21,26 +22,13 @@ namespace LLama.Native
/// <summary> /// <summary>
/// Total number of tokens in the context /// Total number of tokens in the context
/// </summary> /// </summary>
public int ContextSize => ThrowIfDisposed().ContextSize;
public int ContextSize => NativeApi.llama_n_ctx(this);


/// <summary> /// <summary>
/// Dimension of embedding vectors /// Dimension of embedding vectors
/// </summary> /// </summary>
public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize;


/// <summary>
/// Get the number of tokens in the KV Cache for this context
/// </summary>
public int KVCacheTokenCount
{
get
{
if (IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed");
return NativeApi.llama_get_kv_cache_token_count(this);
}
}

/// <summary> /// <summary>
/// Get the model which this context is using /// Get the model which this context is using
/// </summary> /// </summary>
@@ -64,17 +52,20 @@ namespace LLama.Native
_model.DangerousAddRef(ref success); _model.DangerousAddRef(ref success);
if (!success) if (!success)
throw new RuntimeError("Failed to increment model refcount"); throw new RuntimeError("Failed to increment model refcount");

} }


/// <inheritdoc /> /// <inheritdoc />
protected override bool ReleaseHandle() protected override bool ReleaseHandle()
{ {
NativeApi.llama_free(DangerousGetHandle());
SetHandle(IntPtr.Zero);

// Decrement refcount on model // Decrement refcount on model
_model?.DangerousRelease(); _model?.DangerousRelease();
_model = null!; _model = null!;


NativeApi.llama_free(handle);
SetHandle(IntPtr.Zero);
return true; return true;
} }


@@ -103,46 +94,38 @@ namespace LLama.Native


return new(ctx_ptr, model); return new(ctx_ptr, model);
} }
#endregion


/// <summary> /// <summary>
/// Create a new llama context with a clone of the current llama context state
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary> /// </summary>
/// <param name="lparams"></param>
/// <returns></returns> /// <returns></returns>
public SafeLLamaContextHandle Clone(LLamaContextParams lparams)
public Span<float> GetLogits()
{ {
// Allocate space to read the state of the current context
var stateSize = GetStateSize();
var stateMemory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy state from this context into memory
GetState(stateMemory, stateSize);

// Create a new context
var newCtx = Create(ModelHandle, lparams);

// Copy state into new context
newCtx.SetState(stateMemory);
var model = ThrowIfDisposed();


return newCtx;
}
finally
unsafe
{ {
Marshal.FreeHGlobal(stateMemory);
var logits = NativeApi.llama_get_logits(this);
return new Span<float>(logits, model.VocabCount);
} }
} }
#endregion


#region tokens
/// <summary> /// <summary>
/// Convert the given text into tokens /// Convert the given text into tokens
/// </summary> /// </summary>
/// <param name="text">The text to tokenize</param> /// <param name="text">The text to tokenize</param>
/// <param name="add_bos">Whether the "BOS" token should be added</param> /// <param name="add_bos">Whether the "BOS" token should be added</param>
/// <param name="encoding">Encoding to use for the text</param> /// <param name="encoding">Encoding to use for the text</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{ {
ThrowIfDisposed(); ThrowIfDisposed();


@@ -158,7 +141,7 @@ namespace LLama.Native
try try
{ {
// Do the actual conversion // Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special);
if (n < 0) if (n < 0)
{ {
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
@@ -177,25 +160,6 @@ namespace LLama.Native
} }
} }


/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <returns></returns>
public Span<float> GetLogits()
{
var model = ThrowIfDisposed();

unsafe
{
var logits = NativeApi.llama_get_logits(this);
return new Span<float>(logits, model.VocabCount);
}
}

/// <summary> /// <summary>
/// Convert a token into a string /// Convert a token into a string
/// </summary> /// </summary>
@@ -228,25 +192,31 @@ namespace LLama.Native
{ {
return ThrowIfDisposed().TokenToSpan(token, dest); return ThrowIfDisposed().TokenToSpan(token, dest);
} }
#endregion


/// <summary> /// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token. /// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary> /// </summary>
/// <param name="tokens">The provided batch of new tokens to process</param> /// <param name="tokens">The provided batch of new tokens to process</param>
/// <param name="n_past">the number of tokens to use from previous eval calls</param> /// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <param name="n_threads"></param>
/// <returns>Returns true on success</returns> /// <returns>Returns true on success</returns>
public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads)
public bool Eval(ReadOnlySpan<int> tokens, int n_past)
{ {
unsafe unsafe
{ {
fixed (int* pinned = tokens) fixed (int* pinned = tokens)
{ {
return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0;
var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past);
return ret == 0;
} }
} }
} }


public int Decode(LLamaBatchSafeHandle batch)
{
return NativeApi.llama_decode(this, batch.Batch);
}

#region state #region state
/// <summary> /// <summary>
/// Get the size of the state, when saved as bytes /// Get the size of the state, when saved as bytes


+ 31
- 15
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -29,18 +29,30 @@ namespace LLama.Native
/// </summary> /// </summary>
public int EmbeddingSize { get; } public int EmbeddingSize { get; }


/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes { get; }

/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount { get; }

internal SafeLlamaModelHandle(IntPtr handle) internal SafeLlamaModelHandle(IntPtr handle)
: base(handle) : base(handle)
{ {
VocabCount = NativeApi.llama_model_n_vocab(this);
ContextSize = NativeApi.llama_model_n_ctx(this);
EmbeddingSize = NativeApi.llama_model_n_embd(this);
VocabCount = NativeApi.llama_n_vocab(this);
ContextSize = NativeApi.llama_n_ctx_train(this);
EmbeddingSize = NativeApi.llama_n_embd(this);
SizeInBytes = NativeApi.llama_model_size(this);
ParameterCount = NativeApi.llama_model_n_params(this);
} }


/// <inheritdoc /> /// <inheritdoc />
protected override bool ReleaseHandle() protected override bool ReleaseHandle()
{ {
NativeApi.llama_free_model(handle);
NativeApi.llama_free_model(DangerousGetHandle());
SetHandle(IntPtr.Zero); SetHandle(IntPtr.Zero);
return true; return true;
} }
@@ -52,7 +64,7 @@ namespace LLama.Native
/// <param name="lparams"></param> /// <param name="lparams"></param>
/// <returns></returns> /// <returns></returns>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{ {
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == IntPtr.Zero) if (model_ptr == IntPtr.Zero)
@@ -62,21 +74,24 @@ namespace LLama.Native
} }


#region LoRA #region LoRA

/// <summary> /// <summary>
/// Apply a LoRA adapter to a loaded model /// Apply a LoRA adapter to a loaded model
/// </summary> /// </summary>
/// <param name="lora"></param> /// <param name="lora"></param>
/// <param name="scale"></param>
/// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the /// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
/// adapter. Can be NULL to use the current loaded model.</param> /// adapter. Can be NULL to use the current loaded model.</param>
/// <param name="threads"></param> /// <param name="threads"></param>
/// <exception cref="RuntimeError"></exception> /// <exception cref="RuntimeError"></exception>
public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null)
{ {
var err = NativeApi.llama_model_apply_lora_from_file( var err = NativeApi.llama_model_apply_lora_from_file(
this, this,
lora, lora,
scale,
string.IsNullOrEmpty(modelBase) ? null : modelBase, string.IsNullOrEmpty(modelBase) ? null : modelBase,
threads
(int?)threads ?? -1
); );


if (err != 0) if (err != 0)
@@ -97,7 +112,7 @@ namespace LLama.Native
{ {
fixed (byte* destPtr = dest) fixed (byte* destPtr = dest)
{ {
var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length);
var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length);
return Math.Abs(length); return Math.Abs(length);
} }
} }
@@ -113,7 +128,7 @@ namespace LLama.Native
{ {
unsafe unsafe
{ {
var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
if (length == 0) if (length == 0)
return ""; return "";


@@ -121,7 +136,7 @@ namespace LLama.Native


fixed (byte* bytePtr = bytes) fixed (byte* bytePtr = bytes)
{ {
var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length); Debug.Assert(written == bytes.Length);


return encoding.GetString(bytePtr, bytes.Length); return encoding.GetString(bytePtr, bytes.Length);
@@ -139,7 +154,7 @@ namespace LLama.Native
{ {
unsafe unsafe
{ {
var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
if (length == 0) if (length == 0)
return; return;


@@ -147,7 +162,7 @@ namespace LLama.Native
fixed (byte* bytePtr = bytes) fixed (byte* bytePtr = bytes)
{ {
// Decode into bytes // Decode into bytes
var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length); Debug.Assert(written == bytes.Length);


// Decode into chars // Decode into chars
@@ -256,8 +271,9 @@ namespace LLama.Native
/// <param name="text"></param> /// <param name="text"></param>
/// <param name="add_bos"></param> /// <param name="add_bos"></param>
/// <param name="encoding"></param> /// <param name="encoding"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns> /// <returns></returns>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{ {
// Convert string to bytes, adding one extra byte to the end (null terminator) // Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text); var bytesCount = encoding.GetByteCount(text);
@@ -276,13 +292,13 @@ namespace LLama.Native
fixed (byte* bytesPtr = &bytes[0]) fixed (byte* bytesPtr = &bytes[0])
{ {
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special);


// Tokenize again, this time outputting into an array of exactly the right size // Tokenize again, this time outputting into an array of exactly the right size
var tokens = new int[count]; var tokens = new int[count];
fixed (int* tokensPtr = &tokens[0]) fixed (int* tokensPtr = &tokens[0])
{ {
NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens; return tokens;
} }
} }


+ 0
- 108
LLama/Utils.cs View File

@@ -1,108 +0,0 @@
using LLama.Abstractions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Extensions;

namespace LLama
{
using llama_token = Int32;

/// <summary>
/// Assorted llama utilities
/// </summary>
public static class Utils
{
[Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
using var weights = LLamaWeights.LoadFromFile(@params);

using (@params.ToLlamaContextParams(out var lparams))
return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams);
}

[Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
return ctx.Tokenize(text, add_bos, encoding);
}

[Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
if (length != ctx.VocabCount)
throw new ArgumentException("length must be the VocabSize");

return ctx.GetLogits();
}

[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
}

[Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
return ctx.TokenToString(token, encoding);
}

[Obsolete("No longer used internally by LlamaSharp")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static string PtrToString(IntPtr ptr, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
#if NET6_0_OR_GREATER
// ReSharper disable once PossibleUnintendedReferenceComparison
if(encoding == Encoding.UTF8)
{
return Marshal.PtrToStringUTF8(ptr)!;
}
// ReSharper disable once PossibleUnintendedReferenceComparison
else if(encoding == Encoding.Unicode)
{
return Marshal.PtrToStringUni(ptr)!;
}
else
{
return Marshal.PtrToStringAuto(ptr)!;
}
#else
unsafe
{
byte* tp = (byte*)ptr.ToPointer();
List<byte> bytes = new();
while (true)
{
byte c = *tp++;
if (c == '\0')
{
break;
}
else
{
bytes.Add(c);
}
}
return encoding.GetString(bytes.ToArray());
}
#endif
}
}
}

+ 575
- 212
LLama/runtimes/ggml-metal.metal
File diff suppressed because it is too large
View File


BIN
LLama/runtimes/libllama-cuda11.dll View File


BIN
LLama/runtimes/libllama-cuda11.so View File


BIN
LLama/runtimes/libllama-cuda12.dll View File


BIN
LLama/runtimes/libllama-cuda12.so View File


BIN
LLama/runtimes/libllama.dll View File


BIN
LLama/runtimes/libllama.dylib View File


BIN
LLama/runtimes/libllama.so View File


Loading…
Cancel
Save