Browse Source

- Split parameters into two interfaces

- params contains a list of loras, instead of just one
tags/v0.6.0
Martin Evans 2 years ago
parent
commit
669ae47ef7
15 changed files with 178 additions and 257 deletions
  1. +2
    -2
      LLama.Examples/NewVersion/LoadAndSaveSession.cs
  2. +0
    -1
      LLama.Unittest/BasicTest.cs
  3. +5
    -2
      LLama.Unittest/LLamaContextTests.cs
  4. +6
    -3
      LLama.Unittest/ModelsParamsTests.cs
  5. +4
    -9
      LLama.Web/Common/ModelOptions.cs
  6. +60
    -0
      LLama/Abstractions/IContextParams.cs
  7. +11
    -0
      LLama/Abstractions/ILLamaParams.cs
  8. +51
    -63
      LLama/Abstractions/IModelParams.cs
  9. +8
    -10
      LLama/Common/ModelParams.cs
  10. +1
    -1
      LLama/Extensions/IModelParamsExtensions.cs
  11. +4
    -21
      LLama/LLamaContext.cs
  12. +12
    -9
      LLama/LLamaEmbedder.cs
  13. +2
    -16
      LLama/LLamaStatelessExecutor.cs
  14. +12
    -12
      LLama/LLamaWeights.cs
  15. +0
    -108
      LLama/Utils.cs

+ 2
- 2
LLama.Examples/NewVersion/LoadAndSaveSession.cs View File

@@ -8,7 +8,7 @@ namespace LLama.Examples.NewVersion
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim();

var parameters = new ModelParams(modelPath)
{
@@ -50,7 +50,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;

ex.Context.Dispose();
ex = new(new LLamaContext(parameters));
ex = new(new LLamaContext(model, parameters));
session = new ChatSession(ex);
session.LoadSession(statePath);



+ 0
- 1
LLama.Unittest/BasicTest.cs View File

@@ -29,7 +29,6 @@ namespace LLama.Unittest
Assert.Equal(32000, _model.VocabCount);
Assert.Equal(4096, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize);
Assert.Equal(Encoding.UTF8, _model.Encoding);
}
}
}

+ 5
- 2
LLama.Unittest/LLamaContextTests.cs View File

@@ -10,7 +10,10 @@ namespace LLama.Unittest

public LLamaContextTests()
{
var @params = new ModelParams(Constants.ModelPath);
var @params = new ModelParams(Constants.ModelPath)
{
ContextSize = 768,
};
_weights = LLamaWeights.LoadFromFile(@params);
_context = _weights.CreateContext(@params);
}
@@ -24,7 +27,7 @@ namespace LLama.Unittest
[Fact]
public void CheckProperties()
{
Assert.Equal(4096, _context.ContextSize);
Assert.Equal(768, _context.ContextSize);
Assert.Equal(4096, _context.EmbeddingSize);
Assert.Equal(32000, _context.VocabCount);
}


+ 6
- 3
LLama.Unittest/ModelsParamsTests.cs View File

@@ -13,7 +13,6 @@ namespace LLama.Unittest
{
BatchSize = 17,
ContextSize = 42,
LoraAdapter = "adapter",
Seed = 42,
GpuLayerCount = 111
};
@@ -31,9 +30,13 @@ namespace LLama.Unittest
{
BatchSize = 17,
ContextSize = 42,
LoraAdapter = "adapter",
Seed = 42,
GpuLayerCount = 111
GpuLayerCount = 111,
LoraAdapters =
{
new("abc", 1),
new("def", 0)
}
};

var settings = new Newtonsoft.Json.JsonSerializerSettings();


+ 4
- 9
LLama.Web/Common/ModelOptions.cs View File

@@ -4,7 +4,7 @@ using LLama.Abstractions;
namespace LLama.Web.Common
{
public class ModelOptions
: IModelParams
: ILLamaParams
{
public string Name { get; set; }
@@ -51,16 +51,11 @@ namespace LLama.Web.Common
/// Model path (model)
/// </summary>
public string ModelPath { get; set; }

/// <summary>
/// model alias
/// </summary>
public string ModelAlias { get; set; } = "unknown";
/// <summary>
/// lora adapter path (lora_adapter)
/// List of LoRAs to apply
/// </summary>
public string LoraAdapter { get; set; } = string.Empty;

public float LoraAdapterScale { get; set; } = 1;
public AdapterCollection LoraAdapters { get; set; } = new();

/// <summary>
/// base model path for the lora adapter (lora_base)


+ 60
- 0
LLama/Abstractions/IContextParams.cs View File

@@ -0,0 +1,60 @@
using System.Text;

namespace LLama.Abstractions;

/// <summary>
/// The parameters for initializing a LLama context from a model.
/// </summary>
public interface IContextParams
{
/// <summary>
/// Model context size (n_ctx)
/// </summary>
uint ContextSize { get; set; }

/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
uint BatchSize { get; set; }

/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
uint Seed { get; set; }

/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
bool UseFp16Memory { get; set; }

/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
bool Perplexity { get; set; }

/// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
bool EmbeddingMode { get; set; }

/// <summary>
/// RoPE base frequency
/// </summary>
float RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// </summary>
float RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
/// </summary>
bool MulMatQ { get; set; }

/// <summary>
/// The encoding to use for models
/// </summary>
Encoding Encoding { get; set; }
}

+ 11
- 0
LLama/Abstractions/ILLamaParams.cs View File

@@ -0,0 +1,11 @@
namespace LLama.Abstractions
{
/// <summary>
/// Convenience interface for implementing both type of parameters.
/// </summary>
/// <remarks>Mostly exists for backwards compatibility reasons, when these two were not split.</remarks>
public interface ILLamaParams
: IModelParams, IContextParams
{
}
}

+ 51
- 63
LLama/Abstractions/IModelParams.cs View File

@@ -1,4 +1,6 @@
using System.Text;
using System;
using System.Collections.Generic;
using System.Linq;

namespace LLama.Abstractions
{
@@ -7,36 +9,16 @@ namespace LLama.Abstractions
/// </summary>
public interface IModelParams
{
/// <summary>
/// Model context size (n_ctx)
/// </summary>
uint ContextSize { get; set; }

/// <summary>
/// the GPU that is used for scratch and small tensors
/// </summary>
int MainGpu { get; set; }

/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
bool LowVram { get; set; }

/// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary>
int GpuLayerCount { get; set; }

/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
uint Seed { get; set; }

/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
bool UseFp16Memory { get; set; }

/// <summary>
/// Use mmap for faster loads (use_mmap)
/// </summary>
@@ -47,72 +29,78 @@ namespace LLama.Abstractions
/// </summary>
bool UseMemoryLock { get; set; }

/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
bool Perplexity { get; set; }

/// <summary>
/// Model path (model)
/// </summary>
string ModelPath { get; set; }

/// <summary>
/// lora adapter path (lora_adapter)
/// </summary>
string LoraAdapter { get; set; }

float LoraAdapterScale { get; set; }

/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
string LoraBase { get; set; }

/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// </summary>
int Threads { get; set; }

/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
uint BatchSize { get; set; }

/// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
bool EmbeddingMode { get; set; }

/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>
float[]? TensorSplits { get; set; }

/// <summary>
/// RoPE base frequency
/// Load vocab only (no weights)
/// </summary>
float RopeFrequencyBase { get; set; }
bool VocabOnly { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// List of LoRA adapters to apply
/// </summary>
float RopeFrequencyScale { get; set; }
AdapterCollection LoraAdapters { get; }

/// <summary>
/// Use experimental mul_mat_q kernels
/// base model path for the lora adapter (lora_base)
/// </summary>
bool MulMatQ { get; set; }
string LoraBase { get; set; }
}

/// <summary>
/// The encoding to use for models
/// </summary>
Encoding Encoding { get; set; }
/// <summary>
/// A LoRA adapter to apply to a model
/// </summary>
/// <param name="Path">Path to the LoRA file</param>
/// <param name="Scale">Strength of this LoRA</param>
public readonly record struct LoraAdapter(string Path, float Scale);

/// <summary>
/// Load vocab only (no weights)
/// </summary>
bool VocabOnly { get; set; }
/// <summary>
/// A list of LoraAdapter objects
/// </summary>
public sealed class AdapterCollection
: List<LoraAdapter>, IEquatable<AdapterCollection>
{
/// <inheritdoc />
public bool Equals(AdapterCollection? other)
{
if (other == null)
return false;

return this.SequenceEqual(other);
}

/// <inheritdoc/>
public override bool Equals(object? obj)
{
return Equals(obj as AdapterCollection);
}

/// <inheritdoc/>
public override int GetHashCode()
{
unchecked
{
var hash = 17;
for (var i = 0; i < Count; i++)
{
hash += this[i].GetHashCode();
hash *= 7823;
}
return hash;
}
}
}
}

+ 8
- 10
LLama/Common/ModelParams.cs View File

@@ -1,5 +1,6 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -10,7 +11,7 @@ namespace LLama.Common
/// The parameters for initializing a LLama model.
/// </summary>
public record ModelParams
: IModelParams
: ILLamaParams
{
/// <summary>
/// Model context size (n_ctx)
@@ -20,10 +21,7 @@ namespace LLama.Common
/// the GPU that is used for scratch and small tensors
/// </summary>
public int MainGpu { get; set; } = 0;
/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
public bool LowVram { get; set; } = false;

/// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary>
@@ -52,17 +50,17 @@ namespace LLama.Common
/// Model path (model)
/// </summary>
public string ModelPath { get; set; }

/// <summary>
/// lora adapter path (lora_adapter)
/// List of LoRAs to apply
/// </summary>
public string LoraAdapter { get; set; } = string.Empty;

public float LoraAdapterScale { get; set; } = 1;
public AdapterCollection LoraAdapters { get; set; } = new();

/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
public string LoraBase { get; set; } = string.Empty;

/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// </summary>
@@ -162,7 +160,6 @@ namespace LLama.Common
UseMemoryLock = useMemoryLock;
Perplexity = perplexity;
ModelPath = modelPath;
LoraAdapter = loraAdapter;
LoraBase = loraBase;
Threads = threads == -1 ? Math.Max(Environment.ProcessorCount / 2, 1) : threads;
BatchSize = batchSize;
@@ -171,6 +168,7 @@ namespace LLama.Common
RopeFrequencyScale = ropeFrequencyScale;
MulMatQ = mulMatQ;
Encoding = Encoding.GetEncoding(encoding);
LoraAdapters.Add(new LoraAdapter(loraAdapter, 1));
}
}



+ 1
- 1
LLama/Extensions/IModelParamsExtensions.cs View File

@@ -19,7 +19,7 @@ namespace LLama.Extensions
/// <returns></returns>
/// <exception cref="FileNotFoundException"></exception>
/// <exception cref="ArgumentException"></exception>
public static void ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result)
public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result)
{
result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;


+ 4
- 21
LLama/LLamaContext.cs View File

@@ -42,9 +42,9 @@ namespace LLama
public int EmbeddingSize => _ctx.EmbeddingSize;

/// <summary>
/// The model params set for this model.
/// The context params set for this context
/// </summary>
public IModelParams Params { get; set; }
public IContextParams Params { get; set; }

/// <summary>
/// The native handle, which is used to be passed to the native APIs
@@ -57,24 +57,7 @@ namespace LLama
/// </summary>
public Encoding Encoding => _encoding;

/// <summary>
///
/// </summary>
/// <param name="params">Model params.</param>
/// <param name="logger">The logger.</param>
[Obsolete("Use the LLamaWeights.CreateContext instead")]
public LLamaContext(IModelParams @params, ILogger? logger = null)
{
Params = @params;

_logger = logger;
_encoding = @params.Encoding;

_logger?.LogInformation($"[LLamaContext] Initializing LLama model with params: {this.Params}");
_ctx = Utils.InitLLamaContextFromModelParams(Params);
}

internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILogger? logger = null)
internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;

@@ -90,7 +73,7 @@ namespace LLama
/// <param name="params"></param>
/// <param name="logger"></param>
/// <exception cref="ObjectDisposedException"></exception>
public LLamaContext(LLamaWeights model, IModelParams @params, ILogger? logger = null)
public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null)
{
if (model.NativeHandle.IsClosed)
throw new ObjectDisposedException("Cannot create context, model weights have been disposed");


+ 12
- 9
LLama/LLamaEmbedder.cs View File

@@ -18,19 +18,22 @@ namespace LLama
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;

/// <summary>
///
/// </summary>
/// <param name="params"></param>
public LLamaEmbedder(IModelParams @params)
public LLamaEmbedder(ILLamaParams allParams)
: this(allParams, allParams)
{
@params.EmbeddingMode = true;
using var weights = LLamaWeights.LoadFromFile(@params);
_ctx = weights.CreateContext(@params);
}

public LLamaEmbedder(LLamaWeights weights, IModelParams @params)
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
{
using var weights = LLamaWeights.LoadFromFile(modelParams);

contextParams.EmbeddingMode = true;
_ctx = weights.CreateContext(contextParams);
}

public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
{
@params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params);
}



+ 2
- 16
LLama/LLamaStatelessExecutor.cs View File

@@ -20,7 +20,7 @@ namespace LLama
: ILLamaExecutor
{
private readonly LLamaWeights _weights;
private readonly IModelParams _params;
private readonly IContextParams _params;

/// <summary>
/// The context used by the executor when running the inference.
@@ -32,7 +32,7 @@ namespace LLama
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
public StatelessExecutor(LLamaWeights weights, IModelParams @params)
public StatelessExecutor(LLamaWeights weights, IContextParams @params)
{
_weights = weights;
_params = @params;
@@ -41,20 +41,6 @@ namespace LLama
Context.Dispose();
}

/// <summary>
/// Create a new stateless executor which will use the model used to create the given context
/// </summary>
/// <param name="context"></param>
[Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
public StatelessExecutor(LLamaContext context)
{
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding);
_params = context.Params;

Context = _weights.CreateContext(_params);
Context.Dispose();
}

/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{


+ 12
- 12
LLama/LLamaWeights.cs View File

@@ -1,5 +1,4 @@
using System;
using System.Text;
using LLama.Abstractions;
using LLama.Extensions;
using LLama.Native;
@@ -20,11 +19,6 @@ namespace LLama
/// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle => _weights;

/// <summary>
/// Encoding to use to convert text into bytes for the model
/// </summary>
public Encoding Encoding { get; }

/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
@@ -50,10 +44,9 @@ namespace LLama
/// </summary>
public int EmbeddingSize => NativeHandle.EmbeddingSize;

internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
internal LLamaWeights(SafeLlamaModelHandle weights)
{
_weights = weights;
Encoding = encoding;
}

/// <summary>
@@ -66,10 +59,17 @@ namespace LLama
using var pin = @params.ToLlamaModelParams(out var lparams);
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);

if (!string.IsNullOrEmpty(@params.LoraAdapter))
weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraAdapterScale, @params.LoraBase, @params.Threads);
foreach (var adapter in @params.LoraAdapters)
{
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;

weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads);
}

return new LLamaWeights(weights, @params.Encoding);
return new LLamaWeights(weights);
}

/// <inheritdoc />
@@ -83,7 +83,7 @@ namespace LLama
/// </summary>
/// <param name="params"></param>
/// <returns></returns>
public LLamaContext CreateContext(IModelParams @params)
public LLamaContext CreateContext(IContextParams @params)
{
return new LLamaContext(this, @params);
}


+ 0
- 108
LLama/Utils.cs View File

@@ -1,108 +0,0 @@
using LLama.Abstractions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Extensions;

namespace LLama
{
using llama_token = Int32;

/// <summary>
/// Assorted llama utilities
/// </summary>
public static class Utils
{
[Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
using var weights = LLamaWeights.LoadFromFile(@params);

@params.ToLlamaContextParams(out var lparams);
return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams);
}

[Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
return ctx.Tokenize(text, add_bos, encoding);
}

[Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
if (length != ctx.VocabCount)
throw new ArgumentException("length must be the VocabSize");

return ctx.GetLogits();
}

[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
return ctx.Eval(slice, n_past) ? 0 : 1;
}

[Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
return ctx.TokenToString(token, encoding);
}

[Obsolete("No longer used internally by LlamaSharp")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static string PtrToString(IntPtr ptr, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
#if NET6_0_OR_GREATER
// ReSharper disable once PossibleUnintendedReferenceComparison
if(encoding == Encoding.UTF8)
{
return Marshal.PtrToStringUTF8(ptr)!;
}
// ReSharper disable once PossibleUnintendedReferenceComparison
else if(encoding == Encoding.Unicode)
{
return Marshal.PtrToStringUni(ptr)!;
}
else
{
return Marshal.PtrToStringAuto(ptr)!;
}
#else
unsafe
{
byte* tp = (byte*)ptr.ToPointer();
List<byte> bytes = new();
while (true)
{
byte c = *tp++;
if (c == '\0')
{
break;
}
else
{
bytes.Add(c);
}
}
return encoding.GetString(bytes.ToArray());
}
#endif
}
}
}

Loading…
Cancel
Save