using System; using System.Collections.Generic; using System.Text; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; using LLama.Exceptions; using LLama.Extensions; using LLama.Native; using Microsoft.Extensions.Logging; namespace LLama { /// /// A set of model weights, loaded into memory. /// public sealed class LLamaWeights : IDisposable { /// /// The native handle, which is used in the native APIs /// /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle { get; } /// /// Total number of tokens in vocabulary of this model /// public int VocabCount => NativeHandle.VocabCount; /// /// Total number of tokens in the context /// public int ContextSize => NativeHandle.ContextSize; /// /// Get the size of this model in bytes /// public ulong SizeInBytes => NativeHandle.SizeInBytes; /// /// Get the number of parameters in this model /// public ulong ParameterCount => NativeHandle.ParameterCount; /// /// Dimension of embedding vectors /// public int EmbeddingSize => NativeHandle.EmbeddingSize; /// /// Get the special tokens of this model /// public SafeLlamaModelHandle.ModelTokens Tokens => NativeHandle.Tokens; /// /// All metadata keys in this model /// public IReadOnlyDictionary Metadata { get; set; } private LLamaWeights(SafeLlamaModelHandle weights) { NativeHandle = weights; Metadata = weights.ReadMetadata(); } /// /// Load weights into memory /// /// /// public static LLamaWeights LoadFromFile(IModelParams @params) { using var pin = @params.ToLlamaModelParams(out var lparams); var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); foreach (var adapter in @params.LoraAdapters) { if (string.IsNullOrEmpty(adapter.Path)) continue; if (adapter.Scale <= 0) continue; weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); } return new LLamaWeights(weights); } /// /// Load weights into memory /// /// Parameters to use to load the model /// A cancellation token that can interrupt model loading /// Receives progress updates as the model loads (0 to 1) /// /// Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled. /// Thrown if the cancellation token is cancelled. public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress? progressReporter = null) { // don't touch the @params object inside the task, it might be changed // externally! Save a copy of everything that we need later. var modelPath = @params.ModelPath; var loraBase = @params.LoraBase; var loraAdapters = @params.LoraAdapters.ToArray(); // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a // slightly smaller range to allow some space for reporting LoRA loading too. var modelLoadProgressRange = 1f; if (loraAdapters.Length > 0) modelLoadProgressRange = 0.9f; using (@params.ToLlamaModelParams(out var lparams)) { #if !NETSTANDARD2_0 // Overwrite the progress callback with one which polls the cancellation token and updates the progress object if (token.CanBeCanceled || progressReporter != null) { var internalCallback = lparams.progress_callback; lparams.progress_callback = (progress, ctx) => { // Update the progress reporter (remapping the value into the smaller range). progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange); // If the user set a callback in the model params, call that and see if we should cancel if (internalCallback != null && !internalCallback(progress, ctx)) return false; // Check the cancellation token if (token.IsCancellationRequested) return false; return true; }; } #endif var model = await Task.Run(() => { try { // Load the model var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); // Apply the LoRA adapters for (var i = 0; i < loraAdapters.Length; i++) { // Interrupt applying LoRAs if the token is cancelled if (token.IsCancellationRequested) { weights.Dispose(); token.ThrowIfCancellationRequested(); } // Don't apply invalid adapters var adapter = loraAdapters[i]; if (string.IsNullOrEmpty(adapter.Path)) continue; if (adapter.Scale <= 0) continue; weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); // Report progress. Model loading reported progress from 0 -> 0.9, use // the last 0.1 to represent all of the LoRA adapters being applied. progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1)); } // Update progress reporter to indicate completion progressReporter?.Report(1); return new LLamaWeights(weights); } catch (LoadWeightsFailedException) { // Convert a LoadWeightsFailedException into a cancellation exception if possible. token.ThrowIfCancellationRequested(); // Ok the weights failed to load for some reason other than cancellation. throw; } }, token); return model; } } /// public void Dispose() { NativeHandle.Dispose(); } /// /// Create a llama_context using this model /// /// /// /// public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null) { return new LLamaContext(this, @params, logger); } /// /// Convert a string of text into tokens /// /// /// /// /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { return NativeHandle.Tokenize(text, add_bos, special, encoding); } } }